Skip to main content

Installation

Install the Fireworks Python package with training extensions:
pip install --pre "fireworks-ai[training]"
Set your credentials:
export FIREWORKS_API_KEY="your-api-key"
If you want ready-to-run recipes instead of writing a loop from scratch, see The Cookbook for config-driven GRPO, DPO, and SFT training.

Your first training loop

This quickstart walks through a minimal SFT loop from scratch using only the API. For trainer launch, the only shape-specific input you provide is the training shape ID. In most cases, use the full shared path accounts/fireworks/trainingShapes/<shape>. The fireworks account is the public shared shape catalog. The SDK-managed service client resolves the pinned version, creates or reattaches the trainer, and returns a Tinker-compatible training client.

Step 1: Create the managed service

import os
from fireworks.training.sdk import FiretitanServiceClient

api_key = os.environ["FIREWORKS_API_KEY"]
base_url = os.environ.get("FIREWORKS_BASE_URL", "https://api.fireworks.ai")

base_model = "accounts/fireworks/models/qwen3-8b"
shape_id = "accounts/fireworks/trainingShapes/qwen3-8b-128k-h200"

service = FiretitanServiceClient.from_firetitan_config(
    api_key=api_key,
    base_url=base_url,
    base_model=base_model,
    tokenizer_model="Qwen/Qwen3-8B",
    lora_rank=0,
    training_shape_id=shape_id,
    learning_rate=1e-5,
    create_deployment=False,
    cleanup_trainer_on_close=True,
)

Step 2: Create the training client

training_client = service.create_training_client(
    base_model=base_model,
    lora_rank=0,
)
print(f"Trainer job: {service.trainer_job_id}")

Step 3: Build training data

Each training example is a Datum — a tokenized sequence with per-token weights indicating which tokens to train on.
import tinker
import torch
import transformers
from tinker_cookbook.supervised.common import datum_from_model_input_weights

tokenizer = transformers.AutoTokenizer.from_pretrained(
    "Qwen/Qwen3-8B", trust_remote_code=True,
)

conversation = [
    {"role": "user", "content": "What is the capital of France?"},
    {"role": "assistant", "content": "The capital of France is Paris."},
]

full_text = tokenizer.apply_chat_template(conversation, tokenize=False)
full_tokens = tokenizer.encode(full_text)

prompt_only = tokenizer.apply_chat_template(conversation[:1], tokenize=False)
prompt_len = len(tokenizer.encode(prompt_only))

weights = torch.zeros(len(full_tokens), dtype=torch.float32)
weights[prompt_len:] = 1.0

datum = datum_from_model_input_weights(
    tinker.ModelInput.from_ints(full_tokens),
    weights,
    max_length=4096,
)

Step 4: Write a loss function and train

import tinker

def sft_loss(data, logprobs_list):
    total_loss = torch.tensor(0.0)
    n_tokens = 0
    for i, logprobs in enumerate(logprobs_list):
        weights = torch.tensor(
            data[i].loss_fn_inputs["weights"].data, dtype=torch.float32,
        )
        min_len = min(len(logprobs), len(weights))
        total_loss = total_loss - torch.dot(
            logprobs[:min_len].float(), weights[:min_len],
        )
        n_tokens += weights[:min_len].sum().item()
    loss = total_loss / max(n_tokens, 1)
    return loss, {"sft_loss": loss.item(), "n_tokens": n_tokens}

batch = [datum]
for step in range(10):
    result = training_client.forward_backward_custom(batch, sft_loss).result()
    training_client.optim_step(
        tinker.AdamParams(learning_rate=1e-5, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01)
    ).result()
    print(f"Step {step}: {result.metrics}")

Step 5: Save and promote

saved = training_client.save_weights_for_sampler(
    "sft-final",
    checkpoint_type="base",
).result()
print(f"Checkpoint saved: {saved.path}")

# Promote the checkpoint to a deployable Fireworks model. `list_checkpoints`
# returns the full 4-segment checkpoint resource name that promotion expects.
entry = next(
    row for row in service.list_checkpoints(service.trainer_job_id)
    if row["name"].endswith(f"/checkpoints/{saved.path}")
)
model = service.promote_checkpoint(
    name=entry["name"],
    output_model_id="my-sft-model",
    base_model=base_model,
)

service.close()
For production scripts, wrap service.close() in try/finally so SDK-managed trainers are cleaned up on exit — including on exceptions. See Cleanup and Teardown.

Next steps