Suppose there are three n * n matrices X
, Y
, S
. How to fast compute the the following scalars b
for i = 1:n
b = b + sum(sum((X(i,:)' * Y(i,:) - S).^2));
end
The computation cost is O(n^3). There exists a fast way to compute the outer product of two matrices. Specifically, the matrix C
for i = 1:n
C = C + X(i,:)' * Y(i,:);
end
can be calculated without for loop C = A.'*B
which is only O(n^2). Is there exists a faster way to compute b
?
You can use:
X2 = X.^2;
Y2 = Y.^2;
S2 = S.^2;
b = sum(sum(X2.' * Y2 - 2 * (X.' * Y ) .* S + n * S2));
Given your example
b=0;
for i = 1:n
b = b + sum(sum((X(i,:).' * Y(i,:) - S).^2));
end
We can first bring the summation out of the loop:
b=0;
for i = 1:n
b = b + (X(i,:).' * Y(i,:) - S).^2;
end
b=sum(b(:))
Knowing that we can write (a - b)^2
as a^2 - 2*a*b + b^2
b=0;
for i = 1:n
b = b + (X(i,:).' * Y(i,:)).^2 - 2.* (X(i,:).' * Y(i,:)) .*S + S.^2;
end
b=sum(b(:))
And we know that (a * b) ^ 2
is the same as a^2 * b^2
:
X2 = X.^2;
Y2 = Y.^2;
S2 = S.^2;
b=0;
for i = 1:n
b = b + (X2(i,:).' * Y2(i,:)) - 2.* (X(i,:).' * Y(i,:)) .*S + S2;
end
b=sum(b(:))
Now we can compute each term separately:
b = sum(sum(X2.' * Y2 - 2 * (X.' * Y ) .* S + n * S2));
Here is the result of a test in Octave that compares my method and two other methods provided by @AndrasDeak and the original loop based solution for inputs of size 500*500
:
===rahnema1 (B)===
Elapsed time is 0.0984299 seconds.
===Andras Deak (B2)===
Elapsed time is 7.86407 seconds.
===Andras Deak (B3)===
Elapsed time is 2.99158 seconds.
===Loop solution===
Elapsed time is 2.20357 seconds
n=500;
X= rand(n);
Y= rand(n);
S= rand(n);
disp('===rahnema1 (B)===')
tic
X2 = X.^2;
Y2 = Y.^2;
S2 = S.^2;
b=sum(sum(X2.' * Y2 - 2 * (X.' * Y ) .* S + n * S2));
toc
disp('===Andras Deak (B2)===')
tic
b2 = sum(reshape((permute(reshape(X, [n, 1, n]).*Y, [3,2,1]) - S).^2, 1, []));
toc
disp('===Andras Deak (B3)===')
tic
b3 = sum(reshape((reshape(X, [n, 1, n]).*Y - reshape(S.', [1, n, n])).^2, 1, []));
toc
tic
b=0;
for i = 1:n
b = b + sum(sum((X(i,:)' * Y(i,:) - S).^2));
end
toc