I'm trying to design a training flow for sampling samples during training.
My data look like this:
defaultdict(list,
{'C1629836-28004480': [0, 5, 6, 12, 17, 19, 28],
'C0021846-28004480': [1, 7, 15],
'C0162832-28004480': [2, 9],
'C0025929-28004480': [3, 10, 30],
'C1515655-28004480': [4],
...
}
where key is label and value is list of data index
I custom dataset class in which my __getitem__(self, idx)
function need to calculate distance between an anchor (which is chosen randomly) and other data points. It looks like this:
def __getitem__(self, idx):
item_label = self.labels[idx] # C1629836-28004480
item_data = self.data[item_label] # [0, 5, 6, 12, 17, 19, 28]
anchor_index = random.sample(item_data,1)
mentions_indices = [idx for idx in item_data if idx != anchor_index]
with torch.no_grad():
self.model.eval()
anchor_input = ...
anchor_embedding = self.model.mention_encoder(anchor_input)
for idx in mention_indices:
...
Another way to prevent from passing the model into custom dataset is to run inference inside the training_step
function during training.
But I read somewhere that, using dataset and dataloader to prepare data to feed into model might save the training time, as they have parallel mechanism or something like that.
But in fact, I need to compute these kind of distance base on the latest state of weight of my model during training, is this parallel mechanism ensure that? Though in python variable is reference variable instead of value variable.
So which way is more professional and correct?
Edit:
I did both approaches and the second approach much faster than the first approach.
It sounds like what you want is to have these embeddings computed on the fly during training. The best approach for this would be to move the model computation outside of the __getitem__
function and into the training loop.
The __getitem__
method should be used for singular tasks that are disk bound or CPU bound. Computing the embeddings is GPU bound and should be done in batch.
Best practice would be to do something like:
__getitem__
method to return the necessary data to compute anchor_embedding
and other quantities used later oncollate_fn
of your DataLoader
to batch the inputs for computing anchor_embedding
model
to compute anchor_embedding
and other quantities in batch