Search code examples
matlabmatrix-multiplication

How to compute the sum of squares of outer products of two matrices minus a common matrix in Matlab?


Suppose there are three n * n matrices X, Y, S. How to fast compute the the following scalars b

for i = 1:n
  b = b  + sum(sum((X(i,:)' * Y(i,:) - S).^2));
end

The computation cost is O(n^3). There exists a fast way to compute the outer product of two matrices. Specifically, the matrix C

for i = 1:n
  C = C + X(i,:)' * Y(i,:);
end

can be calculated without for loop C = A.'*B which is only O(n^2). Is there exists a faster way to compute b?


Solution

  • You can use:

    X2 = X.^2;
    Y2 = Y.^2;
    S2 = S.^2;
    b = sum(sum(X2.' * Y2 - 2 * (X.' * Y ) .* S + n * S2));
    

    Given your example

    b=0;
    for i = 1:n
       b = b  + sum(sum((X(i,:).' * Y(i,:) - S).^2));
    end
    

    We can first bring the summation out of the loop:

    b=0;
    for i = 1:n
      b = b  + (X(i,:).' * Y(i,:) - S).^2;
    end
    b=sum(b(:))
    

    Knowing that we can write (a - b)^2 as a^2 - 2*a*b + b^2

    b=0;
    for i = 1:n
      b = b  + (X(i,:).' * Y(i,:)).^2 - 2.* (X(i,:).' * Y(i,:)) .*S + S.^2;
    end
    b=sum(b(:))
    

    And we know that (a * b) ^ 2 is the same as a^2 * b^2:

    X2 = X.^2;
    Y2 = Y.^2;
    S2 = S.^2;
    b=0;
    for i = 1:n
      b = b  + (X2(i,:).' * Y2(i,:)) - 2.* (X(i,:).' * Y(i,:)) .*S + S2;
    end
    b=sum(b(:))
    

    Now we can compute each term separately:

     b = sum(sum(X2.' * Y2 - 2 * (X.' * Y ) .* S + n * S2));
    

    Here is the result of a test in Octave that compares my method and two other methods provided by @AndrasDeak and the original loop based solution for inputs of size 500*500:

    ===rahnema1 (B)===
    Elapsed time is 0.0984299 seconds.
    
    ===Andras Deak (B2)===
    Elapsed time is 7.86407 seconds.
    
    ===Andras Deak (B3)===
    Elapsed time is 2.99158 seconds.
    
    ===Loop solution===
    Elapsed time is 2.20357 seconds
    
    
    n=500;
    X= rand(n);
    Y= rand(n);
    S= rand(n);
    
    disp('===rahnema1 (B)===')
    tic
        X2 = X.^2;
        Y2 = Y.^2;
        S2 = S.^2;
        b=sum(sum(X2.' * Y2 - 2 * (X.' * Y ) .* S + n * S2));
    toc
    disp('===Andras Deak (B2)===')
    tic
        b2 = sum(reshape((permute(reshape(X, [n, 1, n]).*Y, [3,2,1]) - S).^2, 1, []));
    toc
    disp('===Andras Deak (B3)===')
    tic
        b3 = sum(reshape((reshape(X, [n, 1, n]).*Y - reshape(S.', [1, n, n])).^2, 1, []));
    toc
    tic
        b=0;
        for i = 1:n
          b = b  + sum(sum((X(i,:)' * Y(i,:) - S).^2));
        end
    toc