Search code examples
pythontransformtorchvision

TorchVision v2 transform - Can't get transform to work (or just plot...?)


I am working on training a DeepLabV3+ model to segment CT images, as a test I want to visualise the transformations applied. However I cannot get the transforms to plot/save.

The slice has been transformed to [3,512,512] while the multilabel stack is [21,512,512]. My current transform block is (I excluded the multilabel stack to simplify my problem):

if self.augment:

            transform = v2.Compose([
                v2.RandomHorizontalFlip(p=0.5),
                v2.RandomVerticalFlip(p=0.5),
                v2.RandomRotation(degrees=(0,15))
            ])
            
            self.scan_slice = transform(self.scan_slice)

An I use the following block to plot and save the slice combined with segmentations, n_samples is the batch size.

        fig, ax = plt.subplots(n_samples, 1, figsize=(10,10*n_samples))
    for i in range(n_samples):
        scan = sample['scan'][i].numpy()
        sub_structures = sample['structures'][i].numpy()
        #Reshape channels to last dimension
        scan = np.moveaxis(scan, 0, -1).astype(np.int16)
        cmap = 'gray'
        alpha = 1
        ax[i].imshow(scan, cmap=cmap, alpha=alpha)
        
        cmap = 'rainbow'
        alpha = 0.5

        for j in range(sub_structures.shape[0]):
            sub_structures[j,:,:] = np.where(sub_structures[j,:,:] < 1.0, np.nan, i)
        
        sub_structures = np.nansum(sub_structures, axis=0)
        ax[i].imshow(sub_structures, cmap=cmap, alpha=alpha)
        ax[i].axis('off')
    plt.savefig('test.png', bbox_inches='tight', pad_inches=0)

I have been working through numerous solutions but cannot pinpoint my mistake. Simply transforming the self.scan_slice pixels to 1000 using numpy shows that my transform block is functional. However, the TorchVision V2 transforms don't seem to get activated.

Could someone point me in the right direction?


Solution

  • The solution came to me a couple of days later, and to be honest, should have ben clear from the start. The transformations (mostly) work on tensors. So I changed the images from np-array to tensor and they work.

    img = torch.from_numpy(img)