Search code examples
pythontensorflowkeraspytorch

TF depth_to_space not same as Torch's PixelShuffle when output channels > 1?


I noticed something interesting when trying to port a torch-trained model to tensorflow (TF). When the output channels of a PixelShuffle operation are greater than one, the corresponding depth_to_space function in TF is not equivalent (Note: I convert the input to TF to NHWC and the output back to NCHW). I was wondering whether this expected behavior OR there is a misunderstanding?

Specifically,

# Torch
torch_out = nn.PixelShuffle(2)(input)

and

# TF/Keras
input = np.transpose(input, (0, 2, 3, 1)) // Convert to NHWC
keras_input = keras.layers.Input(shape=input.shape[1:])
keras_d2s = keras.layers.Lambda(lambda x: tf.nn.depth_to_space(x, 2))(input)
...
keras_out = np.transpose(keras_d2s, (0, 3, 1, 2)) // Convert back to NCHW

and

keras_out != torch_out

Here is a testbench:

import numpy as np

import torch
import tensorflow as tf

from torch import nn
from tensorflow import keras

class Shuffle(nn.Module):
    def __init__(self, s, k, ic):
        super(Shuffle, self).__init__()
        self.shuffle = nn.PixelShuffle(s)

    def forward(self, inputs):
        return self.shuffle(inputs)

def main():
    sz = 4

    h = 3
    w = 3
    k = 3
    ic = 8
    s = 2

    input = np.arange(0, ic * h * w, dtype=np.float32).reshape(ic, h, w)
    input = input[np.newaxis]
    torch_input = torch.from_numpy(input)

    shuffle_model = Shuffle(s, k, ic)
    shuffle_out = shuffle_model(torch_input).detach().numpy()
    print('Shuffle out:', shuffle_out.shape)
    print(shuffle_out)

    input = np.transpose(input, (0, 2, 3, 1))

    keras_input = keras.layers.Input(shape=input.shape[1:])
    keras_d2s = keras.layers.Lambda(lambda x: tf.nn.depth_to_space(x, s))(keras_input)
    keras_model = keras.Model(keras_input, keras_d2s)
    keras_out = keras_model.predict(input)
    
    keras_out = np.transpose(keras_out, (0, 3, 1, 2))
    
    print('Keras out:', keras_out.shape)
    print(keras_out)
    equal = np.allclose(shuffle_out, keras_out)
    print('Equal?', equal)

if __name__ == '__main__':
    main()

Solution

  • They are indeed different. If you want them to match you need to shuffle the channels of one of the inputs. Or if the pixelshuffle/depth_to_space layer follows a convolution layer you can shuffle the channels of the weights of the convolution. Specifically, if oc is the number of output channels and s is the block_size then you need to permute the channels of the convolution's weights in TF using [i + oc * j for i in range(oc) for j in range(s ** 2)] (yields something like [0, 2, 4, 1, 3, 5]).