Search code examples
pythonpytorchdatasetsampling

Is it efficient to pass model into a custom dataset to run model inference during training for sampling strategy?


I'm trying to design a training flow for sampling samples during training.

My data look like this:

defaultdict(list,
        {'C1629836-28004480': [0, 5, 6, 12, 17, 19, 28],
         'C0021846-28004480': [1, 7, 15],
         'C0162832-28004480': [2, 9],
         'C0025929-28004480': [3, 10, 30],
         'C1515655-28004480': [4],
         ...
        }

where key is label and value is list of data index

I custom dataset class in which my __getitem__(self, idx) function need to calculate distance between an anchor (which is chosen randomly) and other data points. It looks like this:

def __getitem__(self, idx):
    item_label = self.labels[idx] # C1629836-28004480
    item_data = self.data[item_label] # [0, 5, 6, 12, 17, 19, 28]

    anchor_index = random.sample(item_data,1)
    mentions_indices = [idx for idx in item_data if idx != anchor_index]
    
    with torch.no_grad():
        self.model.eval()
        anchor_input = ...
        anchor_embedding = self.model.mention_encoder(anchor_input)

        for idx in mention_indices: 
        ...

Another way to prevent from passing the model into custom dataset is to run inference inside the training_step function during training.

But I read somewhere that, using dataset and dataloader to prepare data to feed into model might save the training time, as they have parallel mechanism or something like that.

But in fact, I need to compute these kind of distance base on the latest state of weight of my model during training, is this parallel mechanism ensure that? Though in python variable is reference variable instead of value variable.

So which way is more professional and correct?

Edit:

I did both approaches and the second approach much faster than the first approach.


Solution

  • It sounds like what you want is to have these embeddings computed on the fly during training. The best approach for this would be to move the model computation outside of the __getitem__ function and into the training loop.

    The __getitem__ method should be used for singular tasks that are disk bound or CPU bound. Computing the embeddings is GPU bound and should be done in batch.

    Best practice would be to do something like:

    1. use the __getitem__ method to return the necessary data to compute anchor_embedding and other quantities used later on
    2. use the collate_fn of your DataLoader to batch the inputs for computing anchor_embedding
    3. in your training loop, use model to compute anchor_embedding and other quantities in batch