Search code examples
pythonpytorchevaluationmnistgenerative-adversarial-network

How to use torcheval.metrics.FrechetInceptionDistance in pytorch for mnist dataset?


I defined a GAN model and I want to evaluate it using FID score. I have 1 channel images which are mnist dataset but this method wants 3 channels images. How can I do to solve this problem?


Solution

  • try to split it into 3 channels before evaulating.

    import torch
    import torchvision
    from torcheval import metrics
    
    # Load the MNIST dataset
    mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True)
    
    # Convert the 1 channel images to 3 channel images
    mnist_dataset.data = mnist_dataset.data.unsqueeze(1)
    mnist_dataset.data = mnist_dataset.data.repeat(1, 3, 1, 1)
    
    # Calculate the FID score
    fid_score = metrics.FrechetInceptionDistance()(mnist_dataset.data)
    
    # Evaluate the FID score
    print('FID score:', fid_score)`enter code here`