I am using PyTorch (1.8). Is there a clever way to take an element-wise max over all data points with the same output index?
Let's say I have a data tensor of size (N, M), and an index tensor of size (N,) containing indices [0, K). Now I want to bin the data tensor into a tensor of size (K, M) according to the index values, but if two or more datapoints are binned into the same slot, then I want to keep an element-wise max.
I've seen a naive approach like the one below, but doesn't give the element-wise max but just stores whatever is binned last.
data = torch.randn((N, M))
index = torch.randint(K, (N,))
output = torch.zeros((K, M))
output[index] = data
At the moment I am implementing a custom cuda kernel to solve this issue, but would like to know if this can be solved with standard PyTorch.
Edit: Minimal example:
data = torch.tensor([[10,1],[9,2],[8,3],[7,4],[6,5]])
index = torch.tensor([2,1,0,1,2], dtype=torch.long)
# something happens
# expected output:
# [[8, 3], [9, 4], [10, 5]]
PyTorch doesn't seem to have a native implementation for this yet, but there is a repository which does exactly this. PyTorch Scatter
What I was describing seems to correspond with scatter_max.
from torch_scatter import scatter_max
scatter_max(data, index, dim=0)