Search code examples
pythonjaxflax

Using Orbax to checkpoint flax `TrainState` with new `CheckpointManager` API


Context

The Flax docs describe how to checkpoint a flax.training.train_state.TrainState with orbax. In a nutshell, you set up a orbax.checkpoint.CheckpointManager which keeps track of checkpoints. Next, you use the CheckpointManager to save the state to disk. Summarising the code snippets from the Flax docs:

import orbax

# <-- Code building an empty and a full chkpt. -->.
abstract_chkpt = ...
chkpt = ...

orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)

options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=2, create=True)
checkpoint_manager = orbax.checkpoint.CheckpointManager(
    '/tmp/flax_ckpt/orbax/managed', orbax_checkpointer, options)

# Save and restore a checkpoint.
checkpoint_manager.save(step, ckpt, save_kwargs={'save_args': save_args})
checkpoint_manager.restore(1, items=abstract_ckpt)

The notebook provided by the Flax docs does what I want: periodically track TrainState, which can then be restored. However, when executing the code provided by the Flax docs warn that this orbax checkpoint API is deprecated:

WARNING:absl:Configured CheckpointManager using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.

The link indicated by the error message gives some pointers how to use the new orbax.checkpoint.CheckpointManager.

Question

How do I save and restore a Flax TrainState with the new orbax.checkpoint.CheckpointManager API?

Here is my failed attempt (based on the Orbax migration instructions) at saving and restoring a trivial flax.training.train_state.TrainState:

import orbax.checkpoint as obc
from flax.training.train_state import TrainState

abstract_ckpt = TrainState(step=0, apply_fn=lambda _: None, params={}, tx={}, opt_state={})
ckpt = abstract_ckpt.replace(step=1)

# Set up the checkpointer.
options = obc.CheckpointManagerOptions(max_to_keep=2, create=True)
checkpoint_dir = obc.test_utils.create_empty('/tmp/checkpoint_manager')
checkpoint_manager = obc.CheckpointManager(checkpoint_dir, options=options)
save_args = obc.args.StandardSave(abstract_ckpt)

# Do actual checkpointing.
checkpoint_manager.save(1, ckpt, args=save_args)

# Restore checkpoint.
restore_args = obc.args.StandardRestore(abstract_ckpt)
restored_ckpt = checkpoint_manager.restore(1, args=restore_args)

# Verify if it is correctly restored.
assert ckpt.step == restored_ckpt.step  # AssertionError

My guess would be that the problem relates to save_args, but I haven't managed to pinpoint the problem and figure out a fix. Any suggestions how to correctly restore the checkpoint using the new CheckpointManager API?


Solution

  • You created save_args = ocp.args.StandardSave(abstract_ckpt) instead of save_args = ocp.args.StandardSave(ckpt), so you're just saving the wrong thing.

    Also note that checkpoint_dir = ocp.test_utils.create_empty('/tmp/checkpoint_manager') is a bit unnecessary - it's just a test utility for deleting a directory if it already exists - makes running our colabs a bit easier. Probably you shouldn't need to use it in real life, as the create option in CheckpointManager will create the directory for you.