Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xtuner/v1/rl/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .controller import TrainingController, TrainingControllerProxy
from .loss import BaseRLLossConfig, RLLossContextInputItem
from .worker import TrainingWorker, TrainingWorkerClass, TrainingWorkerProxy, WorkerConfig
from .worker import TrainingWorker, TrainingWorkerClass, TrainingWorkerProxy, WorkerConfig, WorkerLogItem


__all__ = [
Expand All @@ -12,4 +12,5 @@
"WorkerConfig",
"BaseRLLossConfig",
"RLLossContextInputItem",
"WorkerLogItem",
]
7 changes: 4 additions & 3 deletions xtuner/v1/rl/base/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from xtuner.v1.train.trainer import LoadCheckpointConfig
from xtuner.v1.utils import ray_method

from .worker import TrainingWorker
from .worker import TrainingWorker, WorkerLogItem


class ColateItem(TypedDict):
Expand Down Expand Up @@ -165,7 +165,7 @@ def _grouped_by_max_length(self, packed_data_batches):
return sorted(packed_data_batches, key=lambda x: x["seq_ctx"].max_length_q, reverse=True)

@ray_method
def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: int):
def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: int) -> list[WorkerLogItem]:
has_rollout_routed_experts = False
language_cfg = None
if data_batches[0]["seq_ctx"].rollout_routed_experts is not None:
Expand Down Expand Up @@ -256,7 +256,8 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx:
rollout_idx=rollout_idx,
)
)
ray.get(handles)
log_infos = ray.get(handles)
return log_infos

@ray_method
def offload(self, target: Literal["model", "optimizer", "all"] = "all"):
Expand Down
80 changes: 61 additions & 19 deletions xtuner/v1/rl/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from ray.actor import ActorClass, ActorProxy
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor import DTensor
from typing_extensions import NotRequired

from xtuner.v1.config.fsdp import FSDPConfig
from xtuner.v1.config.optim import LRConfig, OptimConfig
from xtuner.v1.data_proto.sequence_context import SequenceContext
from xtuner.v1.engine.train_engine import TrainEngine
from xtuner.v1.engine.train_engine import LossLog, OtherLog, TrainEngine
from xtuner.v1.engine.vision_compose_train_engine import (
VisionComposeTrainEngine,
)
Expand Down Expand Up @@ -141,6 +142,29 @@ class WorkerInputItem(TypedDict):
rollout_logprobs: torch.Tensor | None


class RLOtherLog(TypedDict):
maxvio: NotRequired[float]
step_consumed_tokens: int
step_consumed_img_tokens: NotRequired[float]
efficient_attn_ratio: float
max_ratio: NotRequired[float]
loss: NotRequired[float]
grad_norm: NotRequired[float]


class WorkerTrainLogItem(TypedDict):
loss_log: LossLog
rl_other_log: RLOtherLog


class WorkerLogItem(TypedDict):
train_entropy: float
rollout_entropy: NotRequired[float]
mismatch_metrics: NotRequired[dict[str, float]]
rollout_is_metrics: NotRequired[dict[str, float]]
train_metrics: List[WorkerTrainLogItem]


class TrainingWorker(SingleAcceleratorWorker):
_SAVE_OPTIMIZER_DIR = "optimizer"
_SAVE_MODEL_DIR = "model"
Expand Down Expand Up @@ -299,20 +323,32 @@ def compute_ref_logprobs(
self._ref_model.to_device("cpu")
return loss_ctx_input_list

def _update_other_log(self, other_log: dict):
def _get_rl_other_log(self, other_log: OtherLog) -> RLOtherLog:
from xtuner.v1.model.utils import ModelForwardExtraLogInfo

extra_info = other_log.get("extra_info", {})
extra_info: ModelForwardExtraLogInfo | dict = other_log.get("extra_info", {})
if isinstance(extra_info, ModelForwardExtraLogInfo):
extra_info_dict = extra_info.get()
else:
extra_info_updated = ModelForwardExtraLogInfo(extra_info)
extra_info_dict = extra_info_updated.get()
other_log["extra_info"] = extra_info_dict
return other_log

for k, v in extra_info_dict.items():
if isinstance(v, torch.Tensor):
extra_info_dict[k] = v.item()

rl_other_log: RLOtherLog = {
"maxvio": other_log.get("maxvio", 0.0),
"step_consumed_tokens": other_log["step_consumed_tokens"],
"step_consumed_img_tokens": float(other_log.get("step_consumed_img_tokens", 0.0)),
"efficient_attn_ratio": other_log["efficient_attn_ratio"],
"max_ratio": extra_info_dict.get("max_ratio", 0.0),
"loss": extra_info_dict.get("loss", 0.0),
}
return rl_other_log

@ray_method
def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int):
def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLogItem:
# NOTE: sglang会清除logger handle, 重新创建
self.logger = get_logger(log_dir=self.log_dir, tag="TrainingWorker")
loss_cfg = self.config.loss_cfg
Expand Down Expand Up @@ -455,27 +491,36 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int):
all_rollout_is_metrics.append(rollout_is_metrics)
all_mismatch_metrics.append(mismatch_metrics)

