Search code examples

PyTorch custom transformation with additional argument in __call__

I have a custom dataset that I want to train a neural network on. A sample of the dataset might be [1,2,3,4] and the corresponding time axis is then for example [0, 0.2, 0.4, 0.6].

This time axis is different for every sample in the dataset and is needed for certain transformations.

I only want to train the neural network on the actual data ([1,2,3,4]). Therefore in my custom transformation I need to pass in an additional time list only used for that transformation. However I have not found any example of how to accomplish this.

I have read but in their transformation the __call__ always only takes the "sample" as input like this:

def __call__(self, sample):

I could pass the time axis as part of the sample, but then wouldn't the neural network also train on the time axis? Which I do not want.

How can I accomplish passing the time axis to the call function for a custom PyTorch transformation without training on the time data?


  • You control how the transformations are called in the dataset, so if you write your own dataset you can transform your sample with whatever extra data you want directly in __getitem__.

    If you want to follow the model of separating your transforms from your dataset (which is probably a good practice), then you can write your dataset to expect transforms that take both sample and time-axis. Since torchvision's built-in transforms don't expect the time-axis you can write a wrapper to apply them only to the sample argument. One caveat is that if we want to continue using torchvision's Compose transform then we need our transforms to take a single argument. We could write a custom compose pretty easily but it's a bit easier IMO to just pack all the arguments into a single tuple argument.

    An incomplete example (you need to fill in ... sections) might look something like this

    from import Dataset, DataLoader
    from torchvision import transforms, utils
    class TransformWrapper:
        """ Wraps a transform that operates on only the sample """
        def __init__(self, t):
            self.t = t
        def __call__(self, data):
                data: tuple containing both sample and time_axis
                returns a tuple containing the transformed sample and original time_axis
            sample, time_axis = data
            return self.t(sample), time_axis
    class CustomTransform:
        """ a custom transform dependent on time axis """
        def __init__(self, ...):
        def __call__(self, data):
            sample, time_axis = data
            new_sample = ... # some function of sample and time_axis
            return new_sample, time_axis
    class MyDataset(Dataset):
        def __init__(self, root, transform=None):
                root: ...
                transform: A transform that operates on a tuple containing sample and time_index
            ... # init dataset
            self.transform = transform
        def __getitem__(self, index):
            sample, time_axis = self.get_data(index)
            if self.transform is not None:
                # transform operates on a tuple containing both sample and time_axis
                sample, time_axis = self.transform((sample, time_axis))
            # dataset doesn't need to return time_axis
            return sample
        def get_data(self, index):
            ... # load and return sample and time_axis at index
        def __len__(self):
            ... # returns length of data
    # example of how to compose wrapped transforms
    dataset = MyDataset(
    loader = DataLoader(dataset, ...)
    # train loop
    for samples in loader: