Hi is there any method for apply trasnformation for certain batch?
It means, I want apply trasnformation for just last batch in every epochs.
What I tried is here
import torch
class test(torch.utils.data.Dataset):
def __init__(self):
self.source = [i for i in range(10)]
def __len__(self):
return len(self.source)
def __getitem__(self, idx):
print(idx)
return self.source[idx]
ds = test()
dl = torch.utils.data.DataLoader(dataset = ds, batch_size = 3,
shuffle = False, num_workers = 5)
for i in dl:
print(i)
because I thought that if I could get idx number, it would be possible to apply for certain batchs.
However If using num_workers outputs are
0
1
2
3
964
57
8
tensor([0, 1, 2])
tensor([3, 4, 5])
tensor([6, 7, 8])
tensor([9])
which are not I thought
without num_worker
0
1
2
tensor([0, 1, 2])
3
4
5
tensor([3, 4, 5])
6
7
8
tensor([6, 7, 8])
9
tensor([9])
So the question is
When you have num_workers
> 1, you have multiple subprocesses doing data loading in parallel. So what is likely happening is that there is a race condition for the print step, and the order you see in the output depends on which subprocess goes first each time.
For most transforms, you can apply them on a specific batch simply by calling the transform after the batch has been loaded. To do this just for the last batch, you could do something like:
for batch_idx, batch_data in dl:
# check if batch is the last batch
if ((batch_idx+1) * batch_size) >= len(ds):
batch_data = transform(batch_data)