I noticed something interesting when trying to port a torch-trained model to tensorflow (TF). When the output channels of a PixelShuffle operation are greater than one, the corresponding depth_to_space function in TF is not equivalent (Note: I convert the input to TF to NHWC and the output back to NCHW). I was wondering whether this expected behavior OR there is a misunderstanding?
Specifically,
# Torch
torch_out = nn.PixelShuffle(2)(input)
and
# TF/Keras
input = np.transpose(input, (0, 2, 3, 1)) // Convert to NHWC
keras_input = keras.layers.Input(shape=input.shape[1:])
keras_d2s = keras.layers.Lambda(lambda x: tf.nn.depth_to_space(x, 2))(input)
...
keras_out = np.transpose(keras_d2s, (0, 3, 1, 2)) // Convert back to NCHW
and
keras_out != torch_out
Here is a testbench:
import numpy as np
import torch
import tensorflow as tf
from torch import nn
from tensorflow import keras
class Shuffle(nn.Module):
def __init__(self, s, k, ic):
super(Shuffle, self).__init__()
self.shuffle = nn.PixelShuffle(s)
def forward(self, inputs):
return self.shuffle(inputs)
def main():
sz = 4
h = 3
w = 3
k = 3
ic = 8
s = 2
input = np.arange(0, ic * h * w, dtype=np.float32).reshape(ic, h, w)
input = input[np.newaxis]
torch_input = torch.from_numpy(input)
shuffle_model = Shuffle(s, k, ic)
shuffle_out = shuffle_model(torch_input).detach().numpy()
print('Shuffle out:', shuffle_out.shape)
print(shuffle_out)
input = np.transpose(input, (0, 2, 3, 1))
keras_input = keras.layers.Input(shape=input.shape[1:])
keras_d2s = keras.layers.Lambda(lambda x: tf.nn.depth_to_space(x, s))(keras_input)
keras_model = keras.Model(keras_input, keras_d2s)
keras_out = keras_model.predict(input)
keras_out = np.transpose(keras_out, (0, 3, 1, 2))
print('Keras out:', keras_out.shape)
print(keras_out)
equal = np.allclose(shuffle_out, keras_out)
print('Equal?', equal)
if __name__ == '__main__':
main()
They are indeed different. If you want them to match you need to shuffle the channels of one of the inputs. Or if the pixelshuffle/depth_to_space layer follows a convolution layer you can shuffle the channels of the weights of the convolution. Specifically, if oc
is the number of output channels and s
is the block_size
then you need to permute the channels of the convolution's weights in TF using [i + oc * j for i in range(oc) for j in range(s ** 2)]
(yields something like [0, 2, 4, 1, 3, 5]).