Search code examples
pythonvectorpytorchtensortorch

Pytorch for loop inefficience


I have a efficient issue with some tensor for loop.

I’m extracting the features from the last layer of a CNN through a image data loader (I’m using batch size 8). Im getting the euclidean distance of the batch tensor and a table with previous features.

I want to add to the table a tensor every time all the tensors in the table are above a threshole. I have implemented a successful running code but the loop i use its not efficient and im wondering how i could do something similar using something more efficient rather than this secuential way.

for i, data in enumerate(dataloader, 0):
  input, label = data
  input, label = input.to(device), label.to(device)
  n,c h,w = input.size()
  outputs = model(input)
  if (i == 0):
    features_list = torch.cat( (features_list, outputs[0].view(1,-1)), 0)
  dist_tensores = torch.cdist(outputs, features_list, p=2.0)
  activation = torch.gt(dist_tensores, AVG, out=torch.cuda.FloatTensor(len(outputs), len(features_list)))
  counter = len(features_list)
  activation_list = torch.sum(activation, dim=0)
  for x in range(len(activation)):
    if (torch.sum(activation[x], dim=0) == counter):
      features_list = torch.cat( (features_list, outputs[x].view(1,-1)), 0)

The last loop is the part i want to change but i really don´t know how to assign and add the tensor i want if it's not by creating a loop where i can control the tensor to add.


Solution

  • idx = activation.sum(1) == counter
    features_list = torch.cat((features_list, outputs[idx]), 0)
    

    This would replace the loop and save computational and inefficiences issues.