Search code examples
tensorflowkeraspytorchmodel

Keras compatible code for shrink and perturb deep model training


I am referring to this study https://proceedings.neurips.cc/paper/2020/file/288cd2567953f06e460a33951f55daaf-Paper.pdf "On Warm-Starting Neural Network Training". Here, the authors propose a shrink and perturb technique to retrain the models on new-arriving data. In warm restart, the models are initialized with their previously trained weights on old data and are retrained on the new data. In the proposed technique, the weights and biases of the existing model are shrunk towards zero and then added with random noise. To shrink a weight, it is multiplied by a value that's between 0 and 1, typically about 0.5. Their official pytorch code is available at https://github.com/JordanAsh/warm_start/blob/main/run.py. A simple explanation of this study is given at https://pureai.com/articles/2021/02/01/warm-start-ml.aspx where the writer gives a simple pytorch function to perform shrink and perturbation of the existing model as shown below:

def shrink_perturb(model, lamda=0.5, sigma=0.01):
  for (name, param) in model.named_parameters():
    if 'weight' in name:   # just weights
      nc = param.shape[0]  # cols
      nr = param.shape[1]  # rows
      for i in range(nr):
        for j in range(nc):
          param.data[j][i] = \
            (lamda * param.data[j][i]) + \
            T.normal(0.0, sigma, size=(1,1))
  return

With the defined function, a prediction model can be initialized with the shrink-perturb technique using code like this:

net = Net().to(device)
fn = ".\\Models\\employee_model_first_100.pth"
net.load_state_dict(T.load(fn))
shrink_perturb(net, lamda=0.5, sigma=0.01)
# now train net as usual

Is there a Keras compatible version of this function definition where we can shrink the weights and add random gaussian noise to an existing model like this?

model = load_model('weights/model.h5')
model.summary()
shrunk_model = shrink_perturn(model,lamda=0.5,sigma=0.01)
shrunk_model.summary()

Solution

  • maybe something like this:

    ws = [w * 0.5 + tf.random.normal(w.shape) for w in model.get_weights()]
    model.set_weights(ws)