Search code examples
matlabmultidimensional-arraymatrix-indexing

3d matrix: how to use (row, column) pairs with 3rd dimension wildcard in MATLAB?


I have a 3 dimensional matrix, and a list of (row, column) pairs. I would like to extract the 2 dimensional matrix that corresponds to the elements in those positions, projected through the depth of the matrix. For instance, suppose,

>> a = rand(4, 3, 2)
a(:,:,1) =
    0.5234    0.7057    0.0282
    0.6173    0.2980    0.9041
    0.7337    0.9380    0.9639
    0.0591    0.8765    0.1693
a(:,:,2) =
    0.8803    0.2094    0.5841
    0.7151    0.9174    0.6203
    0.7914    0.7674    0.6194
    0.2009    0.2542    0.3600
>> rows = [1 4 2 1];
>> cols = [1 2 1 3];

What I'd like to get is,

0.5234    0.8765    0.6173    0.0282
0.8803    0.2542    0.7151    0.5841

maybe with some permutation of dimensions. Also, although this example has the wildcard in the last dimension, I also have cases where it's in the first or second.

I naively tried a(rows, cols, :) and got a 3d matrix where the diagonal plane is what I want. I also found sub2ind, which will extract the desired elements from the a(:,:,1) plane. I could work with one of these to get to what I want, but I'm wondering is there a more canonical, elegant, or efficient method that I'm missing?

Update

This was the solution I used, based on the answer posted below,

sz = size(a);
subs = [repmat(rows, [1, sz(3)]);
     repmat(cols, [1, sz(3)]);
     repelem([1:sz(3)], length(rows))];
result = a(sub2ind(sz, subs(1,:), subs(2,:), subs(3,:)));

Solution

  • sub2ind is pretty much what you have to use here to convert your subscripts into linear indices (apart from manually computing the linear indices yourself). You can do something like the following which will convert the rows and cols to a linear index (in a 2D slice) and then it adds an offset (equal to the number of elements in a 2D slice) to these indices to sample all elements in the third dimension.

    sz = size(a);
    inds = sub2ind(sz(1:2), rows, cols);
    inds = bsxfun(@plus, inds, (0:(sz(3)-1)).' * prod(sz(1:2)));
    result = a(inds);
    

    And to actually compute the linear indices yourself

    inds = (cols - 1) * sz(1) + rows;
    inds = bsxfun(@plus, inds, (0:(sz(3) - 1)).' * prod(sz(1:2)));
    result = a(inds);
    

    Another option would be to permute your initial matrix to bring the third dimension to the first dimension, reshape it to a 2D matrix, and then use the linear index as the second subscript

    % Create a new temporary matrix
    anew = reshape(permute(a, [3, 1, 2]), size(a, 3), []);
    
    % Grab all rows (the 3rd dimension) and compute the columns to grab
    result = anew(:, (cols - 1) * size(a, 1) + rows);