Skip to content

[ROCm] Allow mixed (F8E4M3FNUZ, F8E5M2FNUZ) in Triton dot gate#897

Open
Ruturaj4 wants to merge 1 commit into
mainfrom
ruvaidya/fnuz-mixed-fp8-triton-dot
Open

[ROCm] Allow mixed (F8E4M3FNUZ, F8E5M2FNUZ) in Triton dot gate#897
Ruturaj4 wants to merge 1 commit into
mainfrom
ruvaidya/fnuz-mixed-fp8-triton-dot

Conversation

@Ruturaj4

Copy link
Copy Markdown

The mixed-FP8 in IsTritonSupportedDot only listed the OCP pair (F8E5M2, F8E4M3FN). The ROCm-native FNUZ pair was rejected even though the rest of the file already accepts FNUZ FP8 inputs on ROCm.

This blocks TransformerEngine FP8 GEMM on MI300 (gfx94X), which lowers dgrad to dot_general(F8E4M3FNUZ, F8E5M2FNUZ) and gets routed to a __triton_nested_gemm_fusion. The gate then refuses it at codegen time with "INTERNAL: ... Dot operation only supports same types for lhs and rhs."

Mirror the existing OCP allowance under gpu_version.IsRocm() so the FNUZ pair passes the same check.

Submission Checklist

Comment thread xla/backends/gpu/codegen/triton/support.cc
@Ruturaj4 Ruturaj4 force-pushed the ruvaidya/fnuz-mixed-fp8-triton-dot branch 2 times, most recently from 076f8b3 to 30fd358 Compare June 1, 2026 16:07

EXPECT_TRUE(IsTritonSupportedInstruction(
ti.Instruction(),
se::GpuComputeCapability(se::RocmComputeCapability("gfx950"))));

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fnuz is not natively supported on gfx950. Why does this pass?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. The test doesn't run anything on the GPU, but just checks whether the compiler's support predicate accepts the dot. FNUZ is allowed on any ROCm target (pre-existing behavior; when there's no native FNUZ MFMA it just upcasts to f16), so it passes regardless of the chip. But regardless gfx950 was a confusing thing to assert on, so I moved it to gfx942 where FNUZ is native and which is what the ticket's about.

@Ruturaj4 Ruturaj4 force-pushed the ruvaidya/fnuz-mixed-fp8-triton-dot branch 2 times, most recently from ed67d54 to 254a2a7 Compare June 4, 2026 15:38
The mixed-FP8 carve-out in IsTritonSupportedDot only listed the OCP
pair (F8E5M2, F8E4M3FN). The ROCm-native FNUZ pair was rejected even
though the rest of the file already accepts FNUZ FP8 inputs on ROCm.

This blocks TransformerEngine FP8 GEMM on MI300 (gfx94X), which lowers
dgrad to dot_general(F8E4M3FNUZ, F8E5M2FNUZ) and gets routed to a
__triton_nested_gemm_fusion. The gate then refuses it at codegen time
with "INTERNAL: ... Dot operation only supports same types for lhs and
rhs."

Add a focused support test asserting the FNUZ mixed pair is accepted
on ROCm and rejected on CUDA via RunSupportTest's dual-contract check.

Addresses TODO(b/393299275).
@Ruturaj4 Ruturaj4 force-pushed the ruvaidya/fnuz-mixed-fp8-triton-dot branch from 254a2a7 to b3c6b4a Compare June 4, 2026 16:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants