My data has several conditions A, B, C. I would like to do the following.
Thus, I would have in one batch something like
[condition_A, condition_B, condition_C, random_sample]
I have created a dictionary of the form
loader_dict = {
cond_A : DataLoader(...Subset Magic...),
cond_B : DataLoader(...Subset Magic...),
cond_C : DataLoader(...Subset Magic...)
}
train_loader = DataLoader(...full dataset...)
Now during each epoch I would like to
Currently, I am a bit stuck on the 1st point.
My approach so far is
# get a list of form [loader_A, loader_B, loader_C]
train_loaders = list(zip(*loader_dict.values()))
for batch_idx, batch in enumerate(tqdm(train_loader)):
condit_sample = [next(loader) for loader in train_loaders]
# do something with torch.cat([batch, condit_sample])
Now I am not sure - will the next()
call actually always just pick the first batch of the conditions loaders (not desired) or will it actually iterate through the samples of the conditions?
Also, my data has something like 50% condition_A, 35% condition_B, 15% condition_C
Thus, I wonder, whether my code would run e.g. through all 100 batches of the full dataset and repeat condition_A twice, condition_B nearly 3 times and condition_C 6 times? Or will the code just run through all samples of condition C and break down?
Currently, the multiple cycling through the conditional samples would suffice.
For later purposes, I would like to consider the following:
Made the experiment myself. It will behave as itertools.cycle()