Search code examples
matlabmatrixlinear-algebramatrix-multiplication

Why is this matrix multiplication so slow?


I'm trying to implement a householder tridiagonalization routine to tridiagonalize small (n < 10) hermitian matrices in matlab. This is what I have so far:

H =    1.0e-10 * [    0.1386 + 0.0000i   0.0974 - 0.0260i   0.0434 + 0.0094i   0.0722 + 0.0670i   0.1128 + 0.1269i;    0.0974 + 0.0260i   0.0751 + 0.0000i   0.0288 + 0.0149i   0.0388 + 0.0616i   0.0557 + 0.1112i;    0.0434 - 0.0094i   0.0288 - 0.0149i   0.0146 + 0.0000i   0.0274 + 0.0164i   0.0444 + 0.0323i;    0.0722 - 0.0670i   0.0388 - 0.0616i   0.0274 - 0.0164i   0.0719 + 0.0000i   0.1216 + 0.0116i;    0.1128 - 0.1269i   0.0557 - 0.1112i   0.0444 - 0.0323i   0.1216 - 0.0116i   0.2105 + 0.0000i];


function [T, U] = tridi(H)
% Calculates a unitarily equivalent matrix T and unitary matrix U for a 
% given Hermitian matrix H, such that U'*H*U = T and T is of tridiagonal
% form.

% Setting size and starting point of iteration
n = size(H,1); T = H; U = eye(n);

% Loop over the first n-2 columns of H

