Search code examples
matlabloopsoptimizationmemoryvectorization

Is there a way to optimize/vectorize these loops over elements in a 3D array, without requiring significantly more memory?


I am looking for some help to speed up some Matlab code.

A minimal example is shown below. The code is doing some calculations on a 3D matrix, defined on [x,y,z] coordinates. When using the profiler, the inner loop over ind is the time-consuming part, and so I am wondering if this loop can either be optimized, or removed/vectorized completely.

Nx = 8; % Number of grid points
Ny = 6;
Nz = 4;
Ntot = Nx*Ny*Nz;

xvals = rand(1,Nx); % Create grid vectors
yvals = rand(1,Ny);
zvals = rand(1,Nz);

input_vec = rand(Ny,Nx,Nz); % Generate a dummy 3D matrix ( meshgrid convention, [y,x,z] )
input_vec = reshape( permute(input_vec,[3,1,2]) , [Ntot 1]); % Unwrap to 1D, so z cycles fastest, then y, then x

C1 = 5; % Loop counters
C2 = 6;
C3 = 7;

output_vec = zeros(Ntot,1); % Preallocate
temp_vec = zeros(Ntot,1);

for cnt1 = 1:C1
    for cnt2 = 1:C2
        for cnt3 = 1:C3
            
            factor1 = xvals*cnt1; % Calculate some vectors which depend on cnt variables
            factor2 = yvals*cnt2;
            factor3 = zvals*cnt3;
            
            for ind = 1:Ntot % Loop over every grid point
                j1 = floor( floor((ind-1) / Nz) / Ny) + 1; % +1 and -1 's account for Matlab's [1] indexing
                j2 = mod( floor((ind-1)/Nz) , Ny ) + 1;
                j3 = mod( (ind-1), Nz ) + 1;
                temp_vec(ind) = input_vec(ind) * factor1(j1)*factor2(j2)*factor3(j3);
            end
            output_vec = output_vec + temp_vec;
        end
    end
end

In my real application, the number of points is more like 1024x1024x512, and so I have tried to avoid using lots of meshgrid formatted variables (which contain a lot of repeated information) in order to keep the memory requirements down - this is the reason that the 3D array has been unwrapped to 1D in the code above. For example, one solution might be to precalculate all the j1,j2,j3 values like so

j1 = 1:Nx;
j2 = 1:Ny;
j3 = 1:Nz;
[J1,J2,J3] = meshgrid(j1,j2,j3);
J1 = reshape( permute(J1,[3,1,2]) , [Ntot 1]); % Unwrap to 1D, so z cycles fastest, then y, then x 
J2 = reshape( permute(J2,[3,1,2]) , [Ntot 1]); 
J3 = reshape( permute(J3,[3,1,2]) , [Ntot 1]); 

but this requires much more RAM than calculating a single value of j each time depending on the value of ind.

Can anyone help with a better/faster (but still memory efficient) way to do this? Thank you.


Solution

  • The following:

    factor1 = xvals*cnt1; % Calculate some vectors which depend on cnt variables
    factor2 = yvals*cnt2;
    factor3 = zvals*cnt3;
    for ind = 1:Ntot % Loop over every grid point
       j1 = floor( floor((ind-1) / Nz) / Ny) + 1; % +1 and -1 's account for Matlab's [1] indexing
       j2 = mod( floor((ind-1)/Nz) , Ny ) + 1;
       j3 = mod( (ind-1), Nz ) + 1;
       temp_vec(ind) = input_vec(ind) * factor1(j1)*factor2(j2)*factor3(j3);
    end
    

    can be written as (not tested):

    factor1 = xvals*cnt1; % Calculate some vectors which depend on cnt variables
    factor2 = yvals*cnt2;
    factor3 = zvals*cnt3;
    ind = 1:Ntot;
    j1 = floor( floor((ind-1) / Nz) / Ny) + 1; % +1 and -1 's account for Matlab's [1] indexing
    j2 = mod( floor((ind-1)/Nz) , Ny ) + 1;
    j3 = mod( (ind-1), Nz ) + 1;
    temp_vec = input_vec(:).' .* factor1(j1).*factor2(j2).*factor3(j3);
    

    (Especially not indexing with ind could make a big difference, though I think this is a special case that has been optimized in recent version of MATLAB.)

    But there we're still creating large intermediate arrays. You should be able to simplify with (again, not tested, and likely buggy):

    factor1 = xvals*cnt1;                  % horizontal array
    factor2 = (yvals*cnt2).';              % vertical array
    factor3 = permute(zvals*cnt3,[1,3,2]); % array along 3rd dimension
    temp_vec = input_vec .* factor1 .* factor2 .* factor3;