Search code examples
pythonpytorch

How can I convert a pytorch tensor of an image of the shape (batch, channels, height, width) to (batch, height, width)?


I want to turn a colored image into a grayscale image by removing the color channels. Is this even possible?


Solution

  • You can use the Grayscale transform

    https://pytorch.org/vision/stable/generated/torchvision.transforms.Grayscale.html

    from torchvision.transforms import Grayscale
    
    grayscale_batch = Grayscale()(color_batch)
    

    Which results in a (batch_size, 1, H,W) tensor.

    To remove the channel dimension, apply torch.squeeze()