Search code examples
pytorchhuggingface

Model inference using batch


I am trying to get an inference of an image using the Pix2struct vision transformer model. Currently, I am generating one inference at a time and the code I am using is in below.

processor = Pix2StructProcessor.from_pretrained(
    "google/deplot", is_vqa=True
)
model = Pix2StructForConditionalGeneration.from_pretrained(
    "google/deplot", is_vqa=True
).to(device)

with open('./data/test_imgs/test.png', "rb") as f:
    image = Image.open(f).convert("RGB")

    inputs = processor(
        images=image,
        text="Generate underlying data table of the figure below:",
        return_tensors="pt",
    ).to(device)
    predictions = model.generate(**inputs, max_new_tokens=512)
    deplot_result = processor.decode(
            predictions[0], skip_special_tokens=True
    )
    
    print(deplot_result)

However, the inference time for this method is ~45 secs/image, which is not viable for our project. Is there a way to convert this code into using batches so that I can generate multiple predictions at the same time?


Solution

  • I wrote and tested the following code and it works. The batch consists of two images, you should be able to generalize it and iterate over the entire dataset. I suggest implementing a pythorch dataset, but I don't know if this is possible in your case.

    If you want to test the code, remember to change the names of the images in the batch.

    from PIL import Image
    import torch
    from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Initialize the processor and model
    processor = Pix2StructProcessor.from_pretrained("google/deplot", is_vqa=True)
    model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", is_vqa=True).to(device)
    
    # load images in a list
    image_file_paths = ['./data/cat.jpg', './data/dog.jpg']
    images = [Image.open(f).convert("RGB") for f in image_file_paths]
    
    # Tokenize and prepare the inputs as a batch
    inputs = processor(
        images=images,
        text=["Generate underlying data table of the figure below:"] * len(image_file_paths),  # Use the same text for all images
        return_tensors="pt",
    ).to(device)
    
    # Generate predictions for the batch and decode them
    predictions = model.generate(**inputs)
    for i, prediction in enumerate(predictions):
        deplot_result = processor.decode(prediction, skip_special_tokens=True)
        print(f"Result for image {i+1}: {deplot_result}")
    

    I tested on this to image: img_1 and img_2 and I got the following outputs:

    Result for image 1: TITLE |  <0x0A>  | Frequency<0x0A>(MHz) <0x0A> Unnamed: 1
    Result for image 2: Entity | Value <0x0A> rv | 4 <0x0A> t | 2 <0x0A>