Given data and mask tensors are there a pytorch-way to obtain masked aggregations of data (mean, max, min, etc.)?
x = torch.tensor([
[1, 2, -1, -1],
[10, 20, 30, -1]
])
mask = torch.tensor([
[True, True, False, False],
[True, True, True, False]
])
To compute a masked mean I can do the following, yet are there any pytorch built-in or commonly used package to do that?
n_mask = torch.sum(mask, axis=1)
x_mean = torch.sum(x * mask, axis=1) / n_mask
print(x_mean)
> tensor([ 1.50, 20.00])
If you don't want to use torch.masked
due to it being in prototype stage, you can use scatter_reduce
to aggregate based on sum, prod, mean, amax and amin.
x = torch.tensor([
[1, 2, -1, -1],
[10, 20, 30, -1]
]).float() # note you'll need to cast to float for this to work
mask = torch.tensor([
[True, True, False, False],
[True, True, True, False]
])
rows, cols = mask.nonzero().T
for reduction in ['mean', 'sum', 'prod', 'amax', 'amin']:
output = torch.zeros(x.shape[0], device=x.device, dtype=x.dtype)
output = output.scatter_reduce(0, rows, x[rows, cols], reduce=reduction, include_self=False)
print(f"{reduction}\t{output}")
# # printed output:
# mean tensor([ 1.5000, 20.0000])
# sum tensor([ 3., 60.])
# prod tensor([2.0000e+00, 6.0000e+03])
# amax tensor([ 2., 30.])
# amin tensor([ 1., 10.])