Skip to main content

What this is

This guide walks through DPO (Direct Preference Optimization) training using the Fireworks Training SDK. DPO learns from preference pairs (chosen vs. rejected responses) without a separate reward model.

How DPO differs from GRPO

DPOGRPO
Trainer jobs2 (policy + frozen reference)2 (policy + frozen reference)
DataPreference pairs (chosen/rejected)Prompts + reward function
Reference logprobsCached once at initialization from frozen referenceComputed every step via frozen reference trainer
Loss-log(sigmoid(β × margin))Advantage-weighted policy gradient + KL
DPO uses two RLOR trainer jobs — a policy trainer and a frozen reference trainer (with --forward-only). Reference logprobs are computed once at initialization and cached for the entire training run.

Architecture

Step 1: Provision trainers

import os
import tinker
from concurrent.futures import ThreadPoolExecutor
from fireworks.training.sdk import TrainerJobManager, DeploymentManager, WeightSyncer
from training.utils import (
    InfraConfig, DeployConfig, ReconnectableClient,
    create_trainer_job, setup_deployment,
)

api_key = os.environ["FIREWORKS_API_KEY"]
account_id = os.environ.get("FIREWORKS_ACCOUNT_ID", "")
base_url = os.environ.get("FIREWORKS_BASE_URL", "https://api.fireworks.ai")

rlor_mgr = TrainerJobManager(api_key=api_key, account_id=account_id, base_url=base_url)
deploy_mgr = DeploymentManager(api_key=api_key, account_id=account_id, base_url=base_url)

base_model = "accounts/fireworks/models/qwen3-8b"
infra = InfraConfig()

# Create both trainers in parallel
with ThreadPoolExecutor(max_workers=2) as pool:
    pol_fut = pool.submit(
        create_trainer_job, rlor_mgr,
        base_model=base_model, infra=infra, lora_rank=0,
        display_name="dpo-policy",
        hot_load_deployment_id="dpo-serving",
    )
    ref_fut = pool.submit(
        create_trainer_job, rlor_mgr,
        base_model=base_model, infra=infra, lora_rank=0,
        display_name="dpo-reference", forward_only=True,
    )
    policy_ep = pol_fut.result()
    reference_ep = ref_fut.result()

# ReconnectableClient wraps the SDK training client with auto-reconnect
policy_client = ReconnectableClient(rlor_mgr, policy_ep.job_id, base_model, lora_rank=0, fw_api_key=api_key)
reference_client = ReconnectableClient(rlor_mgr, reference_ep.job_id, base_model, lora_rank=0, fw_api_key=api_key)

# Set up WeightSyncer for checkpoint + hotload
weight_syncer = WeightSyncer(
    policy_client=policy_client.inner,
    deploy_mgr=deploy_mgr,
    deployment_id="dpo-serving",
    base_model=base_model,
    hotload_timeout=600,
    first_checkpoint_type="base",
)

Step 2: Prepare preference data

DPO expects pairs of (chosen, rejected) responses to the same prompt. The starter script supports multiple dataset formats:
{"chosen": {"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "good response"}]},
 "rejected": {"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "bad response"}]}}
{"input": {"messages": [{"role": "user", "content": "..."}]},
 "preferred_output": [{"role": "assistant", "content": "good"}],
 "non_preferred_output": [{"role": "assistant", "content": "bad"}]}

Tokenize and build datums

Detect the shared prompt by finding the longest common token prefix between chosen and rejected:
import torch
from tinker_cookbook.supervised.common import datum_from_tokens_weights

def build_dpo_datums(chosen_tokens, rejected_tokens, prompt_len, max_seq_len):
    chosen_weights = torch.zeros(len(chosen_tokens), dtype=torch.float32)
    chosen_weights[prompt_len:] = 1.0  # Response tokens only
    rejected_weights = torch.zeros(len(rejected_tokens), dtype=torch.float32)
    rejected_weights[prompt_len:] = 1.0

    chosen_datum = datum_from_tokens_weights(
        torch.tensor(chosen_tokens, dtype=torch.long), chosen_weights, max_length=max_seq_len,
    )
    rejected_datum = datum_from_tokens_weights(
        torch.tensor(rejected_tokens, dtype=torch.long), rejected_weights, max_length=max_seq_len,
    )
    return chosen_datum, rejected_datum

Step 3: Cache reference logprobs

Before training starts, compute reference logprobs from the frozen reference trainer:
ref_cache = {}

for i, (chosen_tokens, rejected_tokens, prompt_len) in enumerate(dataset):
    chosen_datum, rejected_datum = build_dpo_datums(
        chosen_tokens, rejected_tokens, prompt_len, max_seq_len=4096,
    )

    # Forward pass on the frozen reference trainer (no backward, no training)
    fwd = reference_client.forward([chosen_datum, rejected_datum], "cross_entropy")

    ref_cache[i] = {
        "ref_chosen": fwd.loss_fn_outputs[0]["logprobs"].data,
        "ref_rejected": fwd.loss_fn_outputs[1]["logprobs"].data,
        "chosen_tokens": chosen_tokens,
        "rejected_tokens": rejected_tokens,
        "prompt_len": prompt_len,
    }

Step 4: DPO loss function

The loss function computes response-only log-probability sums using the weights from each datum, then applies the DPO margin:
import torch.nn.functional as F

def make_dpo_loss_fn(ref_chosen_logprobs, ref_rejected_logprobs, beta=0.1):
    ref_chosen_t = torch.tensor(ref_chosen_logprobs, dtype=torch.float32)
    ref_rejected_t = torch.tensor(ref_rejected_logprobs, dtype=torch.float32)

    def loss_fn(data, logprobs_list):
        pi_chosen, pi_rejected = logprobs_list[0], logprobs_list[1]
        chosen_weights = torch.tensor(data[0].loss_fn_inputs["weights"].data, dtype=torch.float32)
        rejected_weights = torch.tensor(data[1].loss_fn_inputs["weights"].data, dtype=torch.float32)

        # Weighted log-probability sums (response tokens only)
        pi_chosen_sum = torch.dot(pi_chosen.float(), chosen_weights)
        pi_rejected_sum = torch.dot(pi_rejected.float(), rejected_weights)
        ref_chosen_sum = torch.dot(ref_chosen_t.float(), chosen_weights)
        ref_rejected_sum = torch.dot(ref_rejected_t.float(), rejected_weights)

        # DPO margin and loss
        margin = (pi_chosen_sum - ref_chosen_sum) - (pi_rejected_sum - ref_rejected_sum)
        dpo_loss = -F.logsigmoid(beta * margin)

        with torch.no_grad():
            accuracy = 1.0 if margin.item() > 0 else 0.0

        return dpo_loss, {
            "dpo_loss": dpo_loss.item(),
            "margin": margin.item(),
            "accuracy": accuracy,
        }

    return loss_fn

Step 5: Training loop

step = 0
accum_count = 0
grad_accum = 4

for idx in ref_cache:
    cached = ref_cache[idx]
    chosen_datum, rejected_datum = build_dpo_datums(
        cached["chosen_tokens"], cached["rejected_tokens"],
        cached["prompt_len"], max_seq_len=4096,
    )

    loss_fn = make_dpo_loss_fn(
        ref_chosen_logprobs=cached["ref_chosen"],
        ref_rejected_logprobs=cached["ref_rejected"],
        beta=0.1,
    )

    result = policy_client.forward_backward_custom(
        [chosen_datum, rejected_datum], loss_fn,
    )
    accum_count += 1

    if accum_count >= grad_accum:
        policy_client.optim_step(
            tinker.AdamParams(learning_rate=1e-5, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01)
        )
        step += 1
        accum_count = 0
        print(f"Step {step}: {result.metrics}")

Step 6: Save final checkpoint

weight_syncer.save_and_hotload(f"dpo-final-step-{step}")

Cookbook recipe entrypoint

Use the current cookbook recipe API (Config + main) for runnable DPO loops:
from training.recipes.dpo_loop import Config, main
from training.utils import DeployConfig, WandBConfig

cfg = Config(
    log_path="./dpo_logs",
    base_model="accounts/fireworks/models/qwen3-8b",
    dataset="/path/to/preference_data.jsonl",
    tokenizer_model="Qwen/Qwen3-8B",
    beta=0.1,
    epochs=1,
    grad_accum=4,
    max_seq_len=4096,
    deployment=DeployConfig(
        deployment_id="dpo-serving",
    ),
    wandb=WandBConfig(entity="my-team", project="dpo-experiment"),
)

main(cfg)

Config fields reference

FieldTypeDefaultDescription
log_pathstr— (required)Directory for checkpoints.jsonl and logs
base_modelstr"accounts/fireworks/models/qwen3-8b"Base model name
datasetstr""Path or URL to preference JSONL
tokenizer_modelstr""HuggingFace model name for client-side tokenization
betafloat0.1DPO beta (controls preference sharpness)
learning_ratefloat1e-5Learning rate
epochsint1Training epochs
grad_accumint4Gradient accumulation steps
max_seq_lenint | NoneNoneMax sequence length (auto-populated from training shape if not set)
max_pairsint | NoneNoneMax preference pairs to use from dataset
lora_rankint00 for full-parameter, >0 for LoRA
ref_cache_concurrencyint16Max concurrent reference forward passes during cache warm-up
init_from_checkpointstr | NoneNoneLoad pretrained DCP weights (supports "job_id:checkpoint_name")
Sub-configs: infra (InfraConfig), deployment (DeployConfig), hotload (HotloadConfig, defaults to hot_load_interval=0), wandb (WandBConfig). Note: DPO defaults hotload.hot_load_interval=0 (no hotloading by default), unlike GRPO which enables it by default. If deployment_id is set, the existing deployment is used; otherwise a new one is auto-created by setup_deployment.

Operational guidance

  • DPO uses 2 RLOR jobs — a policy trainer and a frozen reference trainer (with --forward-only). Reference logprobs are cached at init.
  • Service mode supports both full-parameter and LoRA tuning.
  • Keep a versioned reference cache tied to tokenizer + base model revision. If the base model changes, recompute reference logprobs.
  • Monitor margin statistics: increasing margins indicate the policy is learning the preference signal. Flat or decreasing margins suggest issues.
  • Resume is handled by checkpoint_utils.resolve_resume() — it reads checkpoints.jsonl from log_path and restores the last saved state automatically on startup.
  • Use ReconnectableClient behavior in the recipe to tolerate transient trainer preemption.

Common pitfalls

  • Mismatched formatting between chosen/rejected sequences corrupts preference signals — ensure identical prompt prefixes.
  • Stale reference cache: If you warm-start from a different model, the cached reference logprobs are invalid.
  • Forgetting to refresh evaluation prompts can overfit to stale checks.