Search code examples
pythontensorflowkeraskeras-layer

How to prevent backpropagation after compiling a keras model?


I have a multi output model such as this

       input
         |
      hidden
         |
        / \
       /   \
output1    output2

I can train this model by model.train_on_batch(input=input,output=[output1,output2]) but at some particulat stage in my training I want to train only one branch (output2) of this model and prevent backpropogation from output1. I initally tried passing a None value in the model model.train_on_batch(input=input,output=[None,output2]) but its showing

AttributeError: 'NoneType' object has no attribute 'shape'

Then I tried passing a NaN arrays of output1 shape model.train_on_batch(input=input,output=[Nan_array,output2]) then the loss becomes NaN. How can I train only one branch in a multi output keras model and prevent backpropogation in the other?

Edit

I was trying to find a solution to this problem and came across K.stop_gradient function. I tried to stop backpropogaion in a one output model like this

import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Dropout
import keras.backend as K

def loss(y_true, y_pred):
    return K.stop_gradient(y_pred)
    
# Generate dummy data
x_train = np.random.random((10, 20))
y_train = np.random.randint(2, size=(10, 1))
x_test = np.random.random((10, 20))
y_test = np.random.randint(2, size=(10, 1))

model = Sequential()
model.add(Dense(64, input_dim=20, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))

model.compile(loss=loss,
              optimizer='rmsprop',
              metrics=['accuracy'])

model.fit(x_train, y_train,
          epochs=1,
          batch_size=128)

score = model.evaluate(x_test, y_test, batch_size=128)

But gets this error

ValueError: Tried to convert 'x' to a tensor and failed. Error: None values not supported.


Solution

  • You can create two Model objects with sharing weights. The first model optimizes on output=[output1, output2], while the second model only contains the branch of output2. If you call train_on_batch on the second model, the weights in branch 1 will not be updated.

    For example,

    x = Input(shape=(32,))
    hidden = Dense(32)(x)
    output1 = Dense(1)(hidden)
    output2 = Dense(1)(hidden)
    
    model = Model(x, [output1, output2])
    model.compile(loss='mse', optimizer='adam')
    
    model_only2 = Model(x, output2)
    model_only2.compile(loss='mse', optimizer='adam')
    
    X = np.random.rand(2, 32)
    y1 = np.random.rand(2)
    y2 = np.random.rand(2)
    
    # verify: all the weights will change if we train on `model`
    w0 = model.get_weights()
    model.train_on_batch(X, [y1, y2])
    w1 = model.get_weights()
    print([np.allclose(x, y) for x, y in zip(w0, w1)])
    # => [False, False, False, False, False, False]
    
    # verify: branch 1 will not change if we train on `model_only2`
    model_only2.train_on_batch(X, y2)
    w2 = model.get_weights()
    print([np.allclose(x, y) for x, y in zip(w1, w2)])
    # => [False, False, True, True, False, False]