What this is
This guide walks through GRPO (Group Relative Policy Optimization) training using the Fireworks Training SDK. GRPO samples multiple completions per prompt, scores them with a reward function, and uses the group’s reward statistics to compute per-token advantages for policy gradient updates. Two variants are covered:- On-policy GRPO: Hotload after every optimizer step so the sampling policy always matches the training policy. No importance sampling needed.
- Off-policy GRPO: Hotload at intervals (e.g. every 10 steps) and use importance sampling to correct for the stale behavior policy.
Architecture
GRPO requires two RLOR trainer jobs:| Job | Role | Configuration |
|---|---|---|
| Policy trainer | Trainable model — runs forward_backward_custom + optim_step | learning_rate > 0, linked to deployment via hot_load_deployment_id |
| Reference trainer | Frozen copy of the initial model — provides KL reference logprobs | --forward-only flag, never call optim_step |
DeploymentSampler (client-side tokenized, token-in/token-out).
Step 1: Provision resources
Step 2: Build training data
GRPO usestinker.Datum objects with token weights that distinguish prompt tokens (weight=0) from response tokens (weight=1). The tinker_cookbook helper handles internal token shifting:
Step 3: GRPO loss function
The loss function receivesdata (list of Datum) and logprobs_list (list of torch.Tensor with requires_grad=True), and returns (loss, metrics_dict):
Step 4: On-policy training loop (prompt-batched)
Off-policy variant
Off-policy GRPO hotloads at intervals and adds importance sampling to correct for the stale sampling policy:Cookbook recipe entrypoint
The cookbook provides a streaming RL loop (rl_loop.py) that handles rollout scheduling, greedy batching, and metrics automatically. It supports multiple policy loss variants via a single policy_loss string:
Config fields reference
| Field | Type | Default | Description |
|---|---|---|---|
base_model | str | "accounts/fireworks/models/qwen3-8b" | Base model name |
dataset | str | GSM8K sample URL | Dataset path or URL (JSONL) |
learning_rate | float | 1e-5 | Learning rate |
kl_beta | float | 0.001 | KL penalty coefficient (GRPO only) |
completions_per_prompt | int | 4 | Number of completions per prompt |
max_completion_tokens | int | 1024 | Max tokens per sampled completion |
temperature | float | 1.0 | Sampling temperature |
epochs | int | 1 | Training epochs over the dataset |
max_rows | int | 100 | Number of dataset prompts to use |
max_seq_len | int | 4096 | Max sequence length for training |
lora_rank | int | 0 | 0 for full-parameter, >0 for LoRA |
prompt_groups_per_step | int | 1 | Prompt groups per optimizer step |
min_samples_per_fwd_bwd | int | None | None (= max) | Min samples to trigger a fwd_bwd call. Controls greedy batching granularity. |
max_samples_per_fwd_bwd | int | 256 | Cap on samples per fwd_bwd call |
max_concurrent | int | 32 | Cap on concurrent in-flight sampling requests |
policy_loss | str | "grpo" | Policy loss variant: "grpo", "dapo", "gspo", or "cispo" |
tis_enabled | bool | False | Enable Truncated Importance Sampling (composes with any loss) |
router_replay | bool | False | Enable Router Replay (R3) for MoE models |
infra (InfraConfig), deployment (DeployConfig), hotload (HotloadConfig), wandb (WandBConfig), resume (ResumeConfig), tis (ISConfig), dapo (DAPOConfig), gspo (GSPOConfig), cispo (CISPOConfig).
Operational guidance
- Service mode supports both full-parameter and LoRA tuning for both policy and reference trainers.
deployment.tokenizer_modelis required for GRPO — the SDK raisesValueErrorif not set. Use the HuggingFace model name (e.g."Qwen/Qwen3-8B").- Tokenizer is client-side: GRPO uses
AutoTokenizer+DeploymentSampler.sample_with_tokens(...)(token-in/token-out path). - Streaming rollout scheduling: Sampling runs concurrently (capped by
max_concurrent), andfwd_bwdcalls fire greedily as soon asmin_samples_per_fwd_bwdsamples accumulate.prompt_groups_per_stepcontrols how many groups form one optimizer step. - Policy loss variants:
"grpo"(REINFORCE + KL),"dapo"(asymmetric PPO clipping),"gspo"(sequence-level clipped PPO),"cispo"(importance-sampling masking). Set viacfg.policy_loss. - TIS (Truncated Importance Sampling) is orthogonal — enable with
cfg.tis_enabled=Trueon top of any policy loss. - Reference trainer uses
--forward-onlyand--no-compileflags (never calloptim_stepon it). - Track reward distributions and KL every step to catch objective drift early.
- Skip prompts with uniform rewards (all correct or all wrong) — they provide no learning signal.
- W&B logging: set
cfg.wandbwithWandBConfig(entity=..., project=..., run_name=...).
Common pitfalls
- Reward normalization bugs can destabilize GRPO updates quickly — verify advantage computation.
- Reference/policy tokenizer mismatch invalidates KL estimates — always use the same
base_model. - Off-policy ratio explosion: If
rhogrows too large, clamp it withclip_rho. - Logprob alignment: Trainer returns N-1 logprobs for N tokens (shifted). Inference returns N logprobs where the first is
None. Useinference[1:]to align with trainer.