Skip to main content

What this is

This runbook demonstrates an on-policy GRPO loop where sampling uses a deployment that is periodically hotloaded with the latest policy checkpoint.

Why this approach

  • On-policy sampling reduces mismatch between policy updates and sampled trajectories.
  • Reference-model KL terms stabilize optimization while preserving exploration.

How to use these APIs

  • Fireworks.reinforcement_fine_tuning_steps.create: Provision policy and reference trainer services.
  • TrainingClient.forward_backward_custom: Apply GRPO objective with reward and KL components.
  • TrainingClient.save_weights_for_sampler: Export checkpoints for deployment hotload.

Workflow

  1. Provision policy trainer (trainable) and reference trainer (frozen).
  2. Sample completions through deployment.
  3. Compute rewards and build token-weighted GRPO batches.
  4. Run custom loss update and optimizer step.
  5. Checkpoint and hotload deployment on cadence.

End-to-end examples

Provision policy and reference trainers

policy_job = fw.reinforcement_fine_tuning_steps.create(
    training_config={
        "base_model": "accounts/fireworks/models/qwen3-8b",
        "lora_rank": 0,
        "max_context_length": 4096,
        "learning_rate": 1e-5,
        "gradient_accumulation_steps": 4,
    },
    extra_body={"serviceMode": True, "keepAlive": False},
)
reference_job = fw.reinforcement_fine_tuning_steps.create(
    training_config={
        "base_model": "accounts/fireworks/models/qwen3-8b",
        "lora_rank": 0,
        "max_context_length": 4096,
        "learning_rate": 0,
        "gradient_accumulation_steps": 4,
    },
    extra_body={"serviceMode": True, "keepAlive": False},
)
policy = make_training_client(policy_job)
reference = make_training_client(reference_job)

Single GRPO update iteration

prompts = sample_prompts(batch_size=8)
completions = sample_with_deployment(prompts)
rewards = score_completions(completions)
batch = build_grpo_batch(prompts, completions)

ref_logprobs = reference.forward(batch).result().logprobs_list
grpo_loss = make_grpo_loss_fn(rewards=rewards, ref_logprobs_list=ref_logprobs, kl_beta=0.01)
policy.forward_backward_custom(batch, grpo_loss).result()
policy.optim_step(
    tinker.AdamParams(learning_rate=1e-5, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01)
).result()

Checkpoint and hotload serving

if step % 10 == 0:
    checkpoint = policy.save_weights_for_sampler(f"grpo-step-{step:05d}").result()
    hotload_deployment(checkpoint.path)
    eval_responses = sample_with_deployment(eval_prompts)
    eval_score = evaluate_responses(eval_responses)
    print({"step": step, "eval_score": eval_score})

Operational guidance

  • Service-mode trainer jobs currently support full-parameter tuning only. Keep lora_rank=0 for both policy and reference trainers.
  • Track reward distributions and KL terms every step to catch objective drift early.
  • Align hotload interval with evaluation cadence to keep metrics meaningful.

Common pitfalls

  • Reward normalization bugs can destabilize GRPO updates quickly.
  • Reference and policy tokenizer mismatch invalidates KL estimates.