I have a custom dataset that I want to train a neural network on. A sample of the dataset might be [1,2,3,4] and the corresponding time axis is then for example [0, 0.2, 0.4, 0.6].
This time axis is different for every sample in the dataset and is needed for certain transformations.
I only want to train the neural network on the actual data ([1,2,3,4]). Therefore in my custom transformation I need to pass in an additional time list only used for that transformation. However I have not found any example of how to accomplish this.
I have read https://pytorch.org/tutorials/beginner/data_loading_tutorial.html but in their transformation the __call__
always only takes the "sample" as input like this:
def __call__(self, sample):
I could pass the time axis as part of the sample, but then wouldn't the neural network also train on the time axis? Which I do not want.
How can I accomplish passing the time axis to the call function for a custom PyTorch transformation without training on the time data?
You control how the transformations are called in the dataset, so if you write your own dataset you can transform your sample with whatever extra data you want directly in __getitem__
.
If you want to follow the model of separating your transforms from your dataset (which is probably a good practice), then you can write your dataset to expect transforms that take both sample and time-axis. Since torchvision's built-in transforms don't expect the time-axis you can write a wrapper to apply them only to the sample argument. One caveat is that if we want to continue using torchvision's Compose
transform then we need our transforms to take a single argument. We could write a custom compose pretty easily but it's a bit easier IMO to just pack all the arguments into a single tuple argument.
An incomplete example (you need to fill in ...
sections) might look something like this
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
class TransformWrapper:
""" Wraps a transform that operates on only the sample """
def __init__(self, t):
self.t = t
def __call__(self, data):
"""
data: tuple containing both sample and time_axis
returns a tuple containing the transformed sample and original time_axis
"""
sample, time_axis = data
return self.t(sample), time_axis
class CustomTransform:
""" a custom transform dependent on time axis """
def __init__(self, ...):
...
def __call__(self, data):
sample, time_axis = data
new_sample = ... # some function of sample and time_axis
return new_sample, time_axis
class MyDataset(Dataset):
def __init__(self, root, transform=None):
"""
root: ...
transform: A transform that operates on a tuple containing sample and time_index
"""
... # init dataset
self.transform = transform
def __getitem__(self, index):
sample, time_axis = self.get_data(index)
if self.transform is not None:
# transform operates on a tuple containing both sample and time_axis
sample, time_axis = self.transform((sample, time_axis))
# dataset doesn't need to return time_axis
return sample
def get_data(self, index):
... # load and return sample and time_axis at index
def __len__(self):
... # returns length of data
# example of how to compose wrapped transforms
dataset = MyDataset(
root=...,
transform=transforms.Compose([
TransformWrapper(transforms.Rescale(256)),
TransformWrapper(transforms.RandomCrop(224)),
CustomTransform(...),
TransformWrapper(transforms.ToTensor())
]))
loader = DataLoader(dataset, ...)
# train loop
for samples in loader:
...