worker_log_item: WorkerLogItem = {
"train_entropy": 0.0,
"train_metrics": [],
}
logger_msg = f"Rollout {rollout_idx}: "

sum_entropy = cast(torch.Tensor, sum_entropy)
dist.all_reduce(sum_entropy, op=dist.ReduceOp.SUM)
avg_sum_entropy = sum_entropy / global_grad_tokens if global_grad_tokens > 0 else 0
avg_sum_entropy = sum_entropy / global_grad_tokens if global_grad_tokens > 0 else torch.tensor(0.0)
worker_log_item["train_entropy"] = avg_sum_entropy.item()
logger_msg += f"avg entropy: {avg_sum_entropy:.4f}"

if sum_rollout_entropy is not None:
sum_rollout_entropy = cast(torch.Tensor, sum_rollout_entropy)
dist.all_reduce(sum_rollout_entropy, op=dist.ReduceOp.SUM)
avg_rollout_entropy = sum_rollout_entropy / global_grad_tokens if global_grad_tokens > 0 else 0
avg_rollout_entropy = (
sum_rollout_entropy / global_grad_tokens if global_grad_tokens > 0 else torch.tensor(0.0)
)
worker_log_item["rollout_entropy"] = avg_rollout_entropy.item()
logger_msg += f", avg rollout entropy: {avg_rollout_entropy:.4f}"

if len(all_mismatch_metrics) > 0:
mismatch_metrics = merge_rollout_is_metrics(all_mismatch_metrics, DEVICE)
if len(mismatch_metrics) > 0:
worker_log_item["mismatch_metrics"] = mismatch_metrics
logger_msg += f"\n rollout mismatch metrics:\n{json.dumps(mismatch_metrics, indent=4)}"

if len(all_rollout_is_metrics) > 0:
rollout_is_metrics = merge_rollout_is_metrics(all_rollout_is_metrics, DEVICE)
if len(rollout_is_metrics) > 0:
worker_log_item["rollout_is_metrics"] = rollout_is_metrics
logger_msg += f"\n rollout importance sampling metrics:\n{json.dumps(rollout_is_metrics, indent=4)}"

if self.rank == 0:
Expand Down Expand Up @@ -522,25 +567,22 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int):
loss_log, other_log = self._engine.train_step(
data_batches=engine_input,
)
other_log = self._update_other_log(other_log) # type: ignore[arg-type]
grad_norm = self._engine.clip_grad_norm()
self._engine.step_optimizer(grad_norm)
log_info = dict() # type: ignore[var-annotated]
log_info.update(loss_log)
for k, v in other_log.items():
if k == "extra_info":
for extra_k, extra_v in v.items():
log_info[extra_k] = extra_v.item() if isinstance(extra_v, torch.Tensor) else extra_v
else:
log_info[k] = v.item() if isinstance(v, torch.Tensor) else v
log_info["grad_norm"] = grad_norm.item()
rl_other_log = self._get_rl_other_log(other_log) # type: ignore[arg-type]
rl_other_log["grad_norm"] = grad_norm.item()
worker_log_item["train_metrics"].append(WorkerTrainLogItem(loss_log=loss_log, rl_other_log=rl_other_log))

log_info = {**loss_log, **rl_other_log}
log_str = ", ".join(
f"{key}={value:.4f}" if isinstance(value, float) else f"{key}={value}"
for key, value in log_info.items()
)
log_str = f"Rollout {rollout_idx} Step {i}: " + log_str
self.logger.info(log_str)

return worker_log_item

@ray_method
def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16):
self._engine.save_hf(hf_dir, save_dtype)
Expand Down
Loading