-
Notifications
You must be signed in to change notification settings - Fork 395
[Feat] add tensorboard for RL trainer #1396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
f4ab23e to
b87d9bf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds comprehensive TensorBoard logging support to the RLTrainer, enabling better monitoring and visualization of RL training metrics. The changes include refactoring the main training loop for better modularity and replacing NumPy operations with PyTorch for consistency.
Key changes:
- Integrated TensorBoard writer to log training metrics, response statistics, evaluation scores, and timing information for each training step
- Refactored the monolithic
fit()method into smaller, focused helper methods (_initial_evaluate, _rollout_step, _train_step, _sync_weights_and_save, _evaluate_step) - Replaced NumPy tensor operations with PyTorch in data preparation and trajectory saving for consistency
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| xtuner/v1/utils/profile.py | Extended the timer context manager to optionally log timing metrics to TensorBoard |
| xtuner/v1/train/rl_trainer.py | Added TensorboardWriter initialization, refactored fit() method into helper methods, added tensorboard logging throughout, replaced numpy with torch tensors, added debug_rollout mode and rollout_steps parameter |
| xtuner/v1/rl/base/worker.py | Modified fit() to return structured logging information (entropy, mismatch, rollout_is, training metrics) for TensorBoard |
| xtuner/v1/rl/base/controller.py | Updated fit() to return log_infos from workers instead of discarding them |
Comments suppressed due to low confidence (1)
xtuner/v1/train/rl_trainer.py:632
- The data_info dictionary from _prepare_train_data contains useful training metrics (advantages, prompt_len, etc.) that are only logged to console. For consistency with other metrics being logged to tensorboard in this PR, consider also logging these to tensorboard using self._writer.add_scalars().
def _log_data_info(self, rollout_idx: int, data_info: dict):
"""Formats and logs the data statistics dictionary."""
log_lines = [f"Rollout {rollout_idx} data statistics:"]
for key, value in data_info.items():
if isinstance(value, float):
log_lines.append(f" - {key:<20}: {value:.4f}")
else:
log_lines.append(f" - {key:<20}: {value}")
self.logger.info("\n".join(log_lines))
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
/gemini review |
b87d9bf to
4e15498
Compare
|
9a7397a to
7abdb3c
Compare
Uh oh!
There was an error while loading. Please reload this page.