Search code examples
pytorch

Merge images by mask


I'm trying to merge two images according to the values in a mask where at all points where the mask is 1, the resulting image has the values of the first image, and otherwise it has values of the second image. Does anyone know how it can be achieved in pytorch? Using numpy, it can be achieved using

>>> import numpy as np
>>> img1 = np.random.rand(100,100,3)
>>> img2 = np.random.rand(100,100,3)
>>> mask = np.random.rand(100,100)>.5
>>> res = img2.copy()
>>> res[mask] = img1[mask]

Solution

  • You are looking for np.where:

    res = np.where(mask, img1, img2)