🚀 Feature Request
Description
The current distributed_async_checkpoint_recipe covers basic usage of dcp.async_save and Pinned Memory optimization. However, it does not cover the fully asynchronous staging capabilities introduced in PyTorch 2.9 via DefaultStager.
Even with async_save, the Device-to-Host (D2H) copy (staging phase) typically happens on the main thread, which can block the training loop.
Proposal
I would like to update the tutorial to include a new section on "Fully Asynchronous Staging with DefaultStager".
This update will demonstrate:
- How to use the
async_stager=DefaultStager() argument.
- How to correctly synchronize staging to achieve full overlap between the D2H copy and the Forward + Backward pass of the next step.
- Timeline comparison between standard async save and stager-based async save.
I have already prepared the content and code example.