I am working with a very memory demanding CNN model for a task of classification. This poses a big limit on the batch size that I can use during training.
One solution is to accumulate the gradients during training, meaning that the weights of the model are not updated after every single batch. Instead the same weights are used for several batches, while the gradients from each batch are accumulated and than averaged for a single weight-update action.
I'm using a Tensorflow backend Keras and I'm pretty sure that Keras has no off-the-shelf function/method to achieve this.
How can it be done for a Keras/tensorflow model?
As was mentioned in the question, there is no off-the-shelf function/method to achieve this with Keras/Tensorflow. However this can be done by writing a custom optimizer for Keras.
The main idea is to use a flag to determine whether to update the weights during each batch.
The following implementation is based on this github post by "alexeydevederkin" and it is an accumulating Adam optimizer:
import keras.backend as K
from keras.legacy import interfaces
from keras.optimizers import Optimizer
class AdamAccumulate(Optimizer):
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
epsilon=None, decay=0., amsgrad=False, accum_iters=1, **kwargs):
if accum_iters < 1:
raise ValueError('accum_iters must be >= 1')
super(AdamAccumulate, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.lr = K.variable(lr, name='lr')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.decay = K.variable(decay, name='decay')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
self.amsgrad = amsgrad
self.accum_iters = K.variable(accum_iters, K.dtype(self.iterations))
self.accum_iters_float = K.cast(self.accum_iters, K.floatx())
@interfaces.legacy_get_updates_support
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
lr = self.lr
completed_updates = K.cast(K.tf.floordiv(self.iterations, self.accum_iters), K.floatx())
if self.initial_decay > 0:
lr = lr * (1. / (1. + self.decay * completed_updates))
t = completed_updates + 1
lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t)))
# self.iterations incremented after processing a batch
# batch: 1 2 3 4 5 6 7 8 9
# self.iterations: 0 1 2 3 4 5 6 7 8
# update_switch = 1: x x (if accum_iters=4)
update_switch = K.equal((self.iterations + 1) % self.accum_iters, 0)
update_switch = K.cast(update_switch, K.floatx())
ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
gs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
if self.amsgrad:
vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
else:
vhats = [K.zeros(1) for _ in params]
self.weights = [self.iterations] + ms + vs + vhats
for p, g, m, v, vhat, tg in zip(params, grads, ms, vs, vhats, gs):
sum_grad = tg + g
avg_grad = sum_grad / self.accum_iters_float
m_t = (self.beta_1 * m) + (1. - self.beta_1) * avg_grad
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(avg_grad)
if self.amsgrad:
vhat_t = K.maximum(vhat, v_t)
p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
self.updates.append(K.update(vhat, (1 - update_switch) * vhat + update_switch * vhat_t))
else:
p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
self.updates.append(K.update(m, (1 - update_switch) * m + update_switch * m_t))
self.updates.append(K.update(v, (1 - update_switch) * v + update_switch * v_t))
self.updates.append(K.update(tg, (1 - update_switch) * sum_grad))
new_p = p_t
# Apply constraints.
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
self.updates.append(K.update(p, (1 - update_switch) * p + update_switch * new_p))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'decay': float(K.get_value(self.decay)),
'epsilon': self.epsilon,
'amsgrad': self.amsgrad}
base_config = super(AdamAccumulate, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
It can be used in the following way:
opt = AdamAccumulate(lr=0.001, decay=1e-5, accum_iters=5)
model.compile( loss='categorical_crossentropy', # Loss function
optimizer=opt, # Optimization technique
metrics=['accuracy']) # Accuracy matrix
model.fit(X_train, y_train, batch_size = 10)
In this example, the model processes 10 samples in every iteration ("batch_size"), but the update to the weights only happens after accumulating 5 such batches ("accum_iters"). So the actual batch size for updating the weights is 50.