What this is
This guide walks through GRPO (Group Relative Policy Optimization) training using the cookbook’srl_loop recipe. GRPO samples multiple completions per prompt, scores them with a reward function, and uses group reward statistics for policy gradient updates.
Architecture
The RL recipe always uses a policy trainer plus an inference deployment. Add a reference trainer when your setup needs reference logprobs:| Component | Role |
|---|---|
| Policy trainer | Trainable model — runs forward_backward_custom + optim_step |
| Reference trainer | Optional frozen copy — provides KL/reference logprobs (--forward-only) when infra.ref_training_shape_id is set |
| Deployment | Sampling completions via DeploymentSampler (client-side tokenized) |
Using the recipe
The simplest way to run GRPO is via the cookbook’sConfig + main:
Policy loss variants
policy_loss | Description |
|---|---|
"grpo" | REINFORCE + KL penalty (default) |
"importance_sampling" | Off-policy ratio weighting with optional clipping |
"reinforce" | Vanilla REINFORCE |
"dapo" | Dynamic advantage with asymmetric PPO clipping |
"dro" | Distributionally robust off-policy objective |
"gspo" | Sequence-level clipped PPO |
"cispo" | Clipped importance sampling policy optimization |
Step-by-step (API-level)
For teams that need full control beyond what the recipe provides, here is the API-level flow.Provision resources with setup_infra
training.utils.rl.setup_infra is the cookbook’s single entrypoint for shape
resolution, parallel trainer + deployment provisioning, LoRA shared-reference
branching, and re-attach. Recipes pass a config + two booleans
(needs_reference, needs_inference) and get back an Infra bundle of wired
trainer clients. Teams that fork training/recipes/rl_loop.py should reuse
setup_infra rather than re-wiring the lower-level helpers below.
WeightSyncScope.PER_TRAINER (default) vs PER_DEPLOYMENT. For the full setup_infra contract, lower-level building blocks, and implementation rationale, see the cookbook’s dev skill: skills/dev/.
Training loop
make_grpo_loss_fn and build_grpo_datums implementations.
Pipeline overlap
Sampling and training overlap within policy windows controlled byweight_sync_interval. All prompts in a window sample concurrently; results train as they arrive. At window boundaries the pipeline drains, weights sync to the deployment, and the next window samples against the updated weights.
weight_sync_interval | Behavior |
|---|---|
1 (default) | No overlap — sample, train, sync, repeat |
N > 1 | N-step windows with overlap inside, sync at boundaries |
0 | No syncs — the deployment keeps the base weights for the entire run. Useful for debugging or ablations, not standard RL training. |
Operational guidance
deployment.tokenizer_modelis required — the API raisesValueErrorif not set.- Set
infra.training_shape_id— training shapes are the launch path for cookbook trainers. - Set
infra.ref_training_shape_idwhen you want a reference trainer — if it is unset, the recipe skips reference-model provisioning entirely. - Skip prompts with uniform rewards (all correct or all wrong) — they provide no learning signal.
- Track reward distributions and KL every step to catch objective drift early.
- When configured, the reference trainer uses
--forward-only— never calloptim_stepon it. - Sampling is async under the hood:
DeploymentSampler.sample_with_tokens()issuesnconcurrentn=1requests, so synchronous scripts should wrap it withasyncio.run(...). - DCP checkpoints are disabled by default (
dcp_save_interval=0). If you need to resume training from a checkpoint, explicitly setdcp_save_intervalto a positive value in yourWeightSyncConfig.
Common pitfalls
- Reward normalization bugs can destabilize GRPO updates quickly — verify advantage computation.
- Reference/policy tokenizer mismatch invalidates KL estimates — always use the same
base_model. - Logprob alignment: Trainer returns N-1 logprobs for N tokens. Inference returns N logprobs where the first is
None. Useinference[1:]to align.
Related guides
- Cookbook DPO — preference optimization
- Cookbook Reference — all config classes
- Loss Functions — API-level loss function details