-
Notifications
You must be signed in to change notification settings - Fork 334
SAC Device Axis Mismatch When Loading Checkpoint #659
Copy link
Copy link
Open
Description
In SAC train(), a segment for the workflow when loading a checkpoint is as follows:
- Training state is initialized
training_state = _init_training_state(...) - If
restore_checkpoint_pathis not None, load params from the path and replace them intraining_state
Issue: jax.device_put_replicated(training_state, jax.local_devices()[:local_devices_to_use]) is called at the bottom of _init_training_state() i.e. before we load checkpoint params. When we later call _unpmap((training_state.normalizer_params, training_state.policy_params)) we get an IndexError since params are not per-device: IndexError: Too many indices: 0-dimensional array indexed with 1 regular index at jax.tree_util.tree_map(lambda x: x[0], v).
For reference, the PPO implementation calls jax.device_put_replicated after the checkpoint params have been replaced in the training state.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels