I was trying to vectorize a certain weighted sum but couldn't figure out how to do it. I have created a simple minimal working example below. I guess the solution involves either bsxfun or reshape and kronecker products but I still have not managed to get it working.
rng(1);
N = 200;
T1 = 5;
T2 = 7;
T3 = 10;
A = rand(N,T1,T2,T3);
w1 = rand(T1,1);
w2 = rand(T2,1);
w3 = rand(T3,1);
B = zeros(N,1);
for i = 1:N
for j1=1:T1
for j2=1:T2
for j3=1:T3
B(i) = B(i) + w1(j1) * w2(j2) * w3(j3) * A(i,j1,j2,j3);
end
end
end
end
A = B;
For the two dimensional case there is a smart answer here.
You can use an additional multiplication to modify the w1 * w2'
grid from the previous answer to then multiply by w3
as well. You can then use matrix multiplication again to multiply with a "flattened" version of A
.
W = reshape(w1 * w2.', [], 1) * w3.';
B = reshape(A, size(A, 1), []) * W(:);
You could wrap the creation of weights into it's own function and make this generalizable to N
weights. Since this uses recursion, N
is limited to your current recursion limit (500 by default).
function W = createWeights(W, varargin)
if numel(varargin) > 0
W = createWeights(W(:) * varargin{1}(:).', varargin{2:end});
end
end
And use it with:
W = createWeights(w1, w2, w3);
B = reshape(A, size(A, 1), []) * W(:);
Update
Using part of @CKT's very good suggestion to use kron
, we could modify createWeights
just a little bit.
function W = createWeights(W, varargin)
if numel(varargin) > 0
W = createWeights(kron(varargin{1}, W), varargin{2:end});
end
end