Search code examples
pytorch

groupby aggregate mean in pytorch


I have a 2D tensor:

samples = torch.Tensor([
    [0.1, 0.1],    #-> group / class 1
    [0.2, 0.2],    #-> group / class 2
    [0.4, 0.4],    #-> group / class 2
    [0.0, 0.0]     #-> group / class 0
])

and a label for each sample corresponding to a class:

labels = torch.LongTensor([1, 2, 2, 0])

so len(samples) == len(labels). Now I want to calculate the mean for each class / label. Because there are 3 classes (0, 1 and 2) the final vector should have dimension [n_classes, samples.shape[1]] So the expected solution should be:

result == torch.Tensor([
    [0.1, 0.1],
    [0.3, 0.3], # -> mean of [0.2, 0.2] and [0.4, 0.4]
    [0.0, 0.0]
])

Question: How can this be done in pure pytorch (i.e. no numpy so that I can autograd) and ideally without for loops?


Solution

  • All you need to do is form an mxn matrix (m=num classes, n=num samples) which will select the appropriate weights, and scale the mean appropriately. Then you can perform a matrix multiplication between your newly formed matrix and the samples matrix.

    Given your labels, your matrix should be (each row is a class number, each class a sample number and its weight):

    [[0.0000, 0.0000, 0.0000, 1.0000],
     [1.0000, 0.0000, 0.0000, 0.0000],
     [0.0000, 0.5000, 0.5000, 0.0000]]
    

    Which you can form as follows:

    M = torch.zeros(labels.max()+1, len(samples))
    M[labels, torch.arange(len(samples)] = 1
    M = torch.nn.functional.normalize(M, p=1, dim=1)
    torch.mm(M, samples)
    

    Output:

    tensor([[0.0000, 0.0000],
            [0.1000, 0.1000],
            [0.3000, 0.3000]])
    

    Note that the output means are correctly sorted in class order.

    Why does M[labels, torch.arange(len(samples))] = 1 work?

    This is performing a broadcast operation between the labels and the number of samples. Essentially, we are generating a 2D index for every element in labels: the first specifies which of the m classes it belongs to, and the second simply specifies its index position (from 1 to N). Another way would be top explicitly generate all the 2D indices:

    twoD_indices = []
    for count, label in enumerate(labels):
      twoD_indices.append((label, count))