Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1978,7 +1978,14 @@ def compute_loss(
if num_items_in_batch is not None:
kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **kwargs}
outputs = model(**inputs)
# BF16 loss: keep logits in BF16 to save ~600MB-1.4GB VRAM per forward pass.
# Negligible precision impact — QLoRA already quantizes to 4-bit.
if getattr(self.args, "bf16_loss", False) and self.args.bf16:
import torch
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
outputs = model(**inputs)
else:
outputs = model(**inputs)

# User-defined compute_loss function
if self.compute_loss_func is not None:
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,16 @@ class TrainingArguments:
"help": "Use full BF16 precision for evaluation (not just mixed precision). Faster and saves memory."
},
)
bf16_loss: bool = field(
default=False,
metadata={
"help": (
"Keep cross-entropy loss computation in BF16 instead of upcasting to FP32. "
"Saves ~600 MB–1.4 GB VRAM per logit tensor during training. "
"Negligible precision impact for most workloads — QLoRA already quantizes to 4-bit."
)
},
)
fp16_full_eval: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -1742,6 +1752,9 @@ def _validate_args(self):
if self.fp16_full_eval and self.bf16_full_eval:
raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both")

if self.bf16_loss and not self.bf16:
raise ValueError("`bf16_loss=True` requires `bf16=True`. BF16 loss avoids the FP32 upcast.")

if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
if self.eval_strategy == IntervalStrategy.NO:
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy")
Expand Down
Loading