feat: add bf16_loss training argument for VRAM-efficient QLoRA by keeping loss in BF16 during training to save ~1.4 GB VRAM#45769
Conversation
Keeps cross-entropy loss computation in BF16 instead of upcasting logits to FP32. Saves ~600 MB–1.4 GB VRAM per logit tensor. --bf16_loss requires --bf16. Opt-in flag with validation. Negligible precision impact — QLoRA already quantizes to 4-bit, making the FP32 upcast a disproportionately expensive safeguard.
|
Thanks for the PR @butterwecksolutions. Do you have a small script that can be used to measure this difference? |
OK, this request is a result of a debug marathon of my training pipeline. i currently build a reproducer script with all my fixes and update this within the next few days. Please wait for my results until further investigation. |
|
Thanks for working on the reproducer. Something very simple, even with dummy data, would be absolutely fine. |
Systematic Isolation Complete — No Measurable Effect at Current ScaleSame 14-test reproducer (7 patches × 2 modes, fresh GPU per test) bf16_loss training arg (#45769): 0.0 GB effect.
The concept is sound — keeping loss in BF16 avoids an fp32 upcast that The dominant VRAM leak (7.0 GB) traces entirely to async CUDA streams in This PR isn't wrong — it just doesn't move the needle for the specific Full reproducer, HTML report, raw training logs:
Closing this PR. Thanks for the review. |
|
@butterwecksolutions Thanks for double-checking. If you find a situation where this leads to a significant reduction in memory usage, feel free to re-open the PR. |
Problem
During QLoRA training with
--bf16, logits are upcast from BF16 toFP32 during loss computation, allocating 600 MB–1.4 GB per logit
tensor at model forward time. This is disproportionately expensive
when the model weights are already quantized to 4-bit precision.
Fix
New
--bf16_lossflag (default:False, opt-in). When set:--bf16(validated at startup)model.forward()intorch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)Trade-off
Negligible precision loss for most workloads. QLoRA already
quantizes weights to 4-bit — BF16 logits preserve more signal
than the quantization boundaries discard.
Verification
Tested on Qwen3-VL-9B QLoRA training (RTX 3090 24 GB):
Combined with bitsandbytes#1935 and TRL#5694, peak VRAM drops
from ~22.5 GB to ~12.5 GB in full QLoRA training.