From cf53e4e30bb7f5247ba12ef2ca433d808a2ab17e Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Tue, 12 May 2026 18:54:19 +0800 Subject: [PATCH 01/11] feat(template): add agent template support with ReAct-style tool calling Introduce BaseAgentTemplate and DeepSeekV4AgentTemplate for agent-based interactions. Add ReactCompatMixin for parsing and formatting ReAct-style tool calls, including Action/Action Input/Observation keywords. Implement ToolDesc and AgentKeyword dataclasses to support structured tool descriptions and agent keywords. --- src/twinkle/template/__init__.py | 3 +- src/twinkle/template/base.py | 159 +++++++++++++++++++++++++++- src/twinkle/template/deepseek_v4.py | 138 ++++++++++++++++++++++++ src/twinkle/template/utils.py | 55 +++++++++- 4 files changed, 351 insertions(+), 4 deletions(-) create mode 100644 src/twinkle/template/deepseek_v4.py diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py index 324ce7ac..4b8c27e6 100644 --- a/src/twinkle/template/__init__.py +++ b/src/twinkle/template/__init__.py @@ -1,3 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from .base import Template +from .base import BaseAgentTemplate, Template +from .deepseek_v4 import DeepSeekV4AgentTemplate from .qwen3_5_vl import Qwen3_5Template diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 5784ddae..b3a95ed6 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -1,15 +1,19 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import ast import inspect +import json import numpy as np import os +from abc import ABC, abstractmethod from collections.abc import Mapping from copy import copy, deepcopy -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union from twinkle.data_format import InputFeature, Message, Trajectory from twinkle.hub import HubOperation from twinkle.utils import load_image, to_device -from .utils import TokenizeByRound, transfer_to_standard_message +from .utils import Function, Prompt, TokenizeByRound, split_str_parts_by, transfer_to_standard_message if TYPE_CHECKING: import torch @@ -21,6 +25,157 @@ AudioInput = Union[str, np.ndarray, 'torch.Tensor'] +@dataclass +class AgentKeyword: + action: str = 'Action:' + action_input: str = 'Action Input:' + observation: str = 'Observation:' + + +@dataclass +class ToolDesc: + name_for_model: str + name_for_human: str + description_for_model: str + parameters: str + args_format: str + + +class ReactCompatMixin: + """ReAct-style tool call parsing and formatting compatibility.""" + + keyword = AgentKeyword() + + @staticmethod + def _split_action_action_input(response: str, keyword: AgentKeyword) -> List[Function]: + agent_parts = split_str_parts_by(response, list(asdict(keyword).values())) + functions = [] + action_content = None + + for part in agent_parts: + key, content = part['key'].lower(), part['content'] + if action_content is None and key == keyword.action.lower(): + action_content = content + elif action_content is not None and key == keyword.action_input.lower(): + functions.append(Function(name=action_content, arguments=content)) + action_content = None + + return functions + + def get_toolcall(self, response: str) -> List[Function]: + functions = self._split_action_action_input(response, self.keyword) + if len(functions) == 0 and self.keyword != ReactCompatMixin.keyword: + functions = self._split_action_action_input(response, ReactCompatMixin.keyword) + return functions + + def _format_tool_responses( + self, + assistant_content: str, + tool_messages, + ) -> Tuple[str, Prompt]: + assert len(tool_messages) > 0 + with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content + if with_action: + if not assistant_content.endswith(self.keyword.observation): + if not assistant_content.endswith('\n'): + assistant_content += '\n' + assistant_content += self.keyword.observation + res = [] + for i, tool_message in enumerate(tool_messages): + if i > 0: + res.append(self.keyword.observation) + tool_content = tool_message['content'] + res.append(tool_content) + if not tool_content.endswith('\n'): + res.append('\n') + else: + res = [tool_message['content'] for tool_message in tool_messages] + return assistant_content, res + + @staticmethod + def _parse_tool_call(content) -> Dict[str, Any]: + obj = BaseAgentTemplate._parse_json(content) + name = obj['name'] + arguments = obj.get('arguments') + if arguments is None: + arguments = obj.get('parameters') + arguments = BaseAgentTemplate._parse_json(arguments) + assert arguments is not None, f'content: {content}' + return {'name': name, 'arguments': arguments} + + def _format_tool_calls(self, tool_call_messages) -> str: + tool_calls = [] + for message in tool_call_messages: + tool_call = self._parse_tool_call(message['content']) + tool_calls.append(f'{self.keyword.action} {tool_call["name"]}\n' + f'{self.keyword.action_input} {tool_call["arguments"]}\n') + tool_calls.append(self.keyword.observation) + return ''.join(tool_calls) + + +class BaseAgentTemplate(ReactCompatMixin, ABC): + """Base class for agent templates that support tool calling.""" + + @staticmethod + def _get_tool_name(tool): + return tool.get('name_for_model') or tool.get('name') + + @staticmethod + def unwrap_tool(tool): + assert isinstance(tool, dict), f'tool: {tool}' + if 'type' in tool and 'function' in tool: + tool = tool['function'] + return tool + + @staticmethod + def wrap_tool(tool): + assert isinstance(tool, dict), f'tool: {tool}' + if 'type' not in tool and 'function' not in tool: + tool = {'type': 'function', 'function': tool} + return tool + + @staticmethod + def _parse_tool(tool, lang: Literal['zh', 'en']) -> ToolDesc: + tool = BaseAgentTemplate.unwrap_tool(tool) + name_for_model = BaseAgentTemplate._get_tool_name(tool) + name_for_human = tool.get('name_for_human') or name_for_model + + description = tool.get('description') + if description is None: + description = tool.get('description_for_model') + parameters = tool.get('parameters') or {} + parameters = parameters if isinstance(parameters, str) else json.dumps(parameters, ensure_ascii=False) + args_format = '此工具的输入应为JSON对象。' if lang == 'zh' else 'Format the arguments as a JSON object.' + tool_desc = ToolDesc( + name_for_model=name_for_model, + name_for_human=name_for_human, + description_for_model=description, + parameters=parameters, + args_format=args_format) + assert name_for_model is not None and description is not None, f'tool_desc: {tool_desc}' + return tool_desc + + @staticmethod + def _parse_json(json_str: str) -> Optional[Any]: + if not isinstance(json_str, str): + return json_str + try: + res = json.loads(json_str) + except json.JSONDecodeError: + try: + res = ast.literal_eval(json_str) + except Exception: + return + return res + + @abstractmethod + def _format_tools(self, + tools: List[Union[str, dict]], + system: Optional[str] = None, + user_message: Optional[dict] = None) -> str: + pass + + class Template: # Placeholder tokens in user text diff --git a/src/twinkle/template/deepseek_v4.py b/src/twinkle/template/deepseek_v4.py new file mode 100644 index 00000000..97dcf8e4 --- /dev/null +++ b/src/twinkle/template/deepseek_v4.py @@ -0,0 +1,138 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import json +import re +from typing import Any, Dict, List, Optional, Tuple, Union + +from .utils import Function +from .utils import Prompt +from .base import BaseAgentTemplate + +DSML_TOKEN = '|DSML|' + +TOOLS_TEMPLATE = """## Tools + +You have access to a set of tools to help answer the user's question. \ +You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following: + +<{dsml_token}tool_calls> +<{dsml_token}invoke name="$TOOL_NAME"> +<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<{dsml_token}invoke name="$TOOL_NAME2"> +... + + + +String parameters should be specified as is and set `string="true"`. \ +For all other types (numbers, booleans, arrays, objects), \ +pass the value in JSON format and set `string="false"`. + +If thinking_mode is enabled (triggered by ), \ +you MUST output your complete reasoning inside ... BEFORE any tool calls or final response. + +Otherwise, output directly after with tool calls or final response. + +### Available Tool Schemas + +{tool_schemas} + +You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls. +""" + + +def _to_json(value: Any) -> str: + try: + return json.dumps(value, ensure_ascii=False) + except Exception: + return json.dumps(value, ensure_ascii=True) + + +def _encode_arguments_to_dsml(arguments: Dict[str, Any]) -> str: + """Encode tool call arguments dict into DSML parameter lines.""" + lines = [] + for k, v in arguments.items(): + is_str = 'true' if isinstance(v, str) else 'false' + val = v if isinstance(v, str) else _to_json(v) + lines.append(f'<{DSML_TOKEN}parameter name="{k}" string="{is_str}">{val}') + return '\n'.join(lines) + + +class DeepSeekV4AgentTemplate(BaseAgentTemplate): + + def get_toolcall(self, response: str) -> List[Function]: + # Parse DSML tool calls from model output + # Pattern: <|DSML|invoke name="tool_name">...params... + invoke_pattern = re.compile( + rf'<{re.escape(DSML_TOKEN)}invoke\s+name="([^"]+)">\s*(.*?)\s*', re.DOTALL) + param_pattern = re.compile( + rf'<{re.escape(DSML_TOKEN)}parameter\s+name="([^"]+)"\s+string="(true|false)">' + rf'(.*?)', re.DOTALL) + + functions = [] + for match in invoke_pattern.finditer(response): + tool_name = match.group(1) + params_block = match.group(2) + arguments = {} + for pm in param_pattern.finditer(params_block): + param_name = pm.group(1) + is_string = pm.group(2) + param_value = pm.group(3) + if is_string == 'false': + try: + param_value = json.loads(param_value) + except json.JSONDecodeError: + pass + arguments[param_name] = param_value + functions.append(Function(name=tool_name, arguments=json.dumps(arguments, ensure_ascii=False))) + + if len(functions) == 0: + # Fallback to ReAct format + return super().get_toolcall(response) + return functions + + def _get_tool_responses(self, tool_messages): + return ''.join(f'{tool_message["content"]}' for tool_message in tool_messages) + + def _format_tool_responses( + self, + assistant_content: str, + tool_messages, + ) -> Tuple[str, 'Prompt']: + with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content + if with_action: + return super()._format_tool_responses(assistant_content, tool_messages) + res = [ + '<|end▁of▁sentence|><|User|>', + self._get_tool_responses(tool_messages), + '<|Assistant|>', + ] + return assistant_content, res + + def _format_tools(self, tools: List[Union[str, dict]], system: Optional[str] = None, user_message=None) -> str: + tool_schemas = [] + for tool in tools: + tool = self.unwrap_tool(tool) + tool_schemas.append(_to_json(tool)) + + tools_section = TOOLS_TEMPLATE.format( + tool_schemas='\n'.join(tool_schemas), + dsml_token=DSML_TOKEN, + ) + + system = system or '' + return f'{system}\n\n{tools_section}' if system else tools_section + + def _format_tool_calls(self, tool_call_messages) -> str: + invocations = [] + for message in tool_call_messages: + tool_call = self._parse_tool_call(message['content']) + name = tool_call['name'] + arguments = tool_call['arguments'] + if isinstance(arguments, str): + arguments = json.loads(arguments) + dsml_args = _encode_arguments_to_dsml(arguments) + invocations.append(f'<{DSML_TOKEN}invoke name="{name}">\n{dsml_args}\n') + + tool_calls_str = '\n'.join(invocations) + return f'<{DSML_TOKEN}tool_calls>\n{tool_calls_str}\n' diff --git a/src/twinkle/template/utils.py b/src/twinkle/template/utils.py index 72975d78..0da0417d 100644 --- a/src/twinkle/template/utils.py +++ b/src/twinkle/template/utils.py @@ -1,16 +1,57 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import inspect +import json +import re from copy import copy, deepcopy -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union from twinkle.data_format import Message, Trajectory from twinkle.utils import to_device +from dataclasses import dataclass if TYPE_CHECKING: from transformers import PreTrainedTokenizer _T = TypeVar('_T') +Prompt = List[Union[str, List[int], List[str]]] + + +def _split_str_by_regex(text: str, regex_delimiters: List[str]) -> List[str]: + combined_pattern = '|'.join(f'({pattern})' for pattern in regex_delimiters) + parts = re.split(combined_pattern, text, flags=re.DOTALL) + parts = [part for part in parts if part is not None] + if parts[0] == '': + parts.pop(0) + else: + parts.insert(0, '') + assert len(parts) % 2 == 0, f'result: {parts}' + assert ''.join(parts) == text, f'split_result: {parts}, text: {text}' + return parts + + +def split_str_parts_by(text: str, delimiters: List[str], regex_mode: bool = False) -> List[Dict[str, str]]: + """Split text into keyed delimiter/content parts.""" + assert isinstance(text, str), f'text: {text}' + delimiters_origin = delimiters + if not regex_mode: + delimiters = [re.escape(delimiter) for delimiter in delimiters] + parts = _split_str_by_regex(text, delimiters) if delimiters else ['', text] + res = [] + if regex_mode: + parts = [part for part in parts if part] + for part in parts: + for delimiter, delimiter_origin in zip(delimiters, delimiters_origin): + if re.match(delimiter, part, re.DOTALL): + break + else: + delimiter_origin = '' + res.append({'key': delimiter_origin, 'content': part}) + else: + for key, content in zip(parts[::2], parts[1::2]): + res.append({'key': key, 'content': content}) + return res + def _convert_to_vlm_format(messages: List[Dict]) -> List[Dict]: converted = [] @@ -368,3 +409,15 @@ def tokenize_with_assistant_labels( labels[i] = full_ids[i] return full_ids, labels, encoded + + +@dataclass +class Function: + name: str + arguments: Optional[Union[str, Any]] + + def __post_init__(self): + if not isinstance(self.arguments, str): + self.arguments = json.dumps(self.arguments, ensure_ascii=False) + self.name = self.name.strip() + self.arguments = self.arguments.strip() From cc0a7ef439a09858ff018e4715d369bac8cdafdb Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Tue, 12 May 2026 19:50:52 +0800 Subject: [PATCH 02/11] fix: enhance FSDP decoder layer detection with no_split_modules support - Update `_get_decoder_layers` to first search for modules matching `_no_split_modules` names - Add `_get_no_split_module_names` helper to collect no-split module names from model hierarchy - Add `_normalize_no_split_modules` utility for consistent set conversion - Change return type hint from `nn.ModuleList` to `List[nn.Module]` for flexibility --- cookbook/transformers/deepseek_v4_flash.py | 139 ++++++++++++++++++ .../transformers/strategy/native_fsdp.py | 61 ++++++-- 2 files changed, 186 insertions(+), 14 deletions(-) create mode 100644 cookbook/transformers/deepseek_v4_flash.py diff --git a/cookbook/transformers/deepseek_v4_flash.py b/cookbook/transformers/deepseek_v4_flash.py new file mode 100644 index 00000000..a7a3d73a --- /dev/null +++ b/cookbook/transformers/deepseek_v4_flash.py @@ -0,0 +1,139 @@ +import os + +import twinkle +from peft import LoraConfig +from transformers import AutoConfig +from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor + +logger = get_logger() + +MODEL_ID = os.environ.get('MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-flash-bfa16') +DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') +TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Template') +OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output') + +NUM_LAYERS = int(os.environ.get('NUM_LAYERS', '4')) + +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) +GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '2')) +LR = float(os.environ.get('LR', '1e-4')) +MAX_STEPS = int(os.environ.get('MAX_STEPS', '0')) +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', '50')) +RESHARD_AFTER_FORWARD = os.environ.get('RESHARD_AFTER_FORWARD', '1') == '1' +GRADIENT_CHECKPOINTING = True +IGNORE_MISMATCHED_SIZES = False +LORA_TARGET_MODULES = 'all-linear' +ADAPTER_NAME = 'default' + +device_mesh = DeviceMesh.from_sizes( + fsdp_size=4, + dp_size=1, + device_type=Platform.get_platform().device_prefix(), +) + +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + + +def create_dataset(data_slice=None): + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=data_slice or range(1000))) + dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + dataset.encode(batched=True) + return dataset + + +def eval(model): + dataset = create_dataset(data_slice=range(100)) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh) + for _, batch in enumerate(dataloader): + if callable(batch): + batch = batch() + model.forward_only(inputs=batch, adapter_name=ADAPTER_NAME) + model.calculate_loss(adapter_name=ADAPTER_NAME) + return model.calculate_metric(is_training=False, adapter_name=ADAPTER_NAME) + + +def train(): + dataset = create_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh) + + config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) + if NUM_LAYERS is not None and hasattr(config, 'num_hidden_layers'): + config.num_hidden_layers = NUM_LAYERS + if hasattr(config, 'use_cache'): + config.use_cache = False + + model = TransformersModel( + model_id=MODEL_ID, + config=config, + device_mesh=device_mesh, + strategy='native_fsdp', + memory_efficient_init=True, + ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES, + fsdp_config={ + 'reshard_after_forward': RESHARD_AFTER_FORWARD, + }, + ) + + lora_config = LoraConfig(r=8, lora_alpha=32, target_modules=LORA_TARGET_MODULES) + model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + + if not GRADIENT_CHECKPOINTING: + model.model.gradient_checkpointing_disable() + + model.set_template(TEMPLATE_ID, model_id=MODEL_ID, adapter_name=ADAPTER_NAME) + model.set_optimizer('AdamW', lr=LR, foreach=False, adapter_name=ADAPTER_NAME) + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', + num_warmup_steps=5, + num_training_steps=len(dataloader), + adapter_name=ADAPTER_NAME, + ) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs(adapter_name=ADAPTER_NAME)) + logger.info( + f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, ' + f'grad_accum={GRAD_ACCUM_STEPS}, lr={LR:.2e}, ' + f'num_layers={NUM_LAYERS}, ignore_mismatched_sizes={IGNORE_MISMATCHED_SIZES}, ' + f'gradient_checkpointing={GRADIENT_CHECKPOINTING}, ' + f'reshard_after_forward={RESHARD_AFTER_FORWARD}, ' + f'lora_target_modules={LORA_TARGET_MODULES}') + + best_loss = float('inf') + for step, batch in enumerate(dataloader): + if MAX_STEPS and step >= MAX_STEPS: + break + if callable(batch): + batch = batch() + model.forward_backward( + inputs=batch, + adapter_name=ADAPTER_NAME, + gradient_accumulation_steps=GRAD_ACCUM_STEPS, + ) + model.clip_grad_and_step( + adapter_name=ADAPTER_NAME, + gradient_accumulation_steps=GRAD_ACCUM_STEPS, + ) + + if step % 20 == 0: + metric = model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME) + logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') + + if step > 0 and step % SAVE_STEPS == 0: + metrics = eval(model) + logger.info(f'Eval metric: {metrics}') + loss = float(metrics['loss']) + if loss < best_loss: + model.save(name=f'checkpoint-{step}', output_dir=OUTPUT_DIR, adapter_name=ADAPTER_NAME) + best_loss = loss + + model.save(name='last-checkpoint', output_dir=OUTPUT_DIR, adapter_name=ADAPTER_NAME) + + +if __name__ == '__main__': + train() diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index ad675006..ce30e043 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -4,7 +4,7 @@ from torch import nn from torch.distributed.device_mesh import DeviceMesh as TorchDeviceMesh from torch.distributed.fsdp import fully_shard -from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Set +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set from twinkle.utils import DeviceMesh, Platform, torch_util from .load_context import fsdp_pretrained_load_context @@ -30,7 +30,13 @@ def __init__(self, self.ep_fsdp_device_mesh = self._build_ep_fsdp_device_mesh(ep_size) if enable_ep else None def pretrained_load_context(self): - return fsdp_pretrained_load_context(self._memory_efficient_init and self.device_mesh is not None) + # Native FSDP loads pretrained weights via rank0 broadcast during wrap_model(). + # Avoid Transformers' FSDP loading env here; some versions can hang non-rank0 + # ranks in from_pretrained barriers. + return fsdp_pretrained_load_context(False) + + def use_rank0_pretrained_broadcast(self) -> bool: + return self._memory_efficient_init and self.device_mesh is not None def _build_ep_fsdp_device_mesh(self, ep_size: Optional[int] = None) -> Optional[TorchDeviceMesh]: if self.device_mesh is None: @@ -59,17 +65,18 @@ def wrap_model(self, model, optimizer=None): if optimizer is not None: _unbind_optimizer_params(optimizer) - # EP path requires experts on a real device, incompatible with meta-device flow. - use_meta = self._memory_efficient_init and not ep_enabled + use_meta = self.use_rank0_pretrained_broadcast() and not ep_enabled original_sd = None saved_buffers = None if use_meta: - original_sd = model.state_dict() - saved_buffers = _get_non_persistent_buffers(model) - model = model.to(torch.device('meta')) - if hasattr(model, 'tie_weights'): - model.tie_weights() + is_rank0 = (dist.get_rank() == 0) + original_sd = model.state_dict() if is_rank0 else {} + saved_buffers = _get_non_persistent_buffers(model) if is_rank0 else {} + if is_rank0: + model = model.to(torch.device('meta')) + if hasattr(model, 'tie_weights'): + model.tie_weights() if ep_enabled: _ensure_moe_patched_if_needed(model, self.ep_fsdp_device_mesh) @@ -129,14 +136,13 @@ def wrap_model(self, model, optimizer=None): if use_meta: device_type = self.device_mesh.device_type or 'cuda' - is_rank0 = (dist.get_rank() == 0) _broadcast_sharded_state_dict( model, - original_sd if is_rank0 else {}, + original_sd, device_type=device_type, ) target_device = torch.device(device_type) - _restore_non_persistent_buffers(model, saved_buffers, device=target_device) + _broadcast_non_persistent_buffers(model, saved_buffers or {}, device=target_device) if hasattr(model, 'tie_weights'): model.tie_weights() @@ -322,16 +328,43 @@ def _build_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]: return TorchDeviceMesh(device_mesh.device_type, flat_mesh, mesh_dim_names=('fsdp', )) -def _get_decoder_layers(model: nn.Module) -> Optional[nn.ModuleList]: +def _get_decoder_layers(model: nn.Module) -> Optional[List[nn.Module]]: + no_split_modules = _get_no_split_module_names(model) + if no_split_modules: + layers = [ + module for module in model.modules() + if module is not model and module.__class__.__name__ in no_split_modules + ] + if layers: + return layers + inner_model = getattr(model, 'model', None) if inner_model is not None: inner_layers = getattr(inner_model, 'layers', None) if isinstance(inner_layers, nn.ModuleList): - return inner_layers + return list(inner_layers) return None +def _get_no_split_module_names(model: nn.Module) -> Set[str]: + names = _normalize_no_split_modules(getattr(model, '_no_split_modules', None)) + if names: + return names + + for module in model.modules(): + names.update(_normalize_no_split_modules(getattr(module, '_no_split_modules', None))) + return names + + +def _normalize_no_split_modules(value) -> Set[str]: + if value is None: + return set() + if isinstance(value, str): + return {value} + return set(value) + + def _collect_expert_params(model: nn.Module) -> Optional[Set[nn.Parameter]]: ignored: Set[nn.Parameter] = set() ep_patched = False From 400acfb38c07b1b1f76f0a85f0e39a2db32fc2ff Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Tue, 12 May 2026 20:16:06 +0800 Subject: [PATCH 03/11] fix: broadcast non-persistent buffers and validate state dict shapes in FSDP strategy - Broadcast non-persistent buffers from rank 0 to all ranks instead of only restoring on rank 0 - Add source metadata validation to ensure shape/dtype consistency before distributing tensors - Fix tie_weights() call to execute on all ranks instead of only rank 0 - Improve error handling with explicit KeyError and RuntimeError for state dict mismatches --- .../transformers/strategy/native_fsdp.py | 73 ++++++++++++++----- .../model/transformers/transformers.py | 20 +++++ src/twinkle/template/deepseek_v4.py | 3 +- src/twinkle/template/utils.py | 2 +- src/twinkle/utils/transformers_utils.py | 29 +++++++- 5 files changed, 105 insertions(+), 22 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index ce30e043..ae651fc3 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -73,10 +73,9 @@ def wrap_model(self, model, optimizer=None): is_rank0 = (dist.get_rank() == 0) original_sd = model.state_dict() if is_rank0 else {} saved_buffers = _get_non_persistent_buffers(model) if is_rank0 else {} - if is_rank0: - model = model.to(torch.device('meta')) - if hasattr(model, 'tie_weights'): - model.tie_weights() + model = model.to(torch.device('meta')) + if hasattr(model, 'tie_weights'): + model.tie_weights() if ep_enabled: _ensure_moe_patched_if_needed(model, self.ep_fsdp_device_mesh) @@ -511,31 +510,57 @@ def _broadcast_sharded_state_dict( full_sd: dict, device_type: str = 'cuda', ) -> None: - """Broadcast full state dict from rank 0 and materialise local shards via distribute_tensor.""" + """Distribute rank0 full state dict into local FSDP2 shards.""" from torch.distributed.tensor import DTensor, distribute_tensor meta_sharded_sd = model.state_dict() sharded_sd = {} is_rank0 = (dist.get_rank() == 0) + source_metadata = None + if is_rank0: + source_metadata = { + name: (tuple(tensor.shape), tensor.dtype) + for name, tensor in full_sd.items() if hasattr(tensor, 'shape') and hasattr(tensor, 'dtype') + } + metadata_holder = [source_metadata] + dist.broadcast_object_list(metadata_holder, src=0) + source_metadata = metadata_holder[0] or {} for param_name, sharded_param in meta_sharded_sd.items(): shape = sharded_param.size() - dtype = sharded_param.dtype + if param_name not in source_metadata: + raise KeyError(f"Missing source metadata for parameter '{param_name}'.") + source_shape, source_dtype = source_metadata[param_name] if is_rank0: + if param_name not in full_sd: + raise KeyError( + f"Parameter '{param_name}' found in sharded model state dict but missing from full state dict.") full_param = full_sd[param_name] - full_tensor = full_param.detach().to(device_type) + full_tensor = full_param.detach() if isinstance(full_tensor, DTensor): full_tensor = full_tensor.to_local() + full_tensor = full_tensor.to(device_type) + if tuple(full_tensor.shape) != tuple(source_shape) or full_tensor.dtype != source_dtype: + raise RuntimeError(f"Source metadata mismatch for '{param_name}': " + f'actual shape={tuple(full_tensor.shape)} dtype={full_tensor.dtype}, ' + f'expected shape={source_shape} dtype={source_dtype}.') else: - full_tensor = torch.empty(shape, device=device_type, dtype=dtype) - - dist.broadcast(full_tensor, src=0) + full_tensor = torch.empty(source_shape, device=device_type, dtype=source_dtype) + + if tuple(shape) != tuple(source_shape): + raise RuntimeError(f"Parameter '{param_name}' shape mismatch before broadcast: " + f'sharded logical shape={tuple(shape)}, source shape={source_shape}.') + if isinstance(sharded_param, DTensor): + sharded_tensor = distribute_tensor( + full_tensor, + sharded_param.device_mesh, + sharded_param.placements, + ) + else: + dist.broadcast(full_tensor, src=0) + sharded_tensor = full_tensor torch_util.synchronize() - - device_mesh = sharded_param.device_mesh - placements = sharded_param.placements - sharded_tensor = distribute_tensor(full_tensor, device_mesh, placements) del full_tensor sharded_sd[param_name] = sharded_tensor @@ -562,14 +587,26 @@ def _unbind_optimizer_params(optimizer: torch.optim.Optimizer) -> None: group['params'][i] = torch.empty(1, dtype=param.dtype, device=param.device) -def _restore_non_persistent_buffers( +def _broadcast_non_persistent_buffers( model: nn.Module, saved_buffers: Dict[str, torch.Tensor], device: torch.device, ) -> None: - """Re-register non-persistent buffers saved before to('meta').""" - for fqn, buf_tensor in saved_buffers.items(): - buf_tensor = buf_tensor.to(device) + """Broadcast rank0 non-persistent buffers and re-register them on all ranks.""" + is_rank0 = (dist.get_rank() == 0) + metadata = None + if is_rank0: + metadata = [(name, tuple(tensor.shape), tensor.dtype) for name, tensor in saved_buffers.items()] + metadata_holder = [metadata] + dist.broadcast_object_list(metadata_holder, src=0) + metadata = metadata_holder[0] or [] + + for fqn, shape, dtype in metadata: + if is_rank0: + buf_tensor = saved_buffers[fqn].to(device) + else: + buf_tensor = torch.empty(shape, device=device, dtype=dtype) + dist.broadcast(buf_tensor, src=0) if '.' in fqn: parent_fqn, local_name = fqn.rsplit('.', 1) parent = model.get_submodule(parent_fqn) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 56cbec17..9f80817d 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -42,6 +42,7 @@ from twinkle.utils import construct_class, get_logger, selective_log_softmax, torch_util from twinkle.utils.framework import Torch from twinkle.utils.grad_clip import normalize_and_clip_grad_norm +from twinkle.utils.transformers_utils import filter_from_config_kwargs logger = get_logger() @@ -191,6 +192,8 @@ def __init__( model_cls = getattr(transformers, model_cls) if model_id is None: self.model = model_cls.from_config(self.hf_config, **kwargs) + elif self._should_init_empty_pretrained_model_on_this_rank(): + self.model = self._init_empty_model_from_config(model_cls, **kwargs) else: # Trigger transformers' FSDP-aware loading: meta-device init + rank-0-only weight load. with self.strategy.pretrained_load_context(): @@ -204,6 +207,23 @@ def __init__( self.optimizer_group[_default_adapter_name].adapter_name = _default_adapter_name self.active_group = _default_adapter_name + def _should_init_empty_pretrained_model_on_this_rank(self) -> bool: + use_rank0_broadcast = getattr(self.strategy, 'use_rank0_pretrained_broadcast', lambda: False) + return bool(use_rank0_broadcast() and dist.is_available() and dist.is_initialized() and dist.get_rank() != 0) + + def _init_empty_model_from_config(self, model_cls, **kwargs): + from accelerate import init_empty_weights + + config_kwargs = filter_from_config_kwargs(kwargs) + with init_empty_weights(include_buffers=False): + if hasattr(model_cls, 'from_config'): + model = model_cls.from_config(self.hf_config, **config_kwargs) + else: + model = model_cls._from_config(self.hf_config, **config_kwargs) + if hasattr(model, 'tie_weights'): + model.tie_weights() + return model + def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']): self._expert_parallel_config = self._fsdp_config.pop('expert_parallel', None) self._enable_expert_parallel = self._should_enable_expert_parallel(self._expert_parallel_config, diff --git a/src/twinkle/template/deepseek_v4.py b/src/twinkle/template/deepseek_v4.py index 97dcf8e4..e9d55fac 100644 --- a/src/twinkle/template/deepseek_v4.py +++ b/src/twinkle/template/deepseek_v4.py @@ -3,9 +3,8 @@ import re from typing import Any, Dict, List, Optional, Tuple, Union -from .utils import Function -from .utils import Prompt from .base import BaseAgentTemplate +from .utils import Function, Prompt DSML_TOKEN = '|DSML|' diff --git a/src/twinkle/template/utils.py b/src/twinkle/template/utils.py index 0da0417d..597c8517 100644 --- a/src/twinkle/template/utils.py +++ b/src/twinkle/template/utils.py @@ -3,12 +3,12 @@ import json import re from copy import copy, deepcopy +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union from twinkle.data_format import Message, Trajectory from twinkle.utils import to_device -from dataclasses import dataclass if TYPE_CHECKING: from transformers import PreTrainedTokenizer diff --git a/src/twinkle/utils/transformers_utils.py b/src/twinkle/utils/transformers_utils.py index 036f7538..3d4ba362 100644 --- a/src/twinkle/utils/transformers_utils.py +++ b/src/twinkle/utils/transformers_utils.py @@ -1,6 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import re -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from .utils import deep_getattr @@ -8,6 +8,33 @@ import torch.nn as nn +def filter_from_config_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]: + load_only_keys = { + 'cache_dir', + 'device_map', + 'force_download', + 'ignore_mismatched_sizes', + 'local_files_only', + 'low_cpu_mem_usage', + 'max_memory', + 'offload_buffers', + 'offload_folder', + 'offload_state_dict', + 'output_loading_info', + 'proxies', + 'resume_download', + 'revision', + 'state_dict', + 'subfolder', + 'token', + 'tokenizer_id', + 'trust_remote_code', + 'use_safetensors', + 'weights_only', + } + return {key: value for key, value in kwargs.items() if key not in load_only_keys} + + def find_layers( model: 'nn.Module', cond: Callable[[str, 'nn.Module'], bool], From a8dbf74eeb1b6e128d87e5406e4e742478b5efe3 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Tue, 12 May 2026 20:46:47 +0800 Subject: [PATCH 04/11] fix(template): rename DeepSeekV4AgentTemplate to DeepseekV4Template and update references - Rename class and module exports for consistency with naming conventions - Update default TEMPLATE_ID in deepseek_v4_flash.py to use new template name - Refactor encoding/decoding logic for chat message processing with improved tool call handling --- cookbook/transformers/deepseek_v4_flash.py | 2 +- src/twinkle/template/__init__.py | 2 +- src/twinkle/template/deepseek_v4.py | 263 ++++----- src/twinkle/template/deepseek_v4_encoding.py | 585 +++++++++++++++++++ 4 files changed, 716 insertions(+), 136 deletions(-) create mode 100644 src/twinkle/template/deepseek_v4_encoding.py diff --git a/cookbook/transformers/deepseek_v4_flash.py b/cookbook/transformers/deepseek_v4_flash.py index a7a3d73a..e426c2a0 100644 --- a/cookbook/transformers/deepseek_v4_flash.py +++ b/cookbook/transformers/deepseek_v4_flash.py @@ -13,7 +13,7 @@ MODEL_ID = os.environ.get('MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-flash-bfa16') DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') -TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Template') +TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template') OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output') NUM_LAYERS = int(os.environ.get('NUM_LAYERS', '4')) diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py index 4b8c27e6..4386fc1e 100644 --- a/src/twinkle/template/__init__.py +++ b/src/twinkle/template/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import BaseAgentTemplate, Template -from .deepseek_v4 import DeepSeekV4AgentTemplate +from .deepseek_v4 import DeepseekV4Template from .qwen3_5_vl import Qwen3_5Template diff --git a/src/twinkle/template/deepseek_v4.py b/src/twinkle/template/deepseek_v4.py index e9d55fac..f488d5a6 100644 --- a/src/twinkle/template/deepseek_v4.py +++ b/src/twinkle/template/deepseek_v4.py @@ -1,137 +1,132 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import json -import re -from typing import Any, Dict, List, Optional, Tuple, Union - -from .base import BaseAgentTemplate -from .utils import Function, Prompt - -DSML_TOKEN = '|DSML|' - -TOOLS_TEMPLATE = """## Tools - -You have access to a set of tools to help answer the user's question. \ -You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following: - -<{dsml_token}tool_calls> -<{dsml_token}invoke name="$TOOL_NAME"> -<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE -... - -<{dsml_token}invoke name="$TOOL_NAME2"> -... - - - -String parameters should be specified as is and set `string="true"`. \ -For all other types (numbers, booleans, arrays, objects), \ -pass the value in JSON format and set `string="false"`. - -If thinking_mode is enabled (triggered by ), \ -you MUST output your complete reasoning inside ... BEFORE any tool calls or final response. - -Otherwise, output directly after with tool calls or final response. - -### Available Tool Schemas - -{tool_schemas} - -You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls. -""" - - -def _to_json(value: Any) -> str: - try: - return json.dumps(value, ensure_ascii=False) - except Exception: - return json.dumps(value, ensure_ascii=True) - - -def _encode_arguments_to_dsml(arguments: Dict[str, Any]) -> str: - """Encode tool call arguments dict into DSML parameter lines.""" - lines = [] - for k, v in arguments.items(): - is_str = 'true' if isinstance(v, str) else 'false' - val = v if isinstance(v, str) else _to_json(v) - lines.append(f'<{DSML_TOKEN}parameter name="{k}" string="{is_str}">{val}') - return '\n'.join(lines) - - -class DeepSeekV4AgentTemplate(BaseAgentTemplate): - - def get_toolcall(self, response: str) -> List[Function]: - # Parse DSML tool calls from model output - # Pattern: <|DSML|invoke name="tool_name">...params... - invoke_pattern = re.compile( - rf'<{re.escape(DSML_TOKEN)}invoke\s+name="([^"]+)">\s*(.*?)\s*', re.DOTALL) - param_pattern = re.compile( - rf'<{re.escape(DSML_TOKEN)}parameter\s+name="([^"]+)"\s+string="(true|false)">' - rf'(.*?)', re.DOTALL) - - functions = [] - for match in invoke_pattern.finditer(response): - tool_name = match.group(1) - params_block = match.group(2) - arguments = {} - for pm in param_pattern.finditer(params_block): - param_name = pm.group(1) - is_string = pm.group(2) - param_value = pm.group(3) - if is_string == 'false': - try: - param_value = json.loads(param_value) - except json.JSONDecodeError: - pass - arguments[param_name] = param_value - functions.append(Function(name=tool_name, arguments=json.dumps(arguments, ensure_ascii=False))) - - if len(functions) == 0: - # Fallback to ReAct format - return super().get_toolcall(response) - return functions - - def _get_tool_responses(self, tool_messages): - return ''.join(f'{tool_message["content"]}' for tool_message in tool_messages) - - def _format_tool_responses( +import copy +import torch +from transformers import AutoConfig, PreTrainedTokenizerFast +from typing import Any, Dict, List, Literal, Optional + +from twinkle.hub import HubOperation +from .base import Template +from .deepseek_v4_encoding import encode_messages + + +def get_deepseek_v4_tokenizer(tokenizer): + """Wrap a HF tokenizer with DeepSeek V4's custom chat-template encoder.""" + dsv4_tokenizer = copy.copy(tokenizer) + + added_vocab = tokenizer.get_added_vocab() + added_vocab_size = len(added_vocab) + tokenizer_vocab_size = tokenizer.vocab_size + + class _DeepseekV4Tokenizer(tokenizer.__class__): # type: ignore[misc, valid-type] + + def apply_chat_template( + self, + messages, + tools: Optional[List[Dict[str, Any]]] = None, + **kwargs, + ): + thinking = kwargs.get('thinking', False) + enable_thinking = kwargs.get('enable_thinking', False) + thinking = thinking or enable_thinking + thinking_mode = 'thinking' if thinking else 'chat' + + conversation = kwargs.get('conversation', messages) + messages = conversation.copy() + if tools: + messages.insert(0, {'role': 'system'}) + messages[0]['tools'] = tools + + reasoning_effort = kwargs.get('reasoning_effort') + if reasoning_effort not in ('max', 'high'): + reasoning_effort = None + + prompt_str = encode_messages( + messages, + thinking_mode=thinking_mode, + drop_thinking=kwargs.get('drop_thinking', True), + reasoning_effort=reasoning_effort, + ) + + tokenize = kwargs.get('tokenize', True) + return_dict = kwargs.get('return_dict', False) + return_tensors = kwargs.get('return_tensors') + + if not tokenize: + return {'prompt': prompt_str} if return_dict else prompt_str + + tokenizer_kwargs = {key: kwargs[key] for key in ('truncation', 'max_length') if key in kwargs} + input_ids = self.encode( + prompt_str, + add_special_tokens=False, + **tokenizer_kwargs, + ) + + if not return_dict and return_tensors is None: + return input_ids + + attention_mask = [1] * len(input_ids) + if return_tensors == 'pt': + input_ids = torch.tensor([input_ids], dtype=torch.long) + attention_mask = torch.tensor([attention_mask], dtype=torch.long) + elif return_tensors is not None: + raise ValueError(f'Unsupported return_tensors: {return_tensors}') + + encoded = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + } + if kwargs.get('return_assistant_tokens_mask', False): + # Fall back to round-by-round labeling in Template by omitting + # assistant_masks support from this custom tokenizer wrapper. + pass + if return_dict: + return encoded + return encoded['input_ids'] + + def num_special_tokens_to_add(self) -> int: + return len(self.encode('')) + + def __len__(self) -> int: + return tokenizer_vocab_size + added_vocab_size + + def get_added_vocab(self) -> dict[str, int]: + return added_vocab.copy() + + _DeepseekV4Tokenizer.__name__ = f'DSV4{tokenizer.__class__.__name__}' + dsv4_tokenizer.__class__ = _DeepseekV4Tokenizer + return dsv4_tokenizer + + +class DeepseekV4Template(Template): + + def __init__( self, - assistant_content: str, - tool_messages, - ) -> Tuple[str, 'Prompt']: - with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content - if with_action: - return super()._format_tool_responses(assistant_content, tool_messages) - res = [ - '<|end▁of▁sentence|><|User|>', - self._get_tool_responses(tool_messages), - '<|Assistant|>', + model_id: str, + use_chat_template: bool = True, + max_length: Optional[int] = 8192, + truncation_strategy: Literal['raise', 'left', 'right', 'split'] = 'raise', + default_system: Optional[str] = None, + enable_thinking: bool = True, + **kwargs, + ): + model_id = HubOperation.download_model(model_id, ignore_model=True) + base_tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id, **kwargs) + self.processor = get_deepseek_v4_tokenizer(base_tokenizer) + self.config = AutoConfig.from_pretrained(model_id, **kwargs) + + self.use_chat_template = use_chat_template + self.max_length = max_length + self.enable_thinking = enable_thinking + self.truncation_strategy = truncation_strategy + self.default_system = default_system + self._test_support_assistant_tokens_mask() + self.pre_pipeline = [ + self._add_default_system, + self._to_standard_reasoning_content, + self._build_standard_messages, + ] + self.post_pipeline = [ + self._check_max_length, + self._add_attention_fields, + self._roll_labels, ] - return assistant_content, res - - def _format_tools(self, tools: List[Union[str, dict]], system: Optional[str] = None, user_message=None) -> str: - tool_schemas = [] - for tool in tools: - tool = self.unwrap_tool(tool) - tool_schemas.append(_to_json(tool)) - - tools_section = TOOLS_TEMPLATE.format( - tool_schemas='\n'.join(tool_schemas), - dsml_token=DSML_TOKEN, - ) - - system = system or '' - return f'{system}\n\n{tools_section}' if system else tools_section - - def _format_tool_calls(self, tool_call_messages) -> str: - invocations = [] - for message in tool_call_messages: - tool_call = self._parse_tool_call(message['content']) - name = tool_call['name'] - arguments = tool_call['arguments'] - if isinstance(arguments, str): - arguments = json.loads(arguments) - dsml_args = _encode_arguments_to_dsml(arguments) - invocations.append(f'<{DSML_TOKEN}invoke name="{name}">\n{dsml_args}\n') - - tool_calls_str = '\n'.join(invocations) - return f'<{DSML_TOKEN}tool_calls>\n{tool_calls_str}\n' diff --git a/src/twinkle/template/deepseek_v4_encoding.py b/src/twinkle/template/deepseek_v4_encoding.py new file mode 100644 index 00000000..f10088ce --- /dev/null +++ b/src/twinkle/template/deepseek_v4_encoding.py @@ -0,0 +1,585 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +# ruff: noqa +# fmt: off + +""" +DeepSeek-V4 Encoding + +A self-contained implementation for encoding/decoding DeepSeek-V4 chat messages +with tool calling, thinking mode, and quick instruction task support. +""" + +import copy +import json +import regex as re +from typing import Any, Dict, List, Optional, Tuple, Union + +bos_token: str = '<|begin▁of▁sentence|>' +eos_token: str = '<|end▁of▁sentence|>' +thinking_start_token: str = '' +thinking_end_token: str = '' +dsml_token: str = '|DSML|' + +USER_SP_TOKEN = '<|User|>' +ASSISTANT_SP_TOKEN = '<|Assistant|>' +LATEST_REMINDER_SP_TOKEN = '<|latest_reminder|>' + +DS_TASK_SP_TOKENS = { + 'action': '<|action|>', + 'query': '<|query|>', + 'authority': '<|authority|>', + 'domain': '<|domain|>', + 'title': '<|title|>', + 'read_url': '<|read_url|>', +} +VALID_TASKS = set(DS_TASK_SP_TOKENS.keys()) + +system_msg_template: str = '{content}' +user_msg_template: str = '{content}' +latest_reminder_msg_template: str = '{content}' +assistant_msg_template: str = '{reasoning}{content}{tool_calls}' + eos_token +assistant_msg_wo_eos_template: str = '{reasoning}{content}{tool_calls}' +thinking_template: str = '{reasoning}' + +response_format_template: str = ( + '## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}' +) +tool_call_template: str = ( + "<{dsml_token}invoke name=\"{name}\">\n{arguments}\n" +) +tool_calls_template = ( + '<{dsml_token}{tc_block_name}>\n{tool_calls}\n' +) +tool_calls_block_name: str = 'tool_calls' + +tool_output_template: str = ( + '{content}' +) + +REASONING_EFFORT_MAX = ( + 'Reasoning Effort: Absolute maximum with no shortcuts permitted.\n' + 'You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root ' + 'cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n' + 'Explicitly write out your entire deliberation process, documenting every intermediate step, considered ' + 'alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n' +) + +TOOLS_TEMPLATE = """## Tools + +You have access to a set of tools to help answer the user's question. You can invoke tools by writing a +"<{dsml_token}tool_calls>" block like the following: + +<{dsml_token}tool_calls> +<{dsml_token}invoke name="$TOOL_NAME"> +<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<{dsml_token}invoke name="$TOOL_NAME2"> +... + + + +String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, +objects), pass the value in JSON format and set `string="false"`. + +If thinking_mode is enabled (triggered by {thinking_start_token}), you MUST output your complete reasoning inside +{thinking_start_token}...{thinking_end_token} BEFORE any tool calls or final response. + +Otherwise, output directly after {thinking_end_token} with tool calls or final response. + +### Available Tool Schemas + +{tool_schemas} + +You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls. +""" + + +def to_json(value: Any) -> str: + try: + return json.dumps(value, ensure_ascii=False) + except Exception: + return json.dumps(value, ensure_ascii=True) + + +def tools_from_openai_format(tools): + return [tool['function'] for tool in tools] + + +def tool_calls_from_openai_format(tool_calls): + return [ + { + 'name': tool_call['function']['name'], + 'arguments': tool_call['function']['arguments'], + } + for tool_call in tool_calls + ] + + +def tool_calls_to_openai_format(tool_calls): + return [ + { + 'type': 'function', + 'function': { + 'name': tool_call['name'], + 'arguments': tool_call['arguments'], + } + } + for tool_call in tool_calls + ] + + +def encode_arguments_to_dsml(tool_call: Dict[str, Any]) -> str: + p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}' + p_dsml_strs = [] + + if isinstance(tool_call['arguments'], str): + arguments = json.loads(tool_call['arguments']) + else: + arguments = tool_call['arguments'] + + for k, v in arguments.items(): + p_dsml_str = p_dsml_template.format( + dsml_token=dsml_token, + key=k, + is_str='true' if isinstance(v, str) else 'false', + value=v if isinstance(v, str) else to_json(v), + ) + p_dsml_strs.append(p_dsml_str) + + return '\n'.join(p_dsml_strs) + + +def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]: + def _decode_value(key: str, value: str, string: str): + if string == 'true': + value = to_json(value) + return f'{to_json(key)}: {value}' + + tool_args_json = '{' + ', '.join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + '}' + return dict(name=tool_name, arguments=tool_args_json) + + +def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str: + tools_json = [to_json(t) for t in tools] + + return TOOLS_TEMPLATE.format( + tool_schemas='\n'.join(tools_json), + dsml_token=dsml_token, + thinking_start_token=thinking_start_token, + thinking_end_token=thinking_end_token, + ) + + +def find_last_user_index(messages: List[Dict[str, Any]]) -> int: + last_user_index = -1 + for idx in range(len(messages) - 1, -1, -1): + if messages[idx].get('role') in ['user', 'developer']: + last_user_index = idx + break + return last_user_index + + +def render_message( + index: int, + messages: List[Dict[str, Any]], + thinking_mode: str, + drop_thinking: bool = True, + reasoning_effort: Optional[str] = None, +) -> str: + assert 0 <= index < len(messages) + assert thinking_mode in ['chat', 'thinking'], f'Invalid thinking_mode `{thinking_mode}`' + + prompt = '' + msg = messages[index] + last_user_idx = find_last_user_index(messages) + + role = msg.get('role') + content = msg.get('content') + tools = msg.get('tools') + response_format = msg.get('response_format') + tool_calls = msg.get('tool_calls') + reasoning = msg.get('reasoning') + wo_eos = msg.get('wo_eos', False) + + if tools: + tools = tools_from_openai_format(tools) + if tool_calls: + tool_calls = tool_calls_from_openai_format(tool_calls) + + assert reasoning_effort in ['max', None, 'high'], f'Invalid reasoning effort: {reasoning_effort}' + if index == 0 and thinking_mode == 'thinking' and reasoning_effort == 'max': + prompt += REASONING_EFFORT_MAX + + if role == 'system': + prompt += system_msg_template.format(content=content or '') + if tools: + prompt += '\n\n' + render_tools(tools) + if response_format: + prompt += '\n\n' + response_format_template.format(schema=to_json(response_format)) + + elif role == 'developer': + assert content, f'Invalid message for role `{role}`: {msg}' + + content_developer = USER_SP_TOKEN + content_developer += content + + if tools: + content_developer += '\n\n' + render_tools(tools) + if response_format: + content_developer += '\n\n' + response_format_template.format(schema=to_json(response_format)) + + prompt += user_msg_template.format(content=content_developer) + + elif role == 'user': + prompt += USER_SP_TOKEN + + content_blocks = msg.get('content_blocks') + if content_blocks: + parts = [] + for block in content_blocks: + block_type = block.get('type') + if block_type == 'text': + parts.append(block.get('text', '')) + elif block_type == 'tool_result': + tool_content = block.get('content', '') + if isinstance(tool_content, list): + text_parts = [] + for b in tool_content: + if b.get('type') == 'text': + text_parts.append(b.get('text', '')) + else: + text_parts.append(f"[Unsupported {b.get('type')}]") + tool_content = '\n\n'.join(text_parts) + parts.append(tool_output_template.format(content=tool_content)) + else: + parts.append(f'[Unsupported {block_type}]') + prompt += '\n\n'.join(parts) + else: + prompt += content or '' + + elif role == 'latest_reminder': + prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format(content=content) + + elif role == 'tool': + raise NotImplementedError( + 'deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()') + + elif role == 'assistant': + thinking_part = '' + tc_content = '' + + if tool_calls: + tc_list = [ + tool_call_template.format( + dsml_token=dsml_token, + name=tc.get('name'), + arguments=encode_arguments_to_dsml(tc) + ) + for tc in tool_calls + ] + tc_content += '\n\n' + tool_calls_template.format( + dsml_token=dsml_token, + tool_calls='\n'.join(tc_list), + tc_block_name=tool_calls_block_name, + ) + + summary_content = content or '' + reasoning = reasoning or '' + + prev_has_task = index - 1 >= 0 and messages[index - 1].get('task') is not None + + if thinking_mode == 'thinking' and not prev_has_task: + if not drop_thinking or index > last_user_idx: + thinking_part = thinking_template.format(reasoning=reasoning) + thinking_end_token + else: + thinking_part = '' + + if wo_eos: + prompt += assistant_msg_wo_eos_template.format( + reasoning=thinking_part, + content=summary_content, + tool_calls=tc_content, + ) + else: + prompt += assistant_msg_template.format( + reasoning=thinking_part, + content=summary_content, + tool_calls=tc_content, + ) + else: + raise NotImplementedError(f'Unknown role: {role}') + + if index + 1 < len(messages) and messages[index + 1].get('role') not in ['assistant', 'latest_reminder']: + return prompt + + task = messages[index].get('task') + if task is not None: + assert task in VALID_TASKS, f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}" + task_sp_token = DS_TASK_SP_TOKENS[task] + + if task != 'action': + prompt += task_sp_token + else: + prompt += ASSISTANT_SP_TOKEN + prompt += thinking_end_token if thinking_mode != 'thinking' else thinking_start_token + prompt += task_sp_token + + elif messages[index].get('role') in ['user', 'developer']: + prompt += ASSISTANT_SP_TOKEN + if not drop_thinking and thinking_mode == 'thinking': + prompt += thinking_start_token + elif drop_thinking and thinking_mode == 'thinking' and index >= last_user_idx: + prompt += thinking_start_token + else: + prompt += thinking_end_token + + return prompt + + +def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + merged: List[Dict[str, Any]] = [] + + for msg in messages: + msg = copy.deepcopy(msg) + role = msg.get('role') + + if role == 'tool': + tool_block = { + 'type': 'tool_result', + 'tool_use_id': msg.get('tool_call_id', ''), + 'content': msg.get('content', ''), + } + if merged and merged[-1].get('role') == 'user' and 'content_blocks' in merged[-1]: + merged[-1]['content_blocks'].append(tool_block) + else: + merged.append({ + 'role': 'user', + 'content_blocks': [tool_block], + }) + elif role == 'user': + text_block = {'type': 'text', 'text': msg.get('content', '')} + if (merged and merged[-1].get('role') == 'user' and 'content_blocks' in merged[-1] + and merged[-1].get('task') is None): + merged[-1]['content_blocks'].append(text_block) + else: + new_msg = { + 'role': 'user', + 'content': msg.get('content', ''), + 'content_blocks': [text_block], + } + for key in ('task', 'wo_eos', 'mask'): + if key in msg: + new_msg[key] = msg[key] + merged.append(new_msg) + else: + merged.append(msg) + + return merged + + +def sort_tool_results_by_call_order(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + last_tool_call_order: Dict[str, int] = {} + + for msg in messages: + role = msg.get('role') + if role == 'assistant' and msg.get('tool_calls'): + last_tool_call_order = {} + for idx, tc in enumerate(msg['tool_calls']): + tc_id = tc.get('id') or tc.get('function', {}).get('id', '') + if tc_id: + last_tool_call_order[tc_id] = idx + + elif role == 'user' and msg.get('content_blocks'): + tool_blocks = [b for b in msg['content_blocks'] if b.get('type') == 'tool_result'] + if len(tool_blocks) > 1 and last_tool_call_order: + sorted_blocks = sorted( + tool_blocks, + key=lambda b: last_tool_call_order.get(b.get('tool_use_id', ''), 0) + ) + sorted_idx = 0 + new_blocks = [] + for block in msg['content_blocks']: + if block.get('type') == 'tool_result': + new_blocks.append(sorted_blocks[sorted_idx]) + sorted_idx += 1 + else: + new_blocks.append(block) + msg['content_blocks'] = new_blocks + + return messages + + +def encode_messages( + messages: List[Dict[str, Any]], + thinking_mode: str, + context: Optional[List[Dict[str, Any]]] = None, + drop_thinking: bool = True, + add_default_bos_token: bool = True, + reasoning_effort: Optional[str] = None, +) -> str: + context = context if context else [] + + messages = merge_tool_messages(messages) + messages = sort_tool_results_by_call_order(context + messages)[len(context):] + if context: + context = merge_tool_messages(context) + context = sort_tool_results_by_call_order(context) + + full_messages = context + messages + + prompt = bos_token if add_default_bos_token and len(context) == 0 else '' + + effective_drop_thinking = drop_thinking + if any(m.get('tools') for m in full_messages): + effective_drop_thinking = False + + if thinking_mode == 'thinking' and effective_drop_thinking: + full_messages = _drop_thinking_messages(full_messages) + num_to_render = len(full_messages) - len(_drop_thinking_messages(context)) + context_len = len(full_messages) - num_to_render + else: + num_to_render = len(messages) + context_len = len(context) + + for idx in range(num_to_render): + prompt += render_message( + idx + context_len, + full_messages, + thinking_mode=thinking_mode, + drop_thinking=effective_drop_thinking, + reasoning_effort=reasoning_effort, + ) + + return prompt + + +def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + last_user_idx = find_last_user_index(messages) + result = [] + keep_roles = {'user', 'system', 'tool', 'latest_reminder', 'direct_search_results'} + + for idx, msg in enumerate(messages): + role = msg.get('role') + if role in keep_roles or idx >= last_user_idx: + result.append(msg) + elif role == 'assistant': + msg = copy.copy(msg) + msg.pop('reasoning', None) + result.append(msg) + + return result + + +def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]: + min_pos = len(text) + matched_stop = None + + for s in stop: + pos = text.find(s, index) + if pos != -1 and pos < min_pos: + min_pos = pos + matched_stop = s + + if matched_stop: + content = text[index:min_pos] + return min_pos + len(matched_stop), content, matched_stop + else: + content = text[index:] + return len(text), content, None + + +def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Dict[str, str]]]: + tool_calls: List[Dict[str, Any]] = [] + stop_token = None + tool_calls_end_token = f'' + + while index < len(text): + index, content_before, stop_token = _read_until_stop( + index, text, [f'<{dsml_token}invoke', tool_calls_end_token]) + if content_before != '>\n': + raise ValueError(f"Tool call format error: expected '>\\n' but got '{content_before}'") + + if stop_token == tool_calls_end_token: + break + + if stop_token is None: + raise ValueError('Missing special token in tool calls') + + index, tool_name_content, stop_token = _read_until_stop( + index, text, [f'<{dsml_token}parameter', f'\n$', tool_name_content, flags=re.DOTALL) + if len(p_tool_name) != 1: + raise ValueError(f"Tool name format error: '{tool_name_content}'") + tool_name = p_tool_name[0] + + tool_args: Dict[str, Tuple[str, str]] = {} + while stop_token == f'<{dsml_token}parameter': + index, param_content, stop_token = _read_until_stop(index, text, [f'/{dsml_token}parameter']) + + param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL) + if len(param_kv) != 1: + raise ValueError(f"Parameter format error: '{param_content}'") + param_name, string, param_value = param_kv[0] + + if param_name in tool_args: + raise ValueError(f"Duplicate parameter name: '{param_name}'") + tool_args[param_name] = (param_value, string) + + index, content, stop_token = _read_until_stop( + index, text, [f'<{dsml_token}parameter', f'\n': + raise ValueError(f"Parameter format error: expected '>\\n' but got '{content}'") + + tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args) + tool_calls.append(tool_call) + + return index, stop_token, tool_calls + + +def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]: + summary_content, reasoning = '', '' + tool_calls: List[Dict[str, str]] = [] + index, stop_token = 0, None + tool_calls_start_token = f'\n\n<{dsml_token}{tool_calls_block_name}' + + is_thinking = thinking_mode == 'thinking' + is_tool_calling = False + + if is_thinking: + index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token]) + reasoning = content_delta + if stop_token != thinking_end_token: + raise ValueError('Invalid thinking format: missing ') + + index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token]) + summary_content = content_delta + if stop_token == tool_calls_start_token: + is_tool_calling = True + else: + if stop_token != eos_token: + raise ValueError('Invalid format: missing EOS token') + + if is_tool_calling: + index, stop_token, tool_calls = parse_tool_calls(index, text) + + index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token]) + if tool_ends_text: + raise ValueError('Unexpected content after tool calls') + + if len(text) != index or stop_token not in [eos_token, None]: + raise ValueError('Unexpected content at end') + + for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]: + if sp_token in summary_content or sp_token in reasoning: + raise ValueError(f"Unexpected special token '{sp_token}' in content") + + return { + 'role': 'assistant', + 'content': summary_content, + 'reasoning': reasoning, + 'tool_calls': tool_calls_to_openai_format(tool_calls) + } + +# fmt: on From 09879aec12ad786faf66db7c8284a181d709a697 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Tue, 12 May 2026 21:56:58 +0800 Subject: [PATCH 05/11] fix: update LoRA target modules and NPU dtype alignment in deepseek_v4 example - Change LORA_TARGET_MODULES from 'all-linear' to specific module list for DeepSeek V4 - Remove gradient_accumulation_steps parameter from forward_backward call - Fix NPU device dtype alignment to convert all parameters to base dtype when on NPU --- cookbook/transformers/deepseek_v4_flash.py | 11 +++++++++-- src/twinkle/model/transformers/transformers.py | 10 +++++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/cookbook/transformers/deepseek_v4_flash.py b/cookbook/transformers/deepseek_v4_flash.py index e426c2a0..69ee095a 100644 --- a/cookbook/transformers/deepseek_v4_flash.py +++ b/cookbook/transformers/deepseek_v4_flash.py @@ -26,7 +26,15 @@ RESHARD_AFTER_FORWARD = os.environ.get('RESHARD_AFTER_FORWARD', '1') == '1' GRADIENT_CHECKPOINTING = True IGNORE_MISMATCHED_SIZES = False -LORA_TARGET_MODULES = 'all-linear' +LORA_TARGET_MODULES = [ + 'wq_a', + 'wq_b', + 'wkv', + 'wgate', + 'gate_proj', + 'up_proj', + 'down_proj', +] ADAPTER_NAME = 'default' device_mesh = DeviceMesh.from_sizes( @@ -113,7 +121,6 @@ def train(): model.forward_backward( inputs=batch, adapter_name=ADAPTER_NAME, - gradient_accumulation_steps=GRAD_ACCUM_STEPS, ) model.clip_grad_and_step( adapter_name=ADAPTER_NAME, diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 9f80817d..d6d1939b 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -1054,17 +1054,21 @@ def _load_optimizer(self, checkpoint_dir, **kwargs): def _ensure_lora_dtype(self, model): """Force LoRA parameters to use the same dtype as base model for FSDP2 compatibility.""" base_dtype = None + is_npu_device = False for param in model.parameters(): - if param.dtype in (torch.float16, torch.bfloat16, torch.float32): + if param.device.type == 'npu': + is_npu_device = True + if base_dtype is None and param.dtype in (torch.float16, torch.bfloat16, torch.float32): base_dtype = param.dtype + if base_dtype is not None and is_npu_device: break if base_dtype is None: return - # Convert all LoRA parameters to the base model dtype + # Temporary workaround: NPU requires all parameters to align with the base dtype. with torch.no_grad(): for name, param in model.named_parameters(): - if 'lora_' in name.lower() and param.dtype != base_dtype: + if (is_npu_device or 'lora_' in name.lower()) and param.dtype != base_dtype: param.data = param.data.to(base_dtype) def _load_scaler_state(self, scaler_path, **kwargs): From 00c41ddbbf88941281fe68c21c805745c1f1b3fa Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Tue, 12 May 2026 22:56:21 +0800 Subject: [PATCH 06/11] fix: update DeepSeek-V4-Flash model ID and LoRA target modules - Update default model ID from DeepSeek-V4-flash-bfa16 to DeepSeek-V4-Flash - Add comment explaining FP4/FP8 weight conversion requirement - Fix LoRA target module names to match actual model architecture --- cookbook/transformers/deepseek_v4_flash.py | 15 +++++++++------ cookbook/transformers/deepseek_v4_flash.sh | 6 ++++++ 2 files changed, 15 insertions(+), 6 deletions(-) create mode 100644 cookbook/transformers/deepseek_v4_flash.sh diff --git a/cookbook/transformers/deepseek_v4_flash.py b/cookbook/transformers/deepseek_v4_flash.py index 69ee095a..3e8e6376 100644 --- a/cookbook/transformers/deepseek_v4_flash.py +++ b/cookbook/transformers/deepseek_v4_flash.py @@ -10,8 +10,11 @@ from twinkle.preprocessor import SelfCognitionProcessor logger = get_logger() - -MODEL_ID = os.environ.get('MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-flash-bfa16') +# `deepseek-ai/DeepSeek-V4-Flash` uses mixed FP4/FP8 weights. +# Convert the checkpoint before training by following: +# https://gitcode.com/cann/cann-recipes-train/blob/master/llm_pretrain/deepseekv4/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87 +# Install `transformers==5.8.0` before running this cookbook. +MODEL_ID = os.environ.get('MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-Flash') DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template') OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output') @@ -27,10 +30,10 @@ GRADIENT_CHECKPOINTING = True IGNORE_MISMATCHED_SIZES = False LORA_TARGET_MODULES = [ - 'wq_a', - 'wq_b', - 'wkv', - 'wgate', + 'q_a_proj', + 'q_b_proj', + 'kv_proj', + 'o_b_proj', 'gate_proj', 'up_proj', 'down_proj', diff --git a/cookbook/transformers/deepseek_v4_flash.sh b/cookbook/transformers/deepseek_v4_flash.sh new file mode 100644 index 00000000..991e60eb --- /dev/null +++ b/cookbook/transformers/deepseek_v4_flash.sh @@ -0,0 +1,6 @@ +# `deepseek-ai/DeepSeek-V4-Flash` uses mixed FP4/FP8 weights. +# Convert the checkpoint before training by following: +# https://gitcode.com/cann/cann-recipes-train/blob/master/llm_pretrain/deepseekv4/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87 +# Install `transformers==5.8.0` before running this cookbook. + +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 cookbook/transformers/deepseek_v4_flash.py From 6cecbba4c1bb5b92784a5bd2acf7d78f90dfc917 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Tue, 12 May 2026 23:38:10 +0800 Subject: [PATCH 07/11] refactor(native_fsdp): replace custom sharded state dict broadcast with DCP Replace the custom `_broadcast_sharded_state_dict` function with `_load_rank0_full_state_dict` that leverages `torch.distributed.checkpoint`'s `set_model_state_dict` with `broadcast_from_rank0=True`. This simplifies the code by using the official DCP API for distributing rank0 full state dict to FSDP2 shards, removing manual tensor distribution and metadata broadcasting logic. --- .../transformers/strategy/native_fsdp.py | 77 +++---------------- 1 file changed, 12 insertions(+), 65 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index ae651fc3..3cbaf141 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -135,11 +135,7 @@ def wrap_model(self, model, optimizer=None): if use_meta: device_type = self.device_mesh.device_type or 'cuda' - _broadcast_sharded_state_dict( - model, - original_sd, - device_type=device_type, - ) + _load_rank0_full_state_dict(model, original_sd or {}) target_device = torch.device(device_type) _broadcast_non_persistent_buffers(model, saved_buffers or {}, device=target_device) if hasattr(model, 'tie_weights'): @@ -505,67 +501,18 @@ def _rebind_optimizer(optimizer: torch.optim.Optimizer, model: nn.Module) -> tor return optimizer -def _broadcast_sharded_state_dict( - model: nn.Module, - full_sd: dict, - device_type: str = 'cuda', -) -> None: - """Distribute rank0 full state dict into local FSDP2 shards.""" - from torch.distributed.tensor import DTensor, distribute_tensor - - meta_sharded_sd = model.state_dict() - sharded_sd = {} - is_rank0 = (dist.get_rank() == 0) - source_metadata = None - if is_rank0: - source_metadata = { - name: (tuple(tensor.shape), tensor.dtype) - for name, tensor in full_sd.items() if hasattr(tensor, 'shape') and hasattr(tensor, 'dtype') - } - metadata_holder = [source_metadata] - dist.broadcast_object_list(metadata_holder, src=0) - source_metadata = metadata_holder[0] or {} - - for param_name, sharded_param in meta_sharded_sd.items(): - shape = sharded_param.size() - if param_name not in source_metadata: - raise KeyError(f"Missing source metadata for parameter '{param_name}'.") - source_shape, source_dtype = source_metadata[param_name] - - if is_rank0: - if param_name not in full_sd: - raise KeyError( - f"Parameter '{param_name}' found in sharded model state dict but missing from full state dict.") - full_param = full_sd[param_name] - full_tensor = full_param.detach() - if isinstance(full_tensor, DTensor): - full_tensor = full_tensor.to_local() - full_tensor = full_tensor.to(device_type) - if tuple(full_tensor.shape) != tuple(source_shape) or full_tensor.dtype != source_dtype: - raise RuntimeError(f"Source metadata mismatch for '{param_name}': " - f'actual shape={tuple(full_tensor.shape)} dtype={full_tensor.dtype}, ' - f'expected shape={source_shape} dtype={source_dtype}.') - else: - full_tensor = torch.empty(source_shape, device=device_type, dtype=source_dtype) - - if tuple(shape) != tuple(source_shape): - raise RuntimeError(f"Parameter '{param_name}' shape mismatch before broadcast: " - f'sharded logical shape={tuple(shape)}, source shape={source_shape}.') - if isinstance(sharded_param, DTensor): - sharded_tensor = distribute_tensor( - full_tensor, - sharded_param.device_mesh, - sharded_param.placements, - ) - else: - dist.broadcast(full_tensor, src=0) - sharded_tensor = full_tensor - torch_util.synchronize() - del full_tensor - - sharded_sd[param_name] = sharded_tensor +def _load_rank0_full_state_dict(model: nn.Module, full_sd: dict) -> None: + """Load rank0 full weights into a sharded FSDP2 model via DCP broadcast.""" + from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict - model.load_state_dict(sharded_sd, assign=True) + set_model_state_dict( + model=model, + model_state_dict=full_sd, + options=StateDictOptions( + full_state_dict=True, + broadcast_from_rank0=True, + ), + ) def _get_non_persistent_buffers(model: nn.Module) -> Dict[str, torch.Tensor]: From d292147ff4248e66fdecf842949dd8c5653234c3 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Wed, 13 May 2026 09:30:35 +0800 Subject: [PATCH 08/11] fix: correct is_npu_device detection in _ensure_lora_dtype Use Platform.device_prefix() instead of iterating over model parameters to check for NPU device, improving efficiency and correctness. --- src/twinkle/model/transformers/transformers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index d6d1939b..46484792 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -1054,10 +1054,8 @@ def _load_optimizer(self, checkpoint_dir, **kwargs): def _ensure_lora_dtype(self, model): """Force LoRA parameters to use the same dtype as base model for FSDP2 compatibility.""" base_dtype = None - is_npu_device = False + is_npu_device = Platform.device_prefix() == 'npu' for param in model.parameters(): - if param.device.type == 'npu': - is_npu_device = True if base_dtype is None and param.dtype in (torch.float16, torch.bfloat16, torch.float32): base_dtype = param.dtype if base_dtype is not None and is_npu_device: From a4cd83e96e35021989ba3b95dd0d814c53ad8d24 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Mon, 18 May 2026 17:05:04 +0800 Subject: [PATCH 09/11] feat(moe): support hash routing in ep --- .../model/transformers/moe/ep_utils.py | 49 ++++++++--------- .../model/transformers/moe/expert_parallel.py | 54 ++++++++++++++----- 2 files changed, 64 insertions(+), 39 deletions(-) diff --git a/src/twinkle/model/transformers/moe/ep_utils.py b/src/twinkle/model/transformers/moe/ep_utils.py index 94118448..f64c1796 100644 --- a/src/twinkle/model/transformers/moe/ep_utils.py +++ b/src/twinkle/model/transformers/moe/ep_utils.py @@ -96,24 +96,21 @@ def all_to_all_async(group, input, output_split_size, input_split_size): # ========================== moe_utils ========================== -def permute(tokens: torch.Tensor, routing_map: torch.Tensor): +def permute(tokens: torch.Tensor, expert_mask: torch.Tensor): """ Permutes the tokens according to the routing map. Args: tokens (torch.Tensor): The input token tensor, [num_tokens, hidden_dim]. - routing_map (torch.Tensor): The sparse token to expert mapping, [num_experts, tokens]. + expert_mask (torch.Tensor): The sparse token to expert mapping, [num_experts, top_k, num_tokens]. """ num_tokens, _ = tokens.shape - num_experts = routing_map.shape[0] - - # mask [num_tokens, num_experts] -> [num_experts, num_tokens] - routing_map = routing_map.bool() + expert_mask = expert_mask.bool() # Create a dense expert-to-token mapping from the sparse token-to-expert mapping - token_indices = torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1) - sorted_indices = token_indices.masked_select(routing_map) + token_indices = torch.arange(num_tokens, device=expert_mask.device).view(1, 1, num_tokens).expand_as(expert_mask) + sorted_indices = token_indices.masked_select(expert_mask) # use the mapping to permute the tokens permuted_input = tokens.index_select(0, sorted_indices) @@ -226,6 +223,7 @@ def preprocess( def token_pre_all2all( hidden_states: torch.Tensor, expert_mask: torch.Tensor, + routing_weights: torch.Tensor, num_experts: int, input_splits: torch.Tensor, output_splits: torch.Tensor, @@ -235,9 +233,9 @@ def token_pre_all2all( hidden_dim = hidden_states.size(-1) hidden_states = hidden_states.reshape(-1, hidden_dim) org_hidden_states_shape = hidden_states.shape - routing_map = expert_mask.sum(dim=1) - local_permuted_hidden_states, local_input_permutation_mapping = permute(hidden_states, routing_map) + local_permuted_hidden_states, local_input_permutation_mapping = permute(hidden_states, expert_mask) + local_assignment_weights = routing_weights.T.contiguous().masked_select(expert_mask.bool()) global_permuted_hidden_states = all_to_all(ep_group, local_permuted_hidden_states, output_splits, input_splits) @@ -250,18 +248,21 @@ def token_pre_all2all( permute_order, ) - return global_permuted_hidden_states, routing_map, local_input_permutation_mapping, org_hidden_states_shape + return ( + global_permuted_hidden_states, + local_input_permutation_mapping, + local_assignment_weights, + org_hidden_states_shape, + ) def tokens_post_all2all( expert_outputs: torch.Tensor, - routing_weights: torch.Tensor, - selected_experts: torch.Tensor, + local_assignment_weights: torch.Tensor, num_experts: int, input_splits: torch.Tensor, output_splits: torch.Tensor, num_global_tokens_per_local_expert: torch.Tensor, - routing_map: torch.Tensor, local_input_permutation_mapping: torch.Tensor, org_hidden_states_shape: torch.Size, ep_group: Optional[dist.ProcessGroup] = None, @@ -276,16 +277,12 @@ def tokens_post_all2all( ) unpermute_outputs = all_to_all(ep_group, expert_outputs, input_splits, output_splits) - - # [tokens, experts] - weights_idx = generate_weights_idx(routing_weights, selected_experts, num_experts) - - unpermute_outputs = unpermute( - unpermute_outputs, - weights_idx, - org_hidden_states_shape, - local_input_permutation_mapping, - routing_map, + weighted_outputs = unpermute_outputs * local_assignment_weights.unsqueeze(-1) + hidden_dim = org_hidden_states_shape[-1] + final_outputs = torch.zeros(org_hidden_states_shape, device=weighted_outputs.device, dtype=weighted_outputs.dtype) + final_outputs.scatter_add_( + 0, + local_input_permutation_mapping.unsqueeze(1).expand(-1, hidden_dim), + weighted_outputs, ) - - return unpermute_outputs + return final_outputs diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py index c9ee0fa0..282f49c0 100644 --- a/src/twinkle/model/transformers/moe/expert_parallel.py +++ b/src/twinkle/model/transformers/moe/expert_parallel.py @@ -1,6 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from __future__ import annotations +import inspect import torch import torch.distributed as dist import torch.nn.functional as F @@ -187,6 +188,11 @@ def patch_forward( raise ValueError('MoE block must define top_k/num_experts_per_tok.') orig_forward = block.forward + return_annotation = inspect.signature(orig_forward).return_annotation + returns_router_logits = return_annotation in ( + tuple, + Tuple[torch.Tensor, torch.Tensor | None], + ) num_experts = block._ep_num_experts experts_per_rank = block._ep_experts_per_rank is_tensor_experts = block._ep_tensor_experts @@ -198,8 +204,8 @@ def patch_forward( _install_ep_forward(block.experts, experts_per_rank) def forward(hidden_states: torch.Tensor, *args, **kwargs): - if args or kwargs: - raise RuntimeError('Expert parallel patch only supports forward(hidden_states).') + if args: + raise RuntimeError('Expert parallel patch only supports keyword-only extra args for MoE blocks.') orig_shape = hidden_states.shape if hidden_states.ndim == 3: @@ -218,11 +224,11 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs): top_k=top_k, router_dtype=_get_router_dtype(cfg.router_dtype, hidden_states_2d.dtype), norm_topk_prob=getattr(block, 'norm_topk_prob', False), + **kwargs, ) # Keep routing weights in activation dtype before unpermute weighting. if routing_weights.dtype != hidden_states_2d.dtype: routing_weights = routing_weights.to(hidden_states_2d.dtype) - # Build expert_mask: [num_experts, top_k, num_tokens] expert_mask = torch.nn.functional.one_hot( selected_experts, num_classes=num_experts).permute(2, 1, 0) # [num_experts, top_k, num_tokens] @@ -238,12 +244,13 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs): # 2. token_pre_all2all: permute → all_to_all → sort_chunks ( global_permuted_hidden_states, - routing_map, local_input_permutation_mapping, + local_assignment_weights, org_hidden_states_shape, ) = token_pre_all2all( hidden_states_2d, expert_mask, + routing_weights, num_experts, input_splits, output_splits, @@ -272,13 +279,11 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs): # 4. tokens_post_all2all: sort_chunks → all_to_all → unpermute (with routing weight) final_hidden = tokens_post_all2all( expert_outputs, - routing_weights, - selected_experts, + local_assignment_weights, num_experts, input_splits, output_splits, num_global_tokens_per_local_expert, - routing_map, local_input_permutation_mapping, org_hidden_states_shape, ep_group, @@ -291,7 +296,7 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs): if len(orig_shape) == 3: final_hidden = final_hidden.view(batch_size, seq_len, hidden_dim) - if cfg.keep_router_logits: + if cfg.keep_router_logits and returns_router_logits: return final_hidden, router_logits return final_hidden @@ -311,7 +316,10 @@ def ep_forward( experts_per_rank: int, ) -> torch.Tensor: if permuted_tokens.numel() == 0: - return torch.empty_like(permuted_tokens) + # Preserve the autograd edge to token_pre_all2all. Returning a new + # empty tensor can make this rank skip the matching backward + # all-to-all, causing EP collective order divergence. + return permuted_tokens input_dtype = permuted_tokens.dtype @@ -333,8 +341,12 @@ def ep_forward( compute_dtype = gate_up.dtype if expert_in.dtype != compute_dtype: expert_in = expert_in.to(compute_dtype) - gate, up = F.linear(expert_in, gate_up).chunk(2, dim=-1) - out = self.act_fn(gate) * up + gate_up_out = F.linear(expert_in, gate_up) + if hasattr(self, '_apply_gate'): + out = self._apply_gate(gate_up_out) + else: + gate, up = gate_up_out.chunk(2, dim=-1) + out = self.act_fn(gate) * up out = F.linear(out, down) if out.dtype != input_dtype: @@ -399,6 +411,8 @@ def _maybe_run_shared_expert(block: nn.Module, hidden_states_2d: torch.Tensor, c if cfg.ignore_shared_experts: return None shared = getattr(block, 'shared_expert', None) + if shared is None: + shared = getattr(block, 'shared_experts', None) if shared is None: return None return _run_module_with_casting(shared, hidden_states_2d) @@ -431,7 +445,9 @@ def _run_local_experts( that happens in unpermute. """ if permuted_tokens.numel() == 0: - return torch.empty_like(permuted_tokens) + # Keep the backward path through token_pre_all2all even when this EP + # rank owns no routed tokens for the current block. + return permuted_tokens input_dtype = permuted_tokens.dtype experts = block.experts @@ -487,8 +503,12 @@ def _run_router( top_k: int, router_dtype: torch.dtype, norm_topk_prob: bool, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - gate_out = gate(hidden_states) + gate_kwargs = {} + if 'input_ids' in kwargs and _module_forward_accepts_kwarg(gate, 'input_ids'): + gate_kwargs['input_ids'] = kwargs['input_ids'] + gate_out = gate(hidden_states, **gate_kwargs) if isinstance(gate_out, tuple) and len(gate_out) >= 3: router_logits, routing_weights, selected_experts = gate_out[:3] return router_logits, routing_weights, selected_experts @@ -499,3 +519,11 @@ def _run_router( if norm_topk_prob: routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) return router_logits, routing_weights, selected_experts + + +def _module_forward_accepts_kwarg(module: nn.Module, kwarg: str) -> bool: + signature = inspect.signature(module.forward) + for param in signature.parameters.values(): + if param.kind == inspect.Parameter.VAR_KEYWORD: + return True + return kwarg in signature.parameters From 092e55f8154b2cf2a2bb8c512fd3ba94368a2928 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Mon, 18 May 2026 17:56:57 +0800 Subject: [PATCH 10/11] feat(transformers): add expert parallelism support to NativeFSDP strategy - Enable expert parallelism (EP) with configurable router dtype and logits - Implement rank0 pre-EP full state dict for efficient weight loading - Add EP-aware state dict broadcasting and expert shard specifications - Update DeepSeek V4 flash cookbook with EP configuration example --- cookbook/transformers/deepseek_v4_flash.py | 6 + .../transformers/strategy/native_fsdp.py | 216 +++++++++++++++++- .../model/transformers/transformers.py | 15 ++ 3 files changed, 227 insertions(+), 10 deletions(-) diff --git a/cookbook/transformers/deepseek_v4_flash.py b/cookbook/transformers/deepseek_v4_flash.py index 3e8e6376..eacb8c84 100644 --- a/cookbook/transformers/deepseek_v4_flash.py +++ b/cookbook/transformers/deepseek_v4_flash.py @@ -43,6 +43,7 @@ device_mesh = DeviceMesh.from_sizes( fsdp_size=4, dp_size=1, + ep_size=4, device_type=Platform.get_platform().device_prefix(), ) @@ -87,6 +88,11 @@ def train(): ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES, fsdp_config={ 'reshard_after_forward': RESHARD_AFTER_FORWARD, + 'expert_parallel': { + 'enabled': True, + 'router_dtype': 'fp32', + 'keep_router_logits': False, + }, }, ) diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index 3cbaf141..9e3bbad7 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -28,6 +28,7 @@ def __init__(self, self._memory_efficient_init = memory_efficient_init self.enable_ep = enable_ep self.ep_fsdp_device_mesh = self._build_ep_fsdp_device_mesh(ep_size) if enable_ep else None + self._rank0_pre_ep_full_state_dict = None def pretrained_load_context(self): # Native FSDP loads pretrained weights via rank0 broadcast during wrap_model(). @@ -38,6 +39,9 @@ def pretrained_load_context(self): def use_rank0_pretrained_broadcast(self) -> bool: return self._memory_efficient_init and self.device_mesh is not None + def set_rank0_pre_ep_full_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: + self._rank0_pre_ep_full_state_dict = state_dict + def _build_ep_fsdp_device_mesh(self, ep_size: Optional[int] = None) -> Optional[TorchDeviceMesh]: if self.device_mesh is None: return None @@ -65,17 +69,21 @@ def wrap_model(self, model, optimizer=None): if optimizer is not None: _unbind_optimizer_params(optimizer) - use_meta = self.use_rank0_pretrained_broadcast() and not ep_enabled + use_meta = self.use_rank0_pretrained_broadcast() original_sd = None saved_buffers = None if use_meta: is_rank0 = (dist.get_rank() == 0) - original_sd = model.state_dict() if is_rank0 else {} + if ep_enabled and self._rank0_pre_ep_full_state_dict is not None: + original_sd = self._rank0_pre_ep_full_state_dict if is_rank0 else {} + else: + original_sd = model.state_dict() if is_rank0 else {} saved_buffers = _get_non_persistent_buffers(model) if is_rank0 else {} - model = model.to(torch.device('meta')) - if hasattr(model, 'tie_weights'): - model.tie_weights() + if is_rank0: + model = model.to(torch.device('meta')) + if hasattr(model, 'tie_weights'): + model.tie_weights() if ep_enabled: _ensure_moe_patched_if_needed(model, self.ep_fsdp_device_mesh) @@ -102,7 +110,7 @@ def wrap_model(self, model, optimizer=None): for layer_mod, experts_mod in layer_pairs: layer_mod._fsdp_modules = [] - if experts_mod is not None and ep_fsdp_mesh_1d is not None: + if experts_mod is not None and ep_fsdp_mesh_1d is not None and ep_fsdp_mesh_1d.size() > 1: from torch.distributed.tensor import Shard ep_mp_policy = _build_ep_mp_policy(mp_policy) @@ -113,7 +121,8 @@ def wrap_model(self, model, optimizer=None): mp_policy=ep_mp_policy, shard_placement_fn=lambda param: Shard(1), ) - experts_mod.set_gradient_divide_factor(world_size) + if hasattr(experts_mod, 'set_gradient_divide_factor'): + experts_mod.set_gradient_divide_factor(world_size) layer_mod._fsdp_modules.append(experts_mod) fully_shard( @@ -135,7 +144,16 @@ def wrap_model(self, model, optimizer=None): if use_meta: device_type = self.device_mesh.device_type or 'cuda' - _load_rank0_full_state_dict(model, original_sd or {}) + if ep_enabled: + _broadcast_sharded_state_dict( + model, + original_sd or {}, + device_type=device_type, + expert_shard_specs=_collect_ep_expert_shard_specs(model), + rank_to_ep_rank=_build_rank_to_ep_rank(self.ep_fsdp_device_mesh), + ) + else: + _load_rank0_full_state_dict(model, original_sd or {}) target_device = torch.device(device_type) _broadcast_non_persistent_buffers(model, saved_buffers or {}, device=target_device) if hasattr(model, 'tie_weights'): @@ -376,6 +394,8 @@ def _collect_expert_params(model: nn.Module) -> Optional[Set[nn.Parameter]]: if getattr(module, '_ep_ignore_shared_experts', False) and getattr(module, '_ep_patched', False): ep_patched = True shared = getattr(module, 'shared_expert', None) + if shared is None: + shared = getattr(module, 'shared_experts', None) if shared is not None: ignored.update(shared.parameters()) @@ -411,6 +431,37 @@ def _collect_ep_experts_map(model: nn.Module) -> Dict[str, nn.Module]: return experts_map +def _collect_ep_expert_shard_specs(model: nn.Module) -> Dict[str, Dict[str, int]]: + """Collect state-dict names for expert tensors sharded by EP.""" + specs = {} + for fqn, module in model.named_modules(): + if not getattr(module, '_ep_patched', False): + continue + experts = getattr(module, 'experts', None) + if experts is None: + continue + experts_prefix = f'{fqn}.experts.' if fqn else 'experts.' + for pname, _ in experts.named_parameters(): + specs[experts_prefix + pname] = { + 'num_experts': int(module._ep_num_experts), + 'experts_per_rank': int(module._ep_experts_per_rank), + } + return specs + + +def _build_rank_to_ep_rank(ep_fsdp_device_mesh: Optional[TorchDeviceMesh]) -> Dict[int, int]: + if ep_fsdp_device_mesh is None: + return {} + mesh = ep_fsdp_device_mesh.mesh + if hasattr(mesh, 'detach'): + mesh = mesh.detach().cpu().numpy() + rank_to_ep_rank = {} + for ep_rank in range(mesh.shape[0]): + for rank in mesh[ep_rank].flatten().tolist(): + rank_to_ep_rank[int(rank)] = int(ep_rank) + return rank_to_ep_rank + + def _find_experts_in_layer(layer_mod: nn.Module, experts_map: Dict[str, nn.Module]) -> Optional[nn.Module]: """Find the experts module inside a decoder layer, if any.""" for module in layer_mod.modules(): @@ -444,11 +495,22 @@ def _place_ep_experts_on_local_device(model: nn.Module, ep_fsdp_device_mesh: Opt continue experts = getattr(module, 'experts', None) if experts is not None: - experts.to(local_device) + _move_ep_module_to_device(experts, local_device) if getattr(module, '_ep_ignore_shared_experts', False): shared = getattr(module, 'shared_expert', None) + if shared is None: + shared = getattr(module, 'shared_experts', None) if shared is not None: - shared.to(local_device) + _move_ep_module_to_device(shared, local_device) + + +def _move_ep_module_to_device(module: nn.Module, device: torch.device) -> None: + has_meta_tensor = any(param.is_meta for param in module.parameters(recurse=True)) + has_meta_tensor = has_meta_tensor or any(buffer.is_meta for buffer in module.buffers(recurse=True)) + if has_meta_tensor: + module.to_empty(device=device) + else: + module.to(device) def _ensure_moe_patched_if_needed(model: nn.Module, ep_fsdp_device_mesh: Optional[TorchDeviceMesh]) -> None: @@ -515,6 +577,140 @@ def _load_rank0_full_state_dict(model: nn.Module, full_sd: dict) -> None: ) +def _broadcast_sharded_state_dict( + model: nn.Module, + full_sd: dict, + device_type: str = 'cuda', + expert_shard_specs: Optional[Dict[str, Dict[str, int]]] = None, + rank_to_ep_rank: Optional[Dict[int, int]] = None, +) -> None: + """Broadcast rank0 full state dict and materialize local FSDP2/EP shards.""" + from torch.distributed.tensor import DTensor, Partial, Replicate, Shard + + meta_sharded_sd = model.state_dict() + sharded_sd = {} + is_rank0 = (dist.get_rank() == 0) + expert_shard_specs = expert_shard_specs or {} + rank_to_ep_rank = rank_to_ep_rank or {} + + source_metadata = None + if is_rank0: + source_metadata = { + name: (tuple(tensor.shape), tensor.dtype) + for name, tensor in full_sd.items() if hasattr(tensor, 'shape') and hasattr(tensor, 'dtype') + } + metadata_holder = [source_metadata] + dist.broadcast_object_list(metadata_holder, src=0) + source_metadata = metadata_holder[0] or {} + + def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements): + local_tensor = full_tensor + for mesh_dim, placement in enumerate(placements): + if isinstance(placement, Shard): + local_tensor = placement._shard_tensor( + local_tensor, + device_mesh, + mesh_dim, + src_data_rank=None, + ) + local_tensor = local_tensor.contiguous().clone() + elif isinstance(placement, Replicate): + continue + elif isinstance(placement, Partial): + raise NotImplementedError('Native FSDP2 full-state loading does not support Partial placements.') + else: + raise NotImplementedError(f'Unsupported DTensor placement: {placement}') + return DTensor.from_local( + local_tensor, + device_mesh=device_mesh, + placements=placements, + run_check=False, + shape=full_tensor.shape, + stride=full_tensor.stride(), + ) + + def _scatter_ep_expert_tensor(param_name: str, full_tensor, sharded_param): + spec = expert_shard_specs[param_name] + experts_per_rank = spec['experts_per_rank'] + num_experts = spec['num_experts'] + local_shape = tuple(sharded_param.size()) + if param_name not in source_metadata: + raise KeyError(f"Missing source metadata for EP expert parameter '{param_name}'.") + _, source_dtype = source_metadata[param_name] + local_tensor = torch.empty(local_shape, device=device_type, dtype=source_dtype) + + if is_rank0: + if full_tensor.size(0) != num_experts: + raise RuntimeError(f"EP expert parameter '{param_name}' expects {num_experts} experts, " + f'but source state has shape {tuple(full_tensor.shape)}. ' + 'Rank0 must capture the full pre-EP state_dict before apply_expert_parallel().') + world_size = dist.get_world_size() + for rank in range(world_size): + if rank not in rank_to_ep_rank: + raise RuntimeError(f'Missing EP rank mapping for global rank {rank}.') + ep_rank = rank_to_ep_rank[rank] + start = ep_rank * experts_per_rank + end = start + experts_per_rank + chunk = full_tensor[start:end].contiguous() + chunk_gpu = chunk.to(device_type) + if rank == 0: + local_tensor.copy_(chunk_gpu) + else: + dist.send(chunk_gpu, dst=rank) + else: + dist.recv(local_tensor, src=0) + + return local_tensor + + for param_name, sharded_param in meta_sharded_sd.items(): + is_ep_expert_param = param_name in expert_shard_specs + if param_name not in source_metadata: + raise KeyError(f"Missing source metadata for parameter '{param_name}'.") + source_shape, source_dtype = source_metadata[param_name] + + if is_rank0: + if param_name not in full_sd: + raise KeyError( + f"Parameter '{param_name}' found in sharded model state dict but missing from full state dict.") + full_param = full_sd[param_name] + full_tensor = full_param.detach() + if isinstance(full_tensor, DTensor): + full_tensor = full_tensor.to_local() + if not is_ep_expert_param: + full_tensor = full_tensor.to(device_type) + if tuple(full_tensor.shape) != tuple(source_shape) or full_tensor.dtype != source_dtype: + raise RuntimeError(f"Source metadata mismatch for '{param_name}': " + f'actual shape={tuple(full_tensor.shape)} dtype={full_tensor.dtype}, ' + f'expected shape={source_shape} dtype={source_dtype}.') + else: + full_tensor = None if is_ep_expert_param else torch.empty( + source_shape, device=device_type, dtype=source_dtype) + + if is_ep_expert_param: + full_tensor = _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param) + else: + if tuple(sharded_param.size()) != tuple(source_shape): + raise RuntimeError(f"Parameter '{param_name}' shape mismatch before broadcast: " + f'sharded logical shape={tuple(sharded_param.size())}, ' + f'source shape={source_shape}.') + dist.broadcast(full_tensor, src=0) + torch_util.synchronize() + + if isinstance(sharded_param, DTensor): + sharded_tensor = _dtensor_from_replicated_full_tensor( + full_tensor, + sharded_param.device_mesh, + sharded_param.placements, + ) + else: + sharded_tensor = full_tensor + del full_tensor + + sharded_sd[param_name] = sharded_tensor + + model.load_state_dict(sharded_sd, assign=True) + + def _get_non_persistent_buffers(model: nn.Module) -> Dict[str, torch.Tensor]: """Return {fqn: tensor} for non-persistent buffers (lost on to('meta')).""" non_persistent_fqns: Set[str] = set() diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 46484792..ebb35cab 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -47,6 +47,16 @@ logger = get_logger() +def _clone_state_dict_to_cpu(state_dict: Dict[str, Any]) -> Dict[str, Any]: + cloned = {} + for key, value in state_dict.items(): + if hasattr(value, 'detach'): + cloned[key] = value.detach().cpu().clone() + else: + cloned[key] = value + return cloned + + @dataclass class OptimizerGroup(BaseOptimizerGroup): """Optimizer group for Transformers training.""" @@ -286,6 +296,11 @@ def _not_encoded(inputs): def _lazy_wrap_model(self): if not self._model_wrapped: optimizer_groups = [og for og in self.optimizer_group.values() if og.optimizer is not None] + use_rank0_broadcast = getattr(self.strategy, 'use_rank0_pretrained_broadcast', lambda: False) + set_pre_ep_state = getattr(self.strategy, 'set_rank0_pre_ep_full_state_dict', None) + if self._enable_expert_parallel and use_rank0_broadcast() and set_pre_ep_state is not None: + is_rank0 = dist.is_available() and dist.is_initialized() and dist.get_rank() == 0 + set_pre_ep_state(_clone_state_dict_to_cpu(self.model.state_dict()) if is_rank0 else {}) self._maybe_apply_expert_parallel() self._ensure_sp_strategy() if self.sp_strategy is not None: From a87dbf9b752466a34e8fc3c88c7a0914d487db54 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Mon, 18 May 2026 18:46:18 +0800 Subject: [PATCH 11/11] update cookbook --- cookbook/transformers/deepseek_v4_flash.py | 10 +++++----- cookbook/transformers/deepseek_v4_flash.sh | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cookbook/transformers/deepseek_v4_flash.py b/cookbook/transformers/deepseek_v4_flash.py index eacb8c84..869f4cc8 100644 --- a/cookbook/transformers/deepseek_v4_flash.py +++ b/cookbook/transformers/deepseek_v4_flash.py @@ -41,9 +41,9 @@ ADAPTER_NAME = 'default' device_mesh = DeviceMesh.from_sizes( - fsdp_size=4, + fsdp_size=8, dp_size=1, - ep_size=4, + ep_size=8, device_type=Platform.get_platform().device_prefix(), ) @@ -89,9 +89,9 @@ def train(): fsdp_config={ 'reshard_after_forward': RESHARD_AFTER_FORWARD, 'expert_parallel': { - 'enabled': True, - 'router_dtype': 'fp32', - 'keep_router_logits': False, + 'enabled': True, + 'router_dtype': 'fp32', + 'keep_router_logits': False, }, }, ) diff --git a/cookbook/transformers/deepseek_v4_flash.sh b/cookbook/transformers/deepseek_v4_flash.sh index 991e60eb..bbdb58ff 100644 --- a/cookbook/transformers/deepseek_v4_flash.sh +++ b/cookbook/transformers/deepseek_v4_flash.sh @@ -3,4 +3,4 @@ # https://gitcode.com/cann/cann-recipes-train/blob/master/llm_pretrain/deepseekv4/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87 # Install `transformers==5.8.0` before running this cookbook. -CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 cookbook/transformers/deepseek_v4_flash.py +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 cookbook/transformers/deepseek_v4_flash.py