Search code examples
pythonpytorchhistogramtensor

How to get a Histogram of PyTorch tensors in batches?


Is there a way to get the histograms of torch tensors in batches?

For Example: x is a tensor of shape (64, 224, 224)

# x will have shape of (64, 256)
x = batch_histogram(x, bins=256, min=0, max=255)

Solution

  • As suggested in the Pytorch Issues#99719, you can do this by torch.Tensor.scatter_add_. scatter_add_ is more memory efficient than torch.nn.functional.one_hot.

    Similar to @user118967's answer:

    # https://github.com/pytorch/pytorch/issues/99719#issuecomment-1664135524 
    def batch_histogram(data_tensor, num_classes=-1):
        """
        Computes histograms, even if in batches (as opposed to torch.histc and torch.histogram).
        Arguments:
            data_tensor: a D1 x ... x D_n torch.LongTensor
            num_classes (optional): the number of classes present in data.
                                    If not provided, tensor.max() + 1 is used (an error is thrown if tensor is empty).
        Returns:
            A D1 x ... x D_{n-1} x num_classes 'result' torch.LongTensor,
            containing histograms of the last dimension D_n of tensor,
            that is, result[d_1,...,d_{n-1}, c] = number of times c appears in tensor[d_1,...,d_{n-1}].
        """
        maxd = data_tensor.max()
        nc = (maxd+1) if num_classes <= 0 else num_classes
        hist = torch.zeros((*data_tensor.shape[:-1], nc), dtype=data_tensor.dtype, device=data_tensor.device)
        ones = torch.tensor(1, dtype=hist.dtype, device=hist.device).expand(data_tensor.shape)
        hist.scatter_add_(-1, ((data_tensor * nc) // (maxd+1)).long(), ones)
        return hist
    

    with the test cases in Google colab here