Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 67 additions & 5 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,37 +437,96 @@ def _replace_with_config(self, child, name):
else:
return self._create_column_parallel_layer(child, spec, name)

def _get_tp_shape_config(self):
"""Return the config object that exposes TP-related shape metadata."""
config = getattr(self.module, "config", None)
if config is not None and hasattr(config, "text_config"):
return config.text_config
return config

def _resolve_spec_shape(self, spec: TPLayerSpec, module, name: str):
"""Resolve a static or config-derived logical shape for a matched spec."""
if spec.shape is not None:
return spec.shape

resolver = getattr(spec, "shape_resolver", None)
if resolver is None:
return None

if resolver == "qwen3_5_linear_attn_qkv":
config = self._get_tp_shape_config()
required_fields = [
"linear_num_key_heads",
"linear_key_head_dim",
"linear_num_value_heads",
"linear_value_head_dim",
]
missing = [field for field in required_fields if config is None or not hasattr(config, field)]
if missing:
raise ValueError(f"AutoTP layer '{name}' requires config fields {required_fields} "
f"to resolve qwen3_5 linear_attn.in_proj_qkv, missing {missing}.")

q_size = int(config.linear_num_key_heads) * int(config.linear_key_head_dim)
k_size = q_size
v_size = int(config.linear_num_value_heads) * int(config.linear_value_head_dim)
if q_size + k_size + v_size != module.weight.shape[0]:
raise ValueError(f"AutoTP layer '{name}' resolved fused QKV sizes "
f"({q_size}, {k_size}, {v_size}) do not match weight output dim "
f"{module.weight.shape[0]}.")
return ((q_size, k_size, v_size), -1)

raise ValueError(f"Unknown AutoTP shape_resolver '{resolver}' for layer '{name}'.")

def _can_replace_qwen35_linear_attn(self, module, name: str) -> bool:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like this module is looking at the 'fingerprint' of qwen35. Eventurally, do we need to extend fingerprint probing to other new models as well?

if self.partition_config is None:
return False
required_attrs = [
"in_proj_qkv", "in_proj_z", "in_proj_a", "in_proj_b", "out_proj", "conv1d", "dt_bias", "A_log",
"num_k_heads", "num_v_heads", "head_k_dim", "head_v_dim", "key_dim", "value_dim", "conv_dim"
]
if not all(hasattr(module, attr) for attr in required_attrs):
return False

model_type = self._get_model_type()
spec = self.partition_config.find_matching_spec(f"{name}.in_proj_qkv.weight", model_type)
return spec is not None and getattr(spec, "shape_resolver", None) == "qwen3_5_linear_attn_qkv"

def _replace_qwen35_linear_attn(self, module, name: str):
return Qwen35LinearAttentionLayer(module, self.mp_group, name=name)

def _create_row_parallel_layer(self, module, spec: TPLayerSpec, name: str):
"""Create row-parallel layer (AllReduce after forward)."""
resolved_shape = self._resolve_spec_shape(spec, module, name)
if self.conv_linear_layer:
return Conv_LinearALlreduce(module, self.mp_group, name=name)
# Check for lm_head / embed_out
if name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(module, self.mp_group)

if spec.shape is not None:
if resolved_shape is not None:
return SubParamLinearAllreduce(
module,
self.mp_group,
shape=spec.shape,
shape=resolved_shape,
partition_dim=spec.get_partition_dim(),
name=name,
)
return LinearAllreduce(module, self.mp_group, name=name)

def _create_column_parallel_layer(self, module, spec: TPLayerSpec, name: str):
"""Create column-parallel layer (AllReduce in backward)."""
resolved_shape = self._resolve_spec_shape(spec, module, name)
if self.conv_linear_layer:
return conv_LinearLayer(module, self.mp_group, name=name)
# Only use fused-QKV heuristics when no partition_config is provided.
elif self.partition_config is None and require_tp_fused_qkvw(name, self.mp_size):
# Check and handle fused qkv for TP
return fused_LinearLayer(module, self.mp_group, fused_module=self.module)
if spec.shape is not None:
if resolved_shape is not None:
return SubParamLinearLayer(
module,
self.mp_group,
shape=spec.shape,
shape=resolved_shape,
partition_dim=spec.get_partition_dim(),
name=name,
)
Expand Down Expand Up @@ -517,7 +576,8 @@ def update_mp_params(self, child):
param_list = [
"n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads", "all_head_size",
"embed_dim", "hidden_size", "num_key_value_heads", "num_kv_heads", "kv_n_heads", "d_model",
"num_attention_heads_per_partition", "num_multi_query_groups_per_partition", "hidden_size_per_partition"
"num_attention_heads_per_partition", "num_multi_query_groups_per_partition", "hidden_size_per_partition",
"linear_num_key_heads", "linear_num_value_heads"
]
for param in param_list:
if "Yuan" in str(child) and 'embed_dim' in param_list:
Expand Down Expand Up @@ -575,6 +635,8 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
if new_child is not None:
setattr(r_module, name, new_child)
# If no pattern matched or skip, leave embedding unchanged
elif self._can_replace_qwen35_linear_attn(child, full_name):
setattr(r_module, name, self._replace_qwen35_linear_attn(child, full_name))
elif hasattr(child, "weight") and getattr(child.weight, "dim", lambda: 0)() == 2:
new_child = self._replace_with_config(child, full_name)
if new_child is not None:
Expand Down
59 changes: 59 additions & 0 deletions deepspeed/module_inject/autotp_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ class TPLayerSpec:
# (n_experts, -1, hidden) -> MoE reshape
shape: Optional[Tuple[Union[int, Tuple[int, ...]], ...]] = None

# Optional: resolver name for dynamic shape inference when a static shape
# cannot be expressed in a preset. This is intended for built-in presets
# that need model-config-derived sub-parameter sizes.
shape_resolver: Optional[str] = None

# Which dimension to partition (after optional reshape)
# Default: 0 for COLUMN, 1 for ROW (standard 2D weight matrix)
partition_dim: Optional[int] = None
Expand Down Expand Up @@ -286,6 +291,7 @@ def from_dict(cls, config_dict: dict) -> "AutoTPConfig":
patterns=spec_dict.get("patterns", []),
partition_type=partition_type,
shape=shape,
shape_resolver=spec_dict.get("shape_resolver"),
partition_dim=spec_dict.get("partition_dim"),
model_types=spec_dict.get("model_types"),
))
Expand Down Expand Up @@ -507,6 +513,58 @@ def qwen2() -> AutoTPConfig:
),
], )

@staticmethod
def qwen3_5() -> AutoTPConfig:
"""Qwen 3.5 dense model with gated linear-attention coverage.

This preset covers:
- standard self_attn projections in full-attention layers
- mlp projections in every decoder layer
- linear_attn.in_proj_qkv via config-derived unequal fused QKV splits
- linear_attn.in_proj_z, linear_attn.in_proj_a, linear_attn.in_proj_b,
and linear_attn.out_proj

The built-in AutoTP replacement also shards the remaining local-head
state inside the gated linear-attention block, including conv1d,
dt_bias, and A_log.
"""
return AutoTPConfig(layer_specs=[
TPLayerSpec(
patterns=[r".*\.self_attn\.o_proj\.weight$"],
partition_type=PartitionType.ROW,
),
TPLayerSpec(
patterns=[r".*\.self_attn\.[qkv]_proj\.weight$"],
partition_type=PartitionType.COLUMN,
),
TPLayerSpec(
patterns=[r".*\.linear_attn\.out_proj\.weight$"],
partition_type=PartitionType.ROW,
),
TPLayerSpec(
patterns=[r".*\.linear_attn\.in_proj_qkv\.weight$"],
partition_type=PartitionType.COLUMN,
shape_resolver="qwen3_5_linear_attn_qkv",
partition_dim=0,
),
TPLayerSpec(
patterns=[r".*\.linear_attn\.in_proj_z\.weight$"],
partition_type=PartitionType.COLUMN,
),
TPLayerSpec(
patterns=[r".*\.linear_attn\.in_proj_[ab]\.weight$"],
partition_type=PartitionType.COLUMN,
),
TPLayerSpec(
patterns=[r".*\.mlp\.down_proj\.weight$"],
partition_type=PartitionType.ROW,
),
TPLayerSpec(
patterns=[r".*\.mlp\.(up|gate)_proj\.weight$"],
partition_type=PartitionType.COLUMN,
),
], )

@staticmethod
def phi3() -> AutoTPConfig:
"""Phi3 model with fused QKV and chunked MLP."""
Expand Down Expand Up @@ -546,6 +604,7 @@ def get_preset(model_type: str) -> Optional[AutoTPConfig]:
"mixtral": AutoTPPresets.mixtral,
"deepseek_v2": AutoTPPresets.deepseek_v2,
"qwen2": AutoTPPresets.qwen2,
"qwen3_5": AutoTPPresets.qwen3_5,
"phi3": AutoTPPresets.phi3,
}
preset_fn = presets.get(model_type.lower())
Expand Down
Loading
Loading