Search code examples
pythontensorflowmachine-learningkeraspytorch

Tensorflow version of Pytorch Transforms


I have the following code that I use to prepare images before performing inference in a model:

def image_loader(transform, image_name):
    image = Image.open(image_name)
    #transform
    image = transform(image).float()
    image = torch.tensor(image)
    image = image.unsqueeze(0)
    return image

data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

I've converted the model into a Tensorflow model, however, I'm unsure how I would do similar transformations to images before inference since there doesn't seem to be a or equivalent. Any advice?


Solution

  • Here is some pointer, in you have

    from torchvision import transforms
    from PIL import Image 
    import torch 
    
    def image_loader(transform, image_name):
        image = Image.open(image_name).convert('RGB')
        image = transform(image).float()
        image = torch.tensor(image)
        image = image.unsqueeze(0)
        return image
    
    data_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # check: visualize 
    i = image_loader(data_transforms, '/content/1.png')
    i.shape
    
    plt.figure(figsize=(25,10))
    subplot(121); imshow(np.array(i[0]).transpose(1, 2, 0)); 
    

    And in , you can achieve this as follows

    def transform(image, mean, std):
        for channel in range(3):
            image[:, :, channel] = (image[:, :, channel] - mean[channel]) \
                / std[channel]
        return image
    
    
    def image_loader(image_name):
        image = Image.open(image_name).convert('RGB')
        image = transform(np.array(image) / 255, mean=[0.485, 0.456,
                          0.406], std=[0.229, 0.224, 0.225])
        image = tf.cast(image, tf.float32)
        image = tf.expand_dims(image, 0)
        return image
    
    # check: visualize 
    i = image_loader('/content/1.png')
    i.shape 
    
    plt.figure(figsize=(25,10))
    subplot(121); imshow(i[0]); 
    

    This should output the same. Note, in the second case, we define the transform function, from another OP, here, it's fine, however, you can also check tf. keras...Normalization, see this answer for details.