Context
The Flax docs describe how to checkpoint a flax.training.train_state.TrainState
with orbax. In a nutshell, you set up a orbax.checkpoint.CheckpointManager
which keeps track of checkpoints. Next, you use the CheckpointManager
to save the state to disk. Summarising the code snippets from the Flax docs:
import orbax
# <-- Code building an empty and a full chkpt. -->.
abstract_chkpt = ...
chkpt = ...
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=2, create=True)
checkpoint_manager = orbax.checkpoint.CheckpointManager(
'/tmp/flax_ckpt/orbax/managed', orbax_checkpointer, options)
# Save and restore a checkpoint.
checkpoint_manager.save(step, ckpt, save_kwargs={'save_args': save_args})
checkpoint_manager.restore(1, items=abstract_ckpt)
The notebook provided by the Flax docs does what I want: periodically track TrainState
, which can then be restored.
However, when executing the code provided by the Flax docs warn that this orbax checkpoint API is deprecated:
WARNING:absl:Configured
CheckpointManager
using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
The link indicated by the error message gives some pointers how to use the new orbax.checkpoint.CheckpointManager
.
Question
How do I save and restore a Flax TrainState
with the new orbax.checkpoint.CheckpointManager
API?
Here is my failed attempt (based on the Orbax migration instructions) at saving and restoring a trivial flax.training.train_state.TrainState
:
import orbax.checkpoint as obc
from flax.training.train_state import TrainState
abstract_ckpt = TrainState(step=0, apply_fn=lambda _: None, params={}, tx={}, opt_state={})
ckpt = abstract_ckpt.replace(step=1)
# Set up the checkpointer.
options = obc.CheckpointManagerOptions(max_to_keep=2, create=True)
checkpoint_dir = obc.test_utils.create_empty('/tmp/checkpoint_manager')
checkpoint_manager = obc.CheckpointManager(checkpoint_dir, options=options)
save_args = obc.args.StandardSave(abstract_ckpt)
# Do actual checkpointing.
checkpoint_manager.save(1, ckpt, args=save_args)
# Restore checkpoint.
restore_args = obc.args.StandardRestore(abstract_ckpt)
restored_ckpt = checkpoint_manager.restore(1, args=restore_args)
# Verify if it is correctly restored.
assert ckpt.step == restored_ckpt.step # AssertionError
My guess would be that the problem relates to save_args
, but I haven't managed to pinpoint the problem and figure out a fix. Any suggestions how to correctly restore the checkpoint using the new CheckpointManager
API?
You created save_args = ocp.args.StandardSave(abstract_ckpt)
instead of save_args = ocp.args.StandardSave(ckpt)
, so you're just saving the wrong thing.
Also note that checkpoint_dir = ocp.test_utils.create_empty('/tmp/checkpoint_manager')
is a bit unnecessary - it's just a test utility for deleting a directory if it already exists - makes running our colabs a bit easier. Probably you shouldn't need to use it in real life, as the create
option in CheckpointManager
will create the directory for you.