Skip to main content
This guide walks through the simplest possible custom training loop: supervised fine-tuning (SFT) with a cross-entropy loss. By the end, you’ll understand the full workflow and be ready for more advanced objectives like GRPO or DPO.
New to the training SDK? Read Core Concepts first for background on how the architecture works.

Prerequisites

pip install fireworks-ai tinker-sdk tinker-cookbook transformers torch
Set your API key:
export FIREWORKS_API_KEY="your-api-key"

Step 1: Provision a trainer

Create a service-mode RLOR trainer job. This allocates GPUs and loads the base model.
import os
from fireworks.training.sdk import (
    FiretitanServiceClient,
    TrainerJobManager,
    TrainerJobConfig,
)

api_key = os.environ["FIREWORKS_API_KEY"]
account_id = os.environ.get("FIREWORKS_ACCOUNT_ID", "")
base_url = os.environ.get("FIREWORKS_BASE_URL", "https://api.fireworks.ai")

base_model = "accounts/fireworks/models/qwen3-8b"

# TrainerJobManager handles job creation and polling
rlor_mgr = TrainerJobManager(api_key=api_key, account_id=account_id, base_url=base_url)

# create_and_wait blocks until the trainer is healthy
endpoint = rlor_mgr.create_and_wait(TrainerJobConfig(
    base_model=base_model,
    lora_rank=0,             # 0 = full-parameter tuning
    max_context_length=4096,
    learning_rate=1e-5,
    gradient_accumulation_steps=1,
    display_name="sft-quickstart",
))
print(f"Trainer ready at: {endpoint.base_url}")

Step 2: Connect the training client

FiretitanServiceClient connects your local Python process to the remote trainer.
service = FiretitanServiceClient(
    base_url=endpoint.base_url,
    api_key=api_key,
)
training_client = service.create_training_client(
    base_model=base_model,
    lora_rank=0,  # Must match the trainer job's lora_rank
)

Step 3: Build training data

Each training example is wrapped in a Datum — a tokenized sequence with per-token weights that tell the loss function which tokens to train on.
import torch
import transformers
from tinker_cookbook.supervised.common import datum_from_tokens_weights

tokenizer = transformers.AutoTokenizer.from_pretrained(
    "Qwen/Qwen3-8B", trust_remote_code=True,
)

# Example: a simple prompt → response pair
conversation = [
    {"role": "user", "content": "What is the capital of France?"},
    {"role": "assistant", "content": "The capital of France is Paris."},
]

# Tokenize the full conversation
full_text = tokenizer.apply_chat_template(conversation, tokenize=False)
full_tokens = tokenizer.encode(full_text)

# Tokenize just the prompt to find where the response starts
prompt_only = tokenizer.apply_chat_template(conversation[:1], tokenize=False)
prompt_tokens = tokenizer.encode(prompt_only)
prompt_len = len(prompt_tokens)

# Build token weights: 0 for prompt, 1 for response
weights = torch.zeros(len(full_tokens), dtype=torch.float32)
weights[prompt_len:] = 1.0

datum = datum_from_tokens_weights(
    torch.tensor(full_tokens, dtype=torch.long),
    weights,
    max_length=4096,
)

Step 4: Write a loss function

The loss function receives the datums and per-token logprobs (as autograd tensors from the GPU). For SFT, we compute negative log-likelihood over response tokens.
def sft_loss(data, logprobs_list):
    total_loss = torch.tensor(0.0)
    n_tokens = 0

    for i, logprobs in enumerate(logprobs_list):
        # Token weights: 0 = prompt (ignore), 1 = response (train)
        weights = torch.tensor(
            data[i].loss_fn_inputs["weights"].data, dtype=torch.float32,
        )
        min_len = min(len(logprobs), len(weights))

        # Weighted sum of log-probs over response tokens
        # Negative because we maximize log-probability (minimize NLL)
        total_loss = total_loss - torch.dot(
            logprobs[:min_len].float(), weights[:min_len],
        )
        n_tokens += weights[:min_len].sum().item()

    loss = total_loss / max(n_tokens, 1)
    return loss, {"sft_loss": loss.item(), "n_tokens": n_tokens}
Key points:
  • logprobs_list[i] has requires_grad=True — your loss must be differentiable through it
  • Use torch.dot() for weighted sums — it correctly propagates gradients
  • Return (scalar_loss, metrics_dict)

Step 5: Train

import tinker

batch = [datum]  # In practice, use multiple datums per batch

# Forward + backward: sends data to GPU, gets logprobs, runs your loss, sends gradients back
result = training_client.forward_backward_custom(batch, sft_loss).result()
print(f"Loss: {result.metrics}")

# Apply the weight update
training_client.optim_step(
    tinker.AdamParams(
        learning_rate=1e-5,
        beta1=0.9,
        beta2=0.999,
        eps=1e-8,
        weight_decay=0.01,
    )
).result()
That’s one training step. A real loop iterates over your dataset:
for step, batch in enumerate(data_loader):
    result = training_client.forward_backward_custom(batch, sft_loss).result()
    training_client.optim_step(
        tinker.AdamParams(learning_rate=1e-5, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01)
    ).result()
    print(f"Step {step}: {result.metrics}")

Step 6: Save a checkpoint

Export the trained weights for serving:
result = training_client.save_weights_for_sampler_ext(
    "sft-final",
    checkpoint_type="base",
)
print(f"Checkpoint saved: {result.snapshot_name}")

Step 7: Clean up

rlor_mgr.delete(endpoint.job_id)

Full script

import os
import torch
import tinker
import transformers
from tinker_cookbook.supervised.common import datum_from_tokens_weights
from fireworks.training.sdk import (
    FiretitanServiceClient,
    TrainerJobManager,
    TrainerJobConfig,
)

api_key = os.environ["FIREWORKS_API_KEY"]
account_id = os.environ.get("FIREWORKS_ACCOUNT_ID", "")
base_url = os.environ.get("FIREWORKS_BASE_URL", "https://api.fireworks.ai")
base_model = "accounts/fireworks/models/qwen3-8b"

rlor_mgr = TrainerJobManager(api_key=api_key, account_id=account_id, base_url=base_url)
endpoint = rlor_mgr.create_and_wait(TrainerJobConfig(
    base_model=base_model,
    lora_rank=0,
    max_context_length=4096,
    learning_rate=1e-5,
    gradient_accumulation_steps=1,
    display_name="sft-quickstart",
))

service = FiretitanServiceClient(base_url=endpoint.base_url, api_key=api_key)
training_client = service.create_training_client(base_model=base_model, lora_rank=0)

tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True)

training_data = [
    [
        {"role": "user", "content": "What is the capital of France?"},
        {"role": "assistant", "content": "The capital of France is Paris."},
    ],
    [
        {"role": "user", "content": "What is 2 + 2?"},
        {"role": "assistant", "content": "2 + 2 equals 4."},
    ],
]

def tokenize_conversation(conversation):
    full_text = tokenizer.apply_chat_template(conversation, tokenize=False)
    full_tokens = tokenizer.encode(full_text)
    prompt_text = tokenizer.apply_chat_template(conversation[:1], tokenize=False)
    prompt_len = len(tokenizer.encode(prompt_text))
    weights = torch.zeros(len(full_tokens), dtype=torch.float32)
    weights[prompt_len:] = 1.0
    return datum_from_tokens_weights(
        torch.tensor(full_tokens, dtype=torch.long), weights, max_length=4096,
    )

def sft_loss(data, logprobs_list):
    total_loss = torch.tensor(0.0)
    n_tokens = 0
    for i, logprobs in enumerate(logprobs_list):
        weights = torch.tensor(data[i].loss_fn_inputs["weights"].data, dtype=torch.float32)
        min_len = min(len(logprobs), len(weights))
        total_loss = total_loss - torch.dot(logprobs[:min_len].float(), weights[:min_len])
        n_tokens += weights[:min_len].sum().item()
    loss = total_loss / max(n_tokens, 1)
    return loss, {"sft_loss": loss.item()}

datums = [tokenize_conversation(conv) for conv in training_data]
for step in range(10):
    result = training_client.forward_backward_custom(datums, sft_loss).result()
    training_client.optim_step(
        tinker.AdamParams(learning_rate=1e-5, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01)
    ).result()
    print(f"Step {step}: {result.metrics}")

result = training_client.save_weights_for_sampler_ext("sft-final", checkpoint_type="base")
print(f"Checkpoint: {result.snapshot_name}")

rlor_mgr.delete(endpoint.job_id)

Next steps

  • Core Concepts — deeper explanation of the architecture and abstractions
  • Custom Train Step — detailed API for forward_backward_custom, datum construction, and gradient accumulation
  • GRPO Example — on-policy and off-policy reinforcement learning
  • DPO Example — preference optimization with pairwise data
  • Checkpointing and Hotload — base/delta checkpoints and live deployment updates