Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 213 additions & 0 deletions tests/engine/test_dense_lora_train_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import os
import tempfile
import shutil
import time
import parametrize
import torch
import torch.distributed as dist
from xtuner._testing import DeterministicDDPTestCase
from transformers import AutoTokenizer

from xtuner.v1.model.moe.moe import SequenceContext
from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig
from xtuner.v1.model.base import ModelItem
from xtuner.v1.loss.ce_loss import CELossConfig, CELossContextInputItem
from xtuner.v1.config import FSDPConfig, LRConfig, AdamWConfig
from xtuner.v1.engine.train_engine import TrainEngine
from torch.optim.lr_scheduler import LambdaLR
from xtuner.v1.utils import pad_to_max_length
from xtuner.v1.utils.device import get_device
from xtuner.v1.utils.test_utils import init_data_mesh
from xtuner.v1.model.adapter.lora import LoraConfig


# Qwen3 8B
QWEN3_PATH = os.environ["QWEN3_PATH"]
DEVICE = get_device()


class TestDenseEngine(DeterministicDDPTestCase):
@parametrize.parametrize(
"device,tp_size,sp_size",
[
("cuda", 1, 1),
("cuda", 1, 2),
],
)
def test_dense_engine_train(self, device, tp_size, sp_size):
pg = self.create_pg(device)

dense_cfg = Qwen3Dense8BConfig()
optim_cfg: AdamWConfig = AdamWConfig()
lr_cfg: LRConfig = LRConfig()
fsdp_cfg: FSDPConfig = FSDPConfig(
torch_compile=True,
cpu_offload=False,
tp_size=tp_size,
# hsdp_sharding_size=hsdp_sharding_size,
)

adapter_cfg = LoraConfig(
r=8,
lora_alpha=8,
lora_dropout=0,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
bias="none",
modules_to_save=["lm_head"],
)
engine = TrainEngine(
model_cfg=dense_cfg,
optim_cfg=optim_cfg,
fsdp_cfg=fsdp_cfg,
adapter_cfg=adapter_cfg,
)
engine.from_hf(hf_path=QWEN3_PATH)

loss_cfg = CELossConfig()

total_steps = 1000
warmup_steps = total_steps * lr_cfg.warmup_ratio

def warmup_fn(x):
return x / warmup_steps if x < warmup_steps else 1

lr_scheduler = LambdaLR(engine.optimizer, warmup_fn)

tok = AutoTokenizer.from_pretrained(QWEN3_PATH)
txt = "根据国际地球自转和参考系服务机构的数据,今年夏天是自2020年以来第六次地球自转加速。7月9日将成为有史以来最短的一天,比平时短1.3到1.6毫秒。 "
input_ids = tok.encode(txt, return_tensors="pt").view(1, -1)
labels = input_ids.clone()
input_ids = input_ids[:, :-1]
labels = labels[:, 1:]
pack_len = 8192 - input_ids.shape[1]
input_ids = pad_to_max_length(input_ids, 0, max_length=8192)
labels = pad_to_max_length(labels, -100, max_length=8192)
losses = []

data_mesh = None
if sp_size > 1:
data_mesh = init_data_mesh(str(DEVICE), sp_size)

for _ in range(10):
seq_ctx = SequenceContext.from_input_ids((input_ids,), device=DEVICE)
labels = labels.to(DEVICE)
seq_ctx.num_padding = pack_len
seq_ctx_list = [seq_ctx]
loss_ctx_input_list: list[CELossContextInputItem] = [
CELossContextInputItem(shifted_labels=labels)
]
LossContext = loss_cfg.loss_ctx_cls
batches_loss_kwargs = LossContext.build_batches_loss_kwargs(
loss_ctx_input_list,
loss_cfg,
)
loss_kwargs = batches_loss_kwargs[0]
loss_ctx = LossContext(loss_cfg, loss_kwargs)
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
loss_log, _ = engine.train_step(engine_input)
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
lr_scheduler.step()
losses.append(loss_log["reduced_llm_loss"])
losses_ref = [2.57, 2.57, 2.57, 2.57, 2.57, 2.57, 2.56, 2.56, 2.54, 2.53]
for loss, loss_ref in zip(losses, losses_ref):
self.assertTrue(
abs(loss - loss_ref) < 0.02,
f"loss={loss}, loss_ref={loss_ref}, diff={abs(loss - loss_ref)}",
)

torch.cuda.empty_cache()
try:
dist.destroy_process_group(pg)
except:
pass

@parametrize.parametrize(
"device,tp_size,hsdp_sharding_size",
[
("cuda", 1, 8), # todo: test ep8 and hsdp, OOM in 8 gpus
],
)
def test_save_and_load(self, device, tp_size, hsdp_sharding_size):
pg = self.create_pg(device)

temp_dir = tempfile.mkdtemp()
if dist.get_rank() == 0:
temp_dir = [temp_dir]
else:
temp_dir = [None]
dist.broadcast_object_list(temp_dir, src=0)
temp_dir = temp_dir[0]
moe_cfg = Qwen3Dense8BConfig()
optim_cfg: AdamWConfig = AdamWConfig()
fsdp_cfg: FSDPConfig = FSDPConfig(
torch_compile=True,
cpu_offload=False,
tp_size=tp_size,
hsdp_sharding_size=hsdp_sharding_size,
)
adapter_cfg = LoraConfig(
r=4,
lora_alpha=16,
lora_dropout=0,
target_modules=["q_proj", "v_proj"],
bias="none",
modules_to_save=["lm_head"],
)
engine = TrainEngine(
model_cfg=moe_cfg,
optim_cfg=optim_cfg,
fsdp_cfg=fsdp_cfg,
adapter_cfg=adapter_cfg,
)

engine.from_hf(hf_path=QWEN3_PATH)
engine.save_hf(
hf_dir=temp_dir,
save_dtype=torch.bfloat16,
)

dist.barrier()
time.sleep(1)

# engine2 = TrainEngine(
# model_cfg=moe_cfg,
# optim_cfg=optim_cfg,
# fsdp_cfg=fsdp_cfg,
# )
# engine2.from_hf(hf_path=temp_dir)

# state_dict = engine.model.state_dict()
# state_dict2 = engine2.model.state_dict()
# for key, val in state_dict.items():
# val2 = state_dict2[key]
# val = val.full_tensor().bfloat16()
# val2 = val2.full_tensor().bfloat16()
# self.assertTrue(torch.equal(val, val2[:val.shape[0]]),
# f"Mismatch in {key} between bf16 and fp8, {val} and {val2[:val.shape[0]]}")

if dist.get_rank() == 0:
shutil.rmtree(temp_dir)

torch.cuda.empty_cache()
try:
dist.destroy_process_group(pg)
except:
pass

@property
def world_size(self) -> int:
return int(os.getenv("XTUNER_TEST_WORLD_SIZE", "8"))

@property
def destroy_pg_upon_exit(self) -> bool:
return False
5 changes: 3 additions & 2 deletions tests/ray/test_update_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def init_config(self):
if hasattr(model_cfg, 'balancing_loss_cfg'):
model_cfg.balancing_loss_cfg = BalancingLossConfig()
optim_cfg: AdamWConfig = AdamWConfig(lr=5e-7, foreach=False)
fsdp_cfg: FSDPConfig = FSDPConfig()
fsdp_cfg: FSDPConfig = FSDPConfig(ep_size=4)
model_cfg.ep_size = fsdp_cfg.ep_size
lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=5e-7)
self.worker_cfg: WorkerConfig = WorkerConfig(
model_cfg=model_cfg,
Expand All @@ -84,7 +85,7 @@ def init_config(self):
loss_type="vanilla",
),
ignore_idx=-100,
use_kl_loss=True,
use_kl_loss=False,
kl_loss_coef=0.001,
kl_loss_type="low_var_kl",
mode="eager"),
Expand Down
5 changes: 5 additions & 0 deletions xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from xtuner.v1.config import FSDPConfig, OptimConfig
from xtuner.v1.data_proto.sequence_context import SequenceContext
from xtuner.v1.float8.float8_handler import Float8Handler
from xtuner.v1.model.adapter.lora import LoraConfig
from xtuner.v1.model.base import BaseModel, ModelItem, TransformerConfig
from xtuner.v1.model.utils import ModelForwardExtraLogInfo
from xtuner.v1.module.router import NoAuxRouterConfig
Expand Down Expand Up @@ -145,10 +146,12 @@ def __init__(
optim_cfg: OptimConfig,
fsdp_cfg: FSDPConfig,
intra_layer_micro_batch: int = 1,
adapter_cfg: LoraConfig | None = None,
) -> None:
self.model_cfg = model_cfg
self.optim_cfg = optim_cfg
self.fsdp_cfg = fsdp_cfg
self.adapter_cfg = adapter_cfg
self.model = self.build_model()
self.optimizer = self.build_optimizer(optim_cfg)
self.intra_layer_micro_batch = intra_layer_micro_batch
Expand All @@ -166,6 +169,8 @@ def __has_freeze_params(self) -> bool:
def build_model(self) -> BaseModel:
with torch.device("meta"):
model = self.model_cfg.build()
if self.adapter_cfg:
model = self.adapter_cfg.build(model)

self.float8_handler = None
if self.model_cfg.float8_cfg is not None and self.model_cfg.float8_cfg.enable_float8:
Expand Down
Loading
Loading