Skip to main content

What this is

This guide walks through GRPO (Group Relative Policy Optimization) training using the Fireworks Training SDK. GRPO samples multiple completions per prompt, scores them with a reward function, and uses the group’s reward statistics to compute per-token advantages for policy gradient updates.

Architecture

GRPO requires two RLOR trainer jobs:
JobRoleConfiguration
Policy trainerTrainable model — runs forward_backward_custom + optim_steplearning_rate > 0, linked to deployment via hot_load_deployment_id
Reference trainerFrozen copy of the initial model — provides KL reference logprobs--forward-only flag, never call optim_step
Plus a deployment for sampling completions via DeploymentSampler (client-side tokenized, token-in/token-out).

Step 1: Provision resources

The cookbook helpers create_trainer_job and setup_deployment wrap the SDK’s TrainerJobManager and DeploymentManager with config-driven defaults.
import os
import tinker
import transformers
from concurrent.futures import ThreadPoolExecutor
from fireworks.training.sdk import (
    TrainerJobManager,
    DeploymentManager,
    DeploymentSampler,
    WeightSyncer,
)
from training.utils import (
    InfraConfig,
    DeployConfig,
    ReconnectableClient,
    create_trainer_job,
    setup_deployment,
)

api_key = os.environ["FIREWORKS_API_KEY"]
account_id = os.environ.get("FIREWORKS_ACCOUNT_ID", "")
base_url = os.environ.get("FIREWORKS_BASE_URL", "https://api.fireworks.ai")

rlor_mgr = TrainerJobManager(api_key=api_key, account_id=account_id, base_url=base_url)
deploy_mgr = DeploymentManager(api_key=api_key, account_id=account_id, base_url=base_url)

base_model = "accounts/fireworks/models/qwen3-8b"
infra = InfraConfig()
deploy_cfg = DeployConfig(deployment_id="grpo-serving",
                          tokenizer_model="Qwen/Qwen3-8B")

# Create all resources in parallel (deployment + 2 trainers)
with ThreadPoolExecutor(max_workers=3) as pool:
    dep_fut = pool.submit(setup_deployment, deploy_mgr, deploy_cfg, base_model, infra)
    pol_fut = pool.submit(
        create_trainer_job, rlor_mgr,
        base_model=base_model, infra=infra, lora_rank=0,
        display_name="grpo-policy",
        hot_load_deployment_id="grpo-serving",
    )
    ref_fut = pool.submit(
        create_trainer_job, rlor_mgr,
        base_model=base_model, infra=infra, lora_rank=0,
        display_name="grpo-reference", forward_only=True,
    )
    dep_info = dep_fut.result()
    policy_ep = pol_fut.result()
    reference_ep = ref_fut.result()

# ReconnectableClient wraps the SDK training client with auto-reconnect
policy = ReconnectableClient(rlor_mgr, policy_ep.job_id, base_model, lora_rank=0, fw_api_key=api_key)
reference = ReconnectableClient(rlor_mgr, reference_ep.job_id, base_model, lora_rank=0, fw_api_key=api_key)

# Set up client-side tokenized sampling from deployment
inference_model = dep_info.inference_model if dep_info else base_model
tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True)
sampler = DeploymentSampler(
    inference_url=deploy_mgr.inference_url,
    model=inference_model,
    api_key=api_key,
    tokenizer=tokenizer,
)

Step 2: Build training data

GRPO uses tinker.Datum objects with token weights that distinguish prompt tokens (weight=0) from response tokens (weight=1). The tinker_cookbook helper handles internal token shifting:
import torch
from tinker_cookbook.supervised.common import datum_from_tokens_weights

def build_grpo_datums(sampled_completions, max_seq_len=8192):
    datums = []
    for completion in sampled_completions:
        full_tokens = completion.full_tokens  # prompt + response token IDs
        weights = torch.zeros(len(full_tokens), dtype=torch.float32)
        weights[completion.prompt_len:] = 1.0  # Only response tokens contribute to loss

        datum = datum_from_tokens_weights(
            torch.tensor(full_tokens, dtype=torch.long),
            weights,
            max_length=max_seq_len,
        )
        datums.append(datum)
    return datums

