Search code examples
pythonpytorchhuggingface-transformers

How to display the reconstructed image from huggingface ViTMAEModel?


I am using the following code example:

Using the autoencoder, I want to display the recontracted image. How to display it?

from transformers import AutoImageProcessor, ViTMAEModel
from PIL import Image
import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
model = ViTMAEModel.from_pretrained("facebook/vit-mae-base")

inputs = image_processor(images=image, return_tensors="pt")
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state

Solution

  • I encountered the same issue. According to the official doc of ViTMAE, please have a look at ViT_MAE_visualization_demo.ipynb.

    import torch
    import numpy as np
    import matplotlib.pyplot as plt
    
    from transformers import ViTFeatureExtractor, ViTMAEForPreTraining
    import requests
    from PIL import Image
    
    feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
    imagenet_mean = np.array(feature_extractor.image_mean)
    imagenet_std = np.array(feature_extractor.image_std)
    
    def show_image(image, title=''):
        # image is [H, W, 3]
        assert image.shape[2] == 3
        plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
        plt.title(title, fontsize=16)
        plt.axis('off')
        return
    
    def visualize(pixel_values, model):
        # forward pass
        outputs = model(pixel_values)
        y = model.unpatchify(outputs.logits)
        y = torch.einsum('nchw->nhwc', y).detach().cpu()
        
        # visualize the mask
        mask = outputs.mask.detach()
        mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size**2 *3)  # (N, H*W, p*p*3)
        mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
        mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
        
        x = torch.einsum('nchw->nhwc', pixel_values)
    
        # masked image
        im_masked = x * (1 - mask)
    
        # MAE reconstruction pasted with visible patches
        im_paste = x * (1 - mask) + y * mask
    
        # make the plt figure larger
        plt.rcParams['figure.figsize'] = [24, 24]
    
        plt.subplot(1, 4, 1)
        show_image(x[0], "original")
    
        plt.subplot(1, 4, 2)
        show_image(im_masked[0], "masked")
    
        plt.subplot(1, 4, 3)
        show_image(y[0], "reconstruction")
    
        plt.subplot(1, 4, 4)
        show_image(im_paste[0], "reconstruction + visible")
    
        plt.show()
    
    
    url = "https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    
    pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
    
    # make random mask reproducible (comment out to make it change)
    torch.manual_seed(2)
    
    model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
    
    visualize(pixel_values, model)