Search code examples
matlabvectorizationweighted-average

Triple weighted sum


I was trying to vectorize a certain weighted sum but couldn't figure out how to do it. I have created a simple minimal working example below. I guess the solution involves either bsxfun or reshape and kronecker products but I still have not managed to get it working.

rng(1);
N = 200;
T1 = 5;
T2 = 7;
T3 = 10;


A = rand(N,T1,T2,T3);
w1 = rand(T1,1);
w2 = rand(T2,1);
w3 = rand(T3,1);

B = zeros(N,1);

for i = 1:N
 for j1=1:T1
  for j2=1:T2
   for j3=1:T3
    B(i) = B(i) + w1(j1) * w2(j2) * w3(j3) * A(i,j1,j2,j3);
   end
  end
 end
end

A = B;

For the two dimensional case there is a smart answer here.


Solution

  • You can use an additional multiplication to modify the w1 * w2' grid from the previous answer to then multiply by w3 as well. You can then use matrix multiplication again to multiply with a "flattened" version of A.

    W = reshape(w1 * w2.', [], 1) * w3.';
    B = reshape(A, size(A, 1), []) * W(:);
    

    You could wrap the creation of weights into it's own function and make this generalizable to N weights. Since this uses recursion, N is limited to your current recursion limit (500 by default).

    function W = createWeights(W, varargin)
        if numel(varargin) > 0
            W = createWeights(W(:) * varargin{1}(:).', varargin{2:end});
        end
    end
    

    And use it with:

    W = createWeights(w1, w2, w3);
    B = reshape(A, size(A, 1), []) * W(:);
    

    Update

    Using part of @CKT's very good suggestion to use kron, we could modify createWeights just a little bit.

    function W = createWeights(W, varargin)
        if numel(varargin) > 0
            W = createWeights(kron(varargin{1}, W), varargin{2:end});
        end
    end