for k = 1 : n-1
    % Householdertransformation on column-vector H(k+1:n,k)
    x = T(k+1:n,k);
    
    % Calculating secondary diagonal entry
    normValue = norm(x);
    
    % Initializing e1
    e1 = zeros(length(x),1);
    e1(1) = 1;

    % phase of x1
    phase = sign(x(1));

    % Calculates normalized Householder-Vektor u 
    u = x + norm(x)*phase*e1;
    u = u / norm(u);
    
    % Updating T and U:
    T(k+1:n,k+1:n) = (eye(n-k) - 2*(u*u'))*T(k+1:n,k+1:n)*(eye(n-k) - 2*(u*u'));
    U(2:n, k+1:n) = -phase*(U(2:n, k+1:n) - 2 * (U(2:n, k+1:n) * u) * u');

    % Setting secondary diagonal entry of T
    T(k+1,k) = normValue;
    T(k,k+1) = normValue;
    
    % Setting appropriate row and column of T to zero
    T(k+2:n,k) = 0;   
    T(k,k+2:n) = 0;
end
% Ensure that T is real
T = real(T);
end

However the line T(k+1:n,k+1:n) = (eye(n-k) - 2*(u*u'))*T(k+1:n,k+1:n)*(eye(n-k) - 2*(u*u')) takes about 5 seconds when called around 150.000 times for my given data set - which is a extreme bottleneck to my other calculations.

I've already tried to replace the matrix-matrix multiplications by matrix-vector multiplications:

    % Updating T and U:
    
    % Calculating P*T implicitly
    T(k+1:n, k+1:n) = T(k+1:n, k+1:n) - 2 * u * (u' * T(k+1:n, k+1:n));
    
    % Calculating T*P and U*P implicitly
    T(k+1:n, k+1:n) = T(k+1:n, k+1:n) - 2 * (T(k+1:n, k+1:n) * u) * u';
    U(2:n, k+1:n) = -phase*(U(2:n, k+1:n) - 2 * (U(2:n, k+1:n) * u) * u');

which hasn't proven faster - probably since matlab matrix multiplication is so well optimized.

So I'm now struggling to get the runtime lower (especially compared to matlabs function hess.m which only takes a few tenths of a second.

One of my ideas - which this post originally thematized - was to use that I know that for every iterations T is hermitian, which means just the lower triangular part and the diagonal needs to be calculated; the lower triangular part can then be copied to the upper.

Also probably just storing T in the main diagonal alf and secondary diagonal bet would be smart but I'm not really concerned about memory working with matrices of this size.

Suppose you have two matrices A and B and you know (by some algebraic relation) that the product has to be Hermitian - in my case A and B are symmetric and commute.

Obviously I don't want to multiply the matrices using A*B since then both the upper and lower triangular matrix is calculated even though it's enough to calculate e.g. the lower triangular part and then save the complex conjugate of it in the upper triangular part.

I'm familiar with the functions tril and triu but then - if I understand correctly how Matlab multiplies zero rows of matrices - the multiplication on the upper triangular part is still executed, even though I know the result has to be 0.

I'm only working with matrices of size 10 or smaller, so using sparse matrices probably isn't the way to go.

I thought about just writing a loop such that all elements under the diagonal are calculated separately, stored in lower part of the matrix and then copied to the upper part, but given how much Matlab documentations insists on vectorizing code and not using loops, I don't really know how to approach it from a different way.


Solution

  • The low hanging fruit here is to reduce the repeated computation here. You calculate I(n-k) - 2 * u * u' 3 times in your hot loop. With a little mathematical manipulation, you can achieve a noticeable speedup.

    indices = k+1:n;
    
    % Updating T and U:
    I = eye(n - k);
    two_uu = 2 * (u * u');
    I_minus_two_uu = I - two_uu;
    T(indices, indices) = I_minus_two_uu * T(indices, indices) * I_minus_two_uu;
    U(2:n, indices) = -phase * U(2:n, indices) * I_minus_two_uu;
    

    Profiler

    I created a short script with your code to profile the results with some code to ensure the changes don't mess with the function. Due to the difference in the order of floating point arithmetic, the values will not be the same, but they are

    tol = sqrt(eps); % 1.4901e-08
    for i = 1:2e5
        [T1, U1] = tridi(H);
        [T2, U2] = tridi_faster(H);
        % Ensure we get the exact same results
        assert(all(ismembertol(T1, T2, tol), "all"));
        assert(all(ismembertol(real(U1), real(U2), tol), "all"));
        assert(all(ismembertol(imag(U1), imag(U2), tol), "all"));
    end
    

    Full script

    There are two functions here: tridi() (original), and tridi_faster() (improved).

    H =    1.0e-10 * [    0.1386 + 0.0000i   0.0974 - 0.0260i   0.0434 + 0.0094i   0.0722 + 0.0670i   0.1128 + 0.1269i;    0.0974 + 0.0260i   0.0751 + 0.0000i   0.0288 + 0.0149i   0.0388 + 0.0616i   0.0557 + 0.1112i;    0.0434 - 0.0094i   0.0288 - 0.0149i   0.0146 + 0.0000i   0.0274 + 0.0164i   0.0444 + 0.0323i;    0.0722 - 0.0670i   0.0388 - 0.0616i   0.0274 - 0.0164i   0.0719 + 0.0000i   0.1216 + 0.0116i;    0.1128 - 0.1269i   0.0557 - 0.1112i   0.0444 - 0.0323i   0.1216 - 0.0116i   0.2105 + 0.0000i];
    
    
    function [T, U] = tridi(H)
    % Calculates a unitarily equivalent matrix T and unitary matrix U for a 
    % given Hermitian matrix H, such that U'*H*U = T and T is of tridiagonal
    % form.
    
    % Setting size and starting point of iteration
    n = size(H,1); T = H; U = eye(n);
    
    % Loop over the first n-2 columns of H
    
    for k = 1 : n-1
        % Householdertransformation on column-vector H(k+1:n,k)
        x = T(k+1:n,k);
        
        % Calculating secondary diagonal entry
        normValue = norm(x);
        
        % Initializing e1
        e1 = zeros(length(x),1);
        e1(1) = 1;
    
        % phase of x1
        phase = sign(x(1));
    
        % Calculates normalized Householder-Vektor u 
        u = x + norm(x)*phase*e1;
        u = u / norm(u);
        
        % Updating T and U:
        T(k+1:n,k+1:n) = (eye(n-k) - 2*(u*u'))*T(k+1:n,k+1:n)*(eye(n-k) - 2*(u*u'));
        U(2:n, k+1:n) = -phase*(U(2:n, k+1:n) - 2 * (U(2:n, k+1:n) * u) * u');
    
        % Setting secondary diagonal entry of T
        T(k+1,k) = normValue;
        T(k,k+1) = normValue;
        
        % Setting appropriate row and column of T to zero
        T(k+2:n,k) = 0;   
        T(k,k+2:n) = 0;
    end
    % Ensure that T is real
    T = real(T);
    end
    
    function [T, U] = tridi_faster(H)
    % Calculates a unitarily equivalent matrix T and unitary matrix U for a 
    % given Hermitian matrix H, such that U'*H*U = T and T is of tridiagonal
    % form.
    
    % Setting size and starting point of iteration
    n = size(H,1); T = H; U = eye(n);
    
    % Loop over the first n-2 columns of H
    
    for k = 1 : n-1
        indices = k+1:n;
        % Householdertransformation on column-vector H(k+1:n,k)
        x = T(indices,k);
        
        % Calculating secondary diagonal entry
        normValue = norm(x);
        
        % Initializing e1
        e1 = zeros(length(x),1);
        e1(1) = 1;
    
        % phase of x1
        phase = sign(x(1));
    
        % Calculates normalized Householder-Vektor u 
        u = x + norm(x)*phase*e1;
        u = u / norm(u);
        
        % Updating T and U:
        I = eye(n - k);
        two_uu = 2 * (u * u');
        I_minus_two_uu = I - two_uu;
        T(indices,indices) = I_minus_two_uu * T(indices,indices) * I_minus_two_uu;
        U(2:n, indices) = -phase * U(2:n, indices) * I_minus_two_uu;
    
        % Setting secondary diagonal entry of T
        T(k+1,k) = normValue;
        T(k,k+1) = normValue;
        
        % Setting appropriate row and column of T to zero
        T(k+2:n,k) = 0;   
        T(k,k+2:n) = 0;
    end
    % Ensure that T is real
    T = real(T);
    end
    
    tol = sqrt(eps); % 1.4901e-08
    for i = 1:2e5
        [T1, U1] = tridi(H);
        [T2, U2] = tridi_faster(H);
        % Ensure we get the exact same results
        assert(all(ismembertol(T1, T2, tol), "all"));
        assert(all(ismembertol(real(U1), real(U2), tol), "all"));
        assert(all(ismembertol(imag(U1), imag(U2), tol), "all"));
    end
    

    Results

    Using the built-in MATLAB profiler yielded the following results for 200,000 iterations. Keep in mind, there is high variance in the exact times, but the relative times were similar.

    Profiler result

    Original

    Original line-by-line profiler results

    Improved

    Improved line-by-line profiler results

    Summary

    Function Calls Time
    original 200,000 4.872s
    improved 200,000 3.726s

    Overall, there is a noticeable speedup, but for code that runs in 5 seconds, the amount of time trying to optimize the code will far exceed the runtime of the code.

    If you really need to achieve an order of magnitude speedup, you will need to use a different algorithm. Perhaps there is a way to remove the for loop in the function, or use a built-in MATLAB function, but that requires understanding the mathematics of what the function is actually doing.