Skip to content

Conversation

@timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Jan 24, 2026

Description

This PR adds a grouped linear op, which can be used in the grouped MLP block in Mixture-of-Experts models. It also adds an experimental fused operation for a grouped MLP block, using a CuTe DSL kernel that computes an MXFP8 grouped GEMM and SwiGLU.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add a grouped linear operation
  • Add a post-scaled SwiGLU op and add support for interleaving SwiGLU gate and linear units
  • Add a fused operation for grouped MLP

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

timmoon10 and others added 30 commits January 7, 2026 00:15
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Test is too permissive since the test should still be failing. The weights are not properly interleaved yet.

Signed-off-by: Tim Moon <[email protected]>
@timmoon10 timmoon10 mentioned this pull request Jan 25, 2026
13 tasks
@timmoon10 timmoon10 changed the title [PyTorch] Prototype of fused operation for grouped MLP [PyTorch] Add grouped linear op and experimental fusion for grouped MLP Jan 25, 2026
Signed-off-by: Tim Moon <[email protected]>
@timmoon10 timmoon10 marked this pull request as ready for review January 25, 2026 01:00
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch L1

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 25, 2026

Greptile Overview

Greptile Summary

This PR adds support for grouped linear operations and an experimental fused grouped MLP for Mixture-of-Experts models.

Key changes:

  • Added GroupedLinear operation that applies multiple linear transformations by splitting input along the first dimension, applying separate transformations to each split, and concatenating the results
  • Added ScaledSwiGLU operation that performs SwiGLU with post-scaling, supporting optional gate interleaving for better memory access patterns
  • Refactored SwiGLU operations from activation.py into separate swiglu.py module
  • Added experimental ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 fused operation using CuTe DSL kernel for SM100+ GPUs that fuses GroupedLinear + ScaledSwiGLU + GroupedLinear into a single kernel call
  • Enhanced noop_cat function to handle edge case where split_quantize creates subviews with different storage objects
  • Added comprehensive test coverage for grouped linear, scaled SwiGLU, and grouped MLP operations

Architecture:
The fused operation performs MXFP8 quantization, packs tensors into non-contiguous logical dimensions expected by the CuTe kernel, executes the fused GEMM+SwiGLU computation, and unpacks results for the second linear layer. The standard path uses separate operations that can still benefit from FP8 quantization.

Confidence Score: 4/5

  • This PR is safe to merge with minor consideration for gradient accumulation edge case
  • The implementation is well-tested with comprehensive test coverage for various configurations. The previous review issues have been addressed. There is one potential edge case with gradient accumulation when different weight groups have different overwrite_main_grad settings, but this is likely rare in practice and would require Megatron-LM specific configuration
  • Pay attention to transformer_engine/pytorch/ops/basic/grouped_linear.py for the gradient accumulation logic if using Megatron-LM with mixed overwrite_main_grad settings across groups

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/basic/grouped_linear.py New grouped linear operation supporting multiple linear transformations with per-group weights and biases, including FP8 quantization support and Megatron-LM gradient accumulation
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Experimental fused operation for grouped MLP using CuTe DSL kernel for MXFP8 GEMM + SwiGLU on SM100+ hardware
transformer_engine/pytorch/ops/basic/swiglu.py Added ScaledSwiGLU operation with post-scaling and support for gate interleaving, refactored SwiGLU ops into separate module

Sequence Diagram

sequenceDiagram
    participant User
    participant Sequential
    participant GroupedLinear1 as GroupedLinear (FC1)
    participant ScaledSwiGLU
    participant GroupedLinear2 as GroupedLinear (FC2)
    participant FusedOp as ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8

    User->>Sequential: forward(input, split_sizes, scales, split_sizes)
    
    alt Fusion Available (SM100+, MXFP8, no bias)
        Sequential->>FusedOp: fuser_forward(input)
        FusedOp->>FusedOp: Split and quantize input per group
        FusedOp->>FusedOp: Quantize FC1 weights per group
        FusedOp->>FusedOp: Pack MXFP8 data/scales for CuTe kernel
        FusedOp->>FusedOp: grouped_gemm_swiglu_kernel()<br/>(FC1 GEMM + SwiGLU + post-scale)
        FusedOp->>FusedOp: Unpack MXFP8 output for FC2
        FusedOp->>FusedOp: Construct MXFP8 tensors with rowwise/columnwise scales
        FusedOp->>FusedOp: general_grouped_gemm(FC2)
        FusedOp-->>Sequential: output
    else Standard Path
        Sequential->>GroupedLinear1: forward(input, split_sizes)
        GroupedLinear1->>GroupedLinear1: Split input by split_sizes
        GroupedLinear1->>GroupedLinear1: Quantize inputs and weights (if FP8)
        GroupedLinear1->>GroupedLinear1: general_grouped_gemm(weights, inputs)
        GroupedLinear1-->>Sequential: fc1_output
        
        Sequential->>ScaledSwiGLU: forward(fc1_output, scales)
        ScaledSwiGLU->>ScaledSwiGLU: Remove gate interleaving (if enabled)
        ScaledSwiGLU->>ScaledSwiGLU: Compute SwiGLU activation
        ScaledSwiGLU->>ScaledSwiGLU: Apply post-scaling: output * scales
        ScaledSwiGLU-->>Sequential: swiglu_output
        
        Sequential->>GroupedLinear2: forward(swiglu_output, split_sizes)
        GroupedLinear2->>GroupedLinear2: Split input by split_sizes
        GroupedLinear2->>GroupedLinear2: Quantize inputs and weights (if FP8)
        GroupedLinear2->>GroupedLinear2: general_grouped_gemm(weights, inputs)
        GroupedLinear2-->>Sequential: output
    end
    
    Sequential-->>User: final_output
Loading

greptile-apps[bot]

This comment was marked as resolved.

@greptile-apps

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as resolved.

greptile-apps[bot]

This comment was marked as resolved.

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <[email protected]>
greptile-apps[bot]

This comment was marked as outdated.

quantizer.optimize_for_gemm = True
fc1_xs = tex.split_quantize(fc1_x, split_sizes_cpu, fc1_input_quantizers)

# Pack data tensors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be a silly question: are these packing and unpacking code just for verification? Or will they be in the final version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm working on getting rid of the concatenations, but the permutes are no-ops. The kernel API expects tensors with non-contiguous dims: https://github.com/NVIDIA/cudnn-frontend/blob/main/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py#L240-L245

)

# Fused kernel for FC1 + SwiGLU + post-scale
fc1_kernel_out = self.grouped_gemm_swiglu_kernel()(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After swiglu, it usually needs to multiply with permuted_probs. Does this weighted swiglu supported?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signed-off-by: Tim Moon <[email protected]>
greptile-apps[bot]

This comment was marked as resolved.

Review suggestions from @greptile-apps

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <[email protected]>
greptile-apps[bot]

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as resolved.

Review suggestion from @greptile-apps

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +537 to +539
accumulate_into_main_grad = not getattr(
weight_param, "overwrite_main_grad", False
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accumulate_into_main_grad reassigned in loop - last group's setting applies to all groups in GEMM call on line 576. If different weight groups have different overwrite_main_grad settings, this causes incorrect gradient accumulation behavior. Should either check consistency across groups or use per-group flags.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants