Search code examples
performanceloopsmatlabvectorization

Comparing Matrix Elements pairwise without looping in Matlab


I encountered the following Problem in Matlab:

I have a matrix A, and I want to compare its elements pairwise to a large number of other matrices (B1, B2, B3, etc.) elements. Finally, I want to sum the larger elements, weighted by weights (w1, w2, w3, etc.).

I am trying to calculate the following sum:

S= w1 .* (max(A,B1)) + w2 .* (max(A,B2)) + w3 .* (max(A,B3)) ... etc.

I wrote a loop over the index numbers of B and w, but this approach is very slow.

Is there a nice way to write a faster code for this problem, not using loops?

Thanks for your help


Solution

  • Given:

    w = rand(1,10);
    A = rand(5,5);
    B1 = rand(5,5);
    B2 = rand(5,5);
    ...
    B10 = rand(5,5);
    

    The first thing to do is to put all of the B matrices into a 3-dimensional array:

    B = cat(3, B1, B2, ..., B10);
    

    Since max is vectorized, the max between A and the individual B's is:

    S = max(A,B);
    

    In order to apply the weights, we have to reshape them into the 3rd dimension to multiply the correct max:

    w = reshape(w, 1, 1, 10);   % or permute(w, [1 3 2]);
    S = w.*S;
    

    Then to take the sum, we just tell sum to operate in the 3rd dimension:

    S = sum(S, 3);