Search code examples
tensorflowkerasconv-neural-networkkeras-layer

Controlling information flow and gating factor in Keras layers


Given a CNN architecture (architecture image) where the information flow from one layer to another is controlled by a gating factor. The fraction 'g' of information is sent to the immediate next layer and remaining '1-g' is sent to one of the forward layers (like a skip connection)

How to implement such an architecture in Keras? Thanks in advance


Solution

  • Use the functional API Model.

    For gates (automatic fraction g):

    from keras.models import Model
    from keras.layers import *
    
    inputTensor = Input(someInputShape)
    
    #the actual value
    valueTensor = CreateSomeLayer(parameters)(inputTensor)
    
    #the gate - this is the value of 'g', from zero to 1
    gateTensor = AnotherLayer(matchingParameters, activation='sigmoid')(inputTensor)
    
    #value * gate = fraction g
    fractionG = Lambda(lambda x: x[0]*x[1])([valueTensor,gateTensor])
    
    #value - fraction = 1 - g
    complement = Lambda(lambda x: x[0] - x[1])([valueTensor,fractionG])
    
    #each tensor may go into individual layers and follow individual paths:
    immediateNextOutput = ImmediateNextLayer(params)(fractionG)
    oneOfTheForwardOutputs = OneOfTheForwardLayers(params)(complement)
    
    #keep going, make one or more outputs, and create your model:
    model = Model(inputs=inputTensor, outputs=outputTensorOrListOfOutputTensors)    
    

    For giving two inputs to the same layer, concatenate, sum, multiply, etc., in order to make them one.

    #concat
    joinedTensor = Concatenate(axis=optionalAxis)([input1,input2])
    
    #add
    joinedTensor = Add()([input1,input2])
    
    #etc.....
    
    nextLayerOut = TheLayer(parameters)(joinedTensor)
    

    If you want to control 'g' manually:

    In this case, all we have to do is replace the gateTensor by a user defined one:

    import keras.backend as K
    
    gateTensor = Input(tensor=K.variable([g]))
    

    Pass this tensor as an input when creating the model. (Since it's a tensor input, it won't change the way you use the fit methods).

    model = Model(inputs=[inputTensor,gateTensor], outputs=outputTensorOrListOfOutputTensors)