diff --git a/tests/engine/test_dense_lora_train_engine.py b/tests/engine/test_dense_lora_train_engine.py new file mode 100644 index 000000000..726d17997 --- /dev/null +++ b/tests/engine/test_dense_lora_train_engine.py @@ -0,0 +1,213 @@ +import os +import tempfile +import shutil +import time +import parametrize +import torch +import torch.distributed as dist +from xtuner._testing import DeterministicDDPTestCase +from transformers import AutoTokenizer + +from xtuner.v1.model.moe.moe import SequenceContext +from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig +from xtuner.v1.model.base import ModelItem +from xtuner.v1.loss.ce_loss import CELossConfig, CELossContextInputItem +from xtuner.v1.config import FSDPConfig, LRConfig, AdamWConfig +from xtuner.v1.engine.train_engine import TrainEngine +from torch.optim.lr_scheduler import LambdaLR +from xtuner.v1.utils import pad_to_max_length +from xtuner.v1.utils.device import get_device +from xtuner.v1.utils.test_utils import init_data_mesh +from xtuner.v1.model.adapter.lora import LoraConfig + + +# Qwen3 8B +QWEN3_PATH = os.environ["QWEN3_PATH"] +DEVICE = get_device() + + +class TestDenseEngine(DeterministicDDPTestCase): + @parametrize.parametrize( + "device,tp_size,sp_size", + [ + ("cuda", 1, 1), + ("cuda", 1, 2), + ], + ) + def test_dense_engine_train(self, device, tp_size, sp_size): + pg = self.create_pg(device) + + dense_cfg = Qwen3Dense8BConfig() + optim_cfg: AdamWConfig = AdamWConfig() + lr_cfg: LRConfig = LRConfig() + fsdp_cfg: FSDPConfig = FSDPConfig( + torch_compile=True, + cpu_offload=False, + tp_size=tp_size, + # hsdp_sharding_size=hsdp_sharding_size, + ) + + adapter_cfg = LoraConfig( + r=8, + lora_alpha=8, + lora_dropout=0, + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + bias="none", + modules_to_save=["lm_head"], + ) + engine = TrainEngine( + model_cfg=dense_cfg, + optim_cfg=optim_cfg, + fsdp_cfg=fsdp_cfg, + adapter_cfg=adapter_cfg, + ) + engine.from_hf(hf_path=QWEN3_PATH) + + loss_cfg = CELossConfig() + + total_steps = 1000 + warmup_steps = total_steps * lr_cfg.warmup_ratio + + def warmup_fn(x): + return x / warmup_steps if x < warmup_steps else 1 + + lr_scheduler = LambdaLR(engine.optimizer, warmup_fn) + + tok = AutoTokenizer.from_pretrained(QWEN3_PATH) + txt = "根据国际地球自转和参考系服务机构的数据,今年夏天是自2020年以来第六次地球自转加速。7月9日将成为有史以来最短的一天,比平时短1.3到1.6毫秒。 " + input_ids = tok.encode(txt, return_tensors="pt").view(1, -1) + labels = input_ids.clone() + input_ids = input_ids[:, :-1] + labels = labels[:, 1:] + pack_len = 8192 - input_ids.shape[1] + input_ids = pad_to_max_length(input_ids, 0, max_length=8192) + labels = pad_to_max_length(labels, -100, max_length=8192) + losses = [] + + data_mesh = None + if sp_size > 1: + data_mesh = init_data_mesh(str(DEVICE), sp_size) + + for _ in range(10): + seq_ctx = SequenceContext.from_input_ids((input_ids,), device=DEVICE) + labels = labels.to(DEVICE) + seq_ctx.num_padding = pack_len + seq_ctx_list = [seq_ctx] + loss_ctx_input_list: list[CELossContextInputItem] = [ + CELossContextInputItem(shifted_labels=labels) + ] + LossContext = loss_cfg.loss_ctx_cls + batches_loss_kwargs = LossContext.build_batches_loss_kwargs( + loss_ctx_input_list, + loss_cfg, + ) + loss_kwargs = batches_loss_kwargs[0] + loss_ctx = LossContext(loss_cfg, loss_kwargs) + seq_ctx = seq_ctx_list[0] + engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)] + loss_log, _ = engine.train_step(engine_input) + grad_norm = engine.clip_grad_norm() + engine.step_optimizer(grad_norm) + lr_scheduler.step() + losses.append(loss_log["reduced_llm_loss"]) + losses_ref = [2.57, 2.57, 2.57, 2.57, 2.57, 2.57, 2.56, 2.56, 2.54, 2.53] + for loss, loss_ref in zip(losses, losses_ref): + self.assertTrue( + abs(loss - loss_ref) < 0.02, + f"loss={loss}, loss_ref={loss_ref}, diff={abs(loss - loss_ref)}", + ) + + torch.cuda.empty_cache() + try: + dist.destroy_process_group(pg) + except: + pass + + @parametrize.parametrize( + "device,tp_size,hsdp_sharding_size", + [ + ("cuda", 1, 8), # todo: test ep8 and hsdp, OOM in 8 gpus + ], + ) + def test_save_and_load(self, device, tp_size, hsdp_sharding_size): + pg = self.create_pg(device) + + temp_dir = tempfile.mkdtemp() + if dist.get_rank() == 0: + temp_dir = [temp_dir] + else: + temp_dir = [None] + dist.broadcast_object_list(temp_dir, src=0) + temp_dir = temp_dir[0] + moe_cfg = Qwen3Dense8BConfig() + optim_cfg: AdamWConfig = AdamWConfig() + fsdp_cfg: FSDPConfig = FSDPConfig( + torch_compile=True, + cpu_offload=False, + tp_size=tp_size, + hsdp_sharding_size=hsdp_sharding_size, + ) + adapter_cfg = LoraConfig( + r=4, + lora_alpha=16, + lora_dropout=0, + target_modules=["q_proj", "v_proj"], + bias="none", + modules_to_save=["lm_head"], + ) + engine = TrainEngine( + model_cfg=moe_cfg, + optim_cfg=optim_cfg, + fsdp_cfg=fsdp_cfg, + adapter_cfg=adapter_cfg, + ) + + engine.from_hf(hf_path=QWEN3_PATH) + engine.save_hf( + hf_dir=temp_dir, + save_dtype=torch.bfloat16, + ) + + dist.barrier() + time.sleep(1) + + # engine2 = TrainEngine( + # model_cfg=moe_cfg, + # optim_cfg=optim_cfg, + # fsdp_cfg=fsdp_cfg, + # ) + # engine2.from_hf(hf_path=temp_dir) + + # state_dict = engine.model.state_dict() + # state_dict2 = engine2.model.state_dict() + # for key, val in state_dict.items(): + # val2 = state_dict2[key] + # val = val.full_tensor().bfloat16() + # val2 = val2.full_tensor().bfloat16() + # self.assertTrue(torch.equal(val, val2[:val.shape[0]]), + # f"Mismatch in {key} between bf16 and fp8, {val} and {val2[:val.shape[0]]}") + + if dist.get_rank() == 0: + shutil.rmtree(temp_dir) + + torch.cuda.empty_cache() + try: + dist.destroy_process_group(pg) + except: + pass + + @property + def world_size(self) -> int: + return int(os.getenv("XTUNER_TEST_WORLD_SIZE", "8")) + + @property + def destroy_pg_upon_exit(self) -> bool: + return False diff --git a/tests/ray/test_update_weight.py b/tests/ray/test_update_weight.py index fdb112315..dfc3668ba 100644 --- a/tests/ray/test_update_weight.py +++ b/tests/ray/test_update_weight.py @@ -72,7 +72,8 @@ def init_config(self): if hasattr(model_cfg, 'balancing_loss_cfg'): model_cfg.balancing_loss_cfg = BalancingLossConfig() optim_cfg: AdamWConfig = AdamWConfig(lr=5e-7, foreach=False) - fsdp_cfg: FSDPConfig = FSDPConfig() + fsdp_cfg: FSDPConfig = FSDPConfig(ep_size=4) + model_cfg.ep_size = fsdp_cfg.ep_size lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=5e-7) self.worker_cfg: WorkerConfig = WorkerConfig( model_cfg=model_cfg, @@ -84,7 +85,7 @@ def init_config(self): loss_type="vanilla", ), ignore_idx=-100, - use_kl_loss=True, + use_kl_loss=False, kl_loss_coef=0.001, kl_loss_type="low_var_kl", mode="eager"), diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 1d23e226b..a2802116d 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -27,6 +27,7 @@ from xtuner.v1.config import FSDPConfig, OptimConfig from xtuner.v1.data_proto.sequence_context import SequenceContext from xtuner.v1.float8.float8_handler import Float8Handler +from xtuner.v1.model.adapter.lora import LoraConfig from xtuner.v1.model.base import BaseModel, ModelItem, TransformerConfig from xtuner.v1.model.utils import ModelForwardExtraLogInfo from xtuner.v1.module.router import NoAuxRouterConfig @@ -145,10 +146,12 @@ def __init__( optim_cfg: OptimConfig, fsdp_cfg: FSDPConfig, intra_layer_micro_batch: int = 1, + adapter_cfg: LoraConfig | None = None, ) -> None: self.model_cfg = model_cfg self.optim_cfg = optim_cfg self.fsdp_cfg = fsdp_cfg + self.adapter_cfg = adapter_cfg self.model = self.build_model() self.optimizer = self.build_optimizer(optim_cfg) self.intra_layer_micro_batch = intra_layer_micro_batch @@ -166,6 +169,8 @@ def __has_freeze_params(self) -> bool: def build_model(self) -> BaseModel: with torch.device("meta"): model = self.model_cfg.build() + if self.adapter_cfg: + model = self.adapter_cfg.build(model) self.float8_handler = None if self.model_cfg.float8_cfg is not None and self.model_cfg.float8_cfg.enable_float8: diff --git a/xtuner/v1/model/adapter/lora.py b/xtuner/v1/model/adapter/lora.py new file mode 100644 index 000000000..7ef31aed7 --- /dev/null +++ b/xtuner/v1/model/adapter/lora.py @@ -0,0 +1,372 @@ +import functools +import json +import os +from pathlib import Path +from typing import List, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from pydantic import BaseModel as PydanticBaseModel +from torch.distributed.tensor import DTensor +from tqdm import tqdm + +from xtuner.v1.model.base import _save_file +from xtuner.v1.module.grouped_linear.moe_group_linear import GroupedLinear +from xtuner.v1.module.lora_linear.lora_grouped_linear import LoraGroupedLinear +from xtuner.v1.module.lora_linear.lora_linear import LoraLinear +from xtuner.v1.utils import get_device, get_torch_device_module, profile_time_and_memory +from xtuner.v1.utils.load_spec import LoadSpec + + +DEVICE_MODULE = get_torch_device_module() +DEVICE = get_device() + + +class LoraConfig(PydanticBaseModel): + # 与 peft LoraConfig 对齐的字段 + r: int = 8 + target_modules: Optional[Union[List[str], str]] = None + lora_alpha: int = 8 + lora_dropout: float = 0.0 + bias: str = "none" # "none" | "all" | "lora_only" + modules_to_save: Optional[List[str]] = None + init_lora_weights: bool = True + layers_to_transform: Optional[Union[List[int], int]] = None # 暂不使用 + layers_pattern: Optional[str] = None # 暂不使用 + base_model_name_or_path: Optional[str] = None + + def build(self, base_model): + return LoraModel(base_model, self) + + def save_hf(self, hf_path: str | Path): + """Save the configuration to a HuggingFace-compatible format. + + Args: + hf_path (str | Path): Path where the configuration should be saved. + """ + hf_config = { + "alpha_pattern": {}, + "auto_mapping": None, + "base_model_name_or_path": self.base_model_name_or_path, + "bias": self.bias, + "corda_config": None, + "eva_config": None, + "exclude_modules": None, + "fan_in_fan_out": False, + "inference_mode": True, + "init_lora_weights": self.init_lora_weights, + "layer_replication": None, + "layers_pattern": self.layers_pattern, + "layers_to_transform": self.layers_to_transform, + "loftq_config": {}, + "lora_alpha": self.lora_alpha, + "lora_bias": False, + "lora_dropout": self.lora_dropout, + "megatron_config": None, + "megatron_core": "megatron.core", + "modules_to_save": self.modules_to_save, + "peft_type": "LORA", + "qalora_group_size": 16, + "r": self.r, + "rank_pattern": {}, + "revision": None, + "target_modules": self.target_modules, + "target_parameters": None, + "task_type": "CAUSAL_LM", + "trainable_token_indices": None, + "use_dora": False, + "use_qalora": False, + "use_rslora": False, + } + with open(os.path.join(hf_path, "adapter_config.json"), "w") as f: + json.dump(hf_config, f, indent=2) + + +def wrap_to_hf_key_list(obj): + orig = getattr(obj, "to_hf_key_list") + + @functools.wraps(orig) + def new_to_hf_key_list(key: str): + if "base_layer." in key: + key = key.replace("base_layer.", "") + out = orig(key) + # if ".base_layer." in out[0]: + # out[0] = out[0].replace(".base_layer.", ".") + return out + + setattr(obj, "to_hf_key_list", new_to_hf_key_list) + + +class LoraModel(nn.Module): + def __init__(self, model: nn.Module, lora_config: LoraConfig): + super().__init__() + self.base_model = model + self.lora_config = lora_config + + # apply lora to base model + self._apply_lora() + + # 5. 修改hf的key mapping + wrap_to_hf_key_list(self.base_model) + self.base_model._init_load_spec() + + def _apply_lora(self): + # 1. 冻结整个原模型 + for p in self.base_model.parameters(): + p.requires_grad = False + + # 2. 注入 LoRA + self._replace_linear_layers(self.base_model, prefix="") + + # 3. 按 config.bias 设置 bias 的 requires_grad + self._apply_bias_setting() + + # 4. 按 modules_to_save 让特定模块参数仍然可训练(例如 lm_head) + self._apply_modules_to_save() + + # def __getattr__(self, name): + # try: + # return super().__getattr__(name) + # except AttributeError: + # return getattr(self.model, name) + + def _match_target(self, module_name: str) -> bool: + """与 peft 类似的逻辑: + + - target_modules=None: 所有 nn.Linear 都加 LoRA + - target_modules=str: 名字中包含该 substring 的模块 + - target_modules=List[str]: 名字中包含任一 substring 的模块 + """ + target = self.lora_config.target_modules + if target is None: + return True + if isinstance(target, str): + return target in module_name + return any(t in module_name for t in target) + + def _replace_linear_layers(self, module: nn.Module, prefix: str): + for name, child in list(module.named_children()): + full_name = f"{prefix}.{name}" if prefix else name + + if isinstance(child, (LoraLinear, LoraGroupedLinear)): + # 避免重复 wrap + continue + + if self._match_target(full_name) and isinstance(child, nn.Linear): + lora_layer = LoraLinear( + base_layer=child, + rank=self.lora_config.r, + alpha=self.lora_config.lora_alpha, + lora_dropout=self.lora_config.lora_dropout, + init_lora_weights=self.lora_config.init_lora_weights, + ) + setattr(module, name, lora_layer) + elif self._match_target(full_name) and isinstance(child, GroupedLinear): + lora_layer = LoraGroupedLinear( + base_layer=child, + rank=self.lora_config.r, + alpha=self.lora_config.lora_alpha, + lora_dropout=self.lora_config.lora_dropout, + init_lora_weights=self.lora_config.init_lora_weights, + ) + setattr(module, name, lora_layer) + else: + self._replace_linear_layers(child, prefix=full_name) + + def _apply_bias_setting(self): + """ + - "none": 所有 bias 冻结(默认) + - "all": 所有 bias 可训练 + - "lora_only": 只有挂了 LoRA 的 Linear 的 bias 可训练 + """ + bias_mode = self.lora_config.bias + + if bias_mode == "none": + # 已经在 init 时全部冻住,无需额外处理 + return + + if bias_mode == "all": + for module in self.base_model.modules(): + if isinstance(module, (nn.Linear, GroupedLinear)) and module.bias is not None: + module.bias.requires_grad = True + return + + if bias_mode == "lora_only": + for module in self.base_model.modules(): + if isinstance(module, (LoraLinear, LoraGroupedLinear)) and module.base_layer.bias is not None: + module.base_layer.bias.requires_grad = True + return + + raise ValueError(f"Unknown bias mode: {bias_mode}, expected one of ['none', 'all', 'lora_only']") + + def _apply_modules_to_save(self): + """modules_to_save 里的模块,即使用了 LoRA 也会直接训练 一般用来保留 lm_head、classification + head 等。""" + modules_to_save = self.lora_config.modules_to_save + if not modules_to_save: + return + + # 根据模块名匹配 + for module_name, module in self.base_model.named_modules(): + if module_name in modules_to_save: + for p in module.parameters(): + p.requires_grad = True + + def forward(self, *args, **kwargs): + return self.base_model(*args, **kwargs) + + @torch.no_grad() + def merge_lora(self): + for module in tqdm(self.base_model.modules(), desc="[Merge LoRA]"): + if isinstance(module, LoraLinear): + module.merge_lora() + + @torch.no_grad() + def unmerge_lora(self): + for module in tqdm(self.base_model.modules(), desc="[Unmerge LoRA]"): + if isinstance(module, LoraLinear): + module.unmerge_lora() + + def trainable_parameters(self): + params = [(name, param) for name, param in self.named_parameters() if param.requires_grad] + return params + + def print_trainable_parameters(self): + trainable_params = 0 + all_params = 0 + for _, param in self.named_parameters(): + num = param.numel() + all_params += num + if param.requires_grad: + trainable_params += num + + print( + f"trainable params: {trainable_params} || " + f"all params: {all_params} || " + f"trainable: {100 * trainable_params / all_params:.4f}%" + ) + + def fully_shard(self, *args, **kwargs): + self.base_model = self.base_model.fully_shard(*args, **kwargs) + return self + + @property + def device(self) -> torch.device: + return self.base_model.device + + def to_device(self, device: torch.device | str): + self.base_model.to_device(device) + + def scale_and_reduce_grad(self): + self.base_model.scale_and_reduce_grad() + + def set_hf(self, hf_path: str | Path): + self.lora_config.base_model_name_or_path = str(hf_path) + self.base_model.set_hf(hf_path) + + def save_hf(self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16): + with profile_time_and_memory(f"[Saving HF to {hf_dir} cost]"): + self._save_hf(hf_dir=hf_dir, save_dtype=save_dtype) + + def _save_hf(self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16): + """Save the hf model to the given directory. + + Args: + hf_dir (str): The directory to save the model. + save_dtype (torch.dtype): The dtype to save the model parameters, bfloat16 or float8. + """ + # if self._hf_path is None and self.config.hf_config is None: + # raise NotImplementedError( + # "The model is not loaded from Huggingface, and the `hf_config` property is not implemented, so it cannot be saved in Huggingface format." + # ) + + if isinstance(hf_dir, str): + hf_dir = Path(hf_dir) + hf_dir.mkdir(parents=True, exist_ok=True) + + DEVICE_MODULE.empty_cache() + assert save_dtype in [torch.float8_e4m3fn, torch.bfloat16], f"save_dtype {save_dtype} is not supported" + + lora_param = {} + modules_to_save = set(self.lora_config.modules_to_save or []) + + for name, param in self.state_dict().items(): + if "lora_A." in name or "lora_B." in name: + lora_param[name] = param + continue + + if self.lora_config.bias != "none" and name.endswith("bias"): + if ".base_layer.bias" in name: + lora_param[name] = param + continue + + if modules_to_save: + for mod in modules_to_save: + if mod in name: + lora_param[name] = param + + tensor_list: list[torch.Tensor] = [] + load_spec_list: list[LoadSpec] = [] + name_list: list[str] = [] + for name, param in lora_param.items(): + local_tensor = param._local_tensor if isinstance(param, DTensor) else param + local_tensor = local_tensor.bfloat16() + base_name = name[11:] # remove "base_model." prefix + if "lora_A." in name or "lora_B." in name: + load_spec = self.base_model.load_spec_mapping.get(base_name) + base_layer_name = base_name.replace("lora_A.", "base_layer.").replace("lora_B.", "base_layer.") + base_layer_load_spec = self.base_model.load_spec_mapping.get(base_layer_name) + hf_name = "base_model." + base_layer_load_spec.hf_keys[0][:-6] + "lora_" + name.split("lora_")[-1] + else: + load_spec = self.base_model.load_spec_mapping.get(base_name) + hf_name = "base_model." + load_spec.hf_keys[0] + tensor_list.append(local_tensor) + name_list.append(hf_name) + load_spec_list.append(load_spec) + + if self.base_model.fsdp_mesh is not None: + gathered_tensor_list = self.base_model._fsdp_foreach_allgather(tensor_list, load_spec_list) + else: + gathered_tensor_list = tensor_list + gathered_tensor_list = [ + self.base_model.param_to_safetensor(safetensor, name) + for safetensor, name in zip(gathered_tensor_list, name_list) + ] + + # Sepreately save fused parameters and others to make sure each saving rank will not save + # dupilicated keys + # + weight_map = {} + + safetensor_name = "adapter_model.safetensors" + + if not dist.is_initialized() or dist.get_rank() == 0: + # for tie_word_embeddings, we need to make sure each key is only saved once + unique_name_list = [] + unique_hf_tensor_list = [] + for name, hf_tensor in zip(name_list, gathered_tensor_list): + if name not in weight_map: + unique_name_list.append(name) + unique_hf_tensor_list.append(hf_tensor) + weight_map[name] = safetensor_name + + _save_file( + dict(zip(unique_name_list, unique_hf_tensor_list)), + hf_dir / safetensor_name, + ) + self.lora_config.save_hf(hf_dir) + + if dist.is_initialized(): + torch.distributed.barrier() + + def from_hf(self, hf_path: str | Path, strict: bool = True) -> tuple: + # TODO: load lora weight + loaded_keys, unloaded_keys, missing_keys = self.base_model.from_hf(hf_path, strict=False) + for unloaded_key in unloaded_keys: + assert "lora_A." in unloaded_key or "lora_B." in unloaded_key, ( + f"unloaded key {unloaded_key} is not a lora key" + ) + for missing_key in missing_keys: + assert "lora_A." in missing_key or "lora_B." in missing_key, f"missing key {missing_key} is not a lora key" + return loaded_keys, unloaded_keys, missing_keys diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index a8088e9e0..c9ae4e38d 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -481,7 +481,7 @@ def _get_fused_hf_param( dtype: torch.dtype, device="cpu", bucket_size=None, - return_full_key_per_rank: bool = False, + update_weights_for_rl: bool = False, ) -> Generator[tuple[list[str], list[torch.Tensor]], None, None]: if not params: return @@ -506,63 +506,58 @@ def _get_hf_params( for load_spec, fsdp_unshared_tensor in zip(spec_list, fsdp_unshard_tensor_list): hf_keys = load_spec.hf_keys - if load_spec.group is not None: - all_hf_keys_list: list[None] | list[list[str]] = [None for _ in range(load_spec.group.size())] - dist.all_gather_object(all_hf_keys_list, hf_keys, group=load_spec.group) - all_hf_keys_list = cast(list[list[str]], all_hf_keys_list) - all_hf_keys = list(chain(*all_hf_keys_list)) + if update_weights_for_rl: + hf_keys_list.append(hf_keys) + saved_fused_tensor_list.append(fsdp_unshared_tensor) else: - all_hf_keys = hf_keys - - current_rank = dist.get_rank() - fused_save_ranks = self._get_ranks_to_save_fused_tensor(len(all_hf_keys)) - key_per_rank = len(all_hf_keys) / len(fused_save_ranks) - assert key_per_rank.is_integer(), ( - f"XTuner Internal Error, size of all_hf_keys: {len(all_hf_keys)}, " - f"size of `fused_save_ranks` {len(fused_save_ranks)}" - ) + if load_spec.group is not None: + all_hf_keys_list: list[None] | list[list[str]] = [None for _ in range(load_spec.group.size())] + dist.all_gather_object(all_hf_keys_list, hf_keys, group=load_spec.group) + all_hf_keys_list = cast(list[list[str]], all_hf_keys_list) + all_hf_keys = list(chain(*all_hf_keys_list)) + else: + all_hf_keys = hf_keys + + current_rank = dist.get_rank() + fused_save_ranks = self._get_ranks_to_save_fused_tensor(len(all_hf_keys)) + key_per_rank = len(all_hf_keys) / len(fused_save_ranks) + assert key_per_rank.is_integer(), ( + f"XTuner Internal Error, size of all_hf_keys: {len(all_hf_keys)}, " + f"size of `fused_save_ranks` {len(fused_save_ranks)}" + ) - # 1. When return_full_key_per_rank is False, we intends to save hf models across ranks, - # each rank only saves part of hf keys and tensors - # 2. When return_full_key_per_rank is True, we intends to generate full tensors on each - # rank for ipc updating weights in RL training. - if not return_full_key_per_rank: start = int(current_rank * key_per_rank) end = int(start + key_per_rank) - else: - start = 0 - end = len(all_hf_keys) - _hf_key_list = all_hf_keys[start:end] + _hf_key_list = all_hf_keys[start:end] - if not _hf_key_list: - continue + if not _hf_key_list: + continue - hf_keys_list.append(_hf_key_list) + hf_keys_list.append(_hf_key_list) - assert load_spec.dim is not None - if load_spec.group is not None: assert load_spec.dim is not None - _gathered_tensor_list = [ - torch.zeros_like(fsdp_unshared_tensor) for _ in range(load_spec.group.size()) - ] - dist.all_gather(_gathered_tensor_list, fsdp_unshared_tensor, group=load_spec.group) - _gathered_tensor = torch.cat(_gathered_tensor_list, dim=load_spec.dim) - else: - _gathered_tensor = fsdp_unshared_tensor - - hf_tensor_size = _gathered_tensor.shape[load_spec.dim] / len(all_hf_keys) - _saved_fused_tensor = torch.index_select( - _gathered_tensor, - dim=load_spec.dim, - index=torch.arange( - int(start * hf_tensor_size), - int(end * hf_tensor_size), - dtype=torch.int64, - device=_gathered_tensor.device, - ), - ) - saved_fused_tensor_list.append(_saved_fused_tensor) + if load_spec.group is not None: + assert load_spec.dim is not None + _gathered_tensor_list = [ + torch.zeros_like(fsdp_unshared_tensor) for _ in range(load_spec.group.size()) + ] + dist.all_gather(_gathered_tensor_list, fsdp_unshared_tensor, group=load_spec.group) + _gathered_tensor = torch.cat(_gathered_tensor_list, dim=load_spec.dim) + else: + _gathered_tensor = fsdp_unshared_tensor + hf_tensor_size = _gathered_tensor.shape[load_spec.dim] / len(all_hf_keys) + _saved_fused_tensor = torch.index_select( + _gathered_tensor, + dim=load_spec.dim, + index=torch.arange( + int(start * hf_tensor_size), + int(end * hf_tensor_size), + dtype=torch.int64, + device=_gathered_tensor.device, + ), + ) + saved_fused_tensor_list.append(_saved_fused_tensor) # Split the fused tensor into hf tensors hf_tensor_list: list[torch.Tensor] = [] @@ -1141,6 +1136,14 @@ def _fsdp_foreach_allgather( # Concatenate the tensors along the FSDP shard dim for tensors, size in zip(_fsdp_unsharded_tensor_list, origin_fsdp_size): + # special case for partition of tensors are contiguous + fused_tensor = self.fuse_contiguous_chunks_without_alloc(tensors) + if fused_tensor is not None and fused_tensor.shape[self.FSDP_SHARD_DIM] == size: + fsdp_unsharded_tensor_list.append(fused_tensor) + continue + elif fused_tensor is not None: + # free memory ASAP + del fused_tensor tensor = torch.cat(tensors, dim=self.FSDP_SHARD_DIM) cat_tensor = torch.index_select( tensor, @@ -1157,6 +1160,48 @@ def _fsdp_foreach_allgather( return fsdp_unsharded_tensor_list + @staticmethod + def fuse_contiguous_chunks_without_alloc(tensors: list[torch.Tensor]) -> torch.Tensor | None: + """Fuse contiguous chunks without extra memory allocation. + + Return None if not possible. + """ + if not tensors: + return None + base = tensors[0] + storage = base.untyped_storage() + dtype = base.dtype + device = base.device + stride = base.stride() + + inner_stride = stride[1:] + inner_elems = math.prod(base.shape[1:]) if base.dim() > 1 else 1 + + chunks = [] + for t in tensors: + if ( + t.untyped_storage().data_ptr() != storage.data_ptr() + or t.dtype != dtype + or t.device != device + or t.stride()[1:] != inner_stride + ): + return None + chunks.append((t.storage_offset(), t.shape[0], t)) + chunks.sort(key=lambda x: x[0]) + + expected_offset = chunks[0][0] + total_rows = 0 + for offset, rows, _ in chunks: + if offset != expected_offset: + return None + expected_offset += rows * inner_elems + total_rows += rows + + size = (total_rows, *base.shape[1:]) + flat = torch.empty(0, dtype=dtype, device=device) + flat.set_(storage, chunks[0][0], size, stride) + return flat + def _maybe_compile_layers(self): if self.fsdp_config is not None: if self.fsdp_config.torch_compile: diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 41b35bc3d..e623e6aec 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -20,7 +20,7 @@ MixedPrecisionPolicy, fully_shard, ) -from torch.distributed.tensor import DTensor, Replicate, distribute_tensor +from torch.distributed.tensor import DTensor, Replicate, distribute_tensor, Shard from tqdm import tqdm from typing_extensions import NotRequired, overload, override @@ -47,6 +47,8 @@ ) from xtuner.v1.utils.activation_offload import async_save_on_cpu from xtuner.v1.utils.compile import maybe_compile +from xtuner.v1.float8.float8_gmm_tile_wise import TileWiseFloat8GroupedLinear +from xtuner.v1.module.grouped_linear.moe_group_linear import GroupedLinear DEVICE = get_device() @@ -651,10 +653,22 @@ def fully_shard( self._init_device_mesh(fsdp_config) self._maybe_compile_layers() + def rebuild_dtensor(module): + for name, child in list(module.named_children()): + if hasattr(child, "ep_mesh"): + child.ep_mesh = self.ep_mesh + if isinstance(child, (GroupedLinear, TileWiseFloat8GroupedLinear)) and self.ep_mesh is not None and self.ep_mesh.size() > 1: + print(f"rebuild DTensor in {name}") + weight = child.weight.to_local() if isinstance(child.weight, DTensor) else child.weight + child.weight = nn.Parameter(distribute_tensor(weight, self.ep_mesh, [Shard(0)])) + else: + rebuild_dtensor(child) + # TODO: 一定不能少,因为在模型 init 时候会构建一套 ep_mesh,如果不重新构建,fsdp_mesh 和 ep_mesh 会没有任何联系 # fully_shard 时候会出现: AssertionError: FSDP requires the DP and TP mesh to have the same parent mesh with torch.device("meta"): - self.layers = self.build_layers(self.config) + # self.layers = self.build_layers(self.config) + rebuild_dtensor(self.layers) if float8_handler is not None: # As we modify the shape of the model's parameters, @@ -686,6 +700,10 @@ def fully_shard( mp_policy = MixedPrecisionPolicy( param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype ) + if self.fsdp_config.lm_head_fp32: + lm_head_mp_policy = MixedPrecisionPolicy(param_dtype=torch.float32, reduce_dtype=torch.float32) + else: + lm_head_mp_policy = mp_policy num_recompute_layers = int(self.config.num_hidden_layers * self.fsdp_config.recompute_ratio) for layer_idx, layer in tqdm(self.layers.items(), desc="[FSDP Sharding]"): @@ -731,7 +749,7 @@ def fully_shard( fully_shard( self.lm_head, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, - mp_policy=mp_policy, + mp_policy=lm_head_mp_policy, reshard_after_forward=self.fsdp_config.reshard_after_forward, offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, ) diff --git a/xtuner/v1/module/lora_linear/lora_grouped_linear.py b/xtuner/v1/module/lora_linear/lora_grouped_linear.py new file mode 100644 index 000000000..e33dca598 --- /dev/null +++ b/xtuner/v1/module/lora_linear/lora_grouped_linear.py @@ -0,0 +1,97 @@ +import math + +import torch +import torch.nn as nn +from torch.distributed.tensor import DTensor + +from ..grouped_linear.moe_group_linear import GroupedLinear, build_grouped_linear + + +class LoraGroupedLinear(nn.Module): + def __init__( + self, + base_layer: GroupedLinear, + rank: int, + alpha: int, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + ): + super().__init__() + + self.base_layer = base_layer + self.in_features = base_layer.in_features + self.out_features = base_layer.out_features + self.num_routed_experts = base_layer.num_routed_experts + self.ep_mesh = base_layer.ep_mesh + self.rank = rank + self.alpha = alpha + self.scale = alpha / rank + self.merged = False + + self.lora_A = build_grouped_linear(self.in_features, self.rank, self.num_routed_experts, ep_mesh=self.ep_mesh) + self.lora_B = build_grouped_linear(self.rank, self.out_features, self.num_routed_experts, ep_mesh=self.ep_mesh) + + self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0.0 else nn.Identity() + + for p in self.base_layer.parameters(): + p.requires_grad = False + + if init_lora_weights: + self.reset_parameters() + + def reset_parameters(self): + if isinstance(self.lora_A.weight, DTensor): + # TODO: init DTensor + raise NotImplementedError + else: + nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B.weight) + + def forward(self, x: torch.Tensor, tokens_per_expert: torch.Tensor, decoding: bool = False) -> torch.Tensor: + if self.merged: + return self.base_layer(x, tokens_per_expert, decoding) + original_out = self.base_layer(x, tokens_per_expert, decoding) + # lora_out = self.lora_a_naive_forward(x, tokens_per_expert, decoding) + lora_out = self.lora_A(x, tokens_per_expert, decoding) + # lora_out = self.lora_b_naive_forward(lora_out, tokens_per_expert, decoding) + lora_out = self.lora_B(lora_out, tokens_per_expert, decoding) + return original_out + lora_out * self.scale + + def lora_a_naive_forward( + self, x: torch.Tensor, tokens_per_expert: torch.Tensor, decoding: bool = False + ) -> torch.Tensor: + weight = self.lora_A.weight.view(-1, self.rank, self.in_features) + batch_sizes = tokens_per_expert.cpu().numpy() + + out = [] + start = 0 + for i, size in enumerate(batch_sizes): + rhs = weight[i, :, :].t() + out.append(x[start : start + size, :] @ rhs) + start += size + return torch.cat(out) + + def lora_b_naive_forward( + self, x: torch.Tensor, tokens_per_expert: torch.Tensor, decoding: bool = False + ) -> torch.Tensor: + weight = self.lora_B.weight.view(-1, self.out_features, self.rank) + batch_sizes = tokens_per_expert.cpu().numpy() + + out = [] + start = 0 + for i, size in enumerate(batch_sizes): + rhs = weight[i, :, :].t() + out.append(x[start : start + size, :] @ rhs) + start += size + return torch.cat(out) + + @torch.no_grad() + def merge_lora(self): + raise NotImplementedError + + @torch.no_grad() + def unmerge_lora(self): + raise NotImplementedError + + def __repr__(self) -> str: + return "lora." + super().__repr__() diff --git a/xtuner/v1/module/lora_linear/lora_linear.py b/xtuner/v1/module/lora_linear/lora_linear.py new file mode 100644 index 000000000..a50045471 --- /dev/null +++ b/xtuner/v1/module/lora_linear/lora_linear.py @@ -0,0 +1,100 @@ +import math + +import torch +import torch.nn as nn + +from ..linear.linear import build_linear + + +class LoraLinear(nn.Module): + def __init__( + self, + base_layer: nn.Linear, + rank: int, + alpha: int, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + ): + super().__init__() + + self.base_layer = base_layer + self.in_features = base_layer.in_features + self.out_features = base_layer.out_features + self.rank = rank + self.alpha = alpha + self.scale = alpha / rank + self.merged = False + + weight = base_layer.weight + dtype = weight.dtype + device = weight.device + + # A: (in_features -> r) + self.lora_A = build_linear(self.in_features, rank, bias=False, device=device, dtype=dtype) + # B: (r -> out_features) + self.lora_B = build_linear(rank, self.out_features, bias=False, device=device, dtype=dtype) + + self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0.0 else nn.Identity() + + for p in self.base_layer.parameters(): + p.requires_grad = False + + if init_lora_weights: + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.merged: + return self.base_layer(x) + + original_out = self.base_layer(x) + lora_out = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scale + return original_out + lora_out + + @torch.no_grad() + def merge_lora(self): + """把 LoRA 权重合并进 base_layer.weight.""" + if self.merged: + return + + # delta_W = B @ A shape: [out, in] + delta_w = torch.matmul(self.lora_B.weight, self.lora_A.weight) + self.base_layer.weight += delta_w * self.scale + + self.merged = True + # 合并后 LoRA 参数可以不再训练 + for p in self.lora_A.parameters(): + p.requires_grad = False + for p in self.lora_B.parameters(): + p.requires_grad = False + + @torch.no_grad() + def unmerge_lora(self): + """从 base_layer.weight 中还原 LoRA(如果之前 merge 过)""" + if not self.merged: + return + + delta_w = torch.matmul(self.lora_B.weight, self.lora_A.weight) + self.base_layer.weight -= delta_w * self.scale + + self.merged = False + for p in self.lora_A.parameters(): + p.requires_grad = True + for p in self.lora_B.parameters(): + p.requires_grad = True + + def extra_repr(self) -> str: + return ( + f"in_features={self.in_features}, " + f"out_features={self.out_features}, " + f"r={self.rank}, " + f"lora_alpha={self.alpha}, " + f"scale={self.scale}, " + f"merged={self.merged}" + ) + + def __repr__(self) -> str: + return "lora." + super().__repr__() diff --git a/xtuner/v1/ray/config/worker.py b/xtuner/v1/ray/config/worker.py index 7b75da30d..5dd8b55e8 100644 --- a/xtuner/v1/ray/config/worker.py +++ b/xtuner/v1/ray/config/worker.py @@ -170,6 +170,10 @@ class RolloutConfig(BaseModel): help="Whether to enable returning routed experts for the rollout worker.", ), ] = False + update_weight_bucket_size_in_gb: Annotated[ + float, + Parameter(group=infer_group, help="Bucket size in GB for updating weight."), + ] = 0.5 # 512MB launch_server_method: Annotated[ Literal["ray", "multiprocessing"], Parameter( diff --git a/xtuner/v1/rl/base/rollout_is.py b/xtuner/v1/rl/base/rollout_is.py index 6a93ecdf3..bfb3190ab 100644 --- a/xtuner/v1/rl/base/rollout_is.py +++ b/xtuner/v1/rl/base/rollout_is.py @@ -139,7 +139,7 @@ def compute_rollout_importance_weights( metrics: Dict of IS and mismatch metrics, all scalars with "mismatch/" prefix """ if rollout_is_threshold is None: - return None, response_mask, {} + return None, response_mask, compute_mismatch_metrics(old_log_prob, rollout_log_prob, response_mask) assert rollout_is_mode in ["truncate", "mask", "both"], ( f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate', 'mask', or 'both'." diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 85d1b37b5..38d7ba631 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -27,6 +27,7 @@ from xtuner.v1.model.base import BaseModel as XtunerBaseModel from xtuner.v1.model.base import ModelItem, TransformerConfig from xtuner.v1.model.compose.qwen3_vl import Qwen3VLForConditionalGeneration +from xtuner.v1.model.adapter.lora import LoraConfig, LoraModel from xtuner.v1.ray.base import SingleAcceleratorWorker from xtuner.v1.ray.config import RolloutConfig from xtuner.v1.rl.utils import gather_logprobs @@ -118,6 +119,7 @@ class WorkerConfig(BaseModel): model_config = ConfigDict(title="Worker config", extra="forbid", arbitrary_types_allowed=True) model_cfg: TransformerConfig | VisionComposeConfigProtocol + adapter_cfg: LoraConfig | None = None optim_cfg: OptimConfig loss_cfg: BaseRLLossConfig lr_cfg: LRConfig @@ -194,6 +196,7 @@ def _build_engine(self, worker_cfg: WorkerConfig) -> TrainEngine | VisionCompose optim_cfg=worker_cfg.optim_cfg, fsdp_cfg=worker_cfg.fsdp_cfg, model_cfg=worker_cfg.model_cfg, + adapter_cfg=worker_cfg.adapter_cfg, ) if worker_cfg.load_from is not None: @@ -621,6 +624,7 @@ def update_rollout_info( self.rollout_cfg_info["tp"] = tp self.rollout_cfg_info["ep"] = ep self.rollout_cfg_info["api_key"] = rollout_config.api_key + self.rollout_cfg_info["update_weight_bucket_size_in_gb"] = rollout_config.update_weight_bucket_size_in_gb if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": self.rollout_cfg_info["backend"] = "sglang" else: @@ -641,6 +645,13 @@ def _update_weights_hf_generator(self): assert self.rollout_device_mesh is not None model = self._engine.model + lora_model = None + if isinstance(model, LoraModel): + lora_model = model + lora_model.merge_lora() + model = lora_model.base_model + + DEVICE_MODULE.empty_cache() if isinstance(model.config, VisionComposeConfigProtocol): @@ -651,23 +662,63 @@ def _update_weights_hf_generator(self): else: dtype = torch.bfloat16 - same_gen = model._get_same_hf_param(model._group_param_by_load_spec(LoadEnum.SAME), dtype=dtype, device=DEVICE) + bucket_size = int(self.rollout_cfg_info["update_weight_bucket_size_in_gb"] * 1024**3) + same_gen = model._get_same_hf_param( + model._group_param_by_load_spec(LoadEnum.SAME), dtype=dtype, device=DEVICE, bucket_size=bucket_size + ) fused_gen = model._get_fused_hf_param( model._group_param_by_load_spec(LoadEnum.FUSED), dtype=dtype, device=DEVICE, - return_full_key_per_rank=True, + bucket_size=bucket_size, + update_weights_for_rl=True, ) shard_gen = model._get_shard_hf_param( - model._group_param_by_load_spec(LoadEnum.SHARD), dtype=dtype, device=DEVICE + model._group_param_by_load_spec(LoadEnum.SHARD), dtype=dtype, device=DEVICE, bucket_size=bucket_size ) - for name_list, param_list in chain(same_gen, fused_gen, shard_gen): + + for name_list, fused_param_list in fused_gen: + state_dict = {name: param.detach() for name, param in zip(name_list, fused_param_list)} + for key in list(state_dict.keys()): + if "lora_" in key: + del state_dict[key] + if model.fsdp_config.ep_size > 1: + ep_mesh: DeviceMesh = model.ep_mesh + ep_group = ep_mesh.get_group() + ep_rank = dist.get_rank(group=ep_group) + for src_global_rank in dist.get_process_group_ranks(ep_group): + broadcast_state_dict = dict() + for key, tensor in state_dict.items(): + obj_to_broadcast = [key, tensor.to("meta")] if ep_rank == src_global_rank else [None, None] + dist.broadcast_object_list(obj_to_broadcast, src=src_global_rank, group=ep_group) + real_key, meta_tensor = obj_to_broadcast + buffer = ( + state_dict[real_key] + if ep_rank == src_global_rank + else torch.empty_like(meta_tensor, device=DEVICE) + ) + dist.broadcast(buffer, src=src_global_rank, group=ep_group) + broadcast_state_dict[real_key] = buffer + self.request_update_params(broadcast_state_dict, finished=False) + del broadcast_state_dict, buffer + else: + self.request_update_params(state_dict, finished=False) + del state_dict, name_list, fused_param_list + + for name_list, param_list in chain(same_gen, shard_gen): state_dict = {name: param.detach() for name, param in zip(name_list, param_list)} + for key in list(state_dict.keys()): + if "lora_" in key: + del state_dict[key] self.request_update_params(state_dict, finished=False) + del state_dict, name_list, param_list if self.rollout_cfg_info["backend"] == "pytorch": self.request_update_params({}, finished=True) + if lora_model: + lora_model.unmerge_lora() + dist.barrier() DEVICE_MODULE.empty_cache() return @@ -678,6 +729,11 @@ def _update_weights_by_layer(self): assert self.rollout_device_mesh is not None model = self._engine.model + lora_model = None + if isinstance(model, LoraModel): + lora_model = model + lora_model.merge_lora() + model = lora_model.base_model DEVICE_MODULE.empty_cache() if isinstance(model.config, VisionComposeConfigProtocol): @@ -773,6 +829,9 @@ def get_params(tensor_list, name_list, save_dtype): if self.rollout_cfg_info["backend"] == "pytorch": self.request_update_params({}, finished=True) + if lora_model: + lora_model.unmerge_lora() + dist.barrier() DEVICE_MODULE.empty_cache() return diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index c8da31fa1..4bd76c313 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -298,7 +298,18 @@ def __init__( * total_epochs ) bind_train_rollout(train_controller=self._train_controller, env_controller=self._rollout_env_controller) - ray.get(self._train_controller.offload.remote(target="all")) + # update weights if rollout_config.skip_load_weights == True + if rollout_config.skip_load_weights: + self.logger.info("Rollout workers skip load weights, update weights from train workers.") + ray.get(self._train_controller.offload.remote(target="optimizer")) + ray.get(self._rollout_env_controller.offload.remote()) + ray.get(self._rollout_env_controller.onload_weights.remote()) + ray.get(self._train_controller.update_weights.remote()) + ray.get(self._train_controller.offload.remote(target="model")) + ray.get(self._rollout_env_controller.onload_kvcache.remote()) + self.logger.info("Rollout workers has updated weights from train workers.") + else: + ray.get(self._train_controller.offload.remote(target="all")) self._train_worker_cfg = train_worker_cfg diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 4aa124808..9419e87e9 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -34,6 +34,7 @@ from xtuner.v1.engine.vision_compose_train_engine import VisionComposeConfigProtocol, VisionComposeTrainEngine from xtuner.v1.loss import CELossConfig from xtuner.v1.loss.ce_loss import CELossContextInputItem +from xtuner.v1.model.adapter.lora import LoraConfig from xtuner.v1.model.base import ModelItem, TransformerConfig from xtuner.v1.model.utils import ModelForwardExtraLogInfo from xtuner.v1.patch import patch_default_save_plan @@ -295,6 +296,7 @@ class TrainerConfig(BaseModel): lr_cfg: LRConfig loss_cfg: CELossConfig = CELossConfig() fsdp_cfg: FSDPConfig | None = None + adapter_cfg: LoraConfig | None = None global_batch_size: int | None work_dir: Path | str | None = None log_dir: Path | str | None = None @@ -408,6 +410,7 @@ def __init__( model_cfg: TransformerConfig | VisionComposeConfigProtocol, optim_cfg: OptimConfig, fsdp_cfg: FSDPConfig | None = FSDPConfig(), + adapter_cfg: LoraConfig | None = None, dataset_cfg: DatasetConfigList | None = None, # TODO: Removed in version 1.1.0 dataloader_cfg: DataloaderConfig, loss_cfg: CELossConfig | None = CELossConfig(), @@ -557,6 +560,7 @@ def __init__( model_config=model_cfg, optim_config=optim_cfg, fsdp_config=fsdp_cfg, + adapter_config=adapter_cfg, resume_cfg=resume_cfg, strict=strict_load, intra_layer_micro_batch=intra_layer_micro_batch, @@ -603,6 +607,7 @@ def from_config(cls, config: TrainerConfig) -> Self: model_cfg=config.model_cfg, optim_cfg=config.optim_cfg, fsdp_cfg=config.fsdp_cfg, + adapter_cfg=config.adapter_cfg, dataset_cfg=config.dataset_cfg, dataloader_cfg=config.dataloader_cfg, loss_cfg=config.loss_cfg, @@ -887,6 +892,7 @@ def build_engine( model_config: TransformerConfig | VisionComposeConfigProtocol, optim_config: OptimConfig, fsdp_config: FSDPConfig, + adapter_config: LoraConfig, resume_cfg: ResumeConfig, intra_layer_micro_batch: int = 1, strict: bool = True, @@ -910,6 +916,7 @@ def build_engine( optim_cfg=optim_config, fsdp_cfg=fsdp_config, model_cfg=model_config, + adapter_cfg=adapter_config, intra_layer_micro_batch=intra_layer_micro_batch, ) else: @@ -917,6 +924,7 @@ def build_engine( optim_cfg=optim_config, fsdp_cfg=fsdp_config, model_cfg=model_config, + adapter_cfg=adapter_config, intra_layer_micro_batch=intra_layer_micro_batch, ) if model_path is not None and (model_config.dcp_ignore_frozen_params or resume_cfg.resume_from is None):