Search code examples
pythonmultidimensional-arraymaxpytorchtensor

PyTorch torch.max over multiple dimensions


Have tensor like :x.shape = [3, 2, 2].

import torch

x = torch.tensor([
    [[-0.3000, -0.2926],[-0.2705, -0.2632]],
    [[-0.1821, -0.1747],[-0.1526, -0.1453]],
    [[-0.0642, -0.0568],[-0.0347, -0.0274]]
])

I need to take .max() over the 2nd and 3rd dimensions. I expect some like this [-0.2632, -0.1453, -0.0274] as output. I tried to use: x.max(dim=(1,2)), but this causes an error.


Solution

  • Now, you can do this. The PR was merged (Aug 28 2020) and it is now available in the nightly release.

    Simply use torch.amax():

    import torch
    
    x = torch.tensor([
        [[-0.3000, -0.2926],[-0.2705, -0.2632]],
        [[-0.1821, -0.1747],[-0.1526, -0.1453]],
        [[-0.0642, -0.0568],[-0.0347, -0.0274]]
    ])
    
    print(torch.amax(x, dim=(1, 2)))
    
    # Output:
    # >>> tensor([-0.2632, -0.1453, -0.0274])
    

    Original Answer

    As of today (April 11, 2020), there is no way to do .min() or .max() over multiple dimensions in PyTorch. There is an open issue about it that you can follow and see if it ever gets implemented. A workaround in your case would be:

    import torch
    
    x = torch.tensor([
        [[-0.3000, -0.2926],[-0.2705, -0.2632]],
        [[-0.1821, -0.1747],[-0.1526, -0.1453]],
        [[-0.0642, -0.0568],[-0.0347, -0.0274]]
    ])
    
    print(x.view(x.size(0), -1).max(dim=-1))
    
    # output:
    # >>> values=tensor([-0.2632, -0.1453, -0.0274]),
    # >>> indices=tensor([3, 3, 3]))
    

    So, if you need only the values: x.view(x.size(0), -1).max(dim=-1).values.

    If x is not a contiguous tensor, then .view() will fail. In this case, you should use .reshape() instead.


    Update August 26, 2020

    This feature is being implemented in PR#43092 and the functions will be called amin and amax. They will return only the values. This is probably being merged soon, so you might be able to access these functions on the nightly build by the time you're reading this :) Have fun.