Skip to main content

What this is

This guide walks through GRPO (Group Relative Policy Optimization) training using the Fireworks Training SDK. GRPO samples multiple completions per prompt, scores them with a reward function, and uses the group’s reward statistics to compute per-token advantages for policy gradient updates. Two variants are covered:
  • On-policy GRPO: Hotload after every optimizer step so the sampling policy always matches the training policy. No importance sampling needed.
  • Off-policy GRPO: Hotload at intervals (e.g. every 10 steps) and use importance sampling to correct for the stale behavior policy.

Architecture

GRPO requires two RLOR trainer jobs:
JobRoleConfiguration
Policy trainerTrainable model — runs forward_backward_custom + optim_steplearning_rate > 0, linked to deployment via hot_load_deployment_id
Reference trainerFrozen copy of the initial model — provides KL reference logprobs--forward-only flag, never call optim_step
Plus a deployment for sampling completions via DeploymentSampler (client-side tokenized, token-in/token-out).

Step 1: Provision resources

import os
import tinker
import transformers
from concurrent.futures import ThreadPoolExecutor
from fireworks.training.sdk import (
    FiretitanServiceClient,
    TrainerJobManager,
    TrainerJobConfig,
    DeploymentManager,
    DeploymentSampler,
    WeightSyncer,
)
from training.utils import (
    InfraConfig,
    DeployConfig,
    create_trainer_job,
    setup_deployment,
)

api_key = os.environ["FIREWORKS_API_KEY"]
account_id = os.environ.get("FIREWORKS_ACCOUNT_ID", "")
base_url = os.environ.get("FIREWORKS_BASE_URL", "https://api.fireworks.ai")

rlor_mgr = TrainerJobManager(api_key=api_key, account_id=account_id, base_url=base_url)
deploy_mgr = DeploymentManager(api_key=api_key, account_id=account_id, base_url=base_url)

base_model = "accounts/fireworks/models/qwen3-8b"
infra = InfraConfig()
deploy_cfg = DeployConfig(deployment_id="grpo-serving",
                          tokenizer_model="Qwen/Qwen3-8B")

# Create all resources in parallel (deployment + 2 trainers)
with ThreadPoolExecutor(max_workers=3) as pool:
    dep_fut = pool.submit(setup_deployment, deploy_mgr, deploy_cfg, base_model, infra)
    pol_fut = pool.submit(
        create_trainer_job, rlor_mgr,
        base_model=base_model, infra=infra, lora_rank=0,
        display_name="grpo-policy",
        hot_load_deployment_id="grpo-serving",
    )
    ref_fut = pool.submit(
        create_trainer_job, rlor_mgr,
        base_model=base_model, infra=infra, lora_rank=0,
        display_name="grpo-reference",
        extra_args=["--forward-only", "--no-compile"],
    )
    dep_info = dep_fut.result()
    policy_ep = pol_fut.result()
    reference_ep = ref_fut.result()

# Connect training clients (FiretitanServiceClient adds checkpoint_type + session ID)
policy_svc = FiretitanServiceClient(base_url=policy_ep.base_url, api_key=api_key)
policy_client = policy_svc.create_training_client(base_model=base_model, lora_rank=0)

ref_svc = FiretitanServiceClient(base_url=reference_ep.base_url, api_key=api_key)
reference_client = ref_svc.create_training_client(base_model=base_model, lora_rank=0)

# Set up client-side tokenized sampling from deployment
inference_model = dep_info.inference_model if dep_info else base_model
tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True)
sampler = DeploymentSampler(
    inference_url=deploy_mgr.inference_url,
    model=inference_model,
    api_key=api_key,
    tokenizer=tokenizer,
)

Step 2: Build training data

GRPO uses tinker.Datum objects with token weights that distinguish prompt tokens (weight=0) from response tokens (weight=1). The tinker_cookbook helper handles internal token shifting:
import torch
from tinker_cookbook.supervised.common import datum_from_tokens_weights

def build_grpo_datums(sampled_completions, max_seq_len=8192):
    datums = []
    for completion in sampled_completions:
        full_tokens = completion.full_tokens  # prompt + response token IDs
        weights = torch.zeros(len(full_tokens), dtype=torch.float32)
        weights[completion.prompt_len:] = 1.0  # Only response tokens contribute to loss

        datum = datum_from_tokens_weights(
            torch.tensor(full_tokens, dtype=torch.long),
            weights,
            max_length=max_seq_len,
        )
        datums.append(datum)
    return datums

Step 3: GRPO loss function

