When using powerful hardware, especially TPU, it is often preferable to train multiple steps. For example, in TensorFlow, this is possible.
with strategy.scope():
model = create_model()
optimizer_inner = AdamW(weight_decay=1e-6)
optimizer_middle = SWA(optimizer_inner)
optimizer = Lookahead(optimizer_middle)
training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'training_accuracy', dtype=tf.float32)
# Calculate per replica batch size, and distribute the `tf.data.Dataset`s
# on each TPU worker.
actual_batch_size = 128
gradient_accumulation_step = 1
batch_size = actual_batch_size * gradient_accumulation_step
steps_per_epoch = 60000 // batch_size
validation_steps = 10000 // batch_size
train_dataset = get_dataset(batch_size, is_training=True)
per_replica_batch_size = batch_size // strategy.num_replicas_in_sync
train_dataset = strategy.experimental_distribute_datasets_from_function(
lambda _: get_dataset(per_replica_batch_size, is_training=True))
@tf.function(jit_compile=True)
def train_multiple_steps(iterator, steps):
"""The step function for one training step."""
def step_fn(inputs):
"""The computation to run on each TPU device."""
images, labels = inputs
with tf.GradientTape() as tape:
logits = model(images, training=True)
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits, from_logits=True)
loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
training_loss.update_state(loss * strategy.num_replicas_in_sync)
training_accuracy.update_state(labels, logits)
for _ in tf.range(steps):
strategy.run(step_fn, args=(next(iterator),))
train_iterator = iter(train_dataset)
# Convert `steps_per_epoch` to `tf.Tensor` so the `tf.function` won't get
# retraced if the value changes.
for epoch in range(10):
print('Epoch: {}/10'.format(epoch))
train_multiple_steps(train_iterator, tf.convert_to_tensor(steps_per_epoch))
In Jax or Flax, however, I haven't seen a complete working example of doing so. I guess it would be something like
@jax.jit
def train_for_n_steps(train_state, batches):
for batch in batches:
train_state = train_step_fn(train_state, batch)
return train_state
However, in my case when I am trying to test the complete example, I am not sure how one can create multiple batches. Here is a working example using GPU without training multiple steps. The relevant code should probably be here:
for step,batch in enumerate(train_ds.as_numpy_iterator()):
# Run optimization steps over training batches and compute batch metrics
state = train_step(state, batch) # get updated train state (which contains the updated parameters)
state = compute_metrics(state=state, batch=batch) # aggregate batch metrics
if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed
for metric,value in state.metrics.compute().items(): # compute metrics
metrics_history[f'train_{metric}'].append(value) # record metrics
state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch
# Compute metrics on the test set after each training epoch
test_state = state
for test_batch in test_ds.as_numpy_iterator():
test_state = compute_metrics(state=test_state, batch=test_batch)
for metric,value in test_state.metrics.compute().items():
metrics_history[f'test_{metric}'].append(value)
print(f"train epoch: {(step+1) // num_steps_per_epoch}, "
f"loss: {metrics_history['train_loss'][-1]}, "
f"accuracy: {metrics_history['train_accuracy'][-1] * 100}")
print(f"test epoch: {(step+1) // num_steps_per_epoch}, "
f"loss: {metrics_history['test_loss'][-1]}, "
f"accuracy: {metrics_history['test_accuracy'][-1] * 100}")
My goal is to unroll 5 loops when training.
Any suggestions are welcomed.
You could use more_itertools.chunked
to get something like this:
for step, five_batches in chunked(train_ds.as_numpy_iterator()):
state = five_steps(state, five_batches):
Then do the unrolling
@jax.jit
def five_steps(state, batches):
for batch in batches:
state = train_step(state, batch)
return state
The reason this works is that batches
has a length that isn't data dependent, so the loop will just get executed 5 times during tracing.
This will likely make jitting take much longer than you want, so the perferred but more difficult way is to pack the batches into [N x batch_size x ...]
tensors, then use scan to loop your update function over the inputs.