Skip to main content

What this is

This guide walks through DPO (Direct Preference Optimization) training using the cookbook. DPO learns from preference pairs (chosen vs. rejected responses) without a separate reward model.

How DPO differs from GRPO

DPOGRPO
Trainer jobs2 (policy + frozen reference)2 (policy + frozen reference)
DataPreference pairs (chosen/rejected)Prompts + reward function
Reference logprobsCached once at initializationComputed every step
Loss-log(sigmoid(beta * margin))Advantage-weighted policy gradient + KL

Architecture

Using the recipe

from training.recipes.dpo_loop import Config, main
from training.utils import DeployConfig, InfraConfig, WandBConfig

cfg = Config(
    log_path="./dpo_logs",
    base_model="accounts/fireworks/models/qwen3-8b",
    dataset="/path/to/preference_data.jsonl",
    tokenizer_model="Qwen/Qwen3-8B",
    beta=0.1,
    epochs=1,
    grad_accum=4,
    max_seq_len=4096,
    infra=InfraConfig(
        training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200",
        ref_training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200-forward",
    ),
    deployment=DeployConfig(
        deployment_id="dpo-serving",
    ),
    wandb=WandBConfig(entity="my-team", project="dpo-experiment"),
)

main(cfg)

Dataset format

DPO expects preference pairs. Supported formats: Format 1 — chosen/rejected messages:
{
  "chosen": {"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "good response"}]},
  "rejected": {"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "bad response"}]}
}
Format 2 — input/output split:
{
  "input": {"messages": [{"role": "user", "content": "..."}]},
  "preferred_output": [{"role": "assistant", "content": "good"}],
  "non_preferred_output": [{"role": "assistant", "content": "bad"}]
}

Step-by-step (SDK-level)

Provision trainers

import os
import tinker
from concurrent.futures import ThreadPoolExecutor
from fireworks.training.sdk import TrainerJobManager, DeploymentManager, WeightSyncer
from training.utils import InfraConfig, DeployConfig, ReconnectableClient, create_trainer_job, setup_deployment

api_key = os.environ["FIREWORKS_API_KEY"]
account_id = os.environ.get("FIREWORKS_ACCOUNT_ID", "")
base_url = os.environ.get("FIREWORKS_BASE_URL", "https://api.fireworks.ai")

rlor_mgr = TrainerJobManager(api_key=api_key, account_id=account_id, base_url=base_url)
deploy_mgr = DeploymentManager(api_key=api_key, account_id=account_id, base_url=base_url)

base_model = "accounts/fireworks/models/qwen3-8b"
infra = InfraConfig(
    training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200",
    ref_training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200-forward",
)

with ThreadPoolExecutor(max_workers=2) as pool:
    pol_fut = pool.submit(
        create_trainer_job, rlor_mgr,
        base_model=base_model, infra=infra, lora_rank=0,
        display_name="dpo-policy", hot_load_deployment_id="dpo-serving",
    )
    ref_fut = pool.submit(
        create_trainer_job, rlor_mgr,
        base_model=base_model, infra=infra, lora_rank=0,
        display_name="dpo-reference", forward_only=True,
    )
    policy_ep = pol_fut.result()
    reference_ep = ref_fut.result()

policy_client = ReconnectableClient(rlor_mgr, policy_ep.job_id, base_model, lora_rank=0, fw_api_key=api_key)
reference_client = ReconnectableClient(rlor_mgr, reference_ep.job_id, base_model, lora_rank=0, fw_api_key=api_key)

Cache reference logprobs

Reference logprobs are computed once at initialization and reused throughout training:
ref_cache = {}
for i, (chosen_tokens, rejected_tokens, prompt_len) in enumerate(dataset):
    chosen_datum, rejected_datum = build_dpo_datums(
        chosen_tokens, rejected_tokens, prompt_len, max_seq_len=4096,
    )
    fwd = reference_client.forward([chosen_datum, rejected_datum], "cross_entropy")
    ref_cache[i] = {
        "ref_chosen": fwd.loss_fn_outputs[0]["logprobs"].data,
        "ref_rejected": fwd.loss_fn_outputs[1]["logprobs"].data,
        "chosen_tokens": chosen_tokens,
        "rejected_tokens": rejected_tokens,
        "prompt_len": prompt_len,
    }

DPO loss function

import torch
import torch.nn.functional as F

def make_dpo_loss_fn(ref_chosen_logprobs, ref_rejected_logprobs, beta=0.1):
    ref_chosen_t = torch.tensor(ref_chosen_logprobs, dtype=torch.float32)
    ref_rejected_t = torch.tensor(ref_rejected_logprobs, dtype=torch.float32)

    def loss_fn(data, logprobs_list):
        pi_chosen, pi_rejected = logprobs_list[0], logprobs_list[1]
        chosen_weights = torch.tensor(data[0].loss_fn_inputs["weights"].data, dtype=torch.float32)
        rejected_weights = torch.tensor(data[1].loss_fn_inputs["weights"].data, dtype=torch.float32)

        pi_chosen_sum = torch.dot(pi_chosen.float(), chosen_weights)
        pi_rejected_sum = torch.dot(pi_rejected.float(), rejected_weights)
        ref_chosen_sum = torch.dot(ref_chosen_t.float(), chosen_weights)
        ref_rejected_sum = torch.dot(ref_rejected_t.float(), rejected_weights)

        margin = (pi_chosen_sum - ref_chosen_sum) - (pi_rejected_sum - ref_rejected_sum)
        dpo_loss = -F.logsigmoid(beta * margin)

        with torch.no_grad():
            accuracy = 1.0 if margin.item() > 0 else 0.0

        return dpo_loss, {"dpo_loss": dpo_loss.item(), "margin": margin.item(), "accuracy": accuracy}

    return loss_fn

Training loop

step = 0
accum_count = 0
grad_accum = 4

for idx in ref_cache:
    cached = ref_cache[idx]
    chosen_datum, rejected_datum = build_dpo_datums(
        cached["chosen_tokens"], cached["rejected_tokens"],
        cached["prompt_len"], max_seq_len=4096,
    )
    loss_fn = make_dpo_loss_fn(
        ref_chosen_logprobs=cached["ref_chosen"],
        ref_rejected_logprobs=cached["ref_rejected"],
        beta=0.1,
    )
    result = policy_client.forward_backward_custom([chosen_datum, rejected_datum], loss_fn)
    accum_count += 1

    if accum_count >= grad_accum:
        policy_client.optim_step(
            tinker.AdamParams(learning_rate=1e-5, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01)
        )
        step += 1
        accum_count = 0
        print(f"Step {step}: {result.metrics}")

Operational guidance

  • Set infra.training_shape_id and infra.ref_training_shape_id — DPO launches both a policy trainer and a reference trainer.
  • DPO uses 2 RLOR jobs — policy trainer + frozen reference trainer.
  • DPO defaults weight_sync_interval=0 (no weight sync by default), unlike GRPO.
  • Keep a versioned reference cache tied to tokenizer + base model revision. If the base model changes, recompute reference logprobs.
  • Monitor margin statistics: increasing margins indicate the policy is learning preferences.

Common pitfalls

  • Mismatched formatting between chosen/rejected sequences corrupts preference signals — ensure identical prompt prefixes.
  • Stale reference cache: If you warm-start from a different model, cached reference logprobs are invalid.