Search code examples
pythonkerasone-hot-encoding

Depth-wise concatenate in Keras


I'm trying to depth-wise concat (example of implementation in StarGAN using Pytorch) a one-hot vector into an image input, say

input_img = Input(shape = (row, col, chann))
one_hot = Input(shape = (7, ))

I stumbled on the same problem before (it was class indexes), and so I used RepeatVector+Reshape then Concatenate. But I found RepeatVector is not compatible when you want to repeat 3D into 4D (included batch_num).

How do I implement this method in Keras? I found that Upsampling2D could do the works, but I don't know if it able to keep the one-hot vector structure during upsampling process


Solution

  • I found an idea from How to use tile function in Keras? that you can use tile, but you need to reshape your one_hot to have the same number of dimensions with input_img

    one_hot = Reshape((1, 1, 6))(one_hot)
    one_hot = Lambda(K.tile, arguments = {'n' : (-1, row, col, 1)})(one_hot)
    model_input = Concatenate()([input_img, one_hot])