Search code examples
matlabmathoptimizationmatrixleast-squares

Optimizing repetitive estimation (currently a loop) in MATLAB


I've found myself needing to do a least-squares (or similar matrix-based operation) for every pixel in an image. Every pixel has a set of numbers associated with it, and so it can be arranged as a 3D matrix.

(This next bit can be skipped)

Quick explanation of what I mean by least-squares estimation :

Let's say we have some quadratic system that is modeled by Y = Ax^2 + Bx + C and we're looking for those A,B,C coefficients. With a few samples (at least 3) of X and the corresponding Y, we can estimate them by:

  1. Arrange the (lets say 10) X samples into a matrix like X = [x(:).^2 x(:) ones(10,1)];
  2. Arrange the Y samples into a similar matrix: Y = y(:);
  3. Estimate the coefficients A,B,C by solving: coeffs = (X'*X)^(-1)*X'*Y;

Try this on your own if you want:

A = 5; B = 2; C = 1;
x = 1:10;
y = A*x(:).^2 + B*x(:) + C + .25*randn(10,1); % added some noise here
X = [x(:).^2 x(:) ones(10,1)];
Y = y(:);
coeffs = (X'*X)^-1*X'*Y

coeffs =

  5.0040
  1.9818
  0.9241

START PAYING ATTENTION AGAIN IF I LOST YOU THERE

*MAJOR REWRITE*I've modified to bring it as close to the real problem that I have and still make it a minimum working example.

Problem Setup

%// Setup
xdim = 500; 
ydim = 500; 
ncoils = 8; 
nshots = 4; 
%// matrix size for each pixel is ncoils x nshots (an overdetermined system)

%// each pixel has a matrix stored in the 3rd and 4rth dimensions
regressor = randn(xdim,ydim, ncoils,nshots); 
regressand = randn(xdim, ydim,ncoils); 

So my problem is that I have to do a (X'*X)^-1*X'*Y (least-squares or similar) operation for every pixel in an image. While that itself is vectorized/matrixized the only way that I have to do it for every pixel is in a for loop, like:

Original code style

%// Actual work
tic 
estimate = zeros(xdim,ydim);
for col=1:size(regressor,2)
    for row=1:size(regressor,1)

        X = squeeze(regressor(row,col,:,:));
        Y = squeeze(regressand(row,col,:));

        B = X\Y; 
        % B = (X'*X)^(-1)*X'*Y; %// equivalently

        estimate(row,col) = B(1);
    end
end
toc

Elapsed time = 27.6 seconds

EDITS in reponse to comments and other ideas
I tried some things:
1. Reshaped into a long vector and removed the double for loop. This saved some time.
2. Removed the squeeze (and in-line transposing) by permute-ing the picture before hand: This save alot more time.

Current example:

%// Actual work
tic 
estimate2 = zeros(xdim*ydim,1);
regressor_mod = permute(regressor,[3 4 1 2]);
regressor_mod = reshape(regressor_mod,[ncoils,nshots,xdim*ydim]);
regressand_mod = permute(regressand,[3 1 2]);
regressand_mod = reshape(regressand_mod,[ncoils,xdim*ydim]);

for ind=1:size(regressor_mod,3) % for every pixel

    X = regressor_mod(:,:,ind);
    Y = regressand_mod(:,ind);

    B = X\Y;

    estimate2(ind) = B(1);

end
estimate2 = reshape(estimate2,[xdim,ydim]);
toc

Elapsed time = 2.30 seconds (avg of 10)
isequal(estimate2,estimate) == 1;

Rody Oldenhuis's way

N  = xdim*ydim*ncoils;  %// number of columns
M  = xdim*ydim*nshots;    %// number of rows

ii = repmat(reshape(1:N,[ncoils,xdim*ydim]),[nshots 1]); %//column indicies
jj = repmat(1:M,[ncoils 1]); %//row indicies

X = sparse(ii(:),jj(:),regressor_mod(:));
Y = regressand_mod(:);

B = X\Y;

B = reshape(B(1:nshots:end),[xdim ydim]);

Elapsed time = 2.26 seconds (avg of 10) 
            or 2.18 seconds (if you don't include the definition of N,M,ii,jj)

SO THE QUESTION IS:
Is there an (even) faster way?

(I don't think so.)


Solution

  • Results

    I sped up your original version, since your edit 3 was actually not working (and also does something different).

    So, on my PC:

    Your (original) version: 8.428473 seconds.
    My obfuscated one-liner given below: 0.964589 seconds.

    First, for no other reason than to impress, I'll give it as I wrote it:

    %%// Some example data
    xdim = 500; 
    ydim = 500; 
    n_timepoints = 10; % for example
    estimate = zeros(xdim,ydim); %// initialization with explicit size
    
    picture = randn(xdim,ydim,n_timepoints);
    
    
    %%// Your original solution
    %// (slightly altered to make my version's results agree with yours)
    
    tic
    
    Y = randn(n_timepoints,xdim*ydim);
    ii = 1;
    for x = 1:xdim
        for y = 1:ydim
    
            X = squeeze(picture(x,y,:)); %// or similar creation of X matrix
    
            B = (X'*X)^(-1)*X' * Y(:,ii);
            ii = ii+1;
    
            %// sometimes you keep everything and do
            %// estimate(x,y,:) = B(:);
            %// sometimes just the first element is important and you do
            estimate(x,y) = B(1);
    
        end
    end
    
    toc
    
    
    %%// My version 
    
    tic
    
    %// UNLEASH THE FURY!!
    estimate2 = reshape(sparse(1:xdim*ydim*n_timepoints, ...
        builtin('_paren', ones(n_timepoints,1)*(1:xdim*ydim),:), ...
        builtin('_paren', permute(picture, [3 2 1]),:))\Y(:), ydim,xdim).';  %'
    
    toc
    
    %%// Check for equality
    
    max(abs(estimate(:)-estimate2(:)))  % (always less than ~1e-14)
    

    Breakdown

    First, here's the version that you should actually use:

    %// Construct sparse block-diagonal matrix
    %// (Type "help sparse" for more information)
    N  = xdim*ydim;      %// number of columns
    M  = N*n_timepoints; %// number of rows
    ii = 1:N;
    jj = ones(n_timepoints,1)*(1:N);
    s  = permute(picture, [3 2 1]);
    X  = sparse(ii,jj(:), s(:));
    
    %// Compute ALL the estimates at once
    estimates = X\Y(:);
    
    %// You loop through the *second* dimension first, so to make everything
    %// agree, we have to extract elements in the "wrong" order, and transpose:
    estimate2 = reshape(estimates, ydim,xdim).';  %'
    

    Here's an example of what picture and the corresponding matrix X looks like for xdim = ydim = n_timepoints = 2:

    >> clc, picture, full(X)
    
    picture(:,:,1) =
       -0.5643   -2.0504
       -0.1656    0.4497
    picture(:,:,2) =
        0.6397    0.7782
        0.5830   -0.3138
    
    ans =
       -0.5643         0         0         0
        0.6397         0         0         0
             0   -2.0504         0         0
             0    0.7782         0         0
             0         0   -0.1656         0
             0         0    0.5830         0
             0         0         0    0.4497
             0         0         0   -0.3138
    

    You can see why sparse is necessary -- it's mostly zeros, but will grow large quickly. The full matrix would quickly consume all your RAM, while the sparse one will not consume much more than the original picture matrix does.

    With this matrix X, the new problem

    X·b = Y
    

    now contains all the problems

    X1 · b1 = Y1
    X2 · b2 = Y2
    ...
    

    where

    b = [b1; b2; b3; ...]
    Y = [Y1; Y2; Y3; ...]
    

    so, the single command

    X\Y
    

    will solve all your systems at once.

    This offloads all the hard work to a set of highly specialized, compiled to machine-specific code, optimized-in-every-way algorithms, rather than the interpreted, generic, always-two-steps-away from the hardware loops in MATLAB.

    It should be straightforward to convert this to a version where X is a matrix; you'll end up with something like what blkdiag does, which can also be used by mldivide in exactly the same way as above.