I'm trying to depth-wise concat (example of implementation in StarGAN using Pytorch) a one-hot vector into an image input, say
input_img = Input(shape = (row, col, chann))
one_hot = Input(shape = (7, ))
I stumbled on the same problem before (it was class indexes), and so I used RepeatVector+Reshape then Concatenate. But I found RepeatVector is not compatible when you want to repeat 3D into 4D (included batch_num
).
How do I implement this method in Keras? I found that Upsampling2D
could do the works, but I don't know if it able to keep the one-hot vector structure during upsampling process
I found an idea from How to use tile function in Keras? that you can use tile
, but you need to reshape your one_hot
to have the same number of dimensions with input_img
one_hot = Reshape((1, 1, 6))(one_hot)
one_hot = Lambda(K.tile, arguments = {'n' : (-1, row, col, 1)})(one_hot)
model_input = Concatenate()([input_img, one_hot])