The loss function receives data (list of Datum) and logprobs_list (list of torch.Tensor with requires_grad=True), and returns (loss, metrics_dict):
def make_grpo_loss_fn(rewards, ref_logprobs_list, kl_beta=0.001, eps=1e-8):
    K = len(rewards)
    rewards_t = torch.tensor(rewards, dtype=torch.float32)
    mean_r, std_r = rewards_t.mean(), rewards_t.std()
    if std_r < eps:
        std_r = torch.tensor(1.0)
    advantages = ((rewards_t - mean_r) / (std_r + eps)).tolist()
    ref_tensors = [torch.tensor(lp, dtype=torch.float32) for lp in ref_logprobs_list]

    def loss_fn(data, logprobs_list):
        total_loss = torch.tensor(0.0)
        total_kl = 0.0

        for i in range(K):
            pi_lp = logprobs_list[i]
            ref_lp = ref_tensors[i]
            weights = torch.tensor(data[i].loss_fn_inputs["weights"].data, dtype=torch.float32)

            min_len = min(len(pi_lp), len(ref_lp), len(weights))
            pi_lp_t, ref_lp_t, weights_t = pi_lp[:min_len], ref_lp[:min_len], weights[:min_len]

            pi_sum = torch.dot(pi_lp_t.float(), weights_t)
            kl_term = torch.dot((pi_lp_t - ref_lp_t).float(), weights_t)

            total_loss = total_loss + (-advantages[i] * pi_sum) + (kl_beta * kl_term)

            with torch.no_grad():
                total_kl += kl_term.item()

        loss = total_loss / K
        return loss, {"grpo_loss": loss.item(), "mean_kl": total_kl / K}

    return loss_fn

Step 4: On-policy training loop (prompt-batched)

# Set up WeightSyncer for automatic delta-chain management
tracker = WeightSyncer(
    policy_client=policy_client,
    deploy_mgr=deploy_mgr,       # DeploymentManager instance
    deployment_id="grpo-serving",
    base_model="accounts/fireworks/models/qwen3-8b",
    hotload_timeout=600,
    first_checkpoint_type="base",
)

prompt_groups_per_step = 1  # recipe default
prompt_buffer = []
global_step = 0

for row in dataset:
    input_messages = [m for m in row["messages"] if m.get("role") != "assistant"]

    # 1) Sample with client-side tokenization
    completions = sampler.sample_with_tokens(messages=input_messages, n=4, max_tokens=512)
    rewards = [score(c) for c in completions]
    if len(set(rewards)) == 1:
        continue

    # 2) Build one prompt-group (datums + advantages + reference logprobs)
    datums = build_grpo_datums(completions)
    ref_fwd = reference_client.forward(datums, "cross_entropy").result()
    ref_logprobs = [list(x["logprobs"].data) for x in ref_fwd.loss_fn_outputs]
    prompt_buffer.append((datums, rewards, ref_logprobs))

    # 3) Train once buffer reaches prompt_groups_per_step
    if len(prompt_buffer) < prompt_groups_per_step:
        continue

    # Concatenate all buffered prompts into one training batch
    combined_datums, combined_rewards, combined_ref = [], [], []
    for d, r, ref in prompt_buffer:
        combined_datums.extend(d)
        combined_rewards.extend(r)
        combined_ref.extend(ref)

    loss_fn = make_grpo_loss_fn(combined_rewards, combined_ref, kl_beta=0.001)
    result = policy_client.forward_backward_custom(combined_datums, loss_fn).result()
    policy_client.optim_step(
        tinker.AdamParams(learning_rate=1e-5, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01)
    ).result()
    global_step += 1
    prompt_buffer = []

    # 4) On-policy hotload cadence
    tracker.save_and_hotload(f"step-{global_step:05d}")
    print(f"Step {global_step}: {result.metrics}")

Off-policy variant

Off-policy GRPO hotloads at intervals and adds importance sampling to correct for the stale sampling policy:
# Key differences from on-policy:
# 1. Hotload every N steps instead of every step
# 2. Compute behavior policy logprobs from deployment (prefill)
# 3. Loss includes importance ratio: rho = exp(pi_current - pi_behavior)

def make_offpolicy_grpo_loss_fn(rewards, ref_logprobs_list, behavior_logprobs_list,
                                 kl_beta=0.001, clip_rho=10.0):
    # ... same advantage computation as on-policy ...

    def loss_fn(data, logprobs_list):
        for i in range(K):
            pi_lp = logprobs_list[i]       # Current policy (requires_grad)
            ref_lp = ref_tensors[i]         # Reference (frozen)
            behavior_lp = behavior_tensors[i]  # Behavior policy at sampling time

            # Importance ratio corrects for stale samples
            log_rho = torch.dot((pi_lp - behavior_lp).float(), weights)
            rho = torch.clamp(torch.exp(log_rho / n_response_tokens), max=clip_rho)

            # Weighted policy gradient with importance sampling
            total_loss += rho * (-advantage * pi_sum + kl_beta * kl_term)
        return total_loss / K, metrics

    return loss_fn
