Search code examples

How to get the filenames of misclassified images in the test data set - pytorch?

I am trying to use AlexNet to classify spectrogram images generated for 3s audio segments. I have successfully trained my dataset and am trying to identify which images the model misclassified.

I am able to obtain the filename of a image by calling iterator.dataset.data_df.adressfname, however I am unsure how to use this statement in the for loop in the get_predictions function. If I try to retrieve the filenames using iterator.dataset.data_df.adressfname[i], I get the following error: expected Tensor as element 0 in argument 0, but got str

Ultimately, I want to create a dataframe that contains the filename, actual label and predicted label. Does anyone have any suggestions?

class CustomDataset(Dataset):
  def __init__(self, img_path, csv_file, transforms):
    self.imgs_path = img_path
    self.csv_train_file = csv_file
    file_list = glob.glob(self.imgs_path + "*") = []
    self.data_df = pd.read_csv(self.csv_train_file)
    self.transforms = transforms

    for ind in self.data_df.index:
      img_path = self.data_df['spectrogramSegFilename'][ind]
      class_name = self.data_df['dx'][ind][img_path, class_name])
    self.class_map = {"ProbableAD" : 0, "Control": 1}
    self.img_dim = (256, 256)

  def __len__(self):
    return len(

  def __getitem__(self, idx):
    img_path, class_name =[idx]
    img = cv2.imread(img_path)
    img = cv2.resize(img, self.img_dim)
    class_id = self.class_map[class_name]
    img_tensor = torch.from_numpy(img)
    img_tensor = img_tensor.permute(2, 0, 1)
    data = self.transforms(img_tensor)
    class_id = torch.tensor([class_id])
    return data, class_id

if __name__ == "__main__":
    transformations = transforms.Compose([transforms.ToPILImage(), transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize((0.49966475, 0.1840554, 0.34930056), (0.35317238, 0.17343724, 0.1894943))])
    train_dataset = CustomDataset("/spectrogram_images/spectrogram_train/", "train_features_segmented.csv", transformations)
    test_dataset = CustomDataset("/spectrogram_images/spectrogram_test/", "test_features_segmented.csv", transformations)
    train_dataset = CustomDataset("spectrogram_images/spectrogram_train/", "/train_features_segmented.csv")
    test_dataset = CustomDataset("spectrogram_images/spectrogram_test/", "/test_features_segmented.csv")
    train_data_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_data_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

def get_predictions(model, iterator, device):


    images = []
    labels = []
    probs = []
    participant_ids = []

    with torch.no_grad():

      for i, (x, y) in enumerate(iterator): 

            x =
            y_pred = model(x)
            y_prob = F.softmax(y_pred, dim=-1)
    images =, dim=0)
    labels =, dim=0)
    probs =, dim=0)
    participant_ids =, dim=0)

    return images, labels, probs, participant_ids


  • Just add a third field to the return signature of __getitem__():

    def __getitem__(self,idx): 
        return data,class_id,img_path

    Then, when you call the dataloader:

    for i, (x,y,img_paths) in enumerate(iterator):
        ... call model ...
        ... compare outputs to labels ...
        ... identify incorrect batch indices
        mislabeled_files = [img_paths[idx] for idx in incorrect_batch_indices]

    And there you have it. Note that in your current code block, i indexes the batch index, which really has nothing to do with the dataset object wrapped in the dataloader (both because the dataloader has a batch size so it collates multiple elements from the dataset into a single index, and because the dataloader shuffles the elements in the dataset) so you should not use this index (i) to reference the dataset object. If you wanted to reference the underlying data items in the dataset object you could instead simply return the data idx in __getitem__():

    def __getitem__(self,idx): 
        return data,class_id,idx

    You can then use this idx to reference the original dataset data and get the file names that way.