Search code examples
pythonopencvtorch

Torch cyclegan model loaded using opencv doesn't output desired image


I'm trying to load a pre trained torch (.t7) model in OpenCV. The model is a CycleGAN which converts horse images into zebra images. The model can be found here: https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/models/

I've used both horse2zebra.t7 and horse2zebra_cpu.t7 models but they both return a tiled black and white image instead of a zebra image.

This is a sample input image:

enter image description here

And this is the output:

enter image description here

Code:

import cv2
import numpy as np

model = cv2.dnn.readNetFromTorch('./cyclegan_horse2zebra_cpu.t7')

image = cv2.imread('./images/1.jpg')

blob = cv2.dnn.blobFromImage(image, 1, (256, 256))
model.setInput(blob)

out = model.forward()[0,:,:,:]
out = np.reshape(out, (256, 256, 3))

cv2.imshow('image', out)
cv2.waitKey(0)

cv2.imwrite('out.png', out)

Solution

  • There are two things missed. First one is that model.forward()[0,:,:,:] returns 3x256x256 3D blob (planar image, with channels at the second dimension). However OpenCV works with interleaved images so you need to permute the dimensions by np.transpose: 256x256x3.

    Output values range are also important. For the image from question, I got values from [-0.832621, 0.891473]. We need to normalize it to [0, 255] of type Uint8. all it one are made by cv::normalize.

    import cv2
    import numpy as np
    
    model = cv2.dnn.readNetFromTorch('./horse2zebra_cpu.t7')
    
    image = cv2.imread('./images/1.jpg')
    
    blob = cv2.dnn.blobFromImage(image, 1, (256, 256))
    model.setInput(blob)
    
    out = model.forward()[0,:,:,:]
    out = cv2.normalize(out, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
    
    out = np.transpose(out, (1, 2, 0))
    
    cv2.imshow('image', out)
    cv2.waitKey(0)
    
    cv2.imwrite('out.png', out)
    

    Please note that it might be possible that origin framework has different postprocessing procedures. It'd be nice if you could compare outputs from OpenCV and Torch. This is an image I got from script above:

    enter image description here