Skip to content

[DCP] Add DefaultStager example to distributed async checkpoint recipe #3710

@niyunsheng

Description

@niyunsheng

🚀 Feature Request

Description
The current distributed_async_checkpoint_recipe covers basic usage of dcp.async_save and Pinned Memory optimization. However, it does not cover the fully asynchronous staging capabilities introduced in PyTorch 2.9 via DefaultStager.

Even with async_save, the Device-to-Host (D2H) copy (staging phase) typically happens on the main thread, which can block the training loop.

Proposal
I would like to update the tutorial to include a new section on "Fully Asynchronous Staging with DefaultStager".

This update will demonstrate:

  1. How to use the async_stager=DefaultStager() argument.
  2. How to correctly synchronize staging to achieve full overlap between the D2H copy and the Forward + Backward pass of the next step.
  3. Timeline comparison between standard async save and stager-based async save.

I have already prepared the content and code example.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions