-
Notifications
You must be signed in to change notification settings - Fork 621
[PyTorch] Add grouped linear op and experimental fusion for grouped MLP #2622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order. Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Test is too permissive since the test should still be failing. The weights are not properly interleaved yet. Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
|
/te-ci pytorch L1 |
Greptile OverviewGreptile SummaryThis PR adds support for grouped linear operations and an experimental fused grouped MLP for Mixture-of-Experts models. Key changes:
Architecture: Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Sequential
participant GroupedLinear1 as GroupedLinear (FC1)
participant ScaledSwiGLU
participant GroupedLinear2 as GroupedLinear (FC2)
participant FusedOp as ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8
User->>Sequential: forward(input, split_sizes, scales, split_sizes)
alt Fusion Available (SM100+, MXFP8, no bias)
Sequential->>FusedOp: fuser_forward(input)
FusedOp->>FusedOp: Split and quantize input per group
FusedOp->>FusedOp: Quantize FC1 weights per group
FusedOp->>FusedOp: Pack MXFP8 data/scales for CuTe kernel
FusedOp->>FusedOp: grouped_gemm_swiglu_kernel()<br/>(FC1 GEMM + SwiGLU + post-scale)
FusedOp->>FusedOp: Unpack MXFP8 output for FC2
FusedOp->>FusedOp: Construct MXFP8 tensors with rowwise/columnwise scales
FusedOp->>FusedOp: general_grouped_gemm(FC2)
FusedOp-->>Sequential: output
else Standard Path
Sequential->>GroupedLinear1: forward(input, split_sizes)
GroupedLinear1->>GroupedLinear1: Split input by split_sizes
GroupedLinear1->>GroupedLinear1: Quantize inputs and weights (if FP8)
GroupedLinear1->>GroupedLinear1: general_grouped_gemm(weights, inputs)
GroupedLinear1-->>Sequential: fc1_output
Sequential->>ScaledSwiGLU: forward(fc1_output, scales)
ScaledSwiGLU->>ScaledSwiGLU: Remove gate interleaving (if enabled)
ScaledSwiGLU->>ScaledSwiGLU: Compute SwiGLU activation
ScaledSwiGLU->>ScaledSwiGLU: Apply post-scaling: output * scales
ScaledSwiGLU-->>Sequential: swiglu_output
Sequential->>GroupedLinear2: forward(swiglu_output, split_sizes)
GroupedLinear2->>GroupedLinear2: Split input by split_sizes
GroupedLinear2->>GroupedLinear2: Quantize inputs and weights (if FP8)
GroupedLinear2->>GroupedLinear2: general_grouped_gemm(weights, inputs)
GroupedLinear2-->>Sequential: output
end
Sequential-->>User: final_output
|
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <[email protected]>
| quantizer.optimize_for_gemm = True | ||
| fc1_xs = tex.split_quantize(fc1_x, split_sizes_cpu, fc1_input_quantizers) | ||
|
|
||
| # Pack data tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be a silly question: are these packing and unpacking code just for verification? Or will they be in the final version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm working on getting rid of the concatenations, but the permutes are no-ops. The kernel API expects tensors with non-contiguous dims: https://github.com/NVIDIA/cudnn-frontend/blob/main/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py#L240-L245
| ) | ||
|
|
||
| # Fused kernel for FC1 + SwiGLU + post-scale | ||
| fc1_kernel_out = self.grouped_gemm_swiglu_kernel()( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After swiglu, it usually needs to multiply with permuted_probs. Does this weighted swiglu supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the probs are passed into the kernel here: https://github.com/timmoon10/TransformerEngine/blob/46294be478f6551e2cf251283adc7529ddb2964e/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py#L264
Signed-off-by: Tim Moon <[email protected]>
Review suggestions from @greptile-apps Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Review suggestion from @greptile-apps Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
| accumulate_into_main_grad = not getattr( | ||
| weight_param, "overwrite_main_grad", False | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
accumulate_into_main_grad reassigned in loop - last group's setting applies to all groups in GEMM call on line 576. If different weight groups have different overwrite_main_grad settings, this causes incorrect gradient accumulation behavior. Should either check consistency across groups or use per-group flags.
Signed-off-by: Tim Moon <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
Signed-off-by: Tim Moon <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
Description
This PR adds a grouped linear op, which can be used in the grouped MLP block in Mixture-of-Experts models. It also adds an experimental fused operation for a grouped MLP block, using a CuTe DSL kernel that computes an MXFP8 grouped GEMM and SwiGLU.
Type of change
Changes
Checklist: