Skip to main content

What this solves

When you call forward_backward multiple times before a single optim_step (gradient accumulation), the accumulated gradients scale with the number of micro-batches. Without normalization, doubling grad_accum doubles the effective gradient magnitude, changing training dynamics. Gradient accumulation normalization divides the accumulated gradients by a count (tokens or sequences) at optim_step time, making the optimizer step equivalent regardless of how the batch was partitioned.

Normalization modes

Pass grad_accumulation_normalization to optim_step:
client.optim_step(
    adam_params,
    grad_accumulation_normalization="num_loss_tokens",  # default
)
ModeDivides byBest for
"num_loss_tokens"Total non-zero-grad tokens across all micro-batchesSFT, GRPO (per-token mean)
"num_sequences"Total sequences with at least one non-zero-grad tokenDPO, ORPO (per-sequence mean)
"none"Nothing — gradients used as-isManual normalization in loss function
The SDK defaults to "num_loss_tokens" so gradient accumulation is batch-size invariant out of the box.

Interaction with loss functions: raw_sum

Server-side normalization and client-side loss normalization must not both divide by the same count. The raw_sum parameter on cookbook loss functions controls this:
raw_sumLoss returnsUse with
TrueRaw sum of token losses"num_loss_tokens" or "num_sequences" (server normalizes)
FalsePer-token mean"none" (client already normalized)
Using raw_sum=False with "num_loss_tokens" causes double-normalization — the client divides by token count, then the server divides again. Always use raw_sum=True when server-side normalization is active.
The cookbook recipes handle this automatically:
# From sft_loop.py -- raw_sum is derived from the normalization mode
use_raw_sum = cfg.grad_accumulation_normalization != "none"
loss_fn = make_batch_weighted_sft_loss_fn(raw_sum=use_raw_sum)

verl equivalents

For teams migrating from verl, here is how Fireworks modes map to verl’s loss_agg_mode:
Fireworks (raw_sum + server mode)verl loss_agg_modeFormula
raw_sum=True + "num_loss_tokens"token-meansum(token_losses) / total_tokens
raw_sum=True + "num_sequences"seq-mean-token-sumsum(per_seq_token_sum) / num_sequences
raw_sum=False + "num_sequences"seq-mean-token-meansum(per_seq_token_mean) / num_sequences
raw_sum=False + "none"N/A (client handles everything)mean(token_losses) per micro-batch
The key difference from verl: in Fireworks, the loss computation is split between client (loss function) and server (normalization at optim_step). verl computes both in one agg_loss() call.

Recipe defaults

RecipeModeraw_sumRationale
SFT"num_loss_tokens"TruePer-token mean, all tokens weighted equally
GRPO/RL"num_loss_tokens"TruePer-token policy gradient
DPO"num_sequences"TruePer-pair mean, each preference pair weighted equally
ORPO"num_sequences"TruePer-pair mean

Example: SFT with gradient accumulation

from training.recipes.sft_loop import Config, main

config = Config(
    base_model="accounts/fireworks/models/qwen3-8b",
    dataset="data/my_sft_data.jsonl",
    batch_size=4,
    grad_accum=4,              # 4 micro-batches per optimizer step
    learning_rate=1e-5,
    # Per-token normalization (default) -- training is identical whether
    # grad_accum=1 with batch_size=16, or grad_accum=4 with batch_size=4
    grad_accumulation_normalization="num_loss_tokens",
)
With grad_accum=4 and batch_size=4, each optimizer step processes 16 samples. The server counts all non-zero-grad tokens across the 4 micro-batches and divides the accumulated gradient by that total, producing the same result as a single batch of 16.