Search code examples
pythonpytorchastropyfits

Loading FITS images with PyTorch


I'm trying to create a CNN using PyTorch but my images need importing from the FITS format rather than conventional .png or .jpeg etc.

Is there a way to accomplish this easily using torch.utils.data.DataLoader or is there a place in the source code where I can put in a clause which will handle FITS files while loading in?

I have looked in the documentation and the most relevant thing I've found is the ToPILImage transformer which converts a tensor or ndarray into a PIL Image.

Currently I'm using an image loading routine as follows:

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision

batch_size = 4

transform = transforms.Compose(
                   [transforms.Resize((32,32)),
                    transforms.ToTensor(),
                    ])

trainset = dset.ImageFolder(root="Documents/Image_data",transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True)

Astropy: http://www.astropy.org/

Pytorch: https://pytorch.org/

torch.utils: https://pytorch.org/docs/master/data.html

UPDATE: Perhaps using torchvision.datasets.DatasetFolder instead of DataLoader, an inserting in my own FITS handler would work?

When trying to use this class I get the following error:

AttributeError: module 'torchvision.datasets' has no attribute 'DatasetFolder'

Is DatasetFolder actually supported by torchvision at this point in time?


Solution

  • From reading some combination of the docs and the code, I don't think you necessarily want to be using ImageFolder since it doesn't know anything about FITS.

    Instead you should try using the more generic DataSetFolder class (which in fact is the parent class of ImageFolder). You would pass it a list of extensions it should handle (i.e. ['.fits'] and a "loader" function that takes a FITS file and, it seems, should return a PIL.Image.

    You could even make your own subclass following the example of ImageFolder. E.g.

    class FitsFolder(DatasetFolder):
    
        EXTENSIONS = ['.fits']
    
        def __init__(self, root, transform=None, target_transform=None,
                     loader=None):
            if loader is None:
                loader = self.__fits_loader
    
            super(FitsFolder, self).__init__(root, loader, self.EXTENSIONS,
                                             transform=transform,
                                             target_transform=target_transform)
    
        @staticmethod
        def __fits_loader(filename):
            data = fits.getdata(filename)
            return Image.fromarray(data)
    

    The exact details of __fits_loader may depend on the details of your FITS files. This basic example just uses the high-level fits.getdata() function which returns the first image array in the FITS file (some FITS files may have many extensions with many images, or have tables etc.). So that part would be up to you.