Search code examples
pythonpytorchtensor

Applying torch.all to every dimension except the first


I am calculating my accuracy like

(outputs.round() == targets).all(dim=2).all(dim=1).sum().item() / outputs.shape[0]

with outputs and targets being of shape NxAxB. N is the batch size. The remaining part is the prediction/truth value, where I want to see if they are identical.

Currently I am using .all(dim=2).all(dim=1). The problem now is if I have a different model, the shapes will differ. They will be NxA, so my current approach doesn't work because dim=2 is not present.

(outputs.round() == targets).all(dim=1).sum().item() / outputs.shape[0]

, would work, but then again only for the second model.

Ideally I want to apply .all on everything BUT the fist dimension (batch dimension). How would I do this?


Solution

  • To generalize to any number of dimensions, you could flatten the boolean tensor from the dim=1 outwards using torch.flatten, then apply all, and mean:

    >>> (outputs.round() == targets).flatten(1).all(1).float().mean()
    

    Note: torch.flatten(dim=1) will flatten the tensor from dim=1 to dim=-1.