What this solves
When you call forward_backward multiple times before a single optim_step (gradient accumulation), the accumulated gradients scale with the number of micro-batches. Without normalization, doubling grad_accum doubles the effective gradient magnitude, changing training dynamics.
Gradient accumulation normalization divides the accumulated gradients by a count (tokens or sequences) at optim_step time, making the optimizer step equivalent regardless of how the batch was partitioned.
Normalization modes
Pass grad_accumulation_normalization to optim_step:
client.optim_step(
adam_params,
grad_accumulation_normalization="num_loss_tokens", # default
)
| Mode | Divides by | Best for |
|---|
"num_loss_tokens" | Total non-zero-grad tokens across all micro-batches | SFT, GRPO (per-token mean) |
"num_sequences" | Total sequences with at least one non-zero-grad token | DPO, ORPO (per-sequence mean) |
"none" | Nothing — gradients used as-is | Manual normalization in loss function |
The SDK defaults to "num_loss_tokens" so gradient accumulation is batch-size invariant out of the box.
Interaction with loss functions: raw_sum
Server-side normalization and client-side loss normalization must not both divide by the same count. The raw_sum parameter on cookbook loss functions controls this:
raw_sum | Loss returns | Use with |
|---|
True | Raw sum of token losses | "num_loss_tokens" or "num_sequences" (server normalizes) |
False | Per-token mean | "none" (client already normalized) |
Using raw_sum=False with "num_loss_tokens" causes double-normalization — the client divides by token count, then the server divides again. Always use raw_sum=True when server-side normalization is active.
The cookbook recipes handle this automatically:
# From sft_loop.py -- raw_sum is derived from the normalization mode
use_raw_sum = cfg.grad_accumulation_normalization != "none"
loss_fn = make_batch_weighted_sft_loss_fn(raw_sum=use_raw_sum)
verl equivalents
For teams migrating from verl, here is how Fireworks modes map to verl’s loss_agg_mode:
Fireworks (raw_sum + server mode) | verl loss_agg_mode | Formula |
|---|
raw_sum=True + "num_loss_tokens" | token-mean | sum(token_losses) / total_tokens |
raw_sum=True + "num_sequences" | seq-mean-token-sum | sum(per_seq_token_sum) / num_sequences |
raw_sum=False + "num_sequences" | seq-mean-token-mean | sum(per_seq_token_mean) / num_sequences |
raw_sum=False + "none" | N/A (client handles everything) | mean(token_losses) per micro-batch |
The key difference from verl: in Fireworks, the loss computation is split between client (loss function) and server (normalization at optim_step). verl computes both in one agg_loss() call.
Recipe defaults
| Recipe | Mode | raw_sum | Rationale |
|---|
| SFT | "num_loss_tokens" | True | Per-token mean, all tokens weighted equally |
| GRPO/RL | "num_loss_tokens" | True | Per-token policy gradient |
| DPO | "num_sequences" | True | Per-pair mean, each preference pair weighted equally |
| ORPO | "num_sequences" | True | Per-pair mean |
Example: SFT with gradient accumulation
from training.recipes.sft_loop import Config, main
config = Config(
base_model="accounts/fireworks/models/qwen3-8b",
dataset="data/my_sft_data.jsonl",
batch_size=4,
grad_accum=4, # 4 micro-batches per optimizer step
learning_rate=1e-5,
# Per-token normalization (default) -- training is identical whether
# grad_accum=1 with batch_size=16, or grad_accum=4 with batch_size=4
grad_accumulation_normalization="num_loss_tokens",
)
With grad_accum=4 and batch_size=4, each optimizer step processes 16 samples. The server counts all non-zero-grad tokens across the 4 micro-batches and divides the accumulated gradient by that total, producing the same result as a single batch of 16.