Search code examples
matlabperformancematrixoptimizationmatrix-indexing

Select entries of matrix based on given locations


I have following matrix (MxN where M ≤ N):

0.8147    0.9134    0.2785    0.9649
0.9058    0.6324    0.5469    0.1576
0.1270    0.0975    0.9575    0.9706

From each row, I want to select following column entries respectively (one per row):

idx = [ 3  1  4 ];

This means we keep the elements in (1,3), (2,1) and (3,4), and the rest of the array should be zeros.

For the example above, I would get the following output:

     0         0    0.2785         0
0.9058         0         0         0
     0         0         0    0.9706

I currently generate this using a loop, which gets slow when the matrix size is bigger.

Can anyone suggest a more performant approach?


Solution

  • There is some discussion in the other answers/comments about performance. This is one of the situations where a simple (well constructed) for loop will do the job nicely, with basically no impact on performance.

    % For some original matrix 'm', and column indexing array 'idx':
    x = zeros( size(m) ); % Initialise output of zeros
    for ii = 1:numel(idx) % Loop over indices
        % Assign the value at the column index for this row
        x( ii, idx(ii) ) = m( ii, idx(ii) );     
    end
    

    This code is highly readable and quick. To justify "quick", I've written the below benchmarking code for all 4 current answers' methods, run on MATLAB R2017b. Here are the output plots.

    • For "small" matrices, up to 2^5 columns and 2^4 rows:

      small mats

    • For "large" matrices, up to 2^15 columns and 2^14 rows (same plot with and without the bsxfun solution because it ruins the scaling):

      large mats with bsxfun

      large mats

    The first plot is perhaps slightly misleading. Although a consistent result (in that the performance ranking slow-fast is bsxfun then sub2ind then manual indices then looping), the y axis is 10^(-5) seconds, so it's basically immaterial which method you're using!

    The second plot shows that, for large matrices, the methods are basically equivalent, except for bsxfun which is terrible (and not shown here, but it requires much more memory).

    I'd opt for the clearer loop, it allows you more flexibility and you'll remember exactly what it's doing in your code 2 years from now.


    Benchmarking code:

    function benchie() 
        K = 5;                      % Max loop variable
        T = zeros( K, 4 );          % Timing results
        for k = 1:K
            M = 2^(k-1); N = 2^k;   % size of matrix
            m = rand( M, N );       % random matrix
            idx = randperm( N, M ); % column indices
    
            % Define anonymous functions with no inputs for timeit, and run
            f1 = @() f_sub2ind( m, idx ); T(k,1) = timeit(f1);
            f2 = @() f_linear( m, idx );  T(k,2) = timeit(f2);
            f3 = @() f_loop( m, idx );    T(k,3) = timeit(f3);   
            f4 = @() f_bsxfun( m, idx );  T(k,4) = timeit(f4);   
        end
        % Plot results
        plot( (1:K)', T, 'linewidth', 2 );
        legend( {'sub2ind', 'linear', 'loop', 'bsxfun'} );
        xlabel( 'k, where matrix had 2^{(k-1)} rows and 2^k columns' );
        ylabel( 'function time (s)' )
    end
    
    function f_sub2ind( m, idx )
        % Using the in-built sub2ind to generate linear indices, then indexing
        lin_idx = sub2ind( size(m), 1:numel(idx), idx );
        x = zeros( size(m) );
        x( lin_idx ) = m( lin_idx );
    end
    function f_linear( m, idx )
        % Manually calculating linear indices, then indexing
        lin_idx = (1:numel(idx)) + (idx-1)*size(m,1);
        x = zeros( size(m) );
        x( lin_idx ) = m( lin_idx );
    end
    function f_loop( m, idx )
        % Directly indexing in a simple loop
        x = zeros( size(m) );
        for ii = 1:numel(idx)
            x( ii, idx(ii) ) = m( ii, idx(ii) );
        end
    end
    function f_bsxfun( m, idx )
        % Using bsxfun to create a logical matrix of desired elements, then masking
        % Since R2016b, can use 'x = ( (1:size(m,2)) == idx(:) ) .* m;'
        x = bsxfun(@eq, 1:size(m,2), idx(:)).*m;
    end