TorchVision v2 transform - Can't get transform to work (or just plot...?)

I am working on training a DeepLabV3+ model to segment CT images, as a test I want to visualise the transformations applied. However I cannot get the transforms to plot/save.

The slice has been transformed to [3,512,512] while the multilabel stack is [21,512,512]. My current transform block is (I excluded the multilabel stack to simplify my problem):

if self.augment:

            transform = v2.Compose([
            self.scan_slice = transform(self.scan_slice)

An I use the following block to plot and save the slice combined with segmentations, n_samples is the batch size.

        fig, ax = plt.subplots(n_samples, 1, figsize=(10,10*n_samples))
    for i in range(n_samples):
        scan = sample['scan'][i].numpy()
        sub_structures = sample['structures'][i].numpy()
        #Reshape channels to last dimension
        scan = np.moveaxis(scan, 0, -1).astype(np.int16)
        cmap = 'gray'
        alpha = 1
        ax[i].imshow(scan, cmap=cmap, alpha=alpha)
        cmap = 'rainbow'
        alpha = 0.5

        for j in range(sub_structures.shape[0]):
            sub_structures[j,:,:] = np.where(sub_structures[j,:,:] < 1.0, np.nan, i)
        sub_structures = np.nansum(sub_structures, axis=0)
        ax[i].imshow(sub_structures, cmap=cmap, alpha=alpha)
    plt.savefig('test.png', bbox_inches='tight', pad_inches=0)

I have been working through numerous solutions but cannot pinpoint my mistake. Simply transforming the self.scan_slice pixels to 1000 using numpy shows that my transform block is functional. However, the TorchVision V2 transforms don't seem to get activated.

  • The solution came to me a couple of days later, and to be honest, should have ben clear from the start. The transformations (mostly) work on tensors. So I changed the images from np-array to tensor and they work.

    img = torch.from_numpy(img)