I encountered the following Problem in Matlab:
I have a matrix A, and I want to compare its elements pairwise to a large number of other matrices (B1, B2, B3, etc.) elements. Finally, I want to sum the larger elements, weighted by weights (w1, w2, w3, etc.).
I am trying to calculate the following sum:
S= w1 .* (max(A,B1)) + w2 .* (max(A,B2)) + w3 .* (max(A,B3)) ... etc.
I wrote a loop over the index numbers of B and w, but this approach is very slow.
Is there a nice way to write a faster code for this problem, not using loops?
Thanks for your help
Given:
w = rand(1,10);
A = rand(5,5);
B1 = rand(5,5);
B2 = rand(5,5);
...
B10 = rand(5,5);
The first thing to do is to put all of the B
matrices into a 3-dimensional array:
B = cat(3, B1, B2, ..., B10);
Since max
is vectorized, the max between A
and the individual B
's is:
S = max(A,B);
In order to apply the weights, we have to reshape them into the 3rd dimension to multiply the correct max:
w = reshape(w, 1, 1, 10); % or permute(w, [1 3 2]);
S = w.*S;
Then to take the sum, we just tell sum
to operate in the 3rd dimension:
S = sum(S, 3);