Search code examples
pythontensorflowkeraspytorchdimensions

Keras Upsampling2d vs PyTorch Upsampling


I am trying to convert a Keras Model to PyTorch. Now, it involves the UpSampling2D from keras. When I used torch.nn.UpsamplingNearest2d in pytorch, as default value of UpSampling2D in keras is nearest, I got different inconsistent results. The example is as follows:

Keras behaviour

In [3]: t1 = tf.random_normal([32, 8, 8, 512]) # as we have channels last in keras                                  

In [4]: u_s = tf.keras.layers.UpSampling2D(2)(t1)                               

In [5]: u_s.shape                                                               
Out[5]: TensorShape([Dimension(32), Dimension(16), Dimension(16), Dimension(512)])

So the output shape is (32,16,16,512). Now let's do the same thing with PyTorch.

PyTorch Behaviour

In [2]: t1 = torch.randn([32,512,8,8]) # as channels first in pytorch

In [3]: u_s = torch.nn.UpsamplingNearest2d(2)(t1)

In [4]: u_s.shape
Out[4]: torch.Size([32, 512, 2, 2])

Here output shape is (32,512,2,2) as compared to (32,512,16,16) from keras.

So how do I get equvivlent results of Keras in PyTorch. Thanks


Solution

  • In keras, it uses a scaling factor to upsample. SOURCE.

    tf.keras.layers.UpSampling2D(size, interpolation='nearest')
    

    size: Int, or tuple of 2 integers. The upsampling factors for rows and columns.

    And, PyTorch provides, both, direct output size and scaling factor. SOURCE.

    torch.nn.UpsamplingNearest2d(size=None, scale_factor=None)
    

    To specify the scale, it takes either the size or the scale_factor as its constructor argument.


    So, in your case

    # scaling factor in keras 
    t1 = tf.random.normal([32, 8, 8, 512])
    tf.keras.layers.UpSampling2D(2)(t1).shape
    TensorShape([32, 16, 16, 512])
    
    # direct output size in pytorch 
    t1 = torch.randn([32,512,8,8]) # as channels first in pytorch
    torch.nn.UpsamplingNearest2d(size=(16, 16))(t1).shape
    # or torch.nn.UpsamplingNearest2d(size=16)(t1).shape
    torch.Size([32, 512, 16, 16])
    
    # scaling factor in pytorch.
    torch.nn.UpsamplingNearest2d(scale_factor=2)(t1).shape
    torch.Size([32, 512, 16, 16])