Skip to content

fix: convert weight scale shape[N, K/128] to [ceil(N/128), ceil(K/128)]#1353

Draft
JiaoliangYu wants to merge 1 commit into
ROCm:mainfrom
JiaoliangYu:fix/post_training_fix
Draft

fix: convert weight scale shape[N, K/128] to [ceil(N/128), ceil(K/128)]#1353
JiaoliangYu wants to merge 1 commit into
ROCm:mainfrom
JiaoliangYu:fix/post_training_fix

Conversation

@JiaoliangYu

Copy link
Copy Markdown
Contributor

Motivation

per_block_fp8 online quantization in ATOM can produce degenerate generation (all-f outputs or immediate EOS) because the weight-scale layout emitted by the online quant path does not match the layout expected by the existing linear FP8 blockscale GEMM backend.

Technical Details

In LinearBase.online_quantize_weight for online_quant_type == QuantType.per_1x128, we no longer
use the generic per-group quant output directly (weight_scale shaped like [N, K/128]). Instead,
we explicitly quantize weights into 128x128 blockscale format and emit:

  • q_weight: shape [N, K] (same logical weight shape),
  • weight_scale: shape [ceil(N/128), ceil(K/128)].

This aligns online-quantized weights with the current linear compute path
(gemm_a8w8_blockscale*), which consumes 128x128 block scales.

Additional fix included:
Added float8_e4m3fnuz compatibility in online quant dtype checks and avoided redundant normalize
marking when quant output is already fnuz (relevant on gfx942).

Test Plan

Test Result

Submission Checklist

@zufayu zufayu requested review from valarLip and removed request for valarLip June 26, 2026 06:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant