Search code examples
tensorflowkerasreshapeflatten

Flatten alongside with batch axis in TensorFlow / Keras


In a Sequential model, I'm trying to go from a layer output shape of (None, 300) to something like (1,1,None*300) to apply an AveragePooling layer. In fact I would like to flatten everything (even the batch axis), while both Flatten and Reshape layers always skip the batch axis. Any idea?


Solution

  • You can use a Lambda layer and the K.reshape from backend like this:

    from keras import backend as K
    
    out = Lambda(lambda x: K.reshape(x, (1, 1, -1)))(inp)