Skip to content

SAC Device Axis Mismatch When Loading Checkpoint #659

@MichaelCSI

Description

@MichaelCSI

In SAC train(), a segment for the workflow when loading a checkpoint is as follows:

  1. Training state is initialized training_state = _init_training_state(...)
  2. If restore_checkpoint_path is not None, load params from the path and replace them in training_state

Issue: jax.device_put_replicated(training_state, jax.local_devices()[:local_devices_to_use]) is called at the bottom of _init_training_state() i.e. before we load checkpoint params. When we later call _unpmap((training_state.normalizer_params, training_state.policy_params)) we get an IndexError since params are not per-device: IndexError: Too many indices: 0-dimensional array indexed with 1 regular index at jax.tree_util.tree_map(lambda x: x[0], v).

For reference, the PPO implementation calls jax.device_put_replicated after the checkpoint params have been replaced in the training state.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions