Skip to content

gfx1250 moe#336

Closed
XingerZhu wants to merge 130 commits intomainfrom
mxfp4_gfx1250_moe
Closed

gfx1250 moe#336
XingerZhu wants to merge 130 commits intomainfrom
mxfp4_gfx1250_moe

Conversation

@XingerZhu
Copy link
Copy Markdown
Collaborator

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

sjfeng1999 and others added 30 commits March 3, 2026 08:49
- Fix Python version compatibility in meta.py: add support for Python < 3.11
  by checking for positions attribute availability
- Replace hardcoded MLIR library paths in executor.py with environment variable
  MLIR_PATH, with clear error message when not set
- Update LLVM commit hash and enable ROCM runner in build script
* [FLYDSL]:add copy_atom right_inverse

* [FLYDSL]: right_inverse dynamic process bugfix

* [FLYDSL]:Python refactoring and adaptation

* [FLYDSL]:rm example 05
* Migrate Python bindings to PyConcreteType<> and fix TypeID ODR violation

- FlyExtension.cpp / FlyROCDLExtension.cpp: migrate from legacy
  mlir_type_subclass() to PyConcreteType<> CRTP pattern (required by
  new MLIR Python binding API). Types are defined inside
  namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly, using
  ::mlir:: global qualifiers to avoid the mlir::python::mlir namespace
  collision when NB_DOMAIN=mlir.

- CMakeLists.txt: remove MLIRFlyDialect / MLIRFlyROCDLDialect from
  _fly.so / _fly_rocdl.so PRIVATE_LINK_LIBS. These static archives
  were being linked into both the extension modules AND FlyPythonCAPI.so
  (via EMBED_CAPI_LINK_LIBS → MLIRCPIFly), creating duplicate TypeID
  static variables. The dialect registered under FlyPythonCAPI.so's
  TypeIDs but _fly.so looked up types with its own copy, causing
  "storage uniquer isn't initialized" at runtime. Now all symbols are
  resolved from FlyPythonCAPI.so.

- FlyToROCDL.cpp: use string-based type matching for MmaAtomCDNA3_MFMA
  to work around the same TypeID ODR issue in the conversion pass, and
  fix ROCDL MFMA intrinsic call to use I32Attr attributes instead of
  Value operands for cbsz/abid/blgp control parameters.

* Fix pass registry ODR violation: register Fly passes via CAPI

- PRIVATE_LINK_LIBS MLIRFlyToROCDL in _mlirRegisterEverything pulled in a
local copy of MLIRPass, causing registerFlyPasses() to register into a
LOCAL pass registry inside _mlirRegisterEverything.so while
PassManager.parse() queried the GLOBAL registry in FlyPythonCAPI.so.

- Fix by introducing CAPI functions (mlirRegisterFlyPasses,
mlirRegisterFlyToROCDLConversionPass) in the CAPI libraries so pass
registration happens inside FlyPythonCAPI.so's single global registry.

- update cmake/llvm-hash.txt to keep same with triton llvm hash.

* Sync build_llvm.sh with pre_bumpupllvm and add ROCM runner

Align script with pre_bumpupllvm branch: full clone, buildmlir dir,
NVPTX target, NB_DOMAIN=mlir, package install by default. Keep
reading LLVM commit from cmake/llvm-hash.txt. Add
MLIR_ENABLE_ROCM_RUNNER=ON for GPU kernel execution support.

Co-authored-by: Cursor <cursoragent@cursor.com>

---------

Co-authored-by: Cursor <cursoragent@cursor.com>
- Rename C++ binding structs with Py prefix (e.g. IntTupleType -> PyIntTupleType) for consistency
- Add __all__ exports to typing, primitive, and gpu modules
- Add Int4 numeric type
- Fix frameInfo.positions compatibility for older Python versions
- Fix dialect import order to ensure _Dialect is properly exported
- Add fly_rocdl ops/enum gen copy rules in CMake
- Improve build_llvm.sh with configurable parallel jobs and --no-install flag
- Clean up redundant comments and formatting

Co-authored-by: Cursor <cursoragent@cursor.com>
gemm test ready
* [FLYDSL]: add recast_layout op

* [FLYDSL]: refactor

* [FLYDSL]: add detail namespace

* [FLYDSL]: add upcast assert

* [FLYDSL]: rm bits number

* [FLYDSL]: rm redundant code

* [FLYDSL]: bits number only support static value

* [FLYDSL]: change APIntAttr to I32Attr

* [FLYDSL]: rm notes
* fix run error

* port all  gemm from main

* fuix cudagraph hack

* add int4 version

* change flymemref convert

* test ok

* add build script

* fix graph2

* add files

* fix flops

* fix path

* fix local test

* fix

* clean

* update readme
* add compile only and dumpir
- Add fly-opt tool (tools/fly-opt/) for MLIR IR transformations,
  registering Fly/FlyROCDL dialects and all custom passes
