Search code examples
arraysmatlabmultidimensional-arraymaxargmax

Argmax of a multidimensional array along a subset of dimensions in Matlab


Say, Y is a 7-dimensional array, and I need an efficient way to maximize it along the last 3 dimensions, that will work on GPU. As a result I need a 4-dimensional array with maximal values of Y and three 4-dimensional arrays with the indices of these values in the last three dimensions. I can do

[Y7, X7] = max(Y , [], 7);
[Y6, X6] = max(Y7, [], 6);
[Y5, X5] = max(Y6, [], 5);

Then I have already found the values (Y5) and the indices along the 5th dimension (X5). But I still need indices along the 6th and 7th dimensions.


Solution

  • Here's a way to do it. Let N denote the number of dimensions along which to maximize.

    1. Reshape Y to collapse the last N dimensions into one.
    2. Maximize along the collapsed dimensions. This gives argmax as a linear index over those dimensions.
    3. Unroll the linear index into N subindices, one for each dimension.

    The following code works for any number of dimensions (not necessarily 7 and 3 as in your example). To achieve that, it handles the size of Y generically and uses a comma-separated list obtained from a cell array to get N outputs from sub2ind.

    Y = rand(2,3,2,3,2,3,2); % example 7-dimensional array
    N = 3; % last dimensions along which to maximize
    D = ndims(Y);
    sz = size(Y);
    [~, ind] = max(reshape(Y, [sz(1:D-N) prod(sz(D-N+1:end))]), [], D-N+1);
    sub = cell(1,N);
    [sub{:}] = ind2sub(sz(D-N+1:D), ind);
    

    As a check, after running the above code, observe for example Y(2,3,1,2,:) (shown as a row vector for convenience):

    >> reshape(Y(2,3,1,2,:), 1, [])
    ans =
        0.5621    0.4352    0.3672    0.9011    0.0332    0.5044    0.3416    0.6996    0.0610    0.2638    0.5586    0.3766
    

    The maximum is seen to be 0.9011, which occurs at the 4th position (where "position" is defined along the N=3 collapsed dimensions). In fact,

    >> ind(2,3,1,2)
    ans =
         4
    >> Y(2,3,1,2,ind(2,3,1,2))
    ans =
        0.9011
    

    or, in terms of the N=3 subindices,

    >> Y(2,3,1,2,sub{1}(2,3,1,2),sub{2}(2,3,1,2),sub{3}(2,3,1,2))
    ans =
        0.9011