From cf53e4e30bb7f5247ba12ef2ca433d808a2ab17e Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Tue, 12 May 2026 18:54:19 +0800
Subject: [PATCH 01/11] feat(template): add agent template support with
ReAct-style tool calling
Introduce BaseAgentTemplate and DeepSeekV4AgentTemplate for agent-based interactions. Add ReactCompatMixin for parsing and formatting ReAct-style tool calls, including Action/Action Input/Observation keywords. Implement ToolDesc and AgentKeyword dataclasses to support structured tool descriptions and agent keywords.
---
src/twinkle/template/__init__.py | 3 +-
src/twinkle/template/base.py | 159 +++++++++++++++++++++++++++-
src/twinkle/template/deepseek_v4.py | 138 ++++++++++++++++++++++++
src/twinkle/template/utils.py | 55 +++++++++-
4 files changed, 351 insertions(+), 4 deletions(-)
create mode 100644 src/twinkle/template/deepseek_v4.py
diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py
index 324ce7ac..4b8c27e6 100644
--- a/src/twinkle/template/__init__.py
+++ b/src/twinkle/template/__init__.py
@@ -1,3 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
-from .base import Template
+from .base import BaseAgentTemplate, Template
+from .deepseek_v4 import DeepSeekV4AgentTemplate
from .qwen3_5_vl import Qwen3_5Template
diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py
index 5784ddae..b3a95ed6 100644
--- a/src/twinkle/template/base.py
+++ b/src/twinkle/template/base.py
@@ -1,15 +1,19 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
+import ast
import inspect
+import json
import numpy as np
import os
+from abc import ABC, abstractmethod
from collections.abc import Mapping
from copy import copy, deepcopy
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union
+from dataclasses import asdict, dataclass
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from twinkle.data_format import InputFeature, Message, Trajectory
from twinkle.hub import HubOperation
from twinkle.utils import load_image, to_device
-from .utils import TokenizeByRound, transfer_to_standard_message
+from .utils import Function, Prompt, TokenizeByRound, split_str_parts_by, transfer_to_standard_message
if TYPE_CHECKING:
import torch
@@ -21,6 +25,157 @@
AudioInput = Union[str, np.ndarray, 'torch.Tensor']
+@dataclass
+class AgentKeyword:
+ action: str = 'Action:'
+ action_input: str = 'Action Input:'
+ observation: str = 'Observation:'
+
+
+@dataclass
+class ToolDesc:
+ name_for_model: str
+ name_for_human: str
+ description_for_model: str
+ parameters: str
+ args_format: str
+
+
+class ReactCompatMixin:
+ """ReAct-style tool call parsing and formatting compatibility."""
+
+ keyword = AgentKeyword()
+
+ @staticmethod
+ def _split_action_action_input(response: str, keyword: AgentKeyword) -> List[Function]:
+ agent_parts = split_str_parts_by(response, list(asdict(keyword).values()))
+ functions = []
+ action_content = None
+
+ for part in agent_parts:
+ key, content = part['key'].lower(), part['content']
+ if action_content is None and key == keyword.action.lower():
+ action_content = content
+ elif action_content is not None and key == keyword.action_input.lower():
+ functions.append(Function(name=action_content, arguments=content))
+ action_content = None
+
+ return functions
+
+ def get_toolcall(self, response: str) -> List[Function]:
+ functions = self._split_action_action_input(response, self.keyword)
+ if len(functions) == 0 and self.keyword != ReactCompatMixin.keyword:
+ functions = self._split_action_action_input(response, ReactCompatMixin.keyword)
+ return functions
+
+ def _format_tool_responses(
+ self,
+ assistant_content: str,
+ tool_messages,
+ ) -> Tuple[str, Prompt]:
+ assert len(tool_messages) > 0
+ with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
+ if with_action:
+ if not assistant_content.endswith(self.keyword.observation):
+ if not assistant_content.endswith('\n'):
+ assistant_content += '\n'
+ assistant_content += self.keyword.observation
+ res = []
+ for i, tool_message in enumerate(tool_messages):
+ if i > 0:
+ res.append(self.keyword.observation)
+ tool_content = tool_message['content']
+ res.append(tool_content)
+ if not tool_content.endswith('\n'):
+ res.append('\n')
+ else:
+ res = [tool_message['content'] for tool_message in tool_messages]
+ return assistant_content, res
+
+ @staticmethod
+ def _parse_tool_call(content) -> Dict[str, Any]:
+ obj = BaseAgentTemplate._parse_json(content)
+ name = obj['name']
+ arguments = obj.get('arguments')
+ if arguments is None:
+ arguments = obj.get('parameters')
+ arguments = BaseAgentTemplate._parse_json(arguments)
+ assert arguments is not None, f'content: {content}'
+ return {'name': name, 'arguments': arguments}
+
+ def _format_tool_calls(self, tool_call_messages) -> str:
+ tool_calls = []
+ for message in tool_call_messages:
+ tool_call = self._parse_tool_call(message['content'])
+ tool_calls.append(f'{self.keyword.action} {tool_call["name"]}\n'
+ f'{self.keyword.action_input} {tool_call["arguments"]}\n')
+ tool_calls.append(self.keyword.observation)
+ return ''.join(tool_calls)
+
+
+class BaseAgentTemplate(ReactCompatMixin, ABC):
+ """Base class for agent templates that support tool calling."""
+
+ @staticmethod
+ def _get_tool_name(tool):
+ return tool.get('name_for_model') or tool.get('name')
+
+ @staticmethod
+ def unwrap_tool(tool):
+ assert isinstance(tool, dict), f'tool: {tool}'
+ if 'type' in tool and 'function' in tool:
+ tool = tool['function']
+ return tool
+
+ @staticmethod
+ def wrap_tool(tool):
+ assert isinstance(tool, dict), f'tool: {tool}'
+ if 'type' not in tool and 'function' not in tool:
+ tool = {'type': 'function', 'function': tool}
+ return tool
+
+ @staticmethod
+ def _parse_tool(tool, lang: Literal['zh', 'en']) -> ToolDesc:
+ tool = BaseAgentTemplate.unwrap_tool(tool)
+ name_for_model = BaseAgentTemplate._get_tool_name(tool)
+ name_for_human = tool.get('name_for_human') or name_for_model
+
+ description = tool.get('description')
+ if description is None:
+ description = tool.get('description_for_model')
+ parameters = tool.get('parameters') or {}
+ parameters = parameters if isinstance(parameters, str) else json.dumps(parameters, ensure_ascii=False)
+ args_format = '此工具的输入应为JSON对象。' if lang == 'zh' else 'Format the arguments as a JSON object.'
+ tool_desc = ToolDesc(
+ name_for_model=name_for_model,
+ name_for_human=name_for_human,
+ description_for_model=description,
+ parameters=parameters,
+ args_format=args_format)
+ assert name_for_model is not None and description is not None, f'tool_desc: {tool_desc}'
+ return tool_desc
+
+ @staticmethod
+ def _parse_json(json_str: str) -> Optional[Any]:
+ if not isinstance(json_str, str):
+ return json_str
+ try:
+ res = json.loads(json_str)
+ except json.JSONDecodeError:
+ try:
+ res = ast.literal_eval(json_str)
+ except Exception:
+ return
+ return res
+
+ @abstractmethod
+ def _format_tools(self,
+ tools: List[Union[str, dict]],
+ system: Optional[str] = None,
+ user_message: Optional[dict] = None) -> str:
+ pass
+
+
class Template:
# Placeholder tokens in user text
diff --git a/src/twinkle/template/deepseek_v4.py b/src/twinkle/template/deepseek_v4.py
new file mode 100644
index 00000000..97dcf8e4
--- /dev/null
+++ b/src/twinkle/template/deepseek_v4.py
@@ -0,0 +1,138 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import json
+import re
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from .utils import Function
+from .utils import Prompt
+from .base import BaseAgentTemplate
+
+DSML_TOKEN = '|DSML|'
+
+TOOLS_TEMPLATE = """## Tools
+
+You have access to a set of tools to help answer the user's question. \
+You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following:
+
+<{dsml_token}tool_calls>
+<{dsml_token}invoke name="$TOOL_NAME">
+<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE{dsml_token}parameter>
+...
+{dsml_token}invoke>
+<{dsml_token}invoke name="$TOOL_NAME2">
+...
+{dsml_token}invoke>
+{dsml_token}tool_calls>
+
+String parameters should be specified as is and set `string="true"`. \
+For all other types (numbers, booleans, arrays, objects), \
+pass the value in JSON format and set `string="false"`.
+
+If thinking_mode is enabled (triggered by ), \
+you MUST output your complete reasoning inside ... BEFORE any tool calls or final response.
+
+Otherwise, output directly after with tool calls or final response.
+
+### Available Tool Schemas
+
+{tool_schemas}
+
+You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
+"""
+
+
+def _to_json(value: Any) -> str:
+ try:
+ return json.dumps(value, ensure_ascii=False)
+ except Exception:
+ return json.dumps(value, ensure_ascii=True)
+
+
+def _encode_arguments_to_dsml(arguments: Dict[str, Any]) -> str:
+ """Encode tool call arguments dict into DSML parameter lines."""
+ lines = []
+ for k, v in arguments.items():
+ is_str = 'true' if isinstance(v, str) else 'false'
+ val = v if isinstance(v, str) else _to_json(v)
+ lines.append(f'<{DSML_TOKEN}parameter name="{k}" string="{is_str}">{val}{DSML_TOKEN}parameter>')
+ return '\n'.join(lines)
+
+
+class DeepSeekV4AgentTemplate(BaseAgentTemplate):
+
+ def get_toolcall(self, response: str) -> List[Function]:
+ # Parse DSML tool calls from model output
+ # Pattern: <|DSML|invoke name="tool_name">...params...|DSML|invoke>
+ invoke_pattern = re.compile(
+ rf'<{re.escape(DSML_TOKEN)}invoke\s+name="([^"]+)">\s*(.*?)\s*{re.escape(DSML_TOKEN)}invoke>', re.DOTALL)
+ param_pattern = re.compile(
+ rf'<{re.escape(DSML_TOKEN)}parameter\s+name="([^"]+)"\s+string="(true|false)">'
+ rf'(.*?){re.escape(DSML_TOKEN)}parameter>', re.DOTALL)
+
+ functions = []
+ for match in invoke_pattern.finditer(response):
+ tool_name = match.group(1)
+ params_block = match.group(2)
+ arguments = {}
+ for pm in param_pattern.finditer(params_block):
+ param_name = pm.group(1)
+ is_string = pm.group(2)
+ param_value = pm.group(3)
+ if is_string == 'false':
+ try:
+ param_value = json.loads(param_value)
+ except json.JSONDecodeError:
+ pass
+ arguments[param_name] = param_value
+ functions.append(Function(name=tool_name, arguments=json.dumps(arguments, ensure_ascii=False)))
+
+ if len(functions) == 0:
+ # Fallback to ReAct format
+ return super().get_toolcall(response)
+ return functions
+
+ def _get_tool_responses(self, tool_messages):
+ return ''.join(f'{tool_message["content"]}' for tool_message in tool_messages)
+
+ def _format_tool_responses(
+ self,
+ assistant_content: str,
+ tool_messages,
+ ) -> Tuple[str, 'Prompt']:
+ with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
+ if with_action:
+ return super()._format_tool_responses(assistant_content, tool_messages)
+ res = [
+ '<|end▁of▁sentence|><|User|>',
+ self._get_tool_responses(tool_messages),
+ '<|Assistant|>',
+ ]
+ return assistant_content, res
+
+ def _format_tools(self, tools: List[Union[str, dict]], system: Optional[str] = None, user_message=None) -> str:
+ tool_schemas = []
+ for tool in tools:
+ tool = self.unwrap_tool(tool)
+ tool_schemas.append(_to_json(tool))
+
+ tools_section = TOOLS_TEMPLATE.format(
+ tool_schemas='\n'.join(tool_schemas),
+ dsml_token=DSML_TOKEN,
+ )
+
+ system = system or ''
+ return f'{system}\n\n{tools_section}' if system else tools_section
+
+ def _format_tool_calls(self, tool_call_messages) -> str:
+ invocations = []
+ for message in tool_call_messages:
+ tool_call = self._parse_tool_call(message['content'])
+ name = tool_call['name']
+ arguments = tool_call['arguments']
+ if isinstance(arguments, str):
+ arguments = json.loads(arguments)
+ dsml_args = _encode_arguments_to_dsml(arguments)
+ invocations.append(f'<{DSML_TOKEN}invoke name="{name}">\n{dsml_args}\n{DSML_TOKEN}invoke>')
+
+ tool_calls_str = '\n'.join(invocations)
+ return f'<{DSML_TOKEN}tool_calls>\n{tool_calls_str}\n{DSML_TOKEN}tool_calls>'
diff --git a/src/twinkle/template/utils.py b/src/twinkle/template/utils.py
index 72975d78..0da0417d 100644
--- a/src/twinkle/template/utils.py
+++ b/src/twinkle/template/utils.py
@@ -1,16 +1,57 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import inspect
+import json
+import re
from copy import copy, deepcopy
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, TypeVar
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
from twinkle.data_format import Message, Trajectory
from twinkle.utils import to_device
+from dataclasses import dataclass
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
_T = TypeVar('_T')
+Prompt = List[Union[str, List[int], List[str]]]
+
+
+def _split_str_by_regex(text: str, regex_delimiters: List[str]) -> List[str]:
+ combined_pattern = '|'.join(f'({pattern})' for pattern in regex_delimiters)
+ parts = re.split(combined_pattern, text, flags=re.DOTALL)
+ parts = [part for part in parts if part is not None]
+ if parts[0] == '':
+ parts.pop(0)
+ else:
+ parts.insert(0, '')
+ assert len(parts) % 2 == 0, f'result: {parts}'
+ assert ''.join(parts) == text, f'split_result: {parts}, text: {text}'
+ return parts
+
+
+def split_str_parts_by(text: str, delimiters: List[str], regex_mode: bool = False) -> List[Dict[str, str]]:
+ """Split text into keyed delimiter/content parts."""
+ assert isinstance(text, str), f'text: {text}'
+ delimiters_origin = delimiters
+ if not regex_mode:
+ delimiters = [re.escape(delimiter) for delimiter in delimiters]
+ parts = _split_str_by_regex(text, delimiters) if delimiters else ['', text]
+ res = []
+ if regex_mode:
+ parts = [part for part in parts if part]
+ for part in parts:
+ for delimiter, delimiter_origin in zip(delimiters, delimiters_origin):
+ if re.match(delimiter, part, re.DOTALL):
+ break
+ else:
+ delimiter_origin = ''
+ res.append({'key': delimiter_origin, 'content': part})
+ else:
+ for key, content in zip(parts[::2], parts[1::2]):
+ res.append({'key': key, 'content': content})
+ return res
+
def _convert_to_vlm_format(messages: List[Dict]) -> List[Dict]:
converted = []
@@ -368,3 +409,15 @@ def tokenize_with_assistant_labels(
labels[i] = full_ids[i]
return full_ids, labels, encoded
+
+
+@dataclass
+class Function:
+ name: str
+ arguments: Optional[Union[str, Any]]
+
+ def __post_init__(self):
+ if not isinstance(self.arguments, str):
+ self.arguments = json.dumps(self.arguments, ensure_ascii=False)
+ self.name = self.name.strip()
+ self.arguments = self.arguments.strip()
From cc0a7ef439a09858ff018e4715d369bac8cdafdb Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Tue, 12 May 2026 19:50:52 +0800
Subject: [PATCH 02/11] fix: enhance FSDP decoder layer detection with
no_split_modules support
- Update `_get_decoder_layers` to first search for modules matching `_no_split_modules` names
- Add `_get_no_split_module_names` helper to collect no-split module names from model hierarchy
- Add `_normalize_no_split_modules` utility for consistent set conversion
- Change return type hint from `nn.ModuleList` to `List[nn.Module]` for flexibility
---
cookbook/transformers/deepseek_v4_flash.py | 139 ++++++++++++++++++
.../transformers/strategy/native_fsdp.py | 61 ++++++--
2 files changed, 186 insertions(+), 14 deletions(-)
create mode 100644 cookbook/transformers/deepseek_v4_flash.py
diff --git a/cookbook/transformers/deepseek_v4_flash.py b/cookbook/transformers/deepseek_v4_flash.py
new file mode 100644
index 00000000..a7a3d73a
--- /dev/null
+++ b/cookbook/transformers/deepseek_v4_flash.py
@@ -0,0 +1,139 @@
+import os
+
+import twinkle
+from peft import LoraConfig
+from transformers import AutoConfig
+from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+from twinkle.preprocessor import SelfCognitionProcessor
+
+logger = get_logger()
+
+MODEL_ID = os.environ.get('MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-flash-bfa16')
+DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
+TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Template')
+OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output')
+
+NUM_LAYERS = int(os.environ.get('NUM_LAYERS', '4'))
+
+BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4'))
+GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '2'))
+LR = float(os.environ.get('LR', '1e-4'))
+MAX_STEPS = int(os.environ.get('MAX_STEPS', '0'))
+SAVE_STEPS = int(os.environ.get('SAVE_STEPS', '50'))
+RESHARD_AFTER_FORWARD = os.environ.get('RESHARD_AFTER_FORWARD', '1') == '1'
+GRADIENT_CHECKPOINTING = True
+IGNORE_MISMATCHED_SIZES = False
+LORA_TARGET_MODULES = 'all-linear'
+ADAPTER_NAME = 'default'
+
+device_mesh = DeviceMesh.from_sizes(
+ fsdp_size=4,
+ dp_size=1,
+ device_type=Platform.get_platform().device_prefix(),
+)
+
+twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+
+
+def create_dataset(data_slice=None):
+ dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=data_slice or range(1000)))
+ dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
+ dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
+ dataset.encode(batched=True)
+ return dataset
+
+
+def eval(model):
+ dataset = create_dataset(data_slice=range(100))
+ dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh)
+ for _, batch in enumerate(dataloader):
+ if callable(batch):
+ batch = batch()
+ model.forward_only(inputs=batch, adapter_name=ADAPTER_NAME)
+ model.calculate_loss(adapter_name=ADAPTER_NAME)
+ return model.calculate_metric(is_training=False, adapter_name=ADAPTER_NAME)
+
+
+def train():
+ dataset = create_dataset()
+ dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh)
+
+ config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
+ if NUM_LAYERS is not None and hasattr(config, 'num_hidden_layers'):
+ config.num_hidden_layers = NUM_LAYERS
+ if hasattr(config, 'use_cache'):
+ config.use_cache = False
+
+ model = TransformersModel(
+ model_id=MODEL_ID,
+ config=config,
+ device_mesh=device_mesh,
+ strategy='native_fsdp',
+ memory_efficient_init=True,
+ ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES,
+ fsdp_config={
+ 'reshard_after_forward': RESHARD_AFTER_FORWARD,
+ },
+ )
+
+ lora_config = LoraConfig(r=8, lora_alpha=32, target_modules=LORA_TARGET_MODULES)
+ model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
+
+ if not GRADIENT_CHECKPOINTING:
+ model.model.gradient_checkpointing_disable()
+
+ model.set_template(TEMPLATE_ID, model_id=MODEL_ID, adapter_name=ADAPTER_NAME)
+ model.set_optimizer('AdamW', lr=LR, foreach=False, adapter_name=ADAPTER_NAME)
+ model.set_lr_scheduler(
+ scheduler_cls='CosineWarmupScheduler',
+ num_warmup_steps=5,
+ num_training_steps=len(dataloader),
+ adapter_name=ADAPTER_NAME,
+ )
+
+ logger.info(get_device_placement())
+ logger.info(model.get_train_configs(adapter_name=ADAPTER_NAME))
+ logger.info(
+ f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, '
+ f'grad_accum={GRAD_ACCUM_STEPS}, lr={LR:.2e}, '
+ f'num_layers={NUM_LAYERS}, ignore_mismatched_sizes={IGNORE_MISMATCHED_SIZES}, '
+ f'gradient_checkpointing={GRADIENT_CHECKPOINTING}, '
+ f'reshard_after_forward={RESHARD_AFTER_FORWARD}, '
+ f'lora_target_modules={LORA_TARGET_MODULES}')
+
+ best_loss = float('inf')
+ for step, batch in enumerate(dataloader):
+ if MAX_STEPS and step >= MAX_STEPS:
+ break
+ if callable(batch):
+ batch = batch()
+ model.forward_backward(
+ inputs=batch,
+ adapter_name=ADAPTER_NAME,
+ gradient_accumulation_steps=GRAD_ACCUM_STEPS,
+ )
+ model.clip_grad_and_step(
+ adapter_name=ADAPTER_NAME,
+ gradient_accumulation_steps=GRAD_ACCUM_STEPS,
+ )
+
+ if step % 20 == 0:
+ metric = model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME)
+ logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
+
+ if step > 0 and step % SAVE_STEPS == 0:
+ metrics = eval(model)
+ logger.info(f'Eval metric: {metrics}')
+ loss = float(metrics['loss'])
+ if loss < best_loss:
+ model.save(name=f'checkpoint-{step}', output_dir=OUTPUT_DIR, adapter_name=ADAPTER_NAME)
+ best_loss = loss
+
+ model.save(name='last-checkpoint', output_dir=OUTPUT_DIR, adapter_name=ADAPTER_NAME)
+
+
+if __name__ == '__main__':
+ train()
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
index ad675006..ce30e043 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -4,7 +4,7 @@
from torch import nn
from torch.distributed.device_mesh import DeviceMesh as TorchDeviceMesh
from torch.distributed.fsdp import fully_shard
-from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Set
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
from twinkle.utils import DeviceMesh, Platform, torch_util
from .load_context import fsdp_pretrained_load_context
@@ -30,7 +30,13 @@ def __init__(self,
self.ep_fsdp_device_mesh = self._build_ep_fsdp_device_mesh(ep_size) if enable_ep else None
def pretrained_load_context(self):
- return fsdp_pretrained_load_context(self._memory_efficient_init and self.device_mesh is not None)
+ # Native FSDP loads pretrained weights via rank0 broadcast during wrap_model().
+ # Avoid Transformers' FSDP loading env here; some versions can hang non-rank0
+ # ranks in from_pretrained barriers.
+ return fsdp_pretrained_load_context(False)
+
+ def use_rank0_pretrained_broadcast(self) -> bool:
+ return self._memory_efficient_init and self.device_mesh is not None
def _build_ep_fsdp_device_mesh(self, ep_size: Optional[int] = None) -> Optional[TorchDeviceMesh]:
if self.device_mesh is None:
@@ -59,17 +65,18 @@ def wrap_model(self, model, optimizer=None):
if optimizer is not None:
_unbind_optimizer_params(optimizer)
- # EP path requires experts on a real device, incompatible with meta-device flow.
- use_meta = self._memory_efficient_init and not ep_enabled
+ use_meta = self.use_rank0_pretrained_broadcast() and not ep_enabled
original_sd = None
saved_buffers = None
if use_meta:
- original_sd = model.state_dict()
- saved_buffers = _get_non_persistent_buffers(model)
- model = model.to(torch.device('meta'))
- if hasattr(model, 'tie_weights'):
- model.tie_weights()
+ is_rank0 = (dist.get_rank() == 0)
+ original_sd = model.state_dict() if is_rank0 else {}
+ saved_buffers = _get_non_persistent_buffers(model) if is_rank0 else {}
+ if is_rank0:
+ model = model.to(torch.device('meta'))
+ if hasattr(model, 'tie_weights'):
+ model.tie_weights()
if ep_enabled:
_ensure_moe_patched_if_needed(model, self.ep_fsdp_device_mesh)
@@ -129,14 +136,13 @@ def wrap_model(self, model, optimizer=None):
if use_meta:
device_type = self.device_mesh.device_type or 'cuda'
- is_rank0 = (dist.get_rank() == 0)
_broadcast_sharded_state_dict(
model,
- original_sd if is_rank0 else {},
+ original_sd,
device_type=device_type,
)
target_device = torch.device(device_type)
- _restore_non_persistent_buffers(model, saved_buffers, device=target_device)
+ _broadcast_non_persistent_buffers(model, saved_buffers or {}, device=target_device)
if hasattr(model, 'tie_weights'):
model.tie_weights()
@@ -322,16 +328,43 @@ def _build_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]:
return TorchDeviceMesh(device_mesh.device_type, flat_mesh, mesh_dim_names=('fsdp', ))
-def _get_decoder_layers(model: nn.Module) -> Optional[nn.ModuleList]:
+def _get_decoder_layers(model: nn.Module) -> Optional[List[nn.Module]]:
+ no_split_modules = _get_no_split_module_names(model)
+ if no_split_modules:
+ layers = [
+ module for module in model.modules()
+ if module is not model and module.__class__.__name__ in no_split_modules
+ ]
+ if layers:
+ return layers
+
inner_model = getattr(model, 'model', None)
if inner_model is not None:
inner_layers = getattr(inner_model, 'layers', None)
if isinstance(inner_layers, nn.ModuleList):
- return inner_layers
+ return list(inner_layers)
return None
+def _get_no_split_module_names(model: nn.Module) -> Set[str]:
+ names = _normalize_no_split_modules(getattr(model, '_no_split_modules', None))
+ if names:
+ return names
+
+ for module in model.modules():
+ names.update(_normalize_no_split_modules(getattr(module, '_no_split_modules', None)))
+ return names
+
+
+def _normalize_no_split_modules(value) -> Set[str]:
+ if value is None:
+ return set()
+ if isinstance(value, str):
+ return {value}
+ return set(value)
+
+
def _collect_expert_params(model: nn.Module) -> Optional[Set[nn.Parameter]]:
ignored: Set[nn.Parameter] = set()
ep_patched = False
From 400acfb38c07b1b1f76f0a85f0e39a2db32fc2ff Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Tue, 12 May 2026 20:16:06 +0800
Subject: [PATCH 03/11] fix: broadcast non-persistent buffers and validate
state dict shapes in FSDP strategy
- Broadcast non-persistent buffers from rank 0 to all ranks instead of only restoring on rank 0
- Add source metadata validation to ensure shape/dtype consistency before distributing tensors
- Fix tie_weights() call to execute on all ranks instead of only rank 0
- Improve error handling with explicit KeyError and RuntimeError for state dict mismatches
---
.../transformers/strategy/native_fsdp.py | 73 ++++++++++++++-----
.../model/transformers/transformers.py | 20 +++++
src/twinkle/template/deepseek_v4.py | 3 +-
src/twinkle/template/utils.py | 2 +-
src/twinkle/utils/transformers_utils.py | 29 +++++++-
5 files changed, 105 insertions(+), 22 deletions(-)
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
index ce30e043..ae651fc3 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -73,10 +73,9 @@ def wrap_model(self, model, optimizer=None):
is_rank0 = (dist.get_rank() == 0)
original_sd = model.state_dict() if is_rank0 else {}
saved_buffers = _get_non_persistent_buffers(model) if is_rank0 else {}
- if is_rank0:
- model = model.to(torch.device('meta'))
- if hasattr(model, 'tie_weights'):
- model.tie_weights()
+ model = model.to(torch.device('meta'))
+ if hasattr(model, 'tie_weights'):
+ model.tie_weights()
if ep_enabled:
_ensure_moe_patched_if_needed(model, self.ep_fsdp_device_mesh)
@@ -511,31 +510,57 @@ def _broadcast_sharded_state_dict(
full_sd: dict,
device_type: str = 'cuda',
) -> None:
- """Broadcast full state dict from rank 0 and materialise local shards via distribute_tensor."""
+ """Distribute rank0 full state dict into local FSDP2 shards."""
from torch.distributed.tensor import DTensor, distribute_tensor
meta_sharded_sd = model.state_dict()
sharded_sd = {}
is_rank0 = (dist.get_rank() == 0)
+ source_metadata = None
+ if is_rank0:
+ source_metadata = {
+ name: (tuple(tensor.shape), tensor.dtype)
+ for name, tensor in full_sd.items() if hasattr(tensor, 'shape') and hasattr(tensor, 'dtype')
+ }
+ metadata_holder = [source_metadata]
+ dist.broadcast_object_list(metadata_holder, src=0)
+ source_metadata = metadata_holder[0] or {}
for param_name, sharded_param in meta_sharded_sd.items():
shape = sharded_param.size()
- dtype = sharded_param.dtype
+ if param_name not in source_metadata:
+ raise KeyError(f"Missing source metadata for parameter '{param_name}'.")
+ source_shape, source_dtype = source_metadata[param_name]
if is_rank0:
+ if param_name not in full_sd:
+ raise KeyError(
+ f"Parameter '{param_name}' found in sharded model state dict but missing from full state dict.")
full_param = full_sd[param_name]
- full_tensor = full_param.detach().to(device_type)
+ full_tensor = full_param.detach()
if isinstance(full_tensor, DTensor):
full_tensor = full_tensor.to_local()
+ full_tensor = full_tensor.to(device_type)
+ if tuple(full_tensor.shape) != tuple(source_shape) or full_tensor.dtype != source_dtype:
+ raise RuntimeError(f"Source metadata mismatch for '{param_name}': "
+ f'actual shape={tuple(full_tensor.shape)} dtype={full_tensor.dtype}, '
+ f'expected shape={source_shape} dtype={source_dtype}.')
else:
- full_tensor = torch.empty(shape, device=device_type, dtype=dtype)
-
- dist.broadcast(full_tensor, src=0)
+ full_tensor = torch.empty(source_shape, device=device_type, dtype=source_dtype)
+
+ if tuple(shape) != tuple(source_shape):
+ raise RuntimeError(f"Parameter '{param_name}' shape mismatch before broadcast: "
+ f'sharded logical shape={tuple(shape)}, source shape={source_shape}.')
+ if isinstance(sharded_param, DTensor):
+ sharded_tensor = distribute_tensor(
+ full_tensor,
+ sharded_param.device_mesh,
+ sharded_param.placements,
+ )
+ else:
+ dist.broadcast(full_tensor, src=0)
+ sharded_tensor = full_tensor
torch_util.synchronize()
-
- device_mesh = sharded_param.device_mesh
- placements = sharded_param.placements
- sharded_tensor = distribute_tensor(full_tensor, device_mesh, placements)
del full_tensor
sharded_sd[param_name] = sharded_tensor
@@ -562,14 +587,26 @@ def _unbind_optimizer_params(optimizer: torch.optim.Optimizer) -> None:
group['params'][i] = torch.empty(1, dtype=param.dtype, device=param.device)
-def _restore_non_persistent_buffers(
+def _broadcast_non_persistent_buffers(
model: nn.Module,
saved_buffers: Dict[str, torch.Tensor],
device: torch.device,
) -> None:
- """Re-register non-persistent buffers saved before to('meta')."""
- for fqn, buf_tensor in saved_buffers.items():
- buf_tensor = buf_tensor.to(device)
+ """Broadcast rank0 non-persistent buffers and re-register them on all ranks."""
+ is_rank0 = (dist.get_rank() == 0)
+ metadata = None
+ if is_rank0:
+ metadata = [(name, tuple(tensor.shape), tensor.dtype) for name, tensor in saved_buffers.items()]
+ metadata_holder = [metadata]
+ dist.broadcast_object_list(metadata_holder, src=0)
+ metadata = metadata_holder[0] or []
+
+ for fqn, shape, dtype in metadata:
+ if is_rank0:
+ buf_tensor = saved_buffers[fqn].to(device)
+ else:
+ buf_tensor = torch.empty(shape, device=device, dtype=dtype)
+ dist.broadcast(buf_tensor, src=0)
if '.' in fqn:
parent_fqn, local_name = fqn.rsplit('.', 1)
parent = model.get_submodule(parent_fqn)
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index 56cbec17..9f80817d 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -42,6 +42,7 @@
from twinkle.utils import construct_class, get_logger, selective_log_softmax, torch_util
from twinkle.utils.framework import Torch
from twinkle.utils.grad_clip import normalize_and_clip_grad_norm
+from twinkle.utils.transformers_utils import filter_from_config_kwargs
logger = get_logger()
@@ -191,6 +192,8 @@ def __init__(
model_cls = getattr(transformers, model_cls)
if model_id is None:
self.model = model_cls.from_config(self.hf_config, **kwargs)
+ elif self._should_init_empty_pretrained_model_on_this_rank():
+ self.model = self._init_empty_model_from_config(model_cls, **kwargs)
else:
# Trigger transformers' FSDP-aware loading: meta-device init + rank-0-only weight load.
with self.strategy.pretrained_load_context():
@@ -204,6 +207,23 @@ def __init__(
self.optimizer_group[_default_adapter_name].adapter_name = _default_adapter_name
self.active_group = _default_adapter_name
+ def _should_init_empty_pretrained_model_on_this_rank(self) -> bool:
+ use_rank0_broadcast = getattr(self.strategy, 'use_rank0_pretrained_broadcast', lambda: False)
+ return bool(use_rank0_broadcast() and dist.is_available() and dist.is_initialized() and dist.get_rank() != 0)
+
+ def _init_empty_model_from_config(self, model_cls, **kwargs):
+ from accelerate import init_empty_weights
+
+ config_kwargs = filter_from_config_kwargs(kwargs)
+ with init_empty_weights(include_buffers=False):
+ if hasattr(model_cls, 'from_config'):
+ model = model_cls.from_config(self.hf_config, **config_kwargs)
+ else:
+ model = model_cls._from_config(self.hf_config, **config_kwargs)
+ if hasattr(model, 'tie_weights'):
+ model.tie_weights()
+ return model
+
def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']):
self._expert_parallel_config = self._fsdp_config.pop('expert_parallel', None)
self._enable_expert_parallel = self._should_enable_expert_parallel(self._expert_parallel_config,
diff --git a/src/twinkle/template/deepseek_v4.py b/src/twinkle/template/deepseek_v4.py
index 97dcf8e4..e9d55fac 100644
--- a/src/twinkle/template/deepseek_v4.py
+++ b/src/twinkle/template/deepseek_v4.py
@@ -3,9 +3,8 @@
import re
from typing import Any, Dict, List, Optional, Tuple, Union
-from .utils import Function
-from .utils import Prompt
from .base import BaseAgentTemplate
+from .utils import Function, Prompt
DSML_TOKEN = '|DSML|'
diff --git a/src/twinkle/template/utils.py b/src/twinkle/template/utils.py
index 0da0417d..597c8517 100644
--- a/src/twinkle/template/utils.py
+++ b/src/twinkle/template/utils.py
@@ -3,12 +3,12 @@
import json
import re
from copy import copy, deepcopy
+from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
from twinkle.data_format import Message, Trajectory
from twinkle.utils import to_device
-from dataclasses import dataclass
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
diff --git a/src/twinkle/utils/transformers_utils.py b/src/twinkle/utils/transformers_utils.py
index 036f7538..3d4ba362 100644
--- a/src/twinkle/utils/transformers_utils.py
+++ b/src/twinkle/utils/transformers_utils.py
@@ -1,6 +1,6 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import re
-from typing import TYPE_CHECKING, Callable, List, Optional
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from .utils import deep_getattr
@@ -8,6 +8,33 @@
import torch.nn as nn
+def filter_from_config_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
+ load_only_keys = {
+ 'cache_dir',
+ 'device_map',
+ 'force_download',
+ 'ignore_mismatched_sizes',
+ 'local_files_only',
+ 'low_cpu_mem_usage',
+ 'max_memory',
+ 'offload_buffers',
+ 'offload_folder',
+ 'offload_state_dict',
+ 'output_loading_info',
+ 'proxies',
+ 'resume_download',
+ 'revision',
+ 'state_dict',
+ 'subfolder',
+ 'token',
+ 'tokenizer_id',
+ 'trust_remote_code',
+ 'use_safetensors',
+ 'weights_only',
+ }
+ return {key: value for key, value in kwargs.items() if key not in load_only_keys}
+
+
def find_layers(
model: 'nn.Module',
cond: Callable[[str, 'nn.Module'], bool],
From a8dbf74eeb1b6e128d87e5406e4e742478b5efe3 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Tue, 12 May 2026 20:46:47 +0800
Subject: [PATCH 04/11] fix(template): rename DeepSeekV4AgentTemplate to
DeepseekV4Template and update references
- Rename class and module exports for consistency with naming conventions
- Update default TEMPLATE_ID in deepseek_v4_flash.py to use new template name
- Refactor encoding/decoding logic for chat message processing with improved tool call handling
---
cookbook/transformers/deepseek_v4_flash.py | 2 +-
src/twinkle/template/__init__.py | 2 +-
src/twinkle/template/deepseek_v4.py | 263 ++++-----
src/twinkle/template/deepseek_v4_encoding.py | 585 +++++++++++++++++++
4 files changed, 716 insertions(+), 136 deletions(-)
create mode 100644 src/twinkle/template/deepseek_v4_encoding.py
diff --git a/cookbook/transformers/deepseek_v4_flash.py b/cookbook/transformers/deepseek_v4_flash.py
index a7a3d73a..e426c2a0 100644
--- a/cookbook/transformers/deepseek_v4_flash.py
+++ b/cookbook/transformers/deepseek_v4_flash.py
@@ -13,7 +13,7 @@
MODEL_ID = os.environ.get('MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-flash-bfa16')
DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
-TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Template')
+TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template')
OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output')
NUM_LAYERS = int(os.environ.get('NUM_LAYERS', '4'))
diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py
index 4b8c27e6..4386fc1e 100644
--- a/src/twinkle/template/__init__.py
+++ b/src/twinkle/template/__init__.py
@@ -1,4 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from .base import BaseAgentTemplate, Template
-from .deepseek_v4 import DeepSeekV4AgentTemplate
+from .deepseek_v4 import DeepseekV4Template
from .qwen3_5_vl import Qwen3_5Template
diff --git a/src/twinkle/template/deepseek_v4.py b/src/twinkle/template/deepseek_v4.py
index e9d55fac..f488d5a6 100644
--- a/src/twinkle/template/deepseek_v4.py
+++ b/src/twinkle/template/deepseek_v4.py
@@ -1,137 +1,132 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
-import json
-import re
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-from .base import BaseAgentTemplate
-from .utils import Function, Prompt
-
-DSML_TOKEN = '|DSML|'
-
-TOOLS_TEMPLATE = """## Tools
-
-You have access to a set of tools to help answer the user's question. \
-You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following:
-
-<{dsml_token}tool_calls>
-<{dsml_token}invoke name="$TOOL_NAME">
-<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE{dsml_token}parameter>
-...
-{dsml_token}invoke>
-<{dsml_token}invoke name="$TOOL_NAME2">
-...
-{dsml_token}invoke>
-{dsml_token}tool_calls>
-
-String parameters should be specified as is and set `string="true"`. \
-For all other types (numbers, booleans, arrays, objects), \
-pass the value in JSON format and set `string="false"`.
-
-If thinking_mode is enabled (triggered by ), \
-you MUST output your complete reasoning inside ... BEFORE any tool calls or final response.
-
-Otherwise, output directly after with tool calls or final response.
-
-### Available Tool Schemas
-
-{tool_schemas}
-
-You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
-"""
-
-
-def _to_json(value: Any) -> str:
- try:
- return json.dumps(value, ensure_ascii=False)
- except Exception:
- return json.dumps(value, ensure_ascii=True)
-
-
-def _encode_arguments_to_dsml(arguments: Dict[str, Any]) -> str:
- """Encode tool call arguments dict into DSML parameter lines."""
- lines = []
- for k, v in arguments.items():
- is_str = 'true' if isinstance(v, str) else 'false'
- val = v if isinstance(v, str) else _to_json(v)
- lines.append(f'<{DSML_TOKEN}parameter name="{k}" string="{is_str}">{val}{DSML_TOKEN}parameter>')
- return '\n'.join(lines)
-
-
-class DeepSeekV4AgentTemplate(BaseAgentTemplate):
-
- def get_toolcall(self, response: str) -> List[Function]:
- # Parse DSML tool calls from model output
- # Pattern: <|DSML|invoke name="tool_name">...params...|DSML|invoke>
- invoke_pattern = re.compile(
- rf'<{re.escape(DSML_TOKEN)}invoke\s+name="([^"]+)">\s*(.*?)\s*{re.escape(DSML_TOKEN)}invoke>', re.DOTALL)
- param_pattern = re.compile(
- rf'<{re.escape(DSML_TOKEN)}parameter\s+name="([^"]+)"\s+string="(true|false)">'
- rf'(.*?){re.escape(DSML_TOKEN)}parameter>', re.DOTALL)
-
- functions = []
- for match in invoke_pattern.finditer(response):
- tool_name = match.group(1)
- params_block = match.group(2)
- arguments = {}
- for pm in param_pattern.finditer(params_block):
- param_name = pm.group(1)
- is_string = pm.group(2)
- param_value = pm.group(3)
- if is_string == 'false':
- try:
- param_value = json.loads(param_value)
- except json.JSONDecodeError:
- pass
- arguments[param_name] = param_value
- functions.append(Function(name=tool_name, arguments=json.dumps(arguments, ensure_ascii=False)))
-
- if len(functions) == 0:
- # Fallback to ReAct format
- return super().get_toolcall(response)
- return functions
-
- def _get_tool_responses(self, tool_messages):
- return ''.join(f'{tool_message["content"]}' for tool_message in tool_messages)
-
- def _format_tool_responses(
+import copy
+import torch
+from transformers import AutoConfig, PreTrainedTokenizerFast
+from typing import Any, Dict, List, Literal, Optional
+
+from twinkle.hub import HubOperation
+from .base import Template
+from .deepseek_v4_encoding import encode_messages
+
+
+def get_deepseek_v4_tokenizer(tokenizer):
+ """Wrap a HF tokenizer with DeepSeek V4's custom chat-template encoder."""
+ dsv4_tokenizer = copy.copy(tokenizer)
+
+ added_vocab = tokenizer.get_added_vocab()
+ added_vocab_size = len(added_vocab)
+ tokenizer_vocab_size = tokenizer.vocab_size
+
+ class _DeepseekV4Tokenizer(tokenizer.__class__): # type: ignore[misc, valid-type]
+
+ def apply_chat_template(
+ self,
+ messages,
+ tools: Optional[List[Dict[str, Any]]] = None,
+ **kwargs,
+ ):
+ thinking = kwargs.get('thinking', False)
+ enable_thinking = kwargs.get('enable_thinking', False)
+ thinking = thinking or enable_thinking
+ thinking_mode = 'thinking' if thinking else 'chat'
+
+ conversation = kwargs.get('conversation', messages)
+ messages = conversation.copy()
+ if tools:
+ messages.insert(0, {'role': 'system'})
+ messages[0]['tools'] = tools
+
+ reasoning_effort = kwargs.get('reasoning_effort')
+ if reasoning_effort not in ('max', 'high'):
+ reasoning_effort = None
+
+ prompt_str = encode_messages(
+ messages,
+ thinking_mode=thinking_mode,
+ drop_thinking=kwargs.get('drop_thinking', True),
+ reasoning_effort=reasoning_effort,
+ )
+
+ tokenize = kwargs.get('tokenize', True)
+ return_dict = kwargs.get('return_dict', False)
+ return_tensors = kwargs.get('return_tensors')
+
+ if not tokenize:
+ return {'prompt': prompt_str} if return_dict else prompt_str
+
+ tokenizer_kwargs = {key: kwargs[key] for key in ('truncation', 'max_length') if key in kwargs}
+ input_ids = self.encode(
+ prompt_str,
+ add_special_tokens=False,
+ **tokenizer_kwargs,
+ )
+
+ if not return_dict and return_tensors is None:
+ return input_ids
+
+ attention_mask = [1] * len(input_ids)
+ if return_tensors == 'pt':
+ input_ids = torch.tensor([input_ids], dtype=torch.long)
+ attention_mask = torch.tensor([attention_mask], dtype=torch.long)
+ elif return_tensors is not None:
+ raise ValueError(f'Unsupported return_tensors: {return_tensors}')
+
+ encoded = {
+ 'input_ids': input_ids,
+ 'attention_mask': attention_mask,
+ }
+ if kwargs.get('return_assistant_tokens_mask', False):
+ # Fall back to round-by-round labeling in Template by omitting
+ # assistant_masks support from this custom tokenizer wrapper.
+ pass
+ if return_dict:
+ return encoded
+ return encoded['input_ids']
+
+ def num_special_tokens_to_add(self) -> int:
+ return len(self.encode(''))
+
+ def __len__(self) -> int:
+ return tokenizer_vocab_size + added_vocab_size
+
+ def get_added_vocab(self) -> dict[str, int]:
+ return added_vocab.copy()
+
+ _DeepseekV4Tokenizer.__name__ = f'DSV4{tokenizer.__class__.__name__}'
+ dsv4_tokenizer.__class__ = _DeepseekV4Tokenizer
+ return dsv4_tokenizer
+
+
+class DeepseekV4Template(Template):
+
+ def __init__(
self,
- assistant_content: str,
- tool_messages,
- ) -> Tuple[str, 'Prompt']:
- with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
- if with_action:
- return super()._format_tool_responses(assistant_content, tool_messages)
- res = [
- '<|end▁of▁sentence|><|User|>',
- self._get_tool_responses(tool_messages),
- '<|Assistant|>',
+ model_id: str,
+ use_chat_template: bool = True,
+ max_length: Optional[int] = 8192,
+ truncation_strategy: Literal['raise', 'left', 'right', 'split'] = 'raise',
+ default_system: Optional[str] = None,
+ enable_thinking: bool = True,
+ **kwargs,
+ ):
+ model_id = HubOperation.download_model(model_id, ignore_model=True)
+ base_tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id, **kwargs)
+ self.processor = get_deepseek_v4_tokenizer(base_tokenizer)
+ self.config = AutoConfig.from_pretrained(model_id, **kwargs)
+
+ self.use_chat_template = use_chat_template
+ self.max_length = max_length
+ self.enable_thinking = enable_thinking
+ self.truncation_strategy = truncation_strategy
+ self.default_system = default_system
+ self._test_support_assistant_tokens_mask()
+ self.pre_pipeline = [
+ self._add_default_system,
+ self._to_standard_reasoning_content,
+ self._build_standard_messages,
+ ]
+ self.post_pipeline = [
+ self._check_max_length,
+ self._add_attention_fields,
+ self._roll_labels,
]
- return assistant_content, res
-
- def _format_tools(self, tools: List[Union[str, dict]], system: Optional[str] = None, user_message=None) -> str:
- tool_schemas = []
- for tool in tools:
- tool = self.unwrap_tool(tool)
- tool_schemas.append(_to_json(tool))
-
- tools_section = TOOLS_TEMPLATE.format(
- tool_schemas='\n'.join(tool_schemas),
- dsml_token=DSML_TOKEN,
- )
-
- system = system or ''
- return f'{system}\n\n{tools_section}' if system else tools_section
-
- def _format_tool_calls(self, tool_call_messages) -> str:
- invocations = []
- for message in tool_call_messages:
- tool_call = self._parse_tool_call(message['content'])
- name = tool_call['name']
- arguments = tool_call['arguments']
- if isinstance(arguments, str):
- arguments = json.loads(arguments)
- dsml_args = _encode_arguments_to_dsml(arguments)
- invocations.append(f'<{DSML_TOKEN}invoke name="{name}">\n{dsml_args}\n{DSML_TOKEN}invoke>')
-
- tool_calls_str = '\n'.join(invocations)
- return f'<{DSML_TOKEN}tool_calls>\n{tool_calls_str}\n{DSML_TOKEN}tool_calls>'
diff --git a/src/twinkle/template/deepseek_v4_encoding.py b/src/twinkle/template/deepseek_v4_encoding.py
new file mode 100644
index 00000000..f10088ce
--- /dev/null
+++ b/src/twinkle/template/deepseek_v4_encoding.py
@@ -0,0 +1,585 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+# ruff: noqa
+# fmt: off
+
+"""
+DeepSeek-V4 Encoding
+
+A self-contained implementation for encoding/decoding DeepSeek-V4 chat messages
+with tool calling, thinking mode, and quick instruction task support.
+"""
+
+import copy
+import json
+import regex as re
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+bos_token: str = '<|begin▁of▁sentence|>'
+eos_token: str = '<|end▁of▁sentence|>'
+thinking_start_token: str = ''
+thinking_end_token: str = ''
+dsml_token: str = '|DSML|'
+
+USER_SP_TOKEN = '<|User|>'
+ASSISTANT_SP_TOKEN = '<|Assistant|>'
+LATEST_REMINDER_SP_TOKEN = '<|latest_reminder|>'
+
+DS_TASK_SP_TOKENS = {
+ 'action': '<|action|>',
+ 'query': '<|query|>',
+ 'authority': '<|authority|>',
+ 'domain': '<|domain|>',
+ 'title': '<|title|>',
+ 'read_url': '<|read_url|>',
+}
+VALID_TASKS = set(DS_TASK_SP_TOKENS.keys())
+
+system_msg_template: str = '{content}'
+user_msg_template: str = '{content}'
+latest_reminder_msg_template: str = '{content}'
+assistant_msg_template: str = '{reasoning}{content}{tool_calls}' + eos_token
+assistant_msg_wo_eos_template: str = '{reasoning}{content}{tool_calls}'
+thinking_template: str = '{reasoning}'
+
+response_format_template: str = (
+ '## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}'
+)
+tool_call_template: str = (
+ "<{dsml_token}invoke name=\"{name}\">\n{arguments}\n{dsml_token}invoke>"
+)
+tool_calls_template = (
+ '<{dsml_token}{tc_block_name}>\n{tool_calls}\n{dsml_token}{tc_block_name}>'
+)
+tool_calls_block_name: str = 'tool_calls'
+
+tool_output_template: str = (
+ '{content}'
+)
+
+REASONING_EFFORT_MAX = (
+ 'Reasoning Effort: Absolute maximum with no shortcuts permitted.\n'
+ 'You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root '
+ 'cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n'
+ 'Explicitly write out your entire deliberation process, documenting every intermediate step, considered '
+ 'alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n'
+)
+
+TOOLS_TEMPLATE = """## Tools
+
+You have access to a set of tools to help answer the user's question. You can invoke tools by writing a
+"<{dsml_token}tool_calls>" block like the following:
+
+<{dsml_token}tool_calls>
+<{dsml_token}invoke name="$TOOL_NAME">
+<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE{dsml_token}parameter>
+...
+{dsml_token}invoke>
+<{dsml_token}invoke name="$TOOL_NAME2">
+...
+{dsml_token}invoke>
+{dsml_token}tool_calls>
+
+String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays,
+objects), pass the value in JSON format and set `string="false"`.
+
+If thinking_mode is enabled (triggered by {thinking_start_token}), you MUST output your complete reasoning inside
+{thinking_start_token}...{thinking_end_token} BEFORE any tool calls or final response.
+
+Otherwise, output directly after {thinking_end_token} with tool calls or final response.
+
+### Available Tool Schemas
+
+{tool_schemas}
+
+You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
+"""
+
+
+def to_json(value: Any) -> str:
+ try:
+ return json.dumps(value, ensure_ascii=False)
+ except Exception:
+ return json.dumps(value, ensure_ascii=True)
+
+
+def tools_from_openai_format(tools):
+ return [tool['function'] for tool in tools]
+
+
+def tool_calls_from_openai_format(tool_calls):
+ return [
+ {
+ 'name': tool_call['function']['name'],
+ 'arguments': tool_call['function']['arguments'],
+ }
+ for tool_call in tool_calls
+ ]
+
+
+def tool_calls_to_openai_format(tool_calls):
+ return [
+ {
+ 'type': 'function',
+ 'function': {
+ 'name': tool_call['name'],
+ 'arguments': tool_call['arguments'],
+ }
+ }
+ for tool_call in tool_calls
+ ]
+
+
+def encode_arguments_to_dsml(tool_call: Dict[str, Any]) -> str:
+ p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}{dsml_token}parameter>'
+ p_dsml_strs = []
+
+ if isinstance(tool_call['arguments'], str):
+ arguments = json.loads(tool_call['arguments'])
+ else:
+ arguments = tool_call['arguments']
+
+ for k, v in arguments.items():
+ p_dsml_str = p_dsml_template.format(
+ dsml_token=dsml_token,
+ key=k,
+ is_str='true' if isinstance(v, str) else 'false',
+ value=v if isinstance(v, str) else to_json(v),
+ )
+ p_dsml_strs.append(p_dsml_str)
+
+ return '\n'.join(p_dsml_strs)
+
+
+def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]:
+ def _decode_value(key: str, value: str, string: str):
+ if string == 'true':
+ value = to_json(value)
+ return f'{to_json(key)}: {value}'
+
+ tool_args_json = '{' + ', '.join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + '}'
+ return dict(name=tool_name, arguments=tool_args_json)
+
+
+def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
+ tools_json = [to_json(t) for t in tools]
+
+ return TOOLS_TEMPLATE.format(
+ tool_schemas='\n'.join(tools_json),
+ dsml_token=dsml_token,
+ thinking_start_token=thinking_start_token,
+ thinking_end_token=thinking_end_token,
+ )
+
+
+def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
+ last_user_index = -1
+ for idx in range(len(messages) - 1, -1, -1):
+ if messages[idx].get('role') in ['user', 'developer']:
+ last_user_index = idx
+ break
+ return last_user_index
+
+
+def render_message(
+ index: int,
+ messages: List[Dict[str, Any]],
+ thinking_mode: str,
+ drop_thinking: bool = True,
+ reasoning_effort: Optional[str] = None,
+) -> str:
+ assert 0 <= index < len(messages)
+ assert thinking_mode in ['chat', 'thinking'], f'Invalid thinking_mode `{thinking_mode}`'
+
+ prompt = ''
+ msg = messages[index]
+ last_user_idx = find_last_user_index(messages)
+
+ role = msg.get('role')
+ content = msg.get('content')
+ tools = msg.get('tools')
+ response_format = msg.get('response_format')
+ tool_calls = msg.get('tool_calls')
+ reasoning = msg.get('reasoning')
+ wo_eos = msg.get('wo_eos', False)
+
+ if tools:
+ tools = tools_from_openai_format(tools)
+ if tool_calls:
+ tool_calls = tool_calls_from_openai_format(tool_calls)
+
+ assert reasoning_effort in ['max', None, 'high'], f'Invalid reasoning effort: {reasoning_effort}'
+ if index == 0 and thinking_mode == 'thinking' and reasoning_effort == 'max':
+ prompt += REASONING_EFFORT_MAX
+
+ if role == 'system':
+ prompt += system_msg_template.format(content=content or '')
+ if tools:
+ prompt += '\n\n' + render_tools(tools)
+ if response_format:
+ prompt += '\n\n' + response_format_template.format(schema=to_json(response_format))
+
+ elif role == 'developer':
+ assert content, f'Invalid message for role `{role}`: {msg}'
+
+ content_developer = USER_SP_TOKEN
+ content_developer += content
+
+ if tools:
+ content_developer += '\n\n' + render_tools(tools)
+ if response_format:
+ content_developer += '\n\n' + response_format_template.format(schema=to_json(response_format))
+
+ prompt += user_msg_template.format(content=content_developer)
+
+ elif role == 'user':
+ prompt += USER_SP_TOKEN
+
+ content_blocks = msg.get('content_blocks')
+ if content_blocks:
+ parts = []
+ for block in content_blocks:
+ block_type = block.get('type')
+ if block_type == 'text':
+ parts.append(block.get('text', ''))
+ elif block_type == 'tool_result':
+ tool_content = block.get('content', '')
+ if isinstance(tool_content, list):
+ text_parts = []
+ for b in tool_content:
+ if b.get('type') == 'text':
+ text_parts.append(b.get('text', ''))
+ else:
+ text_parts.append(f"[Unsupported {b.get('type')}]")
+ tool_content = '\n\n'.join(text_parts)
+ parts.append(tool_output_template.format(content=tool_content))
+ else:
+ parts.append(f'[Unsupported {block_type}]')
+ prompt += '\n\n'.join(parts)
+ else:
+ prompt += content or ''
+
+ elif role == 'latest_reminder':
+ prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format(content=content)
+
+ elif role == 'tool':
+ raise NotImplementedError(
+ 'deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()')
+
+ elif role == 'assistant':
+ thinking_part = ''
+ tc_content = ''
+
+ if tool_calls:
+ tc_list = [
+ tool_call_template.format(
+ dsml_token=dsml_token,
+ name=tc.get('name'),
+ arguments=encode_arguments_to_dsml(tc)
+ )
+ for tc in tool_calls
+ ]
+ tc_content += '\n\n' + tool_calls_template.format(
+ dsml_token=dsml_token,
+ tool_calls='\n'.join(tc_list),
+ tc_block_name=tool_calls_block_name,
+ )
+
+ summary_content = content or ''
+ reasoning = reasoning or ''
+
+ prev_has_task = index - 1 >= 0 and messages[index - 1].get('task') is not None
+
+ if thinking_mode == 'thinking' and not prev_has_task:
+ if not drop_thinking or index > last_user_idx:
+ thinking_part = thinking_template.format(reasoning=reasoning) + thinking_end_token
+ else:
+ thinking_part = ''
+
+ if wo_eos:
+ prompt += assistant_msg_wo_eos_template.format(
+ reasoning=thinking_part,
+ content=summary_content,
+ tool_calls=tc_content,
+ )
+ else:
+ prompt += assistant_msg_template.format(
+ reasoning=thinking_part,
+ content=summary_content,
+ tool_calls=tc_content,
+ )
+ else:
+ raise NotImplementedError(f'Unknown role: {role}')
+
+ if index + 1 < len(messages) and messages[index + 1].get('role') not in ['assistant', 'latest_reminder']:
+ return prompt
+
+ task = messages[index].get('task')
+ if task is not None:
+ assert task in VALID_TASKS, f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}"
+ task_sp_token = DS_TASK_SP_TOKENS[task]
+
+ if task != 'action':
+ prompt += task_sp_token
+ else:
+ prompt += ASSISTANT_SP_TOKEN
+ prompt += thinking_end_token if thinking_mode != 'thinking' else thinking_start_token
+ prompt += task_sp_token
+
+ elif messages[index].get('role') in ['user', 'developer']:
+ prompt += ASSISTANT_SP_TOKEN
+ if not drop_thinking and thinking_mode == 'thinking':
+ prompt += thinking_start_token
+ elif drop_thinking and thinking_mode == 'thinking' and index >= last_user_idx:
+ prompt += thinking_start_token
+ else:
+ prompt += thinking_end_token
+
+ return prompt
+
+
+def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ merged: List[Dict[str, Any]] = []
+
+ for msg in messages:
+ msg = copy.deepcopy(msg)
+ role = msg.get('role')
+
+ if role == 'tool':
+ tool_block = {
+ 'type': 'tool_result',
+ 'tool_use_id': msg.get('tool_call_id', ''),
+ 'content': msg.get('content', ''),
+ }
+ if merged and merged[-1].get('role') == 'user' and 'content_blocks' in merged[-1]:
+ merged[-1]['content_blocks'].append(tool_block)
+ else:
+ merged.append({
+ 'role': 'user',
+ 'content_blocks': [tool_block],
+ })
+ elif role == 'user':
+ text_block = {'type': 'text', 'text': msg.get('content', '')}
+ if (merged and merged[-1].get('role') == 'user' and 'content_blocks' in merged[-1]
+ and merged[-1].get('task') is None):
+ merged[-1]['content_blocks'].append(text_block)
+ else:
+ new_msg = {
+ 'role': 'user',
+ 'content': msg.get('content', ''),
+ 'content_blocks': [text_block],
+ }
+ for key in ('task', 'wo_eos', 'mask'):
+ if key in msg:
+ new_msg[key] = msg[key]
+ merged.append(new_msg)
+ else:
+ merged.append(msg)
+
+ return merged
+
+
+def sort_tool_results_by_call_order(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ last_tool_call_order: Dict[str, int] = {}
+
+ for msg in messages:
+ role = msg.get('role')
+ if role == 'assistant' and msg.get('tool_calls'):
+ last_tool_call_order = {}
+ for idx, tc in enumerate(msg['tool_calls']):
+ tc_id = tc.get('id') or tc.get('function', {}).get('id', '')
+ if tc_id:
+ last_tool_call_order[tc_id] = idx
+
+ elif role == 'user' and msg.get('content_blocks'):
+ tool_blocks = [b for b in msg['content_blocks'] if b.get('type') == 'tool_result']
+ if len(tool_blocks) > 1 and last_tool_call_order:
+ sorted_blocks = sorted(
+ tool_blocks,
+ key=lambda b: last_tool_call_order.get(b.get('tool_use_id', ''), 0)
+ )
+ sorted_idx = 0
+ new_blocks = []
+ for block in msg['content_blocks']:
+ if block.get('type') == 'tool_result':
+ new_blocks.append(sorted_blocks[sorted_idx])
+ sorted_idx += 1
+ else:
+ new_blocks.append(block)
+ msg['content_blocks'] = new_blocks
+
+ return messages
+
+
+def encode_messages(
+ messages: List[Dict[str, Any]],
+ thinking_mode: str,
+ context: Optional[List[Dict[str, Any]]] = None,
+ drop_thinking: bool = True,
+ add_default_bos_token: bool = True,
+ reasoning_effort: Optional[str] = None,
+) -> str:
+ context = context if context else []
+
+ messages = merge_tool_messages(messages)
+ messages = sort_tool_results_by_call_order(context + messages)[len(context):]
+ if context:
+ context = merge_tool_messages(context)
+ context = sort_tool_results_by_call_order(context)
+
+ full_messages = context + messages
+
+ prompt = bos_token if add_default_bos_token and len(context) == 0 else ''
+
+ effective_drop_thinking = drop_thinking
+ if any(m.get('tools') for m in full_messages):
+ effective_drop_thinking = False
+
+ if thinking_mode == 'thinking' and effective_drop_thinking:
+ full_messages = _drop_thinking_messages(full_messages)
+ num_to_render = len(full_messages) - len(_drop_thinking_messages(context))
+ context_len = len(full_messages) - num_to_render
+ else:
+ num_to_render = len(messages)
+ context_len = len(context)
+
+ for idx in range(num_to_render):
+ prompt += render_message(
+ idx + context_len,
+ full_messages,
+ thinking_mode=thinking_mode,
+ drop_thinking=effective_drop_thinking,
+ reasoning_effort=reasoning_effort,
+ )
+
+ return prompt
+
+
+def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ last_user_idx = find_last_user_index(messages)
+ result = []
+ keep_roles = {'user', 'system', 'tool', 'latest_reminder', 'direct_search_results'}
+
+ for idx, msg in enumerate(messages):
+ role = msg.get('role')
+ if role in keep_roles or idx >= last_user_idx:
+ result.append(msg)
+ elif role == 'assistant':
+ msg = copy.copy(msg)
+ msg.pop('reasoning', None)
+ result.append(msg)
+
+ return result
+
+
+def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]:
+ min_pos = len(text)
+ matched_stop = None
+
+ for s in stop:
+ pos = text.find(s, index)
+ if pos != -1 and pos < min_pos:
+ min_pos = pos
+ matched_stop = s
+
+ if matched_stop:
+ content = text[index:min_pos]
+ return min_pos + len(matched_stop), content, matched_stop
+ else:
+ content = text[index:]
+ return len(text), content, None
+
+
+def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Dict[str, str]]]:
+ tool_calls: List[Dict[str, Any]] = []
+ stop_token = None
+ tool_calls_end_token = f'{dsml_token}{tool_calls_block_name}>'
+
+ while index < len(text):
+ index, content_before, stop_token = _read_until_stop(
+ index, text, [f'<{dsml_token}invoke', tool_calls_end_token])
+ if content_before != '>\n':
+ raise ValueError(f"Tool call format error: expected '>\\n' but got '{content_before}'")
+
+ if stop_token == tool_calls_end_token:
+ break
+
+ if stop_token is None:
+ raise ValueError('Missing special token in tool calls')
+
+ index, tool_name_content, stop_token = _read_until_stop(
+ index, text, [f'<{dsml_token}parameter', f'{dsml_token}invoke'])
+
+ p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
+ if len(p_tool_name) != 1:
+ raise ValueError(f"Tool name format error: '{tool_name_content}'")
+ tool_name = p_tool_name[0]
+
+ tool_args: Dict[str, Tuple[str, str]] = {}
+ while stop_token == f'<{dsml_token}parameter':
+ index, param_content, stop_token = _read_until_stop(index, text, [f'/{dsml_token}parameter'])
+
+ param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL)
+ if len(param_kv) != 1:
+ raise ValueError(f"Parameter format error: '{param_content}'")
+ param_name, string, param_value = param_kv[0]
+
+ if param_name in tool_args:
+ raise ValueError(f"Duplicate parameter name: '{param_name}'")
+ tool_args[param_name] = (param_value, string)
+
+ index, content, stop_token = _read_until_stop(
+ index, text, [f'<{dsml_token}parameter', f'{dsml_token}invoke'])
+ if content != '>\n':
+ raise ValueError(f"Parameter format error: expected '>\\n' but got '{content}'")
+
+ tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
+ tool_calls.append(tool_call)
+
+ return index, stop_token, tool_calls
+
+
+def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]:
+ summary_content, reasoning = '', ''
+ tool_calls: List[Dict[str, str]] = []
+ index, stop_token = 0, None
+ tool_calls_start_token = f'\n\n<{dsml_token}{tool_calls_block_name}'
+
+ is_thinking = thinking_mode == 'thinking'
+ is_tool_calling = False
+
+ if is_thinking:
+ index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token])
+ reasoning = content_delta
+ if stop_token != thinking_end_token:
+ raise ValueError('Invalid thinking format: missing ')
+
+ index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token])
+ summary_content = content_delta
+ if stop_token == tool_calls_start_token:
+ is_tool_calling = True
+ else:
+ if stop_token != eos_token:
+ raise ValueError('Invalid format: missing EOS token')
+
+ if is_tool_calling:
+ index, stop_token, tool_calls = parse_tool_calls(index, text)
+
+ index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
+ if tool_ends_text:
+ raise ValueError('Unexpected content after tool calls')
+
+ if len(text) != index or stop_token not in [eos_token, None]:
+ raise ValueError('Unexpected content at end')
+
+ for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
+ if sp_token in summary_content or sp_token in reasoning:
+ raise ValueError(f"Unexpected special token '{sp_token}' in content")
+
+ return {
+ 'role': 'assistant',
+ 'content': summary_content,
+ 'reasoning': reasoning,
+ 'tool_calls': tool_calls_to_openai_format(tool_calls)
+ }
+
+# fmt: on
From 09879aec12ad786faf66db7c8284a181d709a697 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Tue, 12 May 2026 21:56:58 +0800
Subject: [PATCH 05/11] fix: update LoRA target modules and NPU dtype alignment
in deepseek_v4 example
- Change LORA_TARGET_MODULES from 'all-linear' to specific module list for DeepSeek V4
- Remove gradient_accumulation_steps parameter from forward_backward call
- Fix NPU device dtype alignment to convert all parameters to base dtype when on NPU
---
cookbook/transformers/deepseek_v4_flash.py | 11 +++++++++--
src/twinkle/model/transformers/transformers.py | 10 +++++++---
2 files changed, 16 insertions(+), 5 deletions(-)
diff --git a/cookbook/transformers/deepseek_v4_flash.py b/cookbook/transformers/deepseek_v4_flash.py
index e426c2a0..69ee095a 100644
--- a/cookbook/transformers/deepseek_v4_flash.py
+++ b/cookbook/transformers/deepseek_v4_flash.py
@@ -26,7 +26,15 @@
RESHARD_AFTER_FORWARD = os.environ.get('RESHARD_AFTER_FORWARD', '1') == '1'
GRADIENT_CHECKPOINTING = True
IGNORE_MISMATCHED_SIZES = False
-LORA_TARGET_MODULES = 'all-linear'
+LORA_TARGET_MODULES = [
+ 'wq_a',
+ 'wq_b',
+ 'wkv',
+ 'wgate',
+ 'gate_proj',
+ 'up_proj',
+ 'down_proj',
+]
ADAPTER_NAME = 'default'
device_mesh = DeviceMesh.from_sizes(
@@ -113,7 +121,6 @@ def train():
model.forward_backward(
inputs=batch,
adapter_name=ADAPTER_NAME,
- gradient_accumulation_steps=GRAD_ACCUM_STEPS,
)
model.clip_grad_and_step(
adapter_name=ADAPTER_NAME,
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index 9f80817d..d6d1939b 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -1054,17 +1054,21 @@ def _load_optimizer(self, checkpoint_dir, **kwargs):
def _ensure_lora_dtype(self, model):
"""Force LoRA parameters to use the same dtype as base model for FSDP2 compatibility."""
base_dtype = None
+ is_npu_device = False
for param in model.parameters():
- if param.dtype in (torch.float16, torch.bfloat16, torch.float32):
+ if param.device.type == 'npu':
+ is_npu_device = True
+ if base_dtype is None and param.dtype in (torch.float16, torch.bfloat16, torch.float32):
base_dtype = param.dtype
+ if base_dtype is not None and is_npu_device:
break
if base_dtype is None:
return
- # Convert all LoRA parameters to the base model dtype
+ # Temporary workaround: NPU requires all parameters to align with the base dtype.
with torch.no_grad():
for name, param in model.named_parameters():
- if 'lora_' in name.lower() and param.dtype != base_dtype:
+ if (is_npu_device or 'lora_' in name.lower()) and param.dtype != base_dtype:
param.data = param.data.to(base_dtype)
def _load_scaler_state(self, scaler_path, **kwargs):
From 00c41ddbbf88941281fe68c21c805745c1f1b3fa Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Tue, 12 May 2026 22:56:21 +0800
Subject: [PATCH 06/11] fix: update DeepSeek-V4-Flash model ID and LoRA target
modules
- Update default model ID from DeepSeek-V4-flash-bfa16 to DeepSeek-V4-Flash
- Add comment explaining FP4/FP8 weight conversion requirement
- Fix LoRA target module names to match actual model architecture
---
cookbook/transformers/deepseek_v4_flash.py | 15 +++++++++------
cookbook/transformers/deepseek_v4_flash.sh | 6 ++++++
2 files changed, 15 insertions(+), 6 deletions(-)
create mode 100644 cookbook/transformers/deepseek_v4_flash.sh
diff --git a/cookbook/transformers/deepseek_v4_flash.py b/cookbook/transformers/deepseek_v4_flash.py
index 69ee095a..3e8e6376 100644
--- a/cookbook/transformers/deepseek_v4_flash.py
+++ b/cookbook/transformers/deepseek_v4_flash.py
@@ -10,8 +10,11 @@
from twinkle.preprocessor import SelfCognitionProcessor
logger = get_logger()
-
-MODEL_ID = os.environ.get('MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-flash-bfa16')
+# `deepseek-ai/DeepSeek-V4-Flash` uses mixed FP4/FP8 weights.
+# Convert the checkpoint before training by following:
+# https://gitcode.com/cann/cann-recipes-train/blob/master/llm_pretrain/deepseekv4/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87
+# Install `transformers==5.8.0` before running this cookbook.
+MODEL_ID = os.environ.get('MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-Flash')
DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template')
OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output')
@@ -27,10 +30,10 @@
GRADIENT_CHECKPOINTING = True
IGNORE_MISMATCHED_SIZES = False
LORA_TARGET_MODULES = [
- 'wq_a',
- 'wq_b',
- 'wkv',
- 'wgate',
+ 'q_a_proj',
+ 'q_b_proj',
+ 'kv_proj',
+ 'o_b_proj',
'gate_proj',
'up_proj',
'down_proj',
diff --git a/cookbook/transformers/deepseek_v4_flash.sh b/cookbook/transformers/deepseek_v4_flash.sh
new file mode 100644
index 00000000..991e60eb
--- /dev/null
+++ b/cookbook/transformers/deepseek_v4_flash.sh
@@ -0,0 +1,6 @@
+# `deepseek-ai/DeepSeek-V4-Flash` uses mixed FP4/FP8 weights.
+# Convert the checkpoint before training by following:
+# https://gitcode.com/cann/cann-recipes-train/blob/master/llm_pretrain/deepseekv4/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87
+# Install `transformers==5.8.0` before running this cookbook.
+
+CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 cookbook/transformers/deepseek_v4_flash.py
From 6cecbba4c1bb5b92784a5bd2acf7d78f90dfc917 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Tue, 12 May 2026 23:38:10 +0800
Subject: [PATCH 07/11] refactor(native_fsdp): replace custom sharded state
dict broadcast with DCP
Replace the custom `_broadcast_sharded_state_dict` function with `_load_rank0_full_state_dict` that leverages `torch.distributed.checkpoint`'s `set_model_state_dict` with `broadcast_from_rank0=True`. This simplifies the code by using the official DCP API for distributing rank0 full state dict to FSDP2 shards, removing manual tensor distribution and metadata broadcasting logic.
---
.../transformers/strategy/native_fsdp.py | 77 +++----------------
1 file changed, 12 insertions(+), 65 deletions(-)
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
index ae651fc3..3cbaf141 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -135,11 +135,7 @@ def wrap_model(self, model, optimizer=None):
if use_meta:
device_type = self.device_mesh.device_type or 'cuda'
- _broadcast_sharded_state_dict(
- model,
- original_sd,
- device_type=device_type,
- )
+ _load_rank0_full_state_dict(model, original_sd or {})
target_device = torch.device(device_type)
_broadcast_non_persistent_buffers(model, saved_buffers or {}, device=target_device)
if hasattr(model, 'tie_weights'):
@@ -505,67 +501,18 @@ def _rebind_optimizer(optimizer: torch.optim.Optimizer, model: nn.Module) -> tor
return optimizer
-def _broadcast_sharded_state_dict(
- model: nn.Module,
- full_sd: dict,
- device_type: str = 'cuda',
-) -> None:
- """Distribute rank0 full state dict into local FSDP2 shards."""
- from torch.distributed.tensor import DTensor, distribute_tensor
-
- meta_sharded_sd = model.state_dict()
- sharded_sd = {}
- is_rank0 = (dist.get_rank() == 0)
- source_metadata = None
- if is_rank0:
- source_metadata = {
- name: (tuple(tensor.shape), tensor.dtype)
- for name, tensor in full_sd.items() if hasattr(tensor, 'shape') and hasattr(tensor, 'dtype')
- }
- metadata_holder = [source_metadata]
- dist.broadcast_object_list(metadata_holder, src=0)
- source_metadata = metadata_holder[0] or {}
-
- for param_name, sharded_param in meta_sharded_sd.items():
- shape = sharded_param.size()
- if param_name not in source_metadata:
- raise KeyError(f"Missing source metadata for parameter '{param_name}'.")
- source_shape, source_dtype = source_metadata[param_name]
-
- if is_rank0:
- if param_name not in full_sd:
- raise KeyError(
- f"Parameter '{param_name}' found in sharded model state dict but missing from full state dict.")
- full_param = full_sd[param_name]
- full_tensor = full_param.detach()
- if isinstance(full_tensor, DTensor):
- full_tensor = full_tensor.to_local()
- full_tensor = full_tensor.to(device_type)
- if tuple(full_tensor.shape) != tuple(source_shape) or full_tensor.dtype != source_dtype:
- raise RuntimeError(f"Source metadata mismatch for '{param_name}': "
- f'actual shape={tuple(full_tensor.shape)} dtype={full_tensor.dtype}, '
- f'expected shape={source_shape} dtype={source_dtype}.')
- else:
- full_tensor = torch.empty(source_shape, device=device_type, dtype=source_dtype)
-
- if tuple(shape) != tuple(source_shape):
- raise RuntimeError(f"Parameter '{param_name}' shape mismatch before broadcast: "
- f'sharded logical shape={tuple(shape)}, source shape={source_shape}.')
- if isinstance(sharded_param, DTensor):
- sharded_tensor = distribute_tensor(
- full_tensor,
- sharded_param.device_mesh,
- sharded_param.placements,
- )
- else:
- dist.broadcast(full_tensor, src=0)
- sharded_tensor = full_tensor
- torch_util.synchronize()
- del full_tensor
-
- sharded_sd[param_name] = sharded_tensor
+def _load_rank0_full_state_dict(model: nn.Module, full_sd: dict) -> None:
+ """Load rank0 full weights into a sharded FSDP2 model via DCP broadcast."""
+ from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
- model.load_state_dict(sharded_sd, assign=True)
+ set_model_state_dict(
+ model=model,
+ model_state_dict=full_sd,
+ options=StateDictOptions(
+ full_state_dict=True,
+ broadcast_from_rank0=True,
+ ),
+ )
def _get_non_persistent_buffers(model: nn.Module) -> Dict[str, torch.Tensor]:
From d292147ff4248e66fdecf842949dd8c5653234c3 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Wed, 13 May 2026 09:30:35 +0800
Subject: [PATCH 08/11] fix: correct is_npu_device detection in
_ensure_lora_dtype
Use Platform.device_prefix() instead of iterating over model parameters to check for NPU device, improving efficiency and correctness.
---
src/twinkle/model/transformers/transformers.py | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index d6d1939b..46484792 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -1054,10 +1054,8 @@ def _load_optimizer(self, checkpoint_dir, **kwargs):
def _ensure_lora_dtype(self, model):
"""Force LoRA parameters to use the same dtype as base model for FSDP2 compatibility."""
base_dtype = None
- is_npu_device = False
+ is_npu_device = Platform.device_prefix() == 'npu'
for param in model.parameters():
- if param.device.type == 'npu':
- is_npu_device = True
if base_dtype is None and param.dtype in (torch.float16, torch.bfloat16, torch.float32):
base_dtype = param.dtype
if base_dtype is not None and is_npu_device:
From a4cd83e96e35021989ba3b95dd0d814c53ad8d24 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 18 May 2026 17:05:04 +0800
Subject: [PATCH 09/11] feat(moe): support hash routing in ep
---
.../model/transformers/moe/ep_utils.py | 49 ++++++++---------
.../model/transformers/moe/expert_parallel.py | 54 ++++++++++++++-----
2 files changed, 64 insertions(+), 39 deletions(-)
diff --git a/src/twinkle/model/transformers/moe/ep_utils.py b/src/twinkle/model/transformers/moe/ep_utils.py
index 94118448..f64c1796 100644
--- a/src/twinkle/model/transformers/moe/ep_utils.py
+++ b/src/twinkle/model/transformers/moe/ep_utils.py
@@ -96,24 +96,21 @@ def all_to_all_async(group, input, output_split_size, input_split_size):
# ========================== moe_utils ==========================
-def permute(tokens: torch.Tensor, routing_map: torch.Tensor):
+def permute(tokens: torch.Tensor, expert_mask: torch.Tensor):
"""
Permutes the tokens according to the routing map.
Args:
tokens (torch.Tensor): The input token tensor, [num_tokens, hidden_dim].
- routing_map (torch.Tensor): The sparse token to expert mapping, [num_experts, tokens].
+ expert_mask (torch.Tensor): The sparse token to expert mapping, [num_experts, top_k, num_tokens].
"""
num_tokens, _ = tokens.shape
- num_experts = routing_map.shape[0]
-
- # mask [num_tokens, num_experts] -> [num_experts, num_tokens]
- routing_map = routing_map.bool()
+ expert_mask = expert_mask.bool()
# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
- token_indices = torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
- sorted_indices = token_indices.masked_select(routing_map)
+ token_indices = torch.arange(num_tokens, device=expert_mask.device).view(1, 1, num_tokens).expand_as(expert_mask)
+ sorted_indices = token_indices.masked_select(expert_mask)
# use the mapping to permute the tokens
permuted_input = tokens.index_select(0, sorted_indices)
@@ -226,6 +223,7 @@ def preprocess(
def token_pre_all2all(
hidden_states: torch.Tensor,
expert_mask: torch.Tensor,
+ routing_weights: torch.Tensor,
num_experts: int,
input_splits: torch.Tensor,
output_splits: torch.Tensor,
@@ -235,9 +233,9 @@ def token_pre_all2all(
hidden_dim = hidden_states.size(-1)
hidden_states = hidden_states.reshape(-1, hidden_dim)
org_hidden_states_shape = hidden_states.shape
- routing_map = expert_mask.sum(dim=1)
- local_permuted_hidden_states, local_input_permutation_mapping = permute(hidden_states, routing_map)
+ local_permuted_hidden_states, local_input_permutation_mapping = permute(hidden_states, expert_mask)
+ local_assignment_weights = routing_weights.T.contiguous().masked_select(expert_mask.bool())
global_permuted_hidden_states = all_to_all(ep_group, local_permuted_hidden_states, output_splits, input_splits)
@@ -250,18 +248,21 @@ def token_pre_all2all(
permute_order,
)
- return global_permuted_hidden_states, routing_map, local_input_permutation_mapping, org_hidden_states_shape
+ return (
+ global_permuted_hidden_states,
+ local_input_permutation_mapping,
+ local_assignment_weights,
+ org_hidden_states_shape,
+ )
def tokens_post_all2all(
expert_outputs: torch.Tensor,
- routing_weights: torch.Tensor,
- selected_experts: torch.Tensor,
+ local_assignment_weights: torch.Tensor,
num_experts: int,
input_splits: torch.Tensor,
output_splits: torch.Tensor,
num_global_tokens_per_local_expert: torch.Tensor,
- routing_map: torch.Tensor,
local_input_permutation_mapping: torch.Tensor,
org_hidden_states_shape: torch.Size,
ep_group: Optional[dist.ProcessGroup] = None,
@@ -276,16 +277,12 @@ def tokens_post_all2all(
)
unpermute_outputs = all_to_all(ep_group, expert_outputs, input_splits, output_splits)
-
- # [tokens, experts]
- weights_idx = generate_weights_idx(routing_weights, selected_experts, num_experts)
-
- unpermute_outputs = unpermute(
- unpermute_outputs,
- weights_idx,
- org_hidden_states_shape,
- local_input_permutation_mapping,
- routing_map,
+ weighted_outputs = unpermute_outputs * local_assignment_weights.unsqueeze(-1)
+ hidden_dim = org_hidden_states_shape[-1]
+ final_outputs = torch.zeros(org_hidden_states_shape, device=weighted_outputs.device, dtype=weighted_outputs.dtype)
+ final_outputs.scatter_add_(
+ 0,
+ local_input_permutation_mapping.unsqueeze(1).expand(-1, hidden_dim),
+ weighted_outputs,
)
-
- return unpermute_outputs
+ return final_outputs
diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py
index c9ee0fa0..282f49c0 100644
--- a/src/twinkle/model/transformers/moe/expert_parallel.py
+++ b/src/twinkle/model/transformers/moe/expert_parallel.py
@@ -1,6 +1,7 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from __future__ import annotations
+import inspect
import torch
import torch.distributed as dist
import torch.nn.functional as F
@@ -187,6 +188,11 @@ def patch_forward(
raise ValueError('MoE block must define top_k/num_experts_per_tok.')
orig_forward = block.forward
+ return_annotation = inspect.signature(orig_forward).return_annotation
+ returns_router_logits = return_annotation in (
+ tuple,
+ Tuple[torch.Tensor, torch.Tensor | None],
+ )
num_experts = block._ep_num_experts
experts_per_rank = block._ep_experts_per_rank
is_tensor_experts = block._ep_tensor_experts
@@ -198,8 +204,8 @@ def patch_forward(
_install_ep_forward(block.experts, experts_per_rank)
def forward(hidden_states: torch.Tensor, *args, **kwargs):
- if args or kwargs:
- raise RuntimeError('Expert parallel patch only supports forward(hidden_states).')
+ if args:
+ raise RuntimeError('Expert parallel patch only supports keyword-only extra args for MoE blocks.')
orig_shape = hidden_states.shape
if hidden_states.ndim == 3:
@@ -218,11 +224,11 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
top_k=top_k,
router_dtype=_get_router_dtype(cfg.router_dtype, hidden_states_2d.dtype),
norm_topk_prob=getattr(block, 'norm_topk_prob', False),
+ **kwargs,
)
# Keep routing weights in activation dtype before unpermute weighting.
if routing_weights.dtype != hidden_states_2d.dtype:
routing_weights = routing_weights.to(hidden_states_2d.dtype)
-
# Build expert_mask: [num_experts, top_k, num_tokens]
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=num_experts).permute(2, 1, 0) # [num_experts, top_k, num_tokens]
@@ -238,12 +244,13 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
# 2. token_pre_all2all: permute → all_to_all → sort_chunks
(
global_permuted_hidden_states,
- routing_map,
local_input_permutation_mapping,
+ local_assignment_weights,
org_hidden_states_shape,
) = token_pre_all2all(
hidden_states_2d,
expert_mask,
+ routing_weights,
num_experts,
input_splits,
output_splits,
@@ -272,13 +279,11 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
# 4. tokens_post_all2all: sort_chunks → all_to_all → unpermute (with routing weight)
final_hidden = tokens_post_all2all(
expert_outputs,
- routing_weights,
- selected_experts,
+ local_assignment_weights,
num_experts,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
- routing_map,
local_input_permutation_mapping,
org_hidden_states_shape,
ep_group,
@@ -291,7 +296,7 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
if len(orig_shape) == 3:
final_hidden = final_hidden.view(batch_size, seq_len, hidden_dim)
- if cfg.keep_router_logits:
+ if cfg.keep_router_logits and returns_router_logits:
return final_hidden, router_logits
return final_hidden
@@ -311,7 +316,10 @@ def ep_forward(
experts_per_rank: int,
) -> torch.Tensor:
if permuted_tokens.numel() == 0:
- return torch.empty_like(permuted_tokens)
+ # Preserve the autograd edge to token_pre_all2all. Returning a new
+ # empty tensor can make this rank skip the matching backward
+ # all-to-all, causing EP collective order divergence.
+ return permuted_tokens
input_dtype = permuted_tokens.dtype
@@ -333,8 +341,12 @@ def ep_forward(
compute_dtype = gate_up.dtype
if expert_in.dtype != compute_dtype:
expert_in = expert_in.to(compute_dtype)
- gate, up = F.linear(expert_in, gate_up).chunk(2, dim=-1)
- out = self.act_fn(gate) * up
+ gate_up_out = F.linear(expert_in, gate_up)
+ if hasattr(self, '_apply_gate'):
+ out = self._apply_gate(gate_up_out)
+ else:
+ gate, up = gate_up_out.chunk(2, dim=-1)
+ out = self.act_fn(gate) * up
out = F.linear(out, down)
if out.dtype != input_dtype:
@@ -399,6 +411,8 @@ def _maybe_run_shared_expert(block: nn.Module, hidden_states_2d: torch.Tensor, c
if cfg.ignore_shared_experts:
return None
shared = getattr(block, 'shared_expert', None)
+ if shared is None:
+ shared = getattr(block, 'shared_experts', None)
if shared is None:
return None
return _run_module_with_casting(shared, hidden_states_2d)
@@ -431,7 +445,9 @@ def _run_local_experts(
that happens in unpermute.
"""
if permuted_tokens.numel() == 0:
- return torch.empty_like(permuted_tokens)
+ # Keep the backward path through token_pre_all2all even when this EP
+ # rank owns no routed tokens for the current block.
+ return permuted_tokens
input_dtype = permuted_tokens.dtype
experts = block.experts
@@ -487,8 +503,12 @@ def _run_router(
top_k: int,
router_dtype: torch.dtype,
norm_topk_prob: bool,
+ **kwargs,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- gate_out = gate(hidden_states)
+ gate_kwargs = {}
+ if 'input_ids' in kwargs and _module_forward_accepts_kwarg(gate, 'input_ids'):
+ gate_kwargs['input_ids'] = kwargs['input_ids']
+ gate_out = gate(hidden_states, **gate_kwargs)
if isinstance(gate_out, tuple) and len(gate_out) >= 3:
router_logits, routing_weights, selected_experts = gate_out[:3]
return router_logits, routing_weights, selected_experts
@@ -499,3 +519,11 @@ def _run_router(
if norm_topk_prob:
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
return router_logits, routing_weights, selected_experts
+
+
+def _module_forward_accepts_kwarg(module: nn.Module, kwarg: str) -> bool:
+ signature = inspect.signature(module.forward)
+ for param in signature.parameters.values():
+ if param.kind == inspect.Parameter.VAR_KEYWORD:
+ return True
+ return kwarg in signature.parameters
From 092e55f8154b2cf2a2bb8c512fd3ba94368a2928 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 18 May 2026 17:56:57 +0800
Subject: [PATCH 10/11] feat(transformers): add expert parallelism support to
NativeFSDP strategy
- Enable expert parallelism (EP) with configurable router dtype and logits
- Implement rank0 pre-EP full state dict for efficient weight loading
- Add EP-aware state dict broadcasting and expert shard specifications
- Update DeepSeek V4 flash cookbook with EP configuration example
---
cookbook/transformers/deepseek_v4_flash.py | 6 +
.../transformers/strategy/native_fsdp.py | 216 +++++++++++++++++-
.../model/transformers/transformers.py | 15 ++
3 files changed, 227 insertions(+), 10 deletions(-)
diff --git a/cookbook/transformers/deepseek_v4_flash.py b/cookbook/transformers/deepseek_v4_flash.py
index 3e8e6376..eacb8c84 100644
--- a/cookbook/transformers/deepseek_v4_flash.py
+++ b/cookbook/transformers/deepseek_v4_flash.py
@@ -43,6 +43,7 @@
device_mesh = DeviceMesh.from_sizes(
fsdp_size=4,
dp_size=1,
+ ep_size=4,
device_type=Platform.get_platform().device_prefix(),
)
@@ -87,6 +88,11 @@ def train():
ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES,
fsdp_config={
'reshard_after_forward': RESHARD_AFTER_FORWARD,
+ 'expert_parallel': {
+ 'enabled': True,
+ 'router_dtype': 'fp32',
+ 'keep_router_logits': False,
+ },
},
)
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
index 3cbaf141..9e3bbad7 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -28,6 +28,7 @@ def __init__(self,
self._memory_efficient_init = memory_efficient_init
self.enable_ep = enable_ep
self.ep_fsdp_device_mesh = self._build_ep_fsdp_device_mesh(ep_size) if enable_ep else None
+ self._rank0_pre_ep_full_state_dict = None
def pretrained_load_context(self):
# Native FSDP loads pretrained weights via rank0 broadcast during wrap_model().
@@ -38,6 +39,9 @@ def pretrained_load_context(self):
def use_rank0_pretrained_broadcast(self) -> bool:
return self._memory_efficient_init and self.device_mesh is not None
+ def set_rank0_pre_ep_full_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
+ self._rank0_pre_ep_full_state_dict = state_dict
+
def _build_ep_fsdp_device_mesh(self, ep_size: Optional[int] = None) -> Optional[TorchDeviceMesh]:
if self.device_mesh is None:
return None
@@ -65,17 +69,21 @@ def wrap_model(self, model, optimizer=None):
if optimizer is not None:
_unbind_optimizer_params(optimizer)
- use_meta = self.use_rank0_pretrained_broadcast() and not ep_enabled
+ use_meta = self.use_rank0_pretrained_broadcast()
original_sd = None
saved_buffers = None
if use_meta:
is_rank0 = (dist.get_rank() == 0)
- original_sd = model.state_dict() if is_rank0 else {}
+ if ep_enabled and self._rank0_pre_ep_full_state_dict is not None:
+ original_sd = self._rank0_pre_ep_full_state_dict if is_rank0 else {}
+ else:
+ original_sd = model.state_dict() if is_rank0 else {}
saved_buffers = _get_non_persistent_buffers(model) if is_rank0 else {}
- model = model.to(torch.device('meta'))
- if hasattr(model, 'tie_weights'):
- model.tie_weights()
+ if is_rank0:
+ model = model.to(torch.device('meta'))
+ if hasattr(model, 'tie_weights'):
+ model.tie_weights()
if ep_enabled:
_ensure_moe_patched_if_needed(model, self.ep_fsdp_device_mesh)
@@ -102,7 +110,7 @@ def wrap_model(self, model, optimizer=None):
for layer_mod, experts_mod in layer_pairs:
layer_mod._fsdp_modules = []
- if experts_mod is not None and ep_fsdp_mesh_1d is not None:
+ if experts_mod is not None and ep_fsdp_mesh_1d is not None and ep_fsdp_mesh_1d.size() > 1:
from torch.distributed.tensor import Shard
ep_mp_policy = _build_ep_mp_policy(mp_policy)
@@ -113,7 +121,8 @@ def wrap_model(self, model, optimizer=None):
mp_policy=ep_mp_policy,
shard_placement_fn=lambda param: Shard(1),
)
- experts_mod.set_gradient_divide_factor(world_size)
+ if hasattr(experts_mod, 'set_gradient_divide_factor'):
+ experts_mod.set_gradient_divide_factor(world_size)
layer_mod._fsdp_modules.append(experts_mod)
fully_shard(
@@ -135,7 +144,16 @@ def wrap_model(self, model, optimizer=None):
if use_meta:
device_type = self.device_mesh.device_type or 'cuda'
- _load_rank0_full_state_dict(model, original_sd or {})
+ if ep_enabled:
+ _broadcast_sharded_state_dict(
+ model,
+ original_sd or {},
+ device_type=device_type,
+ expert_shard_specs=_collect_ep_expert_shard_specs(model),
+ rank_to_ep_rank=_build_rank_to_ep_rank(self.ep_fsdp_device_mesh),
+ )
+ else:
+ _load_rank0_full_state_dict(model, original_sd or {})
target_device = torch.device(device_type)
_broadcast_non_persistent_buffers(model, saved_buffers or {}, device=target_device)
if hasattr(model, 'tie_weights'):
@@ -376,6 +394,8 @@ def _collect_expert_params(model: nn.Module) -> Optional[Set[nn.Parameter]]:
if getattr(module, '_ep_ignore_shared_experts', False) and getattr(module, '_ep_patched', False):
ep_patched = True
shared = getattr(module, 'shared_expert', None)
+ if shared is None:
+ shared = getattr(module, 'shared_experts', None)
if shared is not None:
ignored.update(shared.parameters())
@@ -411,6 +431,37 @@ def _collect_ep_experts_map(model: nn.Module) -> Dict[str, nn.Module]:
return experts_map
+def _collect_ep_expert_shard_specs(model: nn.Module) -> Dict[str, Dict[str, int]]:
+ """Collect state-dict names for expert tensors sharded by EP."""
+ specs = {}
+ for fqn, module in model.named_modules():
+ if not getattr(module, '_ep_patched', False):
+ continue
+ experts = getattr(module, 'experts', None)
+ if experts is None:
+ continue
+ experts_prefix = f'{fqn}.experts.' if fqn else 'experts.'
+ for pname, _ in experts.named_parameters():
+ specs[experts_prefix + pname] = {
+ 'num_experts': int(module._ep_num_experts),
+ 'experts_per_rank': int(module._ep_experts_per_rank),
+ }
+ return specs
+
+
+def _build_rank_to_ep_rank(ep_fsdp_device_mesh: Optional[TorchDeviceMesh]) -> Dict[int, int]:
+ if ep_fsdp_device_mesh is None:
+ return {}
+ mesh = ep_fsdp_device_mesh.mesh
+ if hasattr(mesh, 'detach'):
+ mesh = mesh.detach().cpu().numpy()
+ rank_to_ep_rank = {}
+ for ep_rank in range(mesh.shape[0]):
+ for rank in mesh[ep_rank].flatten().tolist():
+ rank_to_ep_rank[int(rank)] = int(ep_rank)
+ return rank_to_ep_rank
+
+
def _find_experts_in_layer(layer_mod: nn.Module, experts_map: Dict[str, nn.Module]) -> Optional[nn.Module]:
"""Find the experts module inside a decoder layer, if any."""
for module in layer_mod.modules():
@@ -444,11 +495,22 @@ def _place_ep_experts_on_local_device(model: nn.Module, ep_fsdp_device_mesh: Opt
continue
experts = getattr(module, 'experts', None)
if experts is not None:
- experts.to(local_device)
+ _move_ep_module_to_device(experts, local_device)
if getattr(module, '_ep_ignore_shared_experts', False):
shared = getattr(module, 'shared_expert', None)
+ if shared is None:
+ shared = getattr(module, 'shared_experts', None)
if shared is not None:
- shared.to(local_device)
+ _move_ep_module_to_device(shared, local_device)
+
+
+def _move_ep_module_to_device(module: nn.Module, device: torch.device) -> None:
+ has_meta_tensor = any(param.is_meta for param in module.parameters(recurse=True))
+ has_meta_tensor = has_meta_tensor or any(buffer.is_meta for buffer in module.buffers(recurse=True))
+ if has_meta_tensor:
+ module.to_empty(device=device)
+ else:
+ module.to(device)
def _ensure_moe_patched_if_needed(model: nn.Module, ep_fsdp_device_mesh: Optional[TorchDeviceMesh]) -> None:
@@ -515,6 +577,140 @@ def _load_rank0_full_state_dict(model: nn.Module, full_sd: dict) -> None:
)
+def _broadcast_sharded_state_dict(
+ model: nn.Module,
+ full_sd: dict,
+ device_type: str = 'cuda',
+ expert_shard_specs: Optional[Dict[str, Dict[str, int]]] = None,
+ rank_to_ep_rank: Optional[Dict[int, int]] = None,
+) -> None:
+ """Broadcast rank0 full state dict and materialize local FSDP2/EP shards."""
+ from torch.distributed.tensor import DTensor, Partial, Replicate, Shard
+
+ meta_sharded_sd = model.state_dict()
+ sharded_sd = {}
+ is_rank0 = (dist.get_rank() == 0)
+ expert_shard_specs = expert_shard_specs or {}
+ rank_to_ep_rank = rank_to_ep_rank or {}
+
+ source_metadata = None
+ if is_rank0:
+ source_metadata = {
+ name: (tuple(tensor.shape), tensor.dtype)
+ for name, tensor in full_sd.items() if hasattr(tensor, 'shape') and hasattr(tensor, 'dtype')
+ }
+ metadata_holder = [source_metadata]
+ dist.broadcast_object_list(metadata_holder, src=0)
+ source_metadata = metadata_holder[0] or {}
+
+ def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements):
+ local_tensor = full_tensor
+ for mesh_dim, placement in enumerate(placements):
+ if isinstance(placement, Shard):
+ local_tensor = placement._shard_tensor(
+ local_tensor,
+ device_mesh,
+ mesh_dim,
+ src_data_rank=None,
+ )
+ local_tensor = local_tensor.contiguous().clone()
+ elif isinstance(placement, Replicate):
+ continue
+ elif isinstance(placement, Partial):
+ raise NotImplementedError('Native FSDP2 full-state loading does not support Partial placements.')
+ else:
+ raise NotImplementedError(f'Unsupported DTensor placement: {placement}')
+ return DTensor.from_local(
+ local_tensor,
+ device_mesh=device_mesh,
+ placements=placements,
+ run_check=False,
+ shape=full_tensor.shape,
+ stride=full_tensor.stride(),
+ )
+
+ def _scatter_ep_expert_tensor(param_name: str, full_tensor, sharded_param):
+ spec = expert_shard_specs[param_name]
+ experts_per_rank = spec['experts_per_rank']
+ num_experts = spec['num_experts']
+ local_shape = tuple(sharded_param.size())
+ if param_name not in source_metadata:
+ raise KeyError(f"Missing source metadata for EP expert parameter '{param_name}'.")
+ _, source_dtype = source_metadata[param_name]
+ local_tensor = torch.empty(local_shape, device=device_type, dtype=source_dtype)
+
+ if is_rank0:
+ if full_tensor.size(0) != num_experts:
+ raise RuntimeError(f"EP expert parameter '{param_name}' expects {num_experts} experts, "
+ f'but source state has shape {tuple(full_tensor.shape)}. '
+ 'Rank0 must capture the full pre-EP state_dict before apply_expert_parallel().')
+ world_size = dist.get_world_size()
+ for rank in range(world_size):
+ if rank not in rank_to_ep_rank:
+ raise RuntimeError(f'Missing EP rank mapping for global rank {rank}.')
+ ep_rank = rank_to_ep_rank[rank]
+ start = ep_rank * experts_per_rank
+ end = start + experts_per_rank
+ chunk = full_tensor[start:end].contiguous()
+ chunk_gpu = chunk.to(device_type)
+ if rank == 0:
+ local_tensor.copy_(chunk_gpu)
+ else:
+ dist.send(chunk_gpu, dst=rank)
+ else:
+ dist.recv(local_tensor, src=0)
+
+ return local_tensor
+
+ for param_name, sharded_param in meta_sharded_sd.items():
+ is_ep_expert_param = param_name in expert_shard_specs
+ if param_name not in source_metadata:
+ raise KeyError(f"Missing source metadata for parameter '{param_name}'.")
+ source_shape, source_dtype = source_metadata[param_name]
+
+ if is_rank0:
+ if param_name not in full_sd:
+ raise KeyError(
+ f"Parameter '{param_name}' found in sharded model state dict but missing from full state dict.")
+ full_param = full_sd[param_name]
+ full_tensor = full_param.detach()
+ if isinstance(full_tensor, DTensor):
+ full_tensor = full_tensor.to_local()
+ if not is_ep_expert_param:
+ full_tensor = full_tensor.to(device_type)
+ if tuple(full_tensor.shape) != tuple(source_shape) or full_tensor.dtype != source_dtype:
+ raise RuntimeError(f"Source metadata mismatch for '{param_name}': "
+ f'actual shape={tuple(full_tensor.shape)} dtype={full_tensor.dtype}, '
+ f'expected shape={source_shape} dtype={source_dtype}.')
+ else:
+ full_tensor = None if is_ep_expert_param else torch.empty(
+ source_shape, device=device_type, dtype=source_dtype)
+
+ if is_ep_expert_param:
+ full_tensor = _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param)
+ else:
+ if tuple(sharded_param.size()) != tuple(source_shape):
+ raise RuntimeError(f"Parameter '{param_name}' shape mismatch before broadcast: "
+ f'sharded logical shape={tuple(sharded_param.size())}, '
+ f'source shape={source_shape}.')
+ dist.broadcast(full_tensor, src=0)
+ torch_util.synchronize()
+
+ if isinstance(sharded_param, DTensor):
+ sharded_tensor = _dtensor_from_replicated_full_tensor(
+ full_tensor,
+ sharded_param.device_mesh,
+ sharded_param.placements,
+ )
+ else:
+ sharded_tensor = full_tensor
+ del full_tensor
+
+ sharded_sd[param_name] = sharded_tensor
+
+ model.load_state_dict(sharded_sd, assign=True)
+
+
def _get_non_persistent_buffers(model: nn.Module) -> Dict[str, torch.Tensor]:
"""Return {fqn: tensor} for non-persistent buffers (lost on to('meta'))."""
non_persistent_fqns: Set[str] = set()
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index 46484792..ebb35cab 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -47,6 +47,16 @@
logger = get_logger()
+def _clone_state_dict_to_cpu(state_dict: Dict[str, Any]) -> Dict[str, Any]:
+ cloned = {}
+ for key, value in state_dict.items():
+ if hasattr(value, 'detach'):
+ cloned[key] = value.detach().cpu().clone()
+ else:
+ cloned[key] = value
+ return cloned
+
+
@dataclass
class OptimizerGroup(BaseOptimizerGroup):
"""Optimizer group for Transformers training."""
@@ -286,6 +296,11 @@ def _not_encoded(inputs):
def _lazy_wrap_model(self):
if not self._model_wrapped:
optimizer_groups = [og for og in self.optimizer_group.values() if og.optimizer is not None]
+ use_rank0_broadcast = getattr(self.strategy, 'use_rank0_pretrained_broadcast', lambda: False)
+ set_pre_ep_state = getattr(self.strategy, 'set_rank0_pre_ep_full_state_dict', None)
+ if self._enable_expert_parallel and use_rank0_broadcast() and set_pre_ep_state is not None:
+ is_rank0 = dist.is_available() and dist.is_initialized() and dist.get_rank() == 0
+ set_pre_ep_state(_clone_state_dict_to_cpu(self.model.state_dict()) if is_rank0 else {})
self._maybe_apply_expert_parallel()
self._ensure_sp_strategy()
if self.sp_strategy is not None:
From a87dbf9b752466a34e8fc3c88c7a0914d487db54 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 18 May 2026 18:46:18 +0800
Subject: [PATCH 11/11] update cookbook
---
cookbook/transformers/deepseek_v4_flash.py | 10 +++++-----
cookbook/transformers/deepseek_v4_flash.sh | 2 +-
2 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/cookbook/transformers/deepseek_v4_flash.py b/cookbook/transformers/deepseek_v4_flash.py
index eacb8c84..869f4cc8 100644
--- a/cookbook/transformers/deepseek_v4_flash.py
+++ b/cookbook/transformers/deepseek_v4_flash.py
@@ -41,9 +41,9 @@
ADAPTER_NAME = 'default'
device_mesh = DeviceMesh.from_sizes(
- fsdp_size=4,
+ fsdp_size=8,
dp_size=1,
- ep_size=4,
+ ep_size=8,
device_type=Platform.get_platform().device_prefix(),
)
@@ -89,9 +89,9 @@ def train():
fsdp_config={
'reshard_after_forward': RESHARD_AFTER_FORWARD,
'expert_parallel': {
- 'enabled': True,
- 'router_dtype': 'fp32',
- 'keep_router_logits': False,
+ 'enabled': True,
+ 'router_dtype': 'fp32',
+ 'keep_router_logits': False,
},
},
)
diff --git a/cookbook/transformers/deepseek_v4_flash.sh b/cookbook/transformers/deepseek_v4_flash.sh
index 991e60eb..bbdb58ff 100644
--- a/cookbook/transformers/deepseek_v4_flash.sh
+++ b/cookbook/transformers/deepseek_v4_flash.sh
@@ -3,4 +3,4 @@
# https://gitcode.com/cann/cann-recipes-train/blob/master/llm_pretrain/deepseekv4/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87
# Install `transformers==5.8.0` before running this cookbook.
-CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 cookbook/transformers/deepseek_v4_flash.py
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 cookbook/transformers/deepseek_v4_flash.py