Search code examples
pythonpytorchfastapistarlette

How do I convert a torch tensor to an image to be returned by FastAPI?


I have a torch tensor which I need to convert to a byte object so that I can pass it to starlette's StreamingResponse which will return a reconstructed image from the byte object. I am trying to convert the tensor and return it like so:

def some_unimportant_function(params):
    return_image = io.BytesIO()
    torch.save(some_img, return_image)
    return_image.seek(0)
    return_img = return_image.read()
    
    return StreamingResponse(content=return_img, media_type="image/jpeg")

The below works fine on regular byte objects and my API returns the reconstructed image:

def some_unimportant_function(params):
    image = Image.open(io.BytesIO(some_image))

    return_image = io.BytesIO()
    image.save(return_image, "JPEG")
    return_image.seek(0)
    return StreamingResponse(content=return_image, media_type="image/jpeg")

Using PIL library for this

what am I doing wrong here?


Solution

  • Converting PyTorch Tensor to the PIL Image object using torchvision.transforms.ToPILImage() module and then treating it as PIL Image as your second function would work. Here is an example.

    def some_unimportant_function(params):
        tensor = # read the tensor from disk or whatever
        image = torchvision.transforms.ToPILImage()(tensor.unsqueeze(0))
        return_image = io.BytesIO()
        image.save(return_image, "JPEG")
        return_image.seek(0)
        return StreamingResponse(content=return_image, media_type="image/jpeg")