Search code examples
pythonpytorch

PyTorch custom transformation with additional argument in __call__


I have a custom dataset that I want to train a neural network on. A sample of the dataset might be [1,2,3,4] and the corresponding time axis is then for example [0, 0.2, 0.4, 0.6].

This time axis is different for every sample in the dataset and is needed for certain transformations.

I only want to train the neural network on the actual data ([1,2,3,4]). Therefore in my custom transformation I need to pass in an additional time list only used for that transformation. However I have not found any example of how to accomplish this.

I have read https://pytorch.org/tutorials/beginner/data_loading_tutorial.html but in their transformation the __call__ always only takes the "sample" as input like this:

def __call__(self, sample):

I could pass the time axis as part of the sample, but then wouldn't the neural network also train on the time axis? Which I do not want.

How can I accomplish passing the time axis to the call function for a custom PyTorch transformation without training on the time data?


Solution

  • You control how the transformations are called in the dataset, so if you write your own dataset you can transform your sample with whatever extra data you want directly in __getitem__.

    If you want to follow the model of separating your transforms from your dataset (which is probably a good practice), then you can write your dataset to expect transforms that take both sample and time-axis. Since torchvision's built-in transforms don't expect the time-axis you can write a wrapper to apply them only to the sample argument. One caveat is that if we want to continue using torchvision's Compose transform then we need our transforms to take a single argument. We could write a custom compose pretty easily but it's a bit easier IMO to just pack all the arguments into a single tuple argument.

    An incomplete example (you need to fill in ... sections) might look something like this

    from torch.utils.data import Dataset, DataLoader
    from torchvision import transforms, utils
    
    
    class TransformWrapper:
        """ Wraps a transform that operates on only the sample """
        def __init__(self, t):
            self.t = t
    
        def __call__(self, data):
            """
                data: tuple containing both sample and time_axis
                returns a tuple containing the transformed sample and original time_axis
            """
            sample, time_axis = data
            return self.t(sample), time_axis
    
    
    class CustomTransform:
        """ a custom transform dependent on time axis """
        def __init__(self, ...):
            ...
    
        def __call__(self, data):
            sample, time_axis = data
            new_sample = ... # some function of sample and time_axis
            return new_sample, time_axis
    
    
    class MyDataset(Dataset):
        def __init__(self, root, transform=None):
            """
                root: ...
                transform: A transform that operates on a tuple containing sample and time_index
            """
            ... # init dataset
            self.transform = transform
    
        def __getitem__(self, index):
            sample, time_axis = self.get_data(index)
            if self.transform is not None:
                # transform operates on a tuple containing both sample and time_axis
                sample, time_axis = self.transform((sample, time_axis))
    
            # dataset doesn't need to return time_axis
            return sample
    
        def get_data(self, index):
            ... # load and return sample and time_axis at index
    
        def __len__(self):
            ... # returns length of data
    
    
    # example of how to compose wrapped transforms
    dataset = MyDataset(
        root=...,
        transform=transforms.Compose([
            TransformWrapper(transforms.Rescale(256)),
            TransformWrapper(transforms.RandomCrop(224)),
            CustomTransform(...),
            TransformWrapper(transforms.ToTensor())
        ]))
    
    loader = DataLoader(dataset, ...)
    
    # train loop
    for samples in loader:
        ...