I am looking for some help to speed up some Matlab code.
A minimal example is shown below. The code is doing some calculations on a 3D matrix, defined on [x,y,z] coordinates. When using the profiler, the inner loop over ind
is the time-consuming part, and so I am wondering if this loop can either be optimized, or removed/vectorized completely.
Nx = 8; % Number of grid points
Ny = 6;
Nz = 4;
Ntot = Nx*Ny*Nz;
xvals = rand(1,Nx); % Create grid vectors
yvals = rand(1,Ny);
zvals = rand(1,Nz);
input_vec = rand(Ny,Nx,Nz); % Generate a dummy 3D matrix ( meshgrid convention, [y,x,z] )
input_vec = reshape( permute(input_vec,[3,1,2]) , [Ntot 1]); % Unwrap to 1D, so z cycles fastest, then y, then x
C1 = 5; % Loop counters
C2 = 6;
C3 = 7;
output_vec = zeros(Ntot,1); % Preallocate
temp_vec = zeros(Ntot,1);
for cnt1 = 1:C1
for cnt2 = 1:C2
for cnt3 = 1:C3
factor1 = xvals*cnt1; % Calculate some vectors which depend on cnt variables
factor2 = yvals*cnt2;
factor3 = zvals*cnt3;
for ind = 1:Ntot % Loop over every grid point
j1 = floor( floor((ind-1) / Nz) / Ny) + 1; % +1 and -1 's account for Matlab's [1] indexing
j2 = mod( floor((ind-1)/Nz) , Ny ) + 1;
j3 = mod( (ind-1), Nz ) + 1;
temp_vec(ind) = input_vec(ind) * factor1(j1)*factor2(j2)*factor3(j3);
end
output_vec = output_vec + temp_vec;
end
end
end
In my real application, the number of points is more like 1024x1024x512, and so I have tried to avoid using lots of meshgrid
formatted variables (which contain a lot of repeated information) in order to keep the memory requirements down - this is the reason that the 3D array has been unwrapped to 1D in the code above. For example, one solution might be to precalculate all the j1,j2,j3
values like so
j1 = 1:Nx;
j2 = 1:Ny;
j3 = 1:Nz;
[J1,J2,J3] = meshgrid(j1,j2,j3);
J1 = reshape( permute(J1,[3,1,2]) , [Ntot 1]); % Unwrap to 1D, so z cycles fastest, then y, then x
J2 = reshape( permute(J2,[3,1,2]) , [Ntot 1]);
J3 = reshape( permute(J3,[3,1,2]) , [Ntot 1]);
but this requires much more RAM than calculating a single value of j each time depending on the value of ind.
Can anyone help with a better/faster (but still memory efficient) way to do this? Thank you.
The following:
factor1 = xvals*cnt1; % Calculate some vectors which depend on cnt variables
factor2 = yvals*cnt2;
factor3 = zvals*cnt3;
for ind = 1:Ntot % Loop over every grid point
j1 = floor( floor((ind-1) / Nz) / Ny) + 1; % +1 and -1 's account for Matlab's [1] indexing
j2 = mod( floor((ind-1)/Nz) , Ny ) + 1;
j3 = mod( (ind-1), Nz ) + 1;
temp_vec(ind) = input_vec(ind) * factor1(j1)*factor2(j2)*factor3(j3);
end
can be written as (not tested):
factor1 = xvals*cnt1; % Calculate some vectors which depend on cnt variables
factor2 = yvals*cnt2;
factor3 = zvals*cnt3;
ind = 1:Ntot;
j1 = floor( floor((ind-1) / Nz) / Ny) + 1; % +1 and -1 's account for Matlab's [1] indexing
j2 = mod( floor((ind-1)/Nz) , Ny ) + 1;
j3 = mod( (ind-1), Nz ) + 1;
temp_vec = input_vec(:).' .* factor1(j1).*factor2(j2).*factor3(j3);
(Especially not indexing with ind
could make a big difference, though I think this is a special case that has been optimized in recent version of MATLAB.)
But there we're still creating large intermediate arrays. You should be able to simplify with (again, not tested, and likely buggy):
factor1 = xvals*cnt1; % horizontal array
factor2 = (yvals*cnt2).'; % vertical array
factor3 = permute(zvals*cnt3,[1,3,2]); % array along 3rd dimension
temp_vec = input_vec .* factor1 .* factor2 .* factor3;