- Add lit.cfg.py with fly-opt/FileCheck configuration
- Test using 'lit -v tests/' to test basic lowering tests
- Add LayoutAlgebra tests: construction, size/cosize, coordinate,
  composition, product, divide, int_tuple operations
- Add Transforms tests: canonicalize, layout_lowering
- Add Conversion tests for convert-fly-to-rocdl pass, split by category:
  type_conversion, memref_alloca, memref_ops, pointer_ops,
  mma_atom, gpu_ops
…gration

- Enable LLVM_BUILD_TOOLS so fly-opt is built with the default ninja target
- Add MLIR lit test section to scripts/run_tests.sh
- Update test/lit.cfg.py to use FLY_BUILD_DIR env var (default: build-fly)
aoli26 and others added 17 commits March 27, 2026 07:44
Introduce the new gfx1250 MoE kernel implementation and its dedicated test harness so mxfp4 bring-up can validate the path in-tree and track regressions on the new architecture.

Made-with: Cursor
Port the preshuffled FP8/FP4 WMMA_SCALE flow into the gfx1250 MoE stage1/stage2 kernels and extend the MoE harness so A8W4 can be validated through standalone and end-to-end bring-up paths.

Made-with: Cursor
- fp16 kernel: replace TDM copy_b_to_lds with software transpose to
  correctly produce [K, N+pad] LDS layout for lds_transpose_load;
  add max_total_warps=8 constraint to prevent VGPR overuse on M/L sizes
- fp8/fp4 tests: use per-1x32 block-scale quantization instead of
  per-token quant to match gfx1250 WMMA_SCALE requirements
- fp4 tests: generate quantized data from controlled FP32 inputs via
  per_1x32_f4_quant to avoid numerical overflow
- fp16 tests: skip shuffle_weight for native gfx1250 kernel path
- compiler: add compile progress logging and include opt_level in
  cache key; update native lib glob patterns

Made-with: Cursor
- Fix pack_as_to_lds A-scale LDS index collision when m_warp > 1:
  use warp-local wm_idx instead of global row/WMMA_M to avoid
  different warps writing to the same LDS positions.
- Replace TDM-based B-scale loading with software copy in fp8
  stage1/stage2 kernels to fix incorrect stride when wmma_n_rep > 1.
- Add max_total_warps=8 for fp8/a8w4 kernel shape selection to
  prevent VGPR overflow on larger tile configurations.
- Skip out_f32 tests for gfx1250 native-dtype kernels (fp4/fp8/a8w4/fp16).
- Use torch.zeros instead of torch.empty for stage1 output buffer.

Made-with: Cursor
…ersion

bf16 inputs were falling through to the generic MFMA compilation path,
causing SIGABRT on gfx1250 (which only supports WMMA). Route bf16 through
the existing fp16 WMMA kernel with a Python-level wrapper that converts
bf16 tensors to fp16 before launch. Also fix weight shuffle mismatch in
tests where bf16 was incorrectly using MFMA-style shuffle_weight.

Made-with: Cursor
Sync the branch with latest main changes and resolve merge conflicts while preserving gfx1250 MOE GEMM updates.

Made-with: Cursor
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds gfx1250-specific MoE 2-stage GEMM kernels and plumbing for an “expert scheduling mode” compile hint, along with new gfx1250-focused kernel tests and a unit test intended to validate ISA emission.

Changes:

  • Introduce kernels/moe_gemm_2stage_gfx1250.py with gfx1250 MoE stage1/stage2 kernel compilation wrappers and single-kernel paths.
  • Add a new compile-hint (expert_scheduling_mode) that injects an AMDGPU passthrough attribute on emitted gpu.func.
  • Add gfx1250 MoE test harness + a unit test for expert scheduling ISA, and tweak existing gfx1250 GEMM scale preshuffle API.

Reviewed changes

Copilot reviewed 6 out of 8 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
kernels/moe_gemm_2stage_gfx1250.py New gfx1250 MoE 2-stage kernel implementations/wrappers (missing SPDX header; contains redundant duplicate assignments).
python/flydsl/compiler/kernel_function.py Adds expert_scheduling_mode handling by writing a passthrough attribute on gpu.func (currently overwrites any existing passthrough).
tests/unit/test_jit_compile_hints.py New unit test intended to validate expert scheduling ISA emission (currently calls flyc.compile() with no args, so compilation won’t run).
tests/kernels/test_moe_gemm_gfx1250.py Large gfx1250 MoE correctness/perf harness + pytest cases (contains an accidental parametrize decorator on a helper).
python/flydsl/compiler/backends/rocm.py Updates toolchain fingerprint globs for native libs (likely drops hashing of _mlirDialectsFly*.so per current CMake module names).
python/flydsl/__init__.py Adds unused imports (ctypes, os).
tests/kernels/test_gemm_fp8fp4_gfx1250.py Extends preshuffle_e8m0_scale with optional byte_swap.
kernels/gemm_fp8fp4_gfx1250.py Adds LDS_PAD_B_BYTES = 0 constant.
Comments suppressed due to low confidence (1)

python/flydsl/compiler/backends/rocm.py:103

  • native_lib_patterns() no longer includes the _mlirDialectsFly*.so / _mlirDialectsFlyROCDL*.so extension module filenames, but the build configuration still declares those MODULE_NAMEs (see python/mlir_flydsl/CMakeLists.txt). If the output .so files are still named _mlirDialectsFly*.so, the toolchain fingerprint _flydsl_key will stop hashing them, risking stale disk-cache hits after native code changes. Consider keeping the old globs (or including both old and new patterns) unless the extension module filenames have been renamed accordingly.
    def native_lib_patterns(self) -> List[str]:
        return [
            "_fly*.so",
            "_fly_rocdl*.so",
            "libFly*.so",
            "libfly_jit_runtime.so",
            "libmlir_rocm_runtime.so",
            "_mlirRegisterEverything*.so",
        ]

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +489 to +494
gpu_func.attributes["passthrough"] = ir.ArrayAttr.get([
ir.ArrayAttr.get([
ir.StringAttr.get("amdgpu-expert-scheduling-mode"),
ir.StringAttr.get("true")
])
])
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gpu_func.attributes["passthrough"] = ... unconditionally overwrites any existing passthrough entries on the same gpu.func. Other code paths (e.g. kernels that set fast-math passthrough attributes) also populate this attribute, so this can silently drop previously-added passthrough settings when both are used together. Consider appending to the existing passthrough ArrayAttr (if present) instead of replacing it.

Suggested change
gpu_func.attributes["passthrough"] = ir.ArrayAttr.get([
ir.ArrayAttr.get([
ir.StringAttr.get("amdgpu-expert-scheduling-mode"),
ir.StringAttr.get("true")
])
])
# Preserve any existing passthrough attributes by appending
existing_passthrough = gpu_func.attributes.get("passthrough")
new_entry = ir.ArrayAttr.get([
ir.StringAttr.get("amdgpu-expert-scheduling-mode"),
ir.StringAttr.get("true"),
])
if isinstance(existing_passthrough, ir.ArrayAttr):
elements = list(existing_passthrough)
elements.append(new_entry)
gpu_func.attributes["passthrough"] = ir.ArrayAttr.get(elements)
else:
gpu_func.attributes["passthrough"] = ir.ArrayAttr.get([new_entry])

Copilot uses AI. Check for mistakes.
Comment on lines +488 to +494
if CompilationContext.get_compile_hints().get("expert_scheduling_mode"):
gpu_func.attributes["passthrough"] = ir.ArrayAttr.get([
ir.ArrayAttr.get([
ir.StringAttr.get("amdgpu-expert-scheduling-mode"),
ir.StringAttr.get("true")
])
])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls remove this block because main already has more comprehensive support

Comment on lines +39 to +52
WMMA_DIM: int = 16,
byte_swap: bool = False) -> torch.Tensor:
"""Preshuffle E8M0 scale: optional byte swap + interleave for ds_load_b128.

Args:
byte_swap: True for FP4 (reorder [0,2,1,3]), False for FP8 (identity).
"""
_, K_scale = scale.shape
assert K_scale % 4 == 0, f"K_scale must be divisible by 4, got {K_scale}"

if byte_swap:
grouped = scale.view(-1, K_scale // 4, 4)
scale = grouped[:, :, [0, 2, 1, 3]].contiguous().view(-1, K_scale)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

byte_swap is deprecated; align directly with main here.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no longer needed now, it's fine to remove this file.

Reapply the edits from the locally amended merge commit that was dropped during branch rebase alignment.

Made-with: Cursor
Reapply the edits from the locally amended merge commit that was dropped during branch rebase alignment.

Made-with: Cursor
return [
"_mlirDialectsFly*.so",
"_fly*.so",
"_fly_rocdl*.so",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change this?

den = 1.0 + emu
sig = rocdl.rcp(T.f32, den)
return x * sig

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel style much different from gemm 1250?

…ough

Align ROCm backend native library pattern matching with current Fly dialect shared object names and avoid setting expert scheduling passthrough twice in kernel function lowering.

Made-with: Cursor
@coderfeli
Copy link
Copy Markdown
Collaborator

coderfeli commented Apr 10, 2026

Is it old codes? shall we close here? Try to reuse more codes. Currently 3000 lines of test and 3000 kernels @XingerZhu

@coderfeli coderfeli closed this Apr 10, 2026
@XingerZhu
Copy link
Copy Markdown
Collaborator Author

It is old codes, just close it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants