Search code examples
matlabvectorizationmatrix-multiplication

Vectorization and Nested Matrix Multiplication


Here is the original code:

K = zeros(N*N)
for a=1:N
    for i=1:I
        for j=1:J
            M = kron(X(:,:,a).',Y(:,:,a,i,j));

            %A function that essentially adds M to K. 
        end
    end
end

The goal is to vectorize the kroniker multiplication calls. My intuition is to think of X and Y as containers of matrices (for reference, the slices of X and Y being fed to kron are square matrices of the order 7x7). Under this container scheme, X appears a 1-D container and Y as a 3-D container. My next guess was to reshape Y into a 2-D container or better yet a 1-D container and then do element wise multiplication of X and Y. Questions are: how would do this reshaping in a way that preserves the trace of M and can matlab even handle this idea in this container idea or do the containers need to be further reshaped to expose the inner matrix elements further?


Solution

  • Approach #1: Matrix multiplication with 6D permute

    % Get sizes
    [m1,m2,~] =  size(X);
    [n1,n2,N,n4,n5] =  size(Y);
    
    % Lose the third dim from X and Y with matrix-multiplication
    parte1 = reshape(permute(Y,[1,2,4,5,3]),[],N)*reshape(X,[],N).';
    
    % Rearrange the leftover dims to bring kron format
    parte2 = reshape(parte1,[n1,n2,I,J,m1,m2]);
    
    % Lose dims correspinding to last two dims coming in from Y corresponding
    % to the iterative summation as suggested in the question
    out = reshape(permute(sum(sum(parte2,3),4),[1,6,2,5,3,4]),m1*n1,m2*n2)
    

    Approach #2: Simple 7D permute

    % Get sizes
    [m1,m2,~] =  size(X);
    [n1,n2,N,n4,n5] =  size(Y);
    
    % Perform kron format elementwise multiplication betwen the first two dims
    % of X and Y, keeping the third dim aligned and "pushing out" leftover dims
    % from Y to the back
    mults = bsxfun(@times,permute(X,[4,2,5,1,3]),permute(Y,[1,6,2,7,3,4,5]));
    
    % Lose the two dims with summation reduction for final output
    out = sum(reshape(mults,m1*n1,m2*n2,[]),3);
    

    Verification

    Here's a setup for running the original and the proposed approaches -

    % Setup inputs
    X = rand(10,10,10);
    Y = rand(10,10,10,10,10);
    
    % Original approach
    [n1,n2,N,I,J] =  size(Y);
    K = zeros(100);
    for a=1:N
        for i=1:I
            for j=1:J
                M = kron(X(:,:,a).',Y(:,:,a,i,j));
                K = K + M;
            end
        end
    end
    
    % Approach #1
    [m1,m2,~] =  size(X);
    [n1,n2,N,n4,n5] =  size(Y);
    mults = bsxfun(@times,permute(X,[4,2,5,1,3]),permute(Y,[1,6,2,7,3,4,5]));
    out1 = sum(reshape(mults,m1*n1,m2*n2,[]),3);
    
    % Approach #2
    [m1,m2,~] =  size(X);
    [n1,n2,N,n4,n5] =  size(Y);
    parte1 = reshape(permute(Y,[1,2,4,5,3]),[],N)*reshape(X,[],N).';
    parte2 = reshape(parte1,[n1,n2,I,J,m1,m2]);
    out2 = reshape(permute(sum(sum(parte2,3),4),[1,6,2,5,3,4]),m1*n1,m2*n2);
    

    After running, we see the max. absolute deviation with the proposed approaches against the original one -

    >> error_app1 = max(abs(K(:)-out1(:)))
    error_app1 =
       1.1369e-12
    >> error_app2 = max(abs(K(:)-out2(:)))
    error_app2 =
       1.1937e-12
    

    Values look good to me!


    Benchmarking

    Timing these three approaches using the same big dataset as used for verification, we get something like this -

    ----------------------------- With Loop
    Elapsed time is 1.541443 seconds.
    ----------------------------- With BSXFUN
    Elapsed time is 1.283935 seconds.
    ----------------------------- With MATRIX-MULTIPLICATION
    Elapsed time is 0.164312 seconds.
    

    Seems like matrix-multiplication is doing fairly good for dataset of these sizes!