To get behavior logprobs, run a prefill on the deployment:
# Uses /v1/completions with echo=True to get logprobs for input tokens
resp = httpx.post(
    f"https://api.fireworks.ai/inference/v1/completions",
    headers={"Authorization": f"Bearer {api_key}"},
    json={
        "model": f"accounts/{account_id}/deployments/{deployment_id}",
        "prompt": token_ids,   # Pass token IDs directly
        "max_tokens": 1,
        "echo": True,
        "logprobs": 1,
        "prompt_cache_max_len": 0,
    },
)
behavior_logprobs = resp.json()["choices"][0]["logprobs"]["token_logprobs"]
# Alignment: inference logprobs[1:] matches trainer logprobs[0:]

Cookbook recipe entrypoint

The cookbook provides a streaming RL loop (rl_loop.py) that handles rollout scheduling, greedy batching, and metrics automatically. It supports multiple policy loss variants via a single policy_loss string:
from training.recipes.rl_loop import Config, main
from training.utils import DeployConfig, HotloadConfig, WandBConfig

cfg = Config(
    base_model="accounts/fireworks/models/qwen3-8b",
    dataset="/path/to/gsm8k.jsonl",
    max_rows=200,
    epochs=1,
    completions_per_prompt=4,
    max_completion_tokens=1024,
    temperature=1.0,
    max_seq_len=4096,
    policy_loss="grpo",  # or "dapo", "gspo", "cispo"
    deployment=DeployConfig(
        deployment_id="grpo-serving",
        tokenizer_model="Qwen/Qwen3-8B",
    ),
    hotload=HotloadConfig(hot_load_interval=1),  # on-policy
    wandb=WandBConfig(entity="my-team", project="grpo-experiment"),
)

main(cfg)

Config fields reference

FieldTypeDefaultDescription
base_modelstr"accounts/fireworks/models/qwen3-8b"Base model name
datasetstrGSM8K sample URLDataset path or URL (JSONL)
learning_ratefloat1e-5Learning rate
kl_betafloat0.001KL penalty coefficient (GRPO only)
completions_per_promptint4Number of completions per prompt
max_completion_tokensint1024Max tokens per sampled completion
temperaturefloat1.0Sampling temperature
epochsint1Training epochs over the dataset
max_rowsint100Number of dataset prompts to use
max_seq_lenint4096Max sequence length for training
lora_rankint00 for full-parameter, >0 for LoRA
prompt_groups_per_stepint1Prompt groups per optimizer step
min_samples_per_fwd_bwdint | NoneNone (= max)Min samples to trigger a fwd_bwd call. Controls greedy batching granularity.
max_samples_per_fwd_bwdint256Cap on samples per fwd_bwd call
max_concurrentint32Cap on concurrent in-flight sampling requests
policy_lossstr"grpo"Policy loss variant: "grpo", "dapo", "gspo", or "cispo"
tis_enabledboolFalseEnable Truncated Importance Sampling (composes with any loss)
router_replayboolFalseEnable Router Replay (R3) for MoE models
Sub-configs: infra (InfraConfig), deployment (DeployConfig), hotload (HotloadConfig), wandb (WandBConfig), resume (ResumeConfig), tis (ISConfig), dapo (DAPOConfig), gspo (GSPOConfig), cispo (CISPOConfig).

Operational guidance

  • Service mode supports both full-parameter and LoRA tuning for both policy and reference trainers.
  • deployment.tokenizer_model is required for GRPO — the SDK raises ValueError if not set. Use the HuggingFace model name (e.g. "Qwen/Qwen3-8B").
  • Tokenizer is client-side: GRPO uses AutoTokenizer + DeploymentSampler.sample_with_tokens(...) (token-in/token-out path).
  • Streaming rollout scheduling: Sampling runs concurrently (capped by max_concurrent), and fwd_bwd calls fire greedily as soon as min_samples_per_fwd_bwd samples accumulate. prompt_groups_per_step controls how many groups form one optimizer step.
  • Policy loss variants: "grpo" (REINFORCE + KL), "dapo" (asymmetric PPO clipping), "gspo" (sequence-level clipped PPO), "cispo" (importance-sampling masking). Set via cfg.policy_loss.
  • TIS (Truncated Importance Sampling) is orthogonal — enable with cfg.tis_enabled=True on top of any policy loss.
  • Reference trainer uses --forward-only and --no-compile flags (never call optim_step on it).
  • Track reward distributions and KL every step to catch objective drift early.
  • Skip prompts with uniform rewards (all correct or all wrong) — they provide no learning signal.
  • W&B logging: set cfg.wandb with WandBConfig(entity=..., project=..., run_name=...).

Common pitfalls

  • Reward normalization bugs can destabilize GRPO updates quickly — verify advantage computation.
  • Reference/policy tokenizer mismatch invalidates KL estimates — always use the same base_model.
  • Off-policy ratio explosion: If rho grows too large, clamp it with clip_rho.
  • Logprob alignment: Trainer returns N-1 logprobs for N tokens (shifted). Inference returns N logprobs where the first is None. Use inference[1:] to align with trainer.