Search code examples
pythonpytorchaggregationmaskmasking

Masked aggregations in pytorch


Given data and mask tensors are there a pytorch-way to obtain masked aggregations of data (mean, max, min, etc.)?

x = torch.tensor([
    [1, 2, -1, -1],
    [10, 20, 30, -1]
])

mask = torch.tensor([
    [True, True, False, False],
    [True, True, True, False]
])

To compute a masked mean I can do the following, yet are there any pytorch built-in or commonly used package to do that?

n_mask = torch.sum(mask, axis=1)
x_mean = torch.sum(x * mask, axis=1) / n_mask

print(x_mean)
> tensor([ 1.50, 20.00])

Solution

  • If you don't want to use torch.masked due to it being in prototype stage, you can use scatter_reduce to aggregate based on sum, prod, mean, amax and amin.

    x = torch.tensor([
        [1, 2, -1, -1],
        [10, 20, 30, -1]
    ]).float() # note you'll need to cast to float for this to work
    
    mask = torch.tensor([
        [True, True, False, False],
        [True, True, True, False]
    ])
    
    rows, cols = mask.nonzero().T
    
    for reduction in ['mean', 'sum', 'prod', 'amax', 'amin']:
        output = torch.zeros(x.shape[0], device=x.device, dtype=x.dtype)
        output = output.scatter_reduce(0, rows, x[rows, cols], reduce=reduction, include_self=False)
        print(f"{reduction}\t{output}")
        
    
    # # printed output:
    # mean  tensor([ 1.5000, 20.0000])
    # sum   tensor([ 3., 60.])
    # prod  tensor([2.0000e+00, 6.0000e+03])
    # amax  tensor([ 2., 30.])
    # amin  tensor([ 1., 10.])