Skip to main content

What this is

This guide walks through GRPO (Group Relative Policy Optimization) training using the cookbook’s rl_loop recipe. GRPO samples multiple completions per prompt, scores them with a reward function, and uses group reward statistics for policy gradient updates.

Architecture

The RL recipe always uses a policy trainer plus an inference deployment. Add a reference trainer when your setup needs reference logprobs:
ComponentRole
Policy trainerTrainable model — runs forward_backward_custom + optim_step
Reference trainerOptional frozen copy — provides KL/reference logprobs (--forward-only) when infra.ref_training_shape_id is set
DeploymentSampling completions via DeploymentSampler (client-side tokenized)

Using the recipe

The simplest way to run GRPO is via the cookbook’s Config + main:
from training.recipes.rl_loop import Config, main
from training.utils import DeployConfig, InfraConfig, WeightSyncConfig, WandBConfig

cfg = Config(
    log_path="./grpo_logs",
    base_model="accounts/fireworks/models/qwen3-8b",
    dataset="/path/to/gsm8k.jsonl",
    max_rows=200,
    epochs=1,
    completions_per_prompt=4,
    max_completion_tokens=1024,
    temperature=1.0,
    max_seq_len=4096,
    policy_loss="grpo",  # or "importance_sampling", "dapo", "dro", "gspo", "cispo"
    infra=InfraConfig(
        training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200",
        ref_training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200-forward",
    ),
    deployment=DeployConfig(
        deployment_id="grpo-serving",
        tokenizer_model="Qwen/Qwen3-8B",
    ),
    weight_sync=WeightSyncConfig(weight_sync_interval=1),
    wandb=WandBConfig(entity="my-team", project="grpo-experiment"),
)

main(cfg)
The recipe handles resource provisioning, rollout scheduling, reference logprobs, checkpointing, and cleanup automatically.

Policy loss variants

policy_lossDescription
"grpo"REINFORCE + KL penalty (default)
"importance_sampling"Off-policy ratio weighting with optional clipping
"dapo"Dynamic advantage with asymmetric PPO clipping
"dro"Distributionally robust off-policy objective
"gspo"Sequence-level clipped PPO
"cispo"Clipped importance sampling policy optimization

Step-by-step (SDK-level)

For teams that need full control beyond what the recipe provides, here is the SDK-level flow.

Provision resources

import os
import tinker
import transformers
from concurrent.futures import ThreadPoolExecutor
from fireworks.training.sdk import (
    TrainerJobManager, DeploymentManager, DeploymentSampler, WeightSyncer,
)
from training.utils import (
    InfraConfig, DeployConfig, ReconnectableClient,
    create_trainer_job, setup_deployment,
)

api_key = os.environ["FIREWORKS_API_KEY"]
account_id = os.environ.get("FIREWORKS_ACCOUNT_ID", "")
base_url = os.environ.get("FIREWORKS_BASE_URL", "https://api.fireworks.ai")

rlor_mgr = TrainerJobManager(api_key=api_key, account_id=account_id, base_url=base_url)
deploy_mgr = DeploymentManager(api_key=api_key, account_id=account_id, base_url=base_url)

base_model = "accounts/fireworks/models/qwen3-8b"
infra = InfraConfig(
    training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200",
    ref_training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200-forward",
)
deploy_cfg = DeployConfig(deployment_id="grpo-serving", tokenizer_model="Qwen/Qwen3-8B")

with ThreadPoolExecutor(max_workers=3) as pool:
    dep_fut = pool.submit(setup_deployment, deploy_mgr, deploy_cfg, base_model, infra)
    pol_fut = pool.submit(
        create_trainer_job, rlor_mgr,
        base_model=base_model, infra=infra, lora_rank=0,
        display_name="grpo-policy", hot_load_deployment_id="grpo-serving",
    )
    ref_fut = pool.submit(
        create_trainer_job, rlor_mgr,
        base_model=base_model, infra=infra, lora_rank=0,
        display_name="grpo-reference", forward_only=True,
    )
    dep_info = dep_fut.result()
    policy_ep = pol_fut.result()
    reference_ep = ref_fut.result()

policy = ReconnectableClient(rlor_mgr, policy_ep.job_id, base_model, lora_rank=0, fw_api_key=api_key)
reference = ReconnectableClient(rlor_mgr, reference_ep.job_id, base_model, lora_rank=0, fw_api_key=api_key)

tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True)
sampler = DeploymentSampler(
    inference_url=deploy_mgr.inference_url,
    model=dep_info.inference_model if dep_info else base_model,
    api_key=api_key,
    tokenizer=tokenizer,
)

Training loop

import asyncio

tracker = WeightSyncer(
    policy_client=policy.inner,
    deploy_mgr=deploy_mgr,
    deployment_id="grpo-serving",
    base_model=base_model,
    hotload_timeout=600,
    first_checkpoint_type="base",
)

for row in dataset:
    input_messages = [m for m in row["messages"] if m.get("role") != "assistant"]
    completions = asyncio.run(
        sampler.sample_with_tokens(messages=input_messages, n=4, max_tokens=512)
    )
    rewards = [score(c) for c in completions]
    if len(set(rewards)) == 1:
        continue

    datums = build_grpo_datums(completions)
    ref_fwd = reference.forward(datums, "cross_entropy")
    ref_logprobs = [list(x["logprobs"].data) for x in ref_fwd.loss_fn_outputs]

    loss_fn = make_grpo_loss_fn(rewards, ref_logprobs, kl_beta=0.001)
    policy.forward_backward_custom(datums, loss_fn)
    policy.optim_step(
        tinker.AdamParams(learning_rate=1e-5, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01)
    )

    tracker.save_and_hotload(f"step-{step:05d}")
See Loss Functions for make_grpo_loss_fn and build_grpo_datums implementations.

Operational guidance

  • deployment.tokenizer_model is required — the SDK raises ValueError if not set.
  • Set infra.training_shape_id — training shapes are the launch path for cookbook trainers.
  • Set infra.ref_training_shape_id when you want a reference trainer — if it is unset, the recipe skips reference-model provisioning entirely.
  • Skip prompts with uniform rewards (all correct or all wrong) — they provide no learning signal.
  • Track reward distributions and KL every step to catch objective drift early.
  • When configured, the reference trainer uses --forward-only — never call optim_step on it.
  • Sampling is async under the hood: DeploymentSampler.sample_with_tokens() issues n concurrent n=1 requests, so synchronous scripts should wrap it with asyncio.run(...).

Common pitfalls

  • Reward normalization bugs can destabilize GRPO updates quickly — verify advantage computation.
  • Reference/policy tokenizer mismatch invalidates KL estimates — always use the same base_model.
  • Logprob alignment: Trainer returns N-1 logprobs for N tokens. Inference returns N logprobs where the first is None. Use inference[1:] to align.