Step 3: GRPO loss function

The loss function receives data (list of Datum) and logprobs_list (list of torch.Tensor with requires_grad=True), and returns (loss, metrics_dict):
def make_grpo_loss_fn(rewards, ref_logprobs_list, kl_beta=0.001, eps=1e-8):
    K = len(rewards)
    rewards_t = torch.tensor(rewards, dtype=torch.float32)
    mean_r, std_r = rewards_t.mean(), rewards_t.std()
    if std_r < eps:
        std_r = torch.tensor(1.0)
    advantages = ((rewards_t - mean_r) / (std_r + eps)).tolist()
    ref_tensors = [torch.tensor(lp, dtype=torch.float32) for lp in ref_logprobs_list]

    def loss_fn(data, logprobs_list):
        total_loss = torch.tensor(0.0)
        total_kl = 0.0

        for i in range(K):
            pi_lp = logprobs_list[i]
            ref_lp = ref_tensors[i]
            weights = torch.tensor(data[i].loss_fn_inputs["weights"].data, dtype=torch.float32)

            min_len = min(len(pi_lp), len(ref_lp), len(weights))
            pi_lp_t, ref_lp_t, weights_t = pi_lp[:min_len], ref_lp[:min_len], weights[:min_len]

            pi_sum = torch.dot(pi_lp_t.float(), weights_t)
            kl_term = torch.dot((pi_lp_t - ref_lp_t).float(), weights_t)

            total_loss = total_loss + (-advantages[i] * pi_sum) + (kl_beta * kl_term)

            with torch.no_grad():
                total_kl += kl_term.item()

        loss = total_loss / K
        return loss, {"grpo_loss": loss.item(), "mean_kl": total_kl / K}

    return loss_fn

Step 4: Training loop (prompt-batched)

# Set up WeightSyncer for automatic delta-chain management
tracker = WeightSyncer(
    policy_client=policy.inner,   # WeightSyncer takes the raw SDK client
    deploy_mgr=deploy_mgr,
    deployment_id="grpo-serving",
    base_model="accounts/fireworks/models/qwen3-8b",
    hotload_timeout=600,
    first_checkpoint_type="base",
)

prompt_groups_per_step = 1  # recipe default
prompt_buffer = []
global_step = 0

for row in dataset:
    input_messages = [m for m in row["messages"] if m.get("role") != "assistant"]

    # 1) Sample with client-side tokenization
    completions = sampler.sample_with_tokens(messages=input_messages, n=4, max_tokens=512)
    rewards = [score(c) for c in completions]
    if len(set(rewards)) == 1:
        continue

    # 2) Build one prompt-group (datums + advantages + reference logprobs)
    datums = build_grpo_datums(completions)
    ref_fwd = reference.forward(datums, "cross_entropy")
    ref_logprobs = [list(x["logprobs"].data) for x in ref_fwd.loss_fn_outputs]
    prompt_buffer.append((datums, rewards, ref_logprobs))

    # 3) Train once buffer reaches prompt_groups_per_step
    if len(prompt_buffer) < prompt_groups_per_step:
        continue

    # Concatenate all buffered prompts into one training batch
    combined_datums, combined_rewards, combined_ref = [], [], []
    for d, r, ref in prompt_buffer:
        combined_datums.extend(d)
        combined_rewards.extend(r)
        combined_ref.extend(ref)

    loss_fn = make_grpo_loss_fn(combined_rewards, combined_ref, kl_beta=0.001)
    result = policy.forward_backward_custom(combined_datums, loss_fn)
    policy.optim_step(
        tinker.AdamParams(learning_rate=1e-5, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01)
    )
    global_step += 1
    prompt_buffer = []

    # 4) Hotload updated weights to deployment
    tracker.save_and_hotload(f"step-{global_step:05d}")
    print(f"Step {global_step}: {result.metrics}")

Cookbook recipe entrypoint

The cookbook provides a streaming RL loop (rl_loop.py) that handles rollout scheduling, greedy batching, and metrics automatically. It supports multiple policy loss variants via a single policy_loss string:
from training.recipes.rl_loop import Config, main
from training.utils import DeployConfig, HotloadConfig, WandBConfig

cfg = Config(
    log_path="./grpo_logs",
    base_model="accounts/fireworks/models/qwen3-8b",
    dataset="/path/to/gsm8k.jsonl",
    max_rows=200,
    epochs=1,
    completions_per_prompt=4,
    max_completion_tokens=1024,
    temperature=1.0,
    max_seq_len=4096,
    policy_loss="grpo",  # or "dapo", "gspo", "cispo"
    deployment=DeployConfig(
        deployment_id="grpo-serving",
        tokenizer_model="Qwen/Qwen3-8B",
    ),
    hotload=HotloadConfig(hot_load_interval=1),
    wandb=WandBConfig(entity="my-team", project="grpo-experiment"),
)

main(cfg)

Config fields reference

FieldTypeDefaultDescription
log_pathstr— (required)Directory for checkpoints.jsonl and logs
base_modelstr"accounts/fireworks/models/qwen3-8b"Base model name
datasetstrGSM8K sample URLDataset path or URL (JSONL)
learning_ratefloat1e-5Learning rate
kl_betafloat0.001KL penalty coefficient (set to 0 to skip reference model)
completions_per_promptint4Number of completions per prompt
max_completion_tokensint1024Max tokens per sampled completion
temperaturefloat1.0Sampling temperature
epochsint1Training epochs over the dataset
max_rowsint100Number of dataset prompts to use
max_seq_lenint | NoneNoneMax sequence length (auto-populated from training shape if not set)
lora_rankint00 for full-parameter, >0 for LoRA
prompt_groups_per_stepint1Prompt groups per optimizer step
max_concurrentint32Cap on concurrent in-flight sampling requests
policy_lossstr"grpo"Policy loss variant: "grpo", "dapo", "gspo", or "cispo"
router_replayboolFalseEnable Router Replay (R3) for MoE models
init_from_checkpointstr | NoneNoneLoad pretrained DCP weights (supports "job_id:checkpoint_name")
Sub-configs: infra (InfraConfig), deployment (DeployConfig), hotload (HotloadConfig), wandb (WandBConfig), is_correction (ISConfig), dapo (DAPOConfig), gspo (GSPOConfig), cispo (CISPOConfig).

Operational guidance

  • Service mode supports both full-parameter and LoRA tuning for both policy and reference trainers.
  • deployment.tokenizer_model is required for GRPO — the SDK raises ValueError if not set. Use the HuggingFace model name (e.g. "Qwen/Qwen3-8B").
  • Tokenizer is client-side: GRPO uses AutoTokenizer + DeploymentSampler.sample_with_tokens(...) (token-in/token-out path).
  • Streaming rollout scheduling: Sampling runs concurrently (capped by max_concurrent). prompt_groups_per_step controls how many groups form one optimizer step.
  • Policy loss variants: "grpo" (REINFORCE + KL), "dapo" (asymmetric PPO clipping), "gspo" (sequence-level clipped PPO). Set via cfg.policy_loss.
  • Reference trainer uses --forward-only and --no-compile flags (never call optim_step on it).
  • Track reward distributions and KL every step to catch objective drift early.
  • Skip prompts with uniform rewards (all correct or all wrong) — they provide no learning signal.
  • W&B logging: set cfg.wandb with WandBConfig(entity=..., project=..., run_name=...).

Common pitfalls

  • Reward normalization bugs can destabilize GRPO updates quickly — verify advantage computation.
  • Reference/policy tokenizer mismatch invalidates KL estimates — always use the same base_model.
  • Logprob alignment: Trainer returns N-1 logprobs for N tokens (shifted). Inference returns N logprobs where the first is None. Use inference[1:] to align with trainer.