Search code examples
pythonpytorchtorchvisionpytorch-dataloader

Get image name of pytorch dataset


I am using a custom dataset for image segmentation. While visualizing some of the images and masks i found an error. The problem for me know is, how to find the name of the image. The code i use for the pytorch datasetset creation is:

class SegmentationDataset(Dataset):

  def __init__(self, df, augmentations):
    self.df = df
    self.augmentations = augmentations

  def __len__(self):
    return len(self.df)

  def __getitem__(self, idx):
    row = self.df.iloc[idx]

    image_path = DATA_DIR + row.images
    mask_path = DATA_DIR + row.masks

    image = skimage.io.imread(image_path)

    mask = skimage.io.imread(mask_path) 
    mask = np.expand_dims(mask, axis = -1) 

    if self.augmentations:
      data = self.augmentations(image = image, mask = mask)
      image = data['image'] 
      mask = data['mask']

    image = np.transpose(image, (2, 0, 1)).astype(np.float32)
    mask = np.transpose(mask, (2, 0, 1)).astype(np.float32)

    image = torch.Tensor(image) / 255.0
    mask = torch.round(torch.Tensor(mask) / 255.0)

    return image, mask


trainset = SegmentationDataset(train_df, get_train_augs())
validset = SegmentationDataset(valid_df, get_valid_augs())

When i then print one specific image, i see that the mask is not available/wrong:

idx = 9
print('Drawn sample ID:', idx)

image, mask = validset[idx]
show_image(image, mask)

How do i now get the image name of this idx = 9?


Solution

  • I'd imagine you could print out one of the following, under this line image = skimage.io.imread(image_path), it should help lead you to your answer:

    print(row)
    print(row.images)
    print(images)
    print(image_path)
    

    To get the file name after you have parsed the fully quaified path above:

    my_str = '/my/data/path/images/wallpaper.jpg'
    
    result = my_str.rsplit('/', 1)[1]
    print(result)  # 'wallpaper.jpg'
    
    with_slash = '/' + my_str.rsplit('/', 1)[1]
    print(with_slash)  # '/wallpaper.jpg'
    
    ['/my/data/path/images/', 'wallpaper.jpg']
    print(my_str.rsplit('/', 1)[1])