Skip to content

Gemma 4: move some computations to BF16#21451

Closed
pwilkin wants to merge 4 commits intoggml-org:masterfrom
pwilkin:gemma-bf16
Closed

Gemma 4: move some computations to BF16#21451
pwilkin wants to merge 4 commits intoggml-org:masterfrom
pwilkin:gemma-bf16

Conversation

@pwilkin
Copy link
Copy Markdown
Member

@pwilkin pwilkin commented Apr 5, 2026

Overview

Putting this here as a draft because frankly I'm not sure what to do with this.

In my personal benchmarks, Gemma 4 has been losing coherence at long contexts. I ran the conversion verification scripts and the result NMSE is hovering around 1e-2 even for the smallest model, where usually I can get below 1e-4. I've tried to pin this down and I've traced it mostly to the divergence between the Transformers code natively computing in BF16 and the Llama.cpp code upcasting to F32. This is especially visible in the input scale, since the upcasting there introduces a systematic error with respect to the BF16 reference that propagates everywhere.

Additional information

With the changes in this PR, I've managed to cut down on the conversion's reported NMSE to half (i.e. from ~2.5e-2 to around 1e-2). Still, I would appreciate if someone who can actually run those models properly could do some systematic comparative tests (I had to do all runs at Q8 KV quants and in the F32 version, I can't even hit 100k context, FWIW, the version from this PR passed NIAH tests for 80, 90 and 100k context 3/3 each.

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: YES, it helped with the kernel modifications

@github-actions github-actions bot added model Model specific Nvidia GPU Issues specific to Nvidia GPUs examples python python script changes ggml changes relating to the ggml tensor library for machine learning labels Apr 5, 2026
@JohannesGaessler
Copy link
Copy Markdown
Contributor

To me it seems very unlikely that a model would produce higher quality outputs by truncating the mantissa. Making the outputs more similar to the Transformers reference does not necessarily make them more correct.

@CISC
Copy link
Copy Markdown
Member

CISC commented Apr 5, 2026

I'm guessing something causes denormalization, and BF16 masks that?

@pwilkin
Copy link
Copy Markdown
Member Author

pwilkin commented Apr 5, 2026

@JohannesGaessler not sure tbh. If the model was originally trained in BF16, then the rounded scale might be the expected behavior. I know I'm grasping at straws here, but there's just been consistent problems with model outputs reported by a lot of people and I haven't been able to get good results at longer context as well. Maybe it's just an inherent problem with the model's iSWA architecture that it falls apart around 60-70k on lower quants, but that's why I'd appreciate some more tests.

@JohannesGaessler
Copy link
Copy Markdown
Contributor

Quite frankly, if Gemma 4 breaks from upcasting BF16 to FP32 it will break from just about anything else where we don't do the exact same operations in the exact same order as the Transformers implementation.

@SharkWipf
Copy link
Copy Markdown

I can't comment on the specifics, but I have noticed Gemma 31B (Q8 K_XL Unsloth, no KV quant, CUDA, single device) performing significantly worse than I'd expect in long-context agentic tasks, ending up in loops over nothing, tasks I'd expect this model to be able to handle easily compared to the performance I expect from other similarly sized models. I have no real "ground truth" to compare to but my gut feeling is something in llama.cpp Gemma 4 inference is still wrong.

@tarruda
Copy link
Copy Markdown

tarruda commented Apr 5, 2026

@pwilkin I noticed the amount of VRAM used for context varies greatly between gemma 4 versions and flags.

  • Gemma 4 31B Q8 with 64k context uses 90GB VRAM with --swa-full, 40GB VRAM without it.
  • Gemma 4 26B Q8 with 64k context uses 42GB VRAM wiht --swa-full, 30GB VRAM without it.

Is it possible that the long context problems are caused by not passing swa?

I haven't really done any long context tests with gemma 4, but if you give me an example I can try. One issue I noticed when gemma 4 is used with pi harness is that it never seems to send thinking traces back to the client. Instead I noticed it created some thinking traces in the comments of the code it is writing.

@SharkWipf
Copy link
Copy Markdown

SharkWipf commented Apr 5, 2026

Aieee that's a lot of VRAM.
Very anecdotally, it does seem better with --swe-full though, despite me having to drop to kv q4_1 and various degrees of offloading. It's difficult to really quantify for me atm as it's much slower running this way for me.
I'll have to patiently throw it at a few real tasks to see how it performs now.

EDIT: Yeah no, can confirm it still gets loopy with --swe-full. In a way that's a relief.

@pwilkin
Copy link
Copy Markdown
Member Author

pwilkin commented Apr 5, 2026

@SharkWipf yes, this is exactly my observation. Compared to Qwen3.5, Gemma 4 breaks quite badly at around 60k-80k on agentic tasks, with the behavior you mentioned (looping, incoherent responses, random stops etc.). However, I lack the resources to compare this to eg. vLLM behavior on similar contexts, so I cannot determine if it is just model behavior or in fact the llama.cpp implementation.

I've spent quite a few hours now comparing the inference. The layers are within error range, there's no obvious mismatches, RoPE patterns match - but even on the smallest models the result is a pretty big difference in final logits for a BF16 to BF16 comparison. An NMSE of 2.5e-2 compared to the Transformers reference is exactly indicative of the types of problems indicated - but I have no idea if it's just to be expected since it's a BF16 to F32 computation difference and the model just struggles on long context or if it's actually expecting the BF16 computation and the divergence actually causes it to misbehave on long contexts.

@ggerganov
Copy link
Copy Markdown
Member

ggerganov commented Apr 5, 2026

@pwilkin Can you past the exact commands and logs that you are using?

On my end the NMSE is in normal range, so most likely you are doing something wrong.

# gemma-4-E4B

export MODEL_NAME=gemma-4-e4b
export MODEL_PATH=google/gemma-4-E4B
export CONVERTED_MODEL=llama.cpp/models/gemma-4-e4b.gguf

make causal-verify-logits DEVICE=cpu

PyTorch logits  : data/pytorch-gemma-4-E4B.bin
llama.cpp logits: data/llamacpp-gemma-4-e4b.bin
Top 10 PyTorch logits: [21.25  21.25  21.    20.625 20.625 20.5   20.5   20.375 20.375 20.25 ]
Top 10 llama.cpp logits: [21.143845 21.064665 20.938082 20.501064 20.469553 20.438086 20.425018
 20.390272 20.370918 20.226387]
Max absolute difference: 0.3095

✅ RESULT: PASS (NMSE = 2.75e-04)
---

# gemma-4-26B-A4B

export MODEL_NAME=gemma-4-26b-a4b
export MODEL_PATH=google/gemma-4-26B-A4B
export CONVERTED_MODEL=llama.cpp/models/gemma-4-26b-a4b.gguf

make causal-verify-logits DEVICE=cpu

PyTorch logits  : data/pytorch-gemma-4-26B-A4B.bin
llama.cpp logits: data/llamacpp-gemma-4-26b-a4b.bin
Top 10 PyTorch logits: [19.25  18.75  18.625 18.625 18.625 18.5   18.25  18.25  18.125 18.125]
Top 10 llama.cpp logits: [19.20818  18.634228 18.5834   18.577785 18.548178 18.4755   18.286913
 18.27792  18.214775 18.208567]
Max absolute difference: 0.1798

✅ RESULT: PASS (NMSE = 3.88e-05)

---

# gemma-4-31B

export MODEL_NAME=gemma-4-31b
export MODEL_PATH=google/gemma-4-31B
export CONVERTED_MODEL=llama.cpp/models/gemma-4-31b.gguf

make causal-verify-logits DEVICE=cpu

PyTorch logits  : data/pytorch-gemma-4-31B.bin
llama.cpp logits: data/llamacpp-gemma-4-31b.bin
Top 10 PyTorch logits: [22.75  22.    21.875 21.875 21.875 21.75  21.625 21.625 21.625 21.625]
Top 10 llama.cpp logits: [22.814135 22.09926  22.051968 22.008593 21.99162  21.792183 21.73987
 21.717546 21.666687 21.654068]
Max absolute difference: 0.2435

✅ RESULT: PASS (NMSE = 2.71e-04)

In my personal benchmarks, Gemma 4 has been losing coherence at long contexts.

Without any systematic way to demonstrate that coherence loss, there is not much that can be done. There is very high chance that you or you harness is doing something wrong.

Is it possible that the long context problems are caused by not passing swa?

@tarruda No. Generally, --swa-full is needed in some advanced use cases and normally there is no need to use that flag. It will just increase the memory usage without any benefit.

In any case, memory usage for Gemma 4 will be optimized next week to properly reduce the global KV cache by x2.

@pwilkin
Copy link
Copy Markdown
Member Author

pwilkin commented Apr 5, 2026

@ggerganov yeah, I'll extract the exact dumps from OpenCode into OpenAI message histories to try to reproduce.

@pwilkin
Copy link
Copy Markdown
Member Author

pwilkin commented Apr 5, 2026

Also as for the check - on my system DEVICE=cpu does nothing in there, but running:

$ CUDA_VISIBLE_DEVICES=-1 MODEL_PATH=/mnt/win/h/models/gemma/gemma-4-E2B-it CONVERTED_MODEL=/mnt/win/h/models/gemma/gemma-4-E2B-it-bf16.gguf make -C examples/model-conversion causal-verify-logits

gives me:

❌ RESULT: NEEDS REVIEW (NMSE = 1.89e-02)

Same one without CUDA disabled gives:

❌ RESULT: NEEDS REVIEW (NMSE = 2.48e-02)

On this branch on CUDA I get half that.

I haven't checked the base models FWIW, I've only tried the instruction tuned ones.

@pwilkin
Copy link
Copy Markdown
Member Author

pwilkin commented Apr 5, 2026

@ggerganov Update: ran

CUDA_VISIBLE_DEVICES=-1 MODEL_PATH=/opt/models/gemma-4-26B-A4B-it CONVERTED_MODEL=/opt/models/gemma-4-26B-A4B-it-bf16.gguf CUDA_VISIBLE_DEVICES=-1 make -C examples/model-conversion causal-verify-logits

on Johannes' server. The result:
❌ RESULT: NEEDS REVIEW (NMSE = 2.03e-02)

and in line with the results from this thread, trying on f32 makes things worse:

❌ RESULT: NEEDS REVIEW (NMSE = 3.26e-02)

@ggerganov
Copy link
Copy Markdown
Member

You have to run the base models. The logits of the instruction tuned models without a chat template are heavily distorted towards a single token, so it is expected to have higher error.

@pwilkin
Copy link
Copy Markdown
Member Author

pwilkin commented Apr 5, 2026

Aight, that explains stuff. Gemma 4 is extremely strongly instruction tuned, haven't had this with most other models.

Going to try to dump those conversations to isolate.

@malformed-c
Copy link
Copy Markdown

malformed-c commented Apr 6, 2026

I confirm that Gemma 4 performs normally only for the first few messages, but its quality degrades significantly beyond ~5k context (in my case), where it begins to produce typos and grammatical errors and overall intelligence loss
It begins to misuse possessives (adding "'s" incorrectly) and confuses pronouns like "you" and "I". Misuses "'m" (as in "I'm") in the wrong places, eg it can write "you'm". It also occasionally merges words together, omitting spaces.

spiritbuun added a commit to spiritbuun/buun-llama-cpp that referenced this pull request Apr 6, 2026
…org#21451)

Gemma 4 was trained in BF16 on TPUs. F32 computation diverges from
training-time BF16 rounding, especially at embedding scale, MoE router
norm, and per-layer embedding scale. Cast to BF16 at these 3 points,
apply BF16-rounded constants, then cast back to F32.

Added BF16 CUDA kernels for scale and rms_norm, BF16 dispatch paths
in binbcast. Fixed upstream PR's bug (router used wrong tensor for
rms_norm input). q8_0 baseline PPL: 444→326 (-27%) on Gemma 4 26B.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@Arvamer
Copy link
Copy Markdown

Arvamer commented Apr 6, 2026

Does losing coherence at long context only apply to smaller models? I ran MRCR v2 at 128k with MoE, and the results without thinking look close enough to those from AI Studio (Q4 scored ~0.1 with both ROCm and Vulkan, Q8 0.126, AI Studio 0.112). However, with the thinking enabled, requests were timing out after 10 minutes, and I don’t know whether it was because of looping or the model was simply thinking that much.

@github-actions github-actions bot added the testing Everything test related label Apr 6, 2026
@pwilkin
Copy link
Copy Markdown
Member Author

pwilkin commented Apr 7, 2026

Update: I'll do a bisect tomorrow, but this (rebased on top all the template fixes, tokenizer fixes and the accumulator change) now passes coherence tests on the Q4_K_M GGML quant for 100k context (as in, there are occasional weird stoppages on high context, but it does act coherently and generate coherent outputs all the way up to 100k on Q8 KV quants).

stephencox-ict pushed a commit to stephencox-ict/llama.cpp that referenced this pull request Apr 7, 2026
Audio encoder fixes:
- Fix swapped conv norm weight mapping in tensor_mapping.py
  (A_ENC_CONV_NORM and A_ENC_NORM_CONV had their gemma4 entries inverted,
  causing the conv pre-norm and internal norm weights to be swapped in GGUF.
  This produced 0.67 encoder cosine vs PyTorch; now 0.9999)
- Fix causal mask off-by-one: add (gq - gk) < max_past to match PyTorch's
  dist < left_window_size (was attending to 13 past tokens instead of 12)
- Use -1e9 instead of -INFINITY for masked positions to match PyTorch's
  attention_invalid_logits_value and avoid NaN in padded attention weights

LM fixes:
- Disable attention logit softcapping for Gemma4 (unlike Gemma2, Gemma4's
  text model does not use attn softcapping; was incorrectly hardcoded)
- Use BF16-rounded embedding scale constants to match PyTorch's native
  BF16 training precision (ref: PR ggml-org#21451). Fixes long-context coherence
  on CPU/Vulkan backends.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@stephencox-ict
Copy link
Copy Markdown
Contributor

@pwilkin I think you are heading in the right direction. This worked helped me fix the CPU/Vulkan case for the audio side

@pwilkin
Copy link
Copy Markdown
Member Author

pwilkin commented Apr 7, 2026

Closing as superseded by fix in #21566

@pwilkin pwilkin closed this Apr 7, 2026
stephencox-ict pushed a commit to stephencox-ict/llama.cpp that referenced this pull request Apr 8, 2026
Audio encoder fixes:
- Fix swapped conv norm weight mapping in tensor_mapping.py
  (A_ENC_CONV_NORM and A_ENC_NORM_CONV had their gemma4 entries inverted,
  causing the conv pre-norm and internal norm weights to be swapped in GGUF.
  This produced 0.67 encoder cosine vs PyTorch; now 0.9999)
- Fix causal mask off-by-one: add (gq - gk) < max_past to match PyTorch's
  dist < left_window_size (was attending to 13 past tokens instead of 12)
- Use -1e9 instead of -INFINITY for masked positions to match PyTorch's
  attention_invalid_logits_value and avoid NaN in padded attention weights

LM fixes:
- Disable attention logit softcapping for Gemma4 (unlike Gemma2, Gemma4's
  text model does not use attn softcapping; was incorrectly hardcoded)
- Use BF16-rounded embedding scale constants to match PyTorch's native
  BF16 training precision (ref: PR ggml-org#21451). Fixes long-context coherence
  on CPU/Vulkan backends.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
stephencox-ict pushed a commit to stephencox-ict/llama.cpp that referenced this pull request Apr 9, 2026
Audio encoder fixes:
- Fix swapped conv norm weight mapping in tensor_mapping.py
  (A_ENC_CONV_NORM and A_ENC_NORM_CONV had their gemma4 entries inverted,
  causing the conv pre-norm and internal norm weights to be swapped in GGUF.
  This produced 0.67 encoder cosine vs PyTorch; now 0.9999)
- Fix causal mask off-by-one: add (gq - gk) < max_past to match PyTorch's
  dist < left_window_size (was attending to 13 past tokens instead of 12)
- Use -1e9 instead of -INFINITY for masked positions to match PyTorch's
  attention_invalid_logits_value and avoid NaN in padded attention weights

LM fixes:
- Disable attention logit softcapping for Gemma4 (unlike Gemma2, Gemma4's
  text model does not use attn softcapping; was incorrectly hardcoded)
- Use BF16-rounded embedding scale constants to match PyTorch's native
  BF16 training precision (ref: PR ggml-org#21451). Fixes long-context coherence
  on CPU/Vulkan backends.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples ggml changes relating to the ggml tensor library for machine learning model Model specific Nvidia GPU Issues specific to Nvidia GPUs python python script changes testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants