diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 852c492f8b8e..45183234c55c 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -437,19 +437,77 @@ def _replace_with_config(self, child, name): else: return self._create_column_parallel_layer(child, spec, name) + def _get_tp_shape_config(self): + """Return the config object that exposes TP-related shape metadata.""" + config = getattr(self.module, "config", None) + if config is not None and hasattr(config, "text_config"): + return config.text_config + return config + + def _resolve_spec_shape(self, spec: TPLayerSpec, module, name: str): + """Resolve a static or config-derived logical shape for a matched spec.""" + if spec.shape is not None: + return spec.shape + + resolver = getattr(spec, "shape_resolver", None) + if resolver is None: + return None + + if resolver == "qwen3_5_linear_attn_qkv": + config = self._get_tp_shape_config() + required_fields = [ + "linear_num_key_heads", + "linear_key_head_dim", + "linear_num_value_heads", + "linear_value_head_dim", + ] + missing = [field for field in required_fields if config is None or not hasattr(config, field)] + if missing: + raise ValueError(f"AutoTP layer '{name}' requires config fields {required_fields} " + f"to resolve qwen3_5 linear_attn.in_proj_qkv, missing {missing}.") + + q_size = int(config.linear_num_key_heads) * int(config.linear_key_head_dim) + k_size = q_size + v_size = int(config.linear_num_value_heads) * int(config.linear_value_head_dim) + if q_size + k_size + v_size != module.weight.shape[0]: + raise ValueError(f"AutoTP layer '{name}' resolved fused QKV sizes " + f"({q_size}, {k_size}, {v_size}) do not match weight output dim " + f"{module.weight.shape[0]}.") + return ((q_size, k_size, v_size), -1) + + raise ValueError(f"Unknown AutoTP shape_resolver '{resolver}' for layer '{name}'.") + + def _can_replace_qwen35_linear_attn(self, module, name: str) -> bool: + if self.partition_config is None: + return False + required_attrs = [ + "in_proj_qkv", "in_proj_z", "in_proj_a", "in_proj_b", "out_proj", "conv1d", "dt_bias", "A_log", + "num_k_heads", "num_v_heads", "head_k_dim", "head_v_dim", "key_dim", "value_dim", "conv_dim" + ] + if not all(hasattr(module, attr) for attr in required_attrs): + return False + + model_type = self._get_model_type() + spec = self.partition_config.find_matching_spec(f"{name}.in_proj_qkv.weight", model_type) + return spec is not None and getattr(spec, "shape_resolver", None) == "qwen3_5_linear_attn_qkv" + + def _replace_qwen35_linear_attn(self, module, name: str): + return Qwen35LinearAttentionLayer(module, self.mp_group, name=name) + def _create_row_parallel_layer(self, module, spec: TPLayerSpec, name: str): """Create row-parallel layer (AllReduce after forward).""" + resolved_shape = self._resolve_spec_shape(spec, module, name) if self.conv_linear_layer: return Conv_LinearALlreduce(module, self.mp_group, name=name) # Check for lm_head / embed_out if name == "lm_head" or name == 'embed_out': return LmHeadLinearAllreduce(module, self.mp_group) - if spec.shape is not None: + if resolved_shape is not None: return SubParamLinearAllreduce( module, self.mp_group, - shape=spec.shape, + shape=resolved_shape, partition_dim=spec.get_partition_dim(), name=name, ) @@ -457,17 +515,18 @@ def _create_row_parallel_layer(self, module, spec: TPLayerSpec, name: str): def _create_column_parallel_layer(self, module, spec: TPLayerSpec, name: str): """Create column-parallel layer (AllReduce in backward).""" + resolved_shape = self._resolve_spec_shape(spec, module, name) if self.conv_linear_layer: return conv_LinearLayer(module, self.mp_group, name=name) # Only use fused-QKV heuristics when no partition_config is provided. elif self.partition_config is None and require_tp_fused_qkvw(name, self.mp_size): # Check and handle fused qkv for TP return fused_LinearLayer(module, self.mp_group, fused_module=self.module) - if spec.shape is not None: + if resolved_shape is not None: return SubParamLinearLayer( module, self.mp_group, - shape=spec.shape, + shape=resolved_shape, partition_dim=spec.get_partition_dim(), name=name, ) @@ -517,7 +576,8 @@ def update_mp_params(self, child): param_list = [ "n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads", "all_head_size", "embed_dim", "hidden_size", "num_key_value_heads", "num_kv_heads", "kv_n_heads", "d_model", - "num_attention_heads_per_partition", "num_multi_query_groups_per_partition", "hidden_size_per_partition" + "num_attention_heads_per_partition", "num_multi_query_groups_per_partition", "hidden_size_per_partition", + "linear_num_key_heads", "linear_num_value_heads" ] for param in param_list: if "Yuan" in str(child) and 'embed_dim' in param_list: @@ -575,6 +635,8 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''): if new_child is not None: setattr(r_module, name, new_child) # If no pattern matched or skip, leave embedding unchanged + elif self._can_replace_qwen35_linear_attn(child, full_name): + setattr(r_module, name, self._replace_qwen35_linear_attn(child, full_name)) elif hasattr(child, "weight") and getattr(child.weight, "dim", lambda: 0)() == 2: new_child = self._replace_with_config(child, full_name) if new_child is not None: diff --git a/deepspeed/module_inject/autotp_config.py b/deepspeed/module_inject/autotp_config.py index 4bafea806829..bac5d72ef292 100644 --- a/deepspeed/module_inject/autotp_config.py +++ b/deepspeed/module_inject/autotp_config.py @@ -110,6 +110,11 @@ class TPLayerSpec: # (n_experts, -1, hidden) -> MoE reshape shape: Optional[Tuple[Union[int, Tuple[int, ...]], ...]] = None + # Optional: resolver name for dynamic shape inference when a static shape + # cannot be expressed in a preset. This is intended for built-in presets + # that need model-config-derived sub-parameter sizes. + shape_resolver: Optional[str] = None + # Which dimension to partition (after optional reshape) # Default: 0 for COLUMN, 1 for ROW (standard 2D weight matrix) partition_dim: Optional[int] = None @@ -286,6 +291,7 @@ def from_dict(cls, config_dict: dict) -> "AutoTPConfig": patterns=spec_dict.get("patterns", []), partition_type=partition_type, shape=shape, + shape_resolver=spec_dict.get("shape_resolver"), partition_dim=spec_dict.get("partition_dim"), model_types=spec_dict.get("model_types"), )) @@ -507,6 +513,58 @@ def qwen2() -> AutoTPConfig: ), ], ) + @staticmethod + def qwen3_5() -> AutoTPConfig: + """Qwen 3.5 dense model with gated linear-attention coverage. + + This preset covers: + - standard self_attn projections in full-attention layers + - mlp projections in every decoder layer + - linear_attn.in_proj_qkv via config-derived unequal fused QKV splits + - linear_attn.in_proj_z, linear_attn.in_proj_a, linear_attn.in_proj_b, + and linear_attn.out_proj + + The built-in AutoTP replacement also shards the remaining local-head + state inside the gated linear-attention block, including conv1d, + dt_bias, and A_log. + """ + return AutoTPConfig(layer_specs=[ + TPLayerSpec( + patterns=[r".*\.self_attn\.o_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.self_attn\.[qkv]_proj\.weight$"], + partition_type=PartitionType.COLUMN, + ), + TPLayerSpec( + patterns=[r".*\.linear_attn\.out_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.linear_attn\.in_proj_qkv\.weight$"], + partition_type=PartitionType.COLUMN, + shape_resolver="qwen3_5_linear_attn_qkv", + partition_dim=0, + ), + TPLayerSpec( + patterns=[r".*\.linear_attn\.in_proj_z\.weight$"], + partition_type=PartitionType.COLUMN, + ), + TPLayerSpec( + patterns=[r".*\.linear_attn\.in_proj_[ab]\.weight$"], + partition_type=PartitionType.COLUMN, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.down_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.(up|gate)_proj\.weight$"], + partition_type=PartitionType.COLUMN, + ), + ], ) + @staticmethod def phi3() -> AutoTPConfig: """Phi3 model with fused QKV and chunked MLP.""" @@ -546,6 +604,7 @@ def get_preset(model_type: str) -> Optional[AutoTPConfig]: "mixtral": AutoTPPresets.mixtral, "deepseek_v2": AutoTPPresets.deepseek_v2, "qwen2": AutoTPPresets.qwen2, + "qwen3_5": AutoTPPresets.qwen3_5, "phi3": AutoTPPresets.phi3, } preset_fn = presets.get(model_type.lower()) diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 0a8cda1d1daa..92e064e38399 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -23,7 +23,7 @@ __all__ = [ "TensorParallel_Layer", "LinearAllreduce", "LinearLayer", "LmHeadLinearAllreduce", "Yuan_LinearAllreduce", "Yuan_LinearLayer", "GateUpPack_LinearLayer", "Conv_LinearALlreduce", "fused_LinearLayer", "conv_LinearLayer", - "SubParamLinearLayer", "SubParamLinearAllreduce" + "SubParamLinearLayer", "SubParamLinearAllreduce", "DepthwiseConv1dLayer", "Qwen35LinearAttentionLayer" ] DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE @@ -1433,6 +1433,316 @@ def _mark_uc_metadata(self): replicated=True) +class DepthwiseConv1dLayer(TensorParallel_Layer): + """Tensor-parallel depthwise Conv1d used by Qwen 3.5 linear attention.""" + + def __init__(self, module, mp_group, **kwargs): + super(DepthwiseConv1dLayer, self).__init__(mp_group, **kwargs) + if module.groups != module.in_channels or module.groups != module.out_channels or module.weight.shape[1] != 1: + raise ValueError(f"AutoTP layer '{self.name}' only supports depthwise Conv1d with one weight channel.") + + self.weight = module.weight + self.bias = module.bias + self.stride = module.stride + self.padding = module.padding + self.dilation = module.dilation + self.padding_mode = module.padding_mode + self._reversed_padding_repeated_twice = getattr(module, "_reversed_padding_repeated_twice", None) + self.kernel_size = module.kernel_size + + self._orig_weight_shape = tuple(module.weight.shape) + self._orig_bias_shape = tuple(module.bias.shape) if module.bias is not None else None + self._logical_weight_shape = tuple(module.weight.shape) + self.sub_param_sizes = kwargs.get('sub_param_sizes', None) + + if self._should_materialize_tp_partition(): + self._tp_partition([self.weight, self.bias]) + + self.in_channels = self.weight.shape[0] + self.out_channels = self.weight.shape[0] + self.groups = self.weight.shape[0] + + self.support_training = True + self.config_tp_params(self.weight) + if self.bias is not None: + self.config_tp_params(self.bias) + self._mark_uc_metadata() + + def forward(self, input): + if self.padding_mode != 'zeros': + return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), self.weight, + self.bias, self.stride, (0, ), self.dilation, self.groups) + return F.conv1d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + @torch.no_grad() + def gather_params(self, params_list): + for idx, param in enumerate(params_list): + if param is None: + continue + params_list[idx].data_partition = param.data + logical_shape = self._logical_weight_shape if idx == 0 else self._orig_bias_shape + full_view = _gather_logical_tensor(param, + logical_shape, + 0, + self.mp_group, + self.tp_world_size, + name=self.name, + subparam_sizes=self.sub_param_sizes) + params_list[idx].data = full_view.reshape(self._orig_weight_shape if idx == 0 else self._orig_bias_shape) + + @torch.no_grad() + def _tp_partition(self, params_list): + for idx, param in enumerate(params_list): + if param is None: + continue + logical_shape = self._logical_weight_shape if idx == 0 else self._orig_bias_shape + view = param.reshape(logical_shape) + partitioned = _partition_logical_tensor(view, + 0, + self.tp_world_size, + self.tp_index, + name=self.name, + subparam_sizes=self.sub_param_sizes) + target = partitioned.reshape(-1) if idx == 1 else partitioned + params_list[idx].data = self.move(target).detach() + + def _mark_uc_metadata(self): + self._set_param_uc_meta(self.weight, + partition_type='column', + partition_dim=0, + logical_shape=self._logical_weight_shape, + output_shape=(self._orig_weight_shape[0], ), + original_shape=self._orig_weight_shape, + target_partition_shape=self.weight.shape) + if self.bias is not None: + self._set_param_uc_meta(self.bias, + partition_type='column', + partition_dim=0, + logical_shape=self._orig_bias_shape, + output_shape=self._orig_bias_shape, + original_shape=self._orig_bias_shape, + target_partition_shape=self.bias.shape, + is_bias=True) + + +class Qwen35LinearAttentionLayer(nn.Module): + """Local-head tensor-parallel wrapper for Qwen 3.5 gated linear attention.""" + + def __init__(self, module, mp_group, **kwargs): + super().__init__() + self.mp_group = mp_group + self.tp_world_size = dist.get_world_size(mp_group) if mp_group is not None else 1 + self.tp_index = dist.get_rank(mp_group) if mp_group is not None else 0 + self.name = kwargs.get('name', 'linear_attn') + + self.hidden_size = module.hidden_size + self.head_k_dim = module.head_k_dim + self.head_v_dim = module.head_v_dim + self.conv_kernel_size = module.conv_kernel_size + self.layer_idx = module.layer_idx + self.activation = module.activation + self.act = module.act + self.layer_norm_epsilon = module.layer_norm_epsilon + self.causal_conv1d_fn = module.causal_conv1d_fn + self.causal_conv1d_update = module.causal_conv1d_update + self.chunk_gated_delta_rule = module.chunk_gated_delta_rule + self.recurrent_gated_delta_rule = module.recurrent_gated_delta_rule + + self._orig_num_k_heads = int(module.num_k_heads) + self._orig_num_v_heads = int(module.num_v_heads) + self._orig_key_dim = int(module.key_dim) + self._orig_value_dim = int(module.value_dim) + self._orig_conv_dim = int(module.conv_dim) + + self.key_dim = get_shard_size(self._orig_key_dim, self.tp_world_size, f"{self.name}.in_proj_qkv") + self.value_dim = get_shard_size(self._orig_value_dim, self.tp_world_size, f"{self.name}.in_proj_z") + if self.key_dim % self.head_k_dim != 0: + raise ValueError( + f"AutoTP layer '{self.name}' resolved local key_dim={self.key_dim} not divisible by head_k_dim={self.head_k_dim}." + ) + if self.value_dim % self.head_v_dim != 0: + raise ValueError( + f"AutoTP layer '{self.name}' resolved local value_dim={self.value_dim} not divisible by head_v_dim={self.head_v_dim}." + ) + self.num_k_heads = self.key_dim // self.head_k_dim + self.num_v_heads = self.value_dim // self.head_v_dim + self.conv_dim = self.key_dim * 2 + self.value_dim + + self.in_proj_qkv = SubParamLinearLayer(module.in_proj_qkv, + mp_group, + shape=((self._orig_key_dim, self._orig_key_dim, self._orig_value_dim), + -1), + partition_dim=0, + name=f"{self.name}.in_proj_qkv") + self.conv1d = DepthwiseConv1dLayer(module.conv1d, + mp_group, + name=f"{self.name}.conv1d", + sub_param_sizes=(self._orig_key_dim, self._orig_key_dim, + self._orig_value_dim)) + self.in_proj_z = LinearLayer(module.in_proj_z, mp_group, name=f"{self.name}.in_proj_z") + self.in_proj_a = LinearLayer(module.in_proj_a, mp_group, name=f"{self.name}.in_proj_a") + self.in_proj_b = LinearLayer(module.in_proj_b, mp_group, name=f"{self.name}.in_proj_b") + self.norm = module.norm + self.out_proj = LinearAllreduce(module.out_proj, mp_group, name=f"{self.name}.out_proj") + + self.dt_bias = module.dt_bias + self.A_log = module.A_log + self._sharded_param_shapes = { + "dt_bias": tuple(module.dt_bias.shape), + "A_log": tuple(module.A_log.shape), + } + self._configure_sharded_vector_param(self.dt_bias, "dt_bias") + self._configure_sharded_vector_param(self.A_log, "A_log") + if self._should_materialize_tp_partition(): + self._partition_sharded_vector_params([self.dt_bias, self.A_log]) + self._set_sharded_vector_param_meta(self.dt_bias, "dt_bias") + self._set_sharded_vector_param_meta(self.A_log, "A_log") + + def _should_materialize_tp_partition(self): + return self.mp_group is not None + + def _move_tensor(self, tensor): + if tensor.is_meta: + return tensor + device = 'cpu' if TensorParallel_Layer.keep_module_on_host else get_accelerator().current_device_name() + return tensor.to(device, copy=not TensorParallel_Layer.keep_module_on_host) + + def _configure_sharded_vector_param(self, param, param_name): + param_name_full = f"{self.name}.{param_name}" + if is_autotp_training_mode(): + param.requires_grad = True + else: + param.requires_grad = False + param.gather_params = self._gather_sharded_vector_params + param._tp_partition = self._partition_sharded_vector_params + setattr(param, DS_TENSOR_MODEL_PARALLEL, True) + setattr(param, DS_IS_REPLACED_MODULE, True) + setattr(param, "_ds_autotp_shard_name", param_name_full) + + def _set_sharded_vector_param_meta(self, param, param_name): + setattr( + param, DS_AUTOTP_UC_META, + _build_param_uc_restore_meta(partition_type='column', + partition_dim=0, + logical_shape=self._sharded_param_shapes[param_name], + output_shape=self._sharded_param_shapes[param_name], + target_partition_shape=tuple(param.shape), + original_shape=self._sharded_param_shapes[param_name], + is_bias=False, + replicated=False)) + + @torch.no_grad() + def _gather_sharded_vector_params(self, params_list): + for param in params_list: + if param is None: + continue + param.data_partition = param.data + full_view = _gather_logical_tensor(param, + self._sharded_param_shapes[param._ds_autotp_shard_name.rsplit('.', + 1)[-1]], + 0, + self.mp_group, + self.tp_world_size, + name=param._ds_autotp_shard_name) + param.data = full_view.reshape(self._sharded_param_shapes[param._ds_autotp_shard_name.rsplit('.', 1)[-1]]) + + @torch.no_grad() + def _partition_sharded_vector_params(self, params_list): + for param in params_list: + if param is None: + continue + shape_key = param._ds_autotp_shard_name.rsplit('.', 1)[-1] + logical_shape = self._sharded_param_shapes[shape_key] + partitioned = _partition_logical_tensor(param.reshape(logical_shape), + 0, + self.tp_world_size, + self.tp_index, + name=param._ds_autotp_shard_name) + param.data = self._move_tensor(partitioned.reshape(-1)).detach() + + def forward(self, hidden_states, cache_params=None, attention_mask=None): + hidden_states = hidden_states if attention_mask is None or attention_mask.shape[ + 1] <= 1 or attention_mask.shape[0] <= 1 else (hidden_states * + attention_mask[:, :, None]).to(hidden_states.dtype) + + batch_size, seq_len, _ = hidden_states.shape + use_precomputed_states = cache_params is not None and cache_params.has_previous_state( + self.layer_idx) and seq_len == 1 + + if use_precomputed_states: + conv_state = cache_params.layers[self.layer_idx].conv_states + recurrent_state = cache_params.layers[self.layer_idx].recurrent_states + + mixed_qkv = self.in_proj_qkv(hidden_states).transpose(1, 2) + + z = self.in_proj_z(hidden_states) + z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) + + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + if use_precomputed_states: + mixed_qkv = self.causal_conv1d_update(mixed_qkv, conv_state, self.conv1d.weight.squeeze(1), + self.conv1d.bias, self.activation) + else: + if cache_params is not None: + conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) + if self.causal_conv1d_fn is not None: + mixed_qkv = self.causal_conv1d_fn(x=mixed_qkv, + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=None) + else: + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = torch.split(mixed_qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1) + + query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) + + beta = b.sigmoid() + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + repeat_factor = self.num_v_heads // self.num_k_heads + query = query.repeat_interleave(repeat_factor, dim=2) + key = key.repeat_interleave(repeat_factor, dim=2) + + if not use_precomputed_states: + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=cache_params + is not None, + use_qk_l2norm_in_kernel=True) + else: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=cache_params + is not None, + use_qk_l2norm_in_kernel=True) + + if cache_params is not None: + cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx) + + core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) + + return self.out_proj(core_attn_out) + + class RMSNormalize(nn.Module): def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None): diff --git a/deepspeed/runtime/tensor_parallel/config.py b/deepspeed/runtime/tensor_parallel/config.py index 56abf5868d5d..8b23d0fa8baf 100644 --- a/deepspeed/runtime/tensor_parallel/config.py +++ b/deepspeed/runtime/tensor_parallel/config.py @@ -89,7 +89,7 @@ class TPTrainingConfig(DeepSpeedConfigModel): preset_model: Optional[str] = None """ Use a built-in preset for common model architectures. - Available presets: "llama", "bloom", "chatglm", "mixtral", "deepseek_v2", "qwen2", "phi3" + Available presets: "llama", "bloom", "chatglm", "mixtral", "deepseek_v2", "qwen2", "qwen3_5", "phi3" """ #The following parameters are required by autoTP parser. diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index f8209c8d8068..17493d083322 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -753,7 +753,7 @@ Configuring the asynchronous I/O module for offloading parameter and optimizer s ### Tensor Parallel (AutoTP) Configure AutoTP tensor parallelism for training via the DeepSpeed config and hybrid TP + ZeRO. AutoTP supports ZeRO stages 0, 1, and 2 (stage 3 is not supported). `deepspeed.tp_model_init()` remains supported for backward compatibility but is not required when `tensor_parallel` is set in the config. -When a HuggingFace model provides a built-in `tp_plan` (via `model.config.base_model_tp_plan`), DeepSpeed automatically detects and uses it. In this case, neither `preset_model` nor `partition_config` is required -- just set `autotp_size`. If `partition_config` is also provided, it takes precedence over the model's `tp_plan`. +When a HuggingFace model provides a built-in `tp_plan` (via `model.config.base_model_tp_plan`), DeepSpeed automatically detects and uses it. In this case, neither `preset_model` nor `partition_config` is required -- just set `autotp_size`. If `preset_model` or `partition_config` is also provided, that explicit manual AutoTP config takes precedence over the model's `tp_plan`. If both are provided, `partition_config` still overrides or extends the preset according to `use_default_specs`. ```json "tensor_parallel": { "autotp_size": 4, @@ -786,7 +786,10 @@ When a HuggingFace model provides a built-in `tp_plan` (via `model.config.base_m | Description | Default | | ----------------------------------------------------------------------------------------------------- | ------- | -| Built-in model presets: `llama`, `bloom`, `chatglm`, `mixtral`, `deepseek_v2`, `qwen2`, `phi3`. | `null` | +| Built-in model presets: `llama`, `bloom`, `chatglm`, `mixtral`, `deepseek_v2`, `qwen2`, `qwen3_5`, `phi3`. | `null` | + +> **`qwen3_5` coverage note:** This manual preset targets dense Qwen 3.5 decoder layers. It covers `mlp.{gate,up,down}_proj` in every decoder layer, `self_attn.{q,k,v,o}_proj` where full-attention layers are present, and `linear_attn.{in_proj_qkv,in_proj_z,out_proj}` on hybrid linear-attention layers. The fused `linear_attn.in_proj_qkv` split sizes are derived from the model config. The preset still does not cover `linear_attn.in_proj_a`, `linear_attn.in_proj_b`, `linear_attn.conv1d`, `linear_attn.norm`, `linear_attn.dt_bias`, `linear_attn.A_log`, or Qwen 3.5 MoE weights. Because `linear_attn.in_proj_a` and `linear_attn.in_proj_b` remain unmatched 2D weights, `strict_mode` is still not compatible with hybrid Qwen 3.5 models when this preset is used by itself. + ***tp_overlap_comm***: [boolean] diff --git a/tests/unit/model_parallelism/test_autotp_preset_qwen35.py b/tests/unit/model_parallelism/test_autotp_preset_qwen35.py new file mode 100644 index 000000000000..d930c0bdc754 --- /dev/null +++ b/tests/unit/model_parallelism/test_autotp_preset_qwen35.py @@ -0,0 +1,627 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from copy import deepcopy +from types import SimpleNamespace + +import deepspeed +import deepspeed.comm as dist +import pytest +import torch +from torch import nn +from torch.nn import functional as F + +from unit.common import DistributedTest, preferred_dtype +from deepspeed.accelerator import get_accelerator +from deepspeed.module_inject.auto_tp import AutoTP +from deepspeed.module_inject.autotp_config import AutoTPConfig, AutoTPPresets, PartitionType +from deepspeed.module_inject.layers import DepthwiseConv1dLayer, LinearAllreduce, LinearLayer, SubParamLinearLayer +from deepspeed.runtime.tensor_parallel.config import TPTrainingConfig +from deepspeed.utils import groups + + +def skip_on_device(): + if get_accelerator().device_name() == "xpu": + pytest.skip("XPU requires a higher version for test") + + +def assert_close_for_preferred_dtype(actual, expected): + atol = 5e-3 + rtol = 2e-2 + if preferred_dtype() is torch.float32: + atol = 1e-5 + rtol = 1e-5 + torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) + + +def make_mock_qwen35_config(): + return SimpleNamespace( + model_type="qwen3_5_text", + linear_num_key_heads=2, + linear_key_head_dim=4, + linear_num_value_heads=2, + linear_value_head_dim=4, + ) + + +def load_real_qwen35_classes(): + try: + from transformers import Qwen3_5ForCausalLM, Qwen3_5TextConfig + except ImportError: + pytest.skip("transformers with Qwen3.5 support is required") + return Qwen3_5TextConfig, Qwen3_5ForCausalLM + + +def make_real_qwen35_text_config(): + qwen35_text_config, _ = load_real_qwen35_classes() + return qwen35_text_config( + vocab_size=128, + hidden_size=64, + intermediate_size=256, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + linear_num_key_heads=4, + linear_key_head_dim=8, + linear_num_value_heads=4, + linear_value_head_dim=8, + attention_dropout=0.0, + use_cache=False, + layer_types=["linear_attention", "linear_attention", "linear_attention", "full_attention"], + ) + + +class MockLinearAttention(nn.Module): + """GatedDeltaNet-style linear-attention submodule.""" + + def __init__(self, hidden_dim): + super().__init__() + config = make_mock_qwen35_config() + self.hidden_size = hidden_dim + self.num_k_heads = config.linear_num_key_heads + self.num_v_heads = config.linear_num_value_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.num_k_heads * self.head_k_dim + self.value_dim = self.num_v_heads * self.head_v_dim + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv_kernel_size = 4 + self.layer_idx = 0 + self.activation = "silu" + self.act = F.silu + self.layer_norm_epsilon = 1e-6 + + self.conv1d = nn.Conv1d(self.conv_dim, + self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1) + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + self.A_log = nn.Parameter(torch.log(torch.linspace(1.0, 2.0, self.num_v_heads))) + self.norm = MockRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon) + self.in_proj_qkv = nn.Linear(hidden_dim, self.conv_dim, bias=False) + self.in_proj_z = nn.Linear(hidden_dim, self.value_dim, bias=False) + self.in_proj_b = nn.Linear(hidden_dim, self.num_v_heads, bias=False) + self.in_proj_a = nn.Linear(hidden_dim, self.num_v_heads, bias=False) + self.out_proj = nn.Linear(self.value_dim, hidden_dim, bias=False) + self.causal_conv1d_fn = None + self.causal_conv1d_update = mock_torch_causal_conv1d_update + self.chunk_gated_delta_rule = mock_chunk_gated_delta_rule + self.recurrent_gated_delta_rule = mock_recurrent_gated_delta_rule + + def forward(self, x): + batch_size, seq_len, _ = x.shape + + mixed_qkv = self.in_proj_qkv(x).transpose(1, 2) + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]).transpose(1, 2) + + z = self.in_proj_z(x).reshape(batch_size, seq_len, -1, self.head_v_dim) + b = self.in_proj_b(x) + a = self.in_proj_a(x) + + query, key, value = torch.split(mixed_qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1) + query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) + + beta = b.sigmoid() + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + repeat_factor = self.num_v_heads // self.num_k_heads + query = query.repeat_interleave(repeat_factor, dim=2) + key = key.repeat_interleave(repeat_factor, dim=2) + + core_attn_out, _ = self.chunk_gated_delta_rule(query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True) + core_attn_out = self.norm(core_attn_out.reshape(-1, self.head_v_dim), z.reshape(-1, self.head_v_dim)) + return self.out_proj(core_attn_out.reshape(batch_size, seq_len, -1)) + + +class MockRMSNormGated(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + return hidden_states.to(input_dtype) + + +def mock_torch_causal_conv1d_update(hidden_states, conv_state, weight, bias=None, activation=None): + _, hidden_size, seq_len = hidden_states.shape + state_len = conv_state.shape[-1] + + hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype) + conv_state.copy_(hidden_states_new[:, :, -state_len:]) + out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size) + out = F.silu(out[:, :, -seq_len:]) + return out.to(hidden_states.dtype) + + +def mock_chunk_gated_delta_rule(query, + key, + value, + g, + beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True): + del initial_state, use_qk_l2norm_in_kernel + qk_mix = query + key + core_attn_out = value + beta.unsqueeze(-1) * qk_mix + g.to(value.dtype).unsqueeze(-1) + last_recurrent_state = core_attn_out[:, :, -1].contiguous() if output_final_state else None + return core_attn_out, last_recurrent_state + + +def mock_recurrent_gated_delta_rule(query, + key, + value, + g, + beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True): + del initial_state + return mock_chunk_gated_delta_rule(query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=output_final_state, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel) + + +class MockFullAttention(nn.Module): + """Standard multi-head attention submodule.""" + + def __init__(self, hidden_dim): + super().__init__() + self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.k_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.v_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.o_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) + + def forward(self, x): + return self.o_proj(self.q_proj(x) + self.k_proj(x) + self.v_proj(x)) + + +class MockGatedQProjAttention(nn.Module): + """Full-attention block with a widened q_proj output.""" + + def __init__(self, hidden_dim): + super().__init__() + self.q_proj = nn.Linear(hidden_dim, hidden_dim * 2, bias=False) + self.o_proj = nn.Linear(hidden_dim * 2, hidden_dim, bias=False) + + def forward(self, x): + return self.o_proj(self.q_proj(x)) + + +class MockMLP(nn.Module): + + def __init__(self, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(hidden_dim, hidden_dim * 4, bias=False) + self.up_proj = nn.Linear(hidden_dim, hidden_dim * 4, bias=False) + self.down_proj = nn.Linear(hidden_dim * 4, hidden_dim, bias=False) + + def forward(self, x): + return self.down_proj(self.gate_proj(x) * self.up_proj(x)) + + +class MockLinearAttnDecoderLayer(nn.Module): + + def __init__(self, hidden_dim): + super().__init__() + self.linear_attn = MockLinearAttention(hidden_dim) + self.mlp = MockMLP(hidden_dim) + + def forward(self, x): + return self.mlp(self.linear_attn(x)) + + +class MockFullAttnDecoderLayer(nn.Module): + + def __init__(self, hidden_dim): + super().__init__() + self.self_attn = MockFullAttention(hidden_dim) + self.mlp = MockMLP(hidden_dim) + + def forward(self, x): + return self.mlp(self.self_attn(x)) + + +class MockGatedQProjLayer(nn.Module): + + def __init__(self, hidden_dim): + super().__init__() + self.self_attn = MockGatedQProjAttention(hidden_dim) + + def forward(self, x): + return self.self_attn(x) + + +class MockQwen35HybridModel(nn.Module): + """4-layer hybrid model with full attention every fourth layer.""" + + def __init__(self, hidden_dim): + super().__init__() + self.config = make_mock_qwen35_config() + self.layers = nn.ModuleList([ + MockLinearAttnDecoderLayer(hidden_dim), + MockLinearAttnDecoderLayer(hidden_dim), + MockLinearAttnDecoderLayer(hidden_dim), + MockFullAttnDecoderLayer(hidden_dim), + ]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +class MockGatedQProjModel(nn.Module): + """Single-layer model for widened q_proj sharding checks.""" + + def __init__(self, hidden_dim): + super().__init__() + self.config = make_mock_qwen35_config() + self.layers = nn.ModuleList([MockGatedQProjLayer(hidden_dim)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +class TestQwen35PresetPatterns: + """Verify qwen3_5 preset resolution and pattern coverage.""" + + def test_get_preset_returns_config(self): + config = AutoTPPresets.get_preset("qwen3_5") + assert config is not None + assert isinstance(config, AutoTPConfig) + + def test_preset_has_eight_layer_specs(self): + config = AutoTPPresets.get_preset("qwen3_5") + assert len(config.layer_specs) == 8 + + def test_self_attn_column_parallel(self): + config = AutoTPPresets.get_preset("qwen3_5") + for proj in ["q_proj", "k_proj", "v_proj"]: + spec = config.find_matching_spec(f"model.layers.3.self_attn.{proj}.weight") + assert spec is not None, f"self_attn.{proj} should match" + assert spec.partition_type == PartitionType.COLUMN + + def test_self_attn_row_parallel(self): + config = AutoTPPresets.get_preset("qwen3_5") + spec = config.find_matching_spec("model.layers.3.self_attn.o_proj.weight") + assert spec is not None + assert spec.partition_type == PartitionType.ROW + + def test_mlp_matches_in_linear_attn_layer(self): + config = AutoTPPresets.get_preset("qwen3_5") + + spec = config.find_matching_spec("model.layers.0.mlp.gate_proj.weight") + assert spec is not None + assert spec.partition_type == PartitionType.COLUMN + + spec = config.find_matching_spec("model.layers.0.mlp.up_proj.weight") + assert spec is not None + assert spec.partition_type == PartitionType.COLUMN + + spec = config.find_matching_spec("model.layers.0.mlp.down_proj.weight") + assert spec is not None + assert spec.partition_type == PartitionType.ROW + + def test_linear_attn_supported_linears_match(self): + config = AutoTPPresets.get_preset("qwen3_5") + + qkv_spec = config.find_matching_spec("model.layers.0.linear_attn.in_proj_qkv.weight") + assert qkv_spec is not None + assert qkv_spec.partition_type == PartitionType.COLUMN + assert qkv_spec.shape_resolver == "qwen3_5_linear_attn_qkv" + + z_spec = config.find_matching_spec("model.layers.0.linear_attn.in_proj_z.weight") + assert z_spec is not None + assert z_spec.partition_type == PartitionType.COLUMN + + a_spec = config.find_matching_spec("model.layers.0.linear_attn.in_proj_a.weight") + assert a_spec is not None + assert a_spec.partition_type == PartitionType.COLUMN + + b_spec = config.find_matching_spec("model.layers.0.linear_attn.in_proj_b.weight") + assert b_spec is not None + assert b_spec.partition_type == PartitionType.COLUMN + + out_spec = config.find_matching_spec("model.layers.0.linear_attn.out_proj.weight") + assert out_spec is not None + assert out_spec.partition_type == PartitionType.ROW + + def test_linear_attn_remaining_weights_not_matched(self): + config = AutoTPPresets.get_preset("qwen3_5") + unmatched_names = [ + "model.layers.0.linear_attn.conv1d.weight", + "model.layers.0.linear_attn.dt_bias", + "model.layers.0.linear_attn.A_log", + ] + for name in unmatched_names: + assert config.find_matching_spec(name) is None, f"Preset should not match {name}" + + def test_moe_names_not_matched(self): + config = AutoTPPresets.get_preset("qwen3_5") + moe_names = [ + "model.layers.0.mlp.experts.gate_up_proj.weight", + "model.layers.0.mlp.experts.down_proj.weight", + "model.layers.0.mlp.shared_expert.gate_proj.weight", + "model.layers.0.mlp.shared_expert.up_proj.weight", + "model.layers.0.mlp.shared_expert.down_proj.weight", + ] + for name in moe_names: + assert config.find_matching_spec(name) is None, f"Dense preset should not match {name}" + + def test_multimodal_prefix_still_matches(self): + config = AutoTPPresets.get_preset("qwen3_5") + + spec = config.find_matching_spec("model.language_model.layers.3.self_attn.q_proj.weight") + assert spec is not None + assert spec.partition_type == PartitionType.COLUMN + + spec = config.find_matching_spec("model.language_model.layers.0.mlp.down_proj.weight") + assert spec is not None + assert spec.partition_type == PartitionType.ROW + + def test_preset_via_get_partition_config_object(self): + tp_config = TPTrainingConfig(autotp_size=2, preset_model="qwen3_5") + config = tp_config.get_partition_config_object() + assert config is not None + assert config.tp_size == 2 + assert len(config.layer_specs) == 8 + + +class TestQwen35MockHybridModel(DistributedTest): + world_size = 2 + reuse_dist_env = False + + def _apply_preset(self, model, strict_mode=False): + groups._init_tp_mesh_device(tensor_model_parallel_size=2) + config = AutoTPPresets.get_preset("qwen3_5") + config.strict_mode = strict_mode + autotp = AutoTP( + module=model, + all_reduce_linears=[], + prefix="", + state_dict=None, + linear_layer_setting=None, + orig_layer_impl=None, + keep_module_on_host=False, + partition_config=config, + ) + autotp.set_tensor_parallel_config(2, groups.get_tensor_model_parallel_group()) + autotp.update_linear_policies() + autotp._replace_module(model) + return model + + def test_preset_replaces_supported_layers(self): + skip_on_device() + model = MockQwen35HybridModel(hidden_dim=16) + model = self._apply_preset(model) + + assert isinstance(model.layers[3].self_attn.q_proj, LinearLayer) + assert isinstance(model.layers[3].self_attn.k_proj, LinearLayer) + assert isinstance(model.layers[3].self_attn.v_proj, LinearLayer) + assert isinstance(model.layers[3].self_attn.o_proj, LinearAllreduce) + + assert isinstance(model.layers[3].mlp.gate_proj, LinearLayer) + assert isinstance(model.layers[3].mlp.up_proj, LinearLayer) + assert isinstance(model.layers[3].mlp.down_proj, LinearAllreduce) + + for i in range(3): + assert isinstance(model.layers[i].mlp.gate_proj, LinearLayer) + assert isinstance(model.layers[i].mlp.up_proj, LinearLayer) + assert isinstance(model.layers[i].mlp.down_proj, LinearAllreduce) + + assert isinstance(model.layers[i].linear_attn.in_proj_qkv, SubParamLinearLayer) + assert isinstance(model.layers[i].linear_attn.conv1d, DepthwiseConv1dLayer) + assert isinstance(model.layers[i].linear_attn.in_proj_z, LinearLayer) + assert isinstance(model.layers[i].linear_attn.in_proj_a, LinearLayer) + assert isinstance(model.layers[i].linear_attn.in_proj_b, LinearLayer) + assert isinstance(model.layers[i].linear_attn.out_proj, LinearAllreduce) + assert model.layers[i].linear_attn.num_k_heads == 1 + assert model.layers[i].linear_attn.num_v_heads == 1 + assert model.layers[i].linear_attn.key_dim == 4 + assert model.layers[i].linear_attn.value_dim == 4 + assert model.layers[i].linear_attn.conv_dim == 12 + assert model.layers[i].linear_attn.in_proj_qkv.weight.shape == (12, 16) + assert model.layers[i].linear_attn.conv1d.weight.shape == (12, 1, 4) + assert model.layers[i].linear_attn.in_proj_z.weight.shape == (4, 16) + assert model.layers[i].linear_attn.in_proj_a.weight.shape == (1, 16) + assert model.layers[i].linear_attn.in_proj_b.weight.shape == (1, 16) + assert model.layers[i].linear_attn.dt_bias.shape == (1, ) + assert model.layers[i].linear_attn.A_log.shape == (1, ) + + def test_strict_mode_accepts_full_linear_attn_block(self): + skip_on_device() + model = MockQwen35HybridModel(hidden_dim=16) + model = self._apply_preset(model, strict_mode=True) + assert isinstance(model.layers[0].linear_attn.in_proj_a, LinearLayer) + assert isinstance(model.layers[0].linear_attn.conv1d, DepthwiseConv1dLayer) + + def test_linear_attn_supported_weights_shard_cleanly(self): + skip_on_device() + hidden_dim = 16 + device = get_accelerator().current_device_name() + + torch.manual_seed(1234) + model = MockQwen35HybridModel(hidden_dim).to(device=device, dtype=preferred_dtype()) + baseline = deepcopy(model) + model = self._apply_preset(model) + + torch.manual_seed(4321) + inputs = torch.randn(2, 3, hidden_dim, dtype=preferred_dtype(), device=device) + full_output = baseline(inputs) + tp_output = model(inputs) + assert_close_for_preferred_dtype(tp_output, full_output) + + def test_gated_q_proj_output_width_still_shards_cleanly(self): + skip_on_device() + hidden_dim = 16 + device = get_accelerator().current_device_name() + + torch.manual_seed(1234) + model = MockGatedQProjModel(hidden_dim).to(device=device, dtype=preferred_dtype()) + baseline = deepcopy(model) + model = self._apply_preset(model) + + q_proj = model.layers[0].self_attn.q_proj + o_proj = model.layers[0].self_attn.o_proj + + assert isinstance(q_proj, LinearLayer) + assert isinstance(o_proj, LinearAllreduce) + assert q_proj.weight.shape == (hidden_dim, hidden_dim) + + torch.manual_seed(4321) + inputs = torch.randn(2, hidden_dim, dtype=preferred_dtype(), device=device) + full_output = baseline(inputs) + tp_output = model(inputs) + assert_close_for_preferred_dtype(tp_output, full_output) + + +class TestQwen35RealHFModel(DistributedTest): + world_size = 2 + reuse_dist_env = False + + def _make_tp_config(self): + config = { + "train_micro_batch_size_per_gpu": 1, + "tensor_parallel": { + "autotp_size": 2, + "preset_model": "qwen3_5", + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-4, + }, + }, + "zero_optimization": { + "stage": 0, + }, + "steps_per_print": 1, + } + if preferred_dtype() == torch.float16: + config["fp16"] = {"enabled": True} + elif preferred_dtype() == torch.bfloat16: + config["bf16"] = {"enabled": True} + return config + + def _seed_all(self, seed): + torch.manual_seed(seed) + get_accelerator().manual_seed_all(seed) + + def _build_real_baseline_and_engine(self): + skip_on_device() + _, qwen35_for_causal_lm = load_real_qwen35_classes() + + config = make_real_qwen35_text_config() + device = get_accelerator().current_device_name() + + self._seed_all(1234) + baseline = qwen35_for_causal_lm(config).to(device=device, dtype=preferred_dtype()) + baseline.eval() + + model = deepcopy(baseline) + engine, _, _, _ = deepspeed.initialize( + model=model, + model_parameters=model.parameters(), + config=self._make_tp_config(), + ) + engine.eval() + return baseline, engine, config, device + + def test_real_qwen35_preset_replaces_supported_layers(self): + _, engine, _, _ = self._build_real_baseline_and_engine() + + assert engine.autotp_size() == 2 + assert isinstance(engine.module.model.layers[0].linear_attn.in_proj_qkv, SubParamLinearLayer) + assert isinstance(engine.module.model.layers[0].linear_attn.conv1d, DepthwiseConv1dLayer) + assert isinstance(engine.module.model.layers[0].linear_attn.in_proj_z, LinearLayer) + assert isinstance(engine.module.model.layers[0].linear_attn.in_proj_a, LinearLayer) + assert isinstance(engine.module.model.layers[0].linear_attn.in_proj_b, LinearLayer) + assert isinstance(engine.module.model.layers[0].linear_attn.out_proj, LinearAllreduce) + assert isinstance(engine.module.model.layers[3].self_attn.q_proj, LinearLayer) + assert isinstance(engine.module.model.layers[3].self_attn.o_proj, LinearAllreduce) + assert engine.module.model.layers[0].linear_attn.num_k_heads == 2 + assert engine.module.model.layers[0].linear_attn.num_v_heads == 2 + assert engine.module.model.layers[0].linear_attn.key_dim == 16 + assert engine.module.model.layers[0].linear_attn.value_dim == 16 + assert engine.module.model.layers[0].linear_attn.conv_dim == 48 + assert engine.module.model.layers[0].linear_attn.dt_bias.shape == (2, ) + assert engine.module.model.layers[0].linear_attn.A_log.shape == (2, ) + + def test_real_qwen35_first_forward_matches_baseline(self): + baseline, engine, config, device = self._build_real_baseline_and_engine() + + self._seed_all(4321) + input_ids = torch.randint(0, config.vocab_size, (1, 8), device=device) + dist.broadcast( + input_ids, + src=groups.get_tensor_model_parallel_src_rank(), + group=groups.get_tensor_model_parallel_group(), + ) + attention_mask = torch.ones_like(input_ids) + + with torch.no_grad(): + baseline_output = baseline( + input_ids=input_ids, + attention_mask=attention_mask, + labels=input_ids, + use_cache=False, + ) + tp_output = engine( + input_ids=input_ids, + attention_mask=attention_mask, + labels=input_ids, + use_cache=False, + ) + + assert_close_for_preferred_dtype(tp_output.loss, baseline_output.loss) + assert_close_for_preferred_dtype(tp_output.logits, baseline_output.logits)