Search code examples
pythonpytorchdataloader

How to change DataLoader in PyTorch to read one image for prediction?


Currently, I have a pre-trained model that uses a DataLoader for reading a batch of images for training the model.

self.data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, 
   num_workers=1, pin_memory=True)

...

model.eval()
for step, inputs in enumerate(test_loader.data_loader):
   outputs = model(torch.cat([inputs], 1))

...

I want to process (make predictions) on images, as they arrive from a queue. It should be similar to a code that reads a single image and runs the model to make predictions on it. Something along the following lines:

from PIL import Image

new_input = Image.open(image_path)
model.eval()
outputs = model(torch.cat([new_input ], 1))

I was wondering if you could guide me how to do this and apply the same transformations in the DataLoader.


Solution

  • You can use do it with IterableDataset :

    from torch.utils.data import IterableDataset
    
    class MyDataset(IterableDataset):
        def __init__(self, image_queue):
          self.queue = image_queue
    
        def read_next_image(self):
            while self.queue.qsize() > 0:
                # you can add transform here
                yield self.queue.get()
            return None
    
        def __iter__(self):
            return self.read_next_image()
    

    and batch_size = 1 :

    import queue
    import torchvision.transforms.functional as TF
    
    buffer = queue.Queue()
    new_input = Image.open(image_path)
    buffer.put(TF.to_tensor(new_input)) 
    # ... Populate queue here
    
    dataset = MyDataset(buffer)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
    for data in dataloader:
       model(data) # data is one-image batch of size [1,3,H,W] where 3 - number of color channels