Skip to main content

What this is

This is the default lifecycle for research loops: bootstrap a trainer, run iterative updates, export checkpoints, then sample through deployment endpoints for realistic evaluation.

Key APIs

APIPurpose
FiretitanServiceClientConnect your local loop to the trainer service
TrainingClient.forward_backward_customCompute gradients with your custom objective
TrainingClient.forwardForward-only pass (e.g. for reference logprobs)
TrainingClient.optim_stepApply optimization update
FiretitanTrainingClient.save_weights_for_sampler_extExport serving-compatible checkpoint (returns SaveSamplerResult)
DeploymentManager + WeightSyncerHotload checkpoints and track base/delta state
DeploymentSamplerClient-side tokenized sampling from deployment endpoints

Workflow

  1. Create resources: a deployment (DeploymentManager) and a service-mode trainer (TrainerJobManager).
  2. Connect a Tinker training client from your Python loop.
  3. Run train steps: forward_backward_custom + optim_step in a loop.
  4. Save checkpoints at regular intervals using base/delta pattern.
  5. Hotload the checkpoint onto your serving deployment.
  6. Sample and evaluate through the deployment endpoint (typically via DeploymentSampler).
  7. Record metrics and decide whether to continue or branch experiments.

End-to-end example

1. Bootstrap

import os
import tinker
from fireworks.training.sdk import (
    FiretitanServiceClient,
    TrainerJobManager,
    TrainerJobConfig,
    DeploymentManager,
    DeploymentConfig,
    WeightSyncer,
)

api_key = os.environ["FIREWORKS_API_KEY"]
account_id = os.environ.get("FIREWORKS_ACCOUNT_ID", "")
base_url = os.environ.get("FIREWORKS_BASE_URL", "https://api.fireworks.ai")

rlor_mgr = TrainerJobManager(api_key=api_key, account_id=account_id, base_url=base_url)
deploy_mgr = DeploymentManager(api_key=api_key, account_id=account_id, base_url=base_url)

# Create deployment for sampling/hotload
deploy_mgr.create_or_get(DeploymentConfig(
    deployment_id="research-serving",
    base_model="accounts/fireworks/models/qwen3-8b",
    min_replica_count=0,
    max_replica_count=1,
))
deploy_mgr.wait_for_ready("research-serving")

# Create trainer (polls until healthy)
endpoint = rlor_mgr.create_and_wait(TrainerJobConfig(
    base_model="accounts/fireworks/models/qwen3-8b",
    lora_rank=0,
    max_context_length=4096,
    learning_rate=1e-5,
    gradient_accumulation_steps=4,
    hot_load_deployment_id="research-serving",
))

# Connect client (FiretitanServiceClient provides checkpoint_type + session ID)
service = FiretitanServiceClient(base_url=endpoint.base_url, api_key=api_key)
training_client = service.create_training_client(
    base_model="accounts/fireworks/models/qwen3-8b", lora_rank=0,
)

2. Train step with custom objective

def objective(data, logprobs_list):
    loss = compute_objective(data=data, logprobs_list=logprobs_list)
    return loss, {"loss": float(loss.item())}

for step in range(total_steps):
    batch = build_batch(step)
    training_client.forward_backward_custom(batch, objective).result()
    training_client.optim_step(
        tinker.AdamParams(learning_rate=1e-5, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01)
    ).result()

3. Checkpoint, hotload, and evaluate

from transformers import AutoTokenizer
from fireworks.training.sdk import DeploymentSampler

# Set up WeightSyncer for automatic delta-chain management
tracker = WeightSyncer(
    policy_client=training_client,
    deploy_mgr=deploy_mgr,       # DeploymentManager instance
    deployment_id="research-serving",
    base_model="accounts/fireworks/models/qwen3-8b",
    hotload_timeout=600,
    first_checkpoint_type="base",
)

if step % eval_interval == 0:
    # WeightSyncer auto-selects base (first) or delta (subsequent)
    tracker.save_and_hotload(f"step_{step:05d}")

    # Sample via deployment for evaluation
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True)
    sampler = DeploymentSampler(
        inference_url=deploy_mgr.inference_url,
        model=f"accounts/{account_id}/deployments/research-serving",
        api_key=api_key,
        tokenizer=tokenizer,  # HuggingFace AutoTokenizer
    )
    completions = sampler.sample_with_tokens(messages=eval_prompts, n=1)
    score = evaluate_responses(completions)
    print({"step": step, "eval_score": score})

Sampling with token IDs (for training)

For training scripts that need token IDs and logprobs (e.g. GRPO, DPO), use DeploymentSampler which handles client-side tokenization via a HuggingFace tokenizer and returns structured SampledCompletion objects:
from transformers import AutoTokenizer
from fireworks.training.sdk import DeploymentSampler

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True)
sampler = DeploymentSampler(
    inference_url="https://api.fireworks.ai",
    model=f"accounts/{account_id}/deployments/{deployment_id}",
    api_key=api_key,
    tokenizer=tokenizer,
)

completions = sampler.sample_with_tokens(
    messages=[{"role": "user", "content": "Solve: 2+2="}],
    n=4,
    max_tokens=1024,
    temperature=0.7,
)
for c in completions:
    print(c.full_tokens)       # prompt + completion token IDs
    print(c.prompt_len)        # number of prompt tokens
    print(c.completion_len)    # number of completion tokens
    print(c.text)              # decoded completion text
    print(c.finish_reason)     # "stop", "length", etc.

SampledCompletion fields

FieldTypeDescription
textstrDecoded completion text
full_tokensList[int]Prompt + completion token IDs
prompt_lenintNumber of prompt tokens
finish_reasonstr"stop", "length", etc.
completion_lenintNumber of completion tokens
inference_logprobsList[float] | NonePer-token logprobs (when logprobs=True is passed)
logprobs_echoedboolTrue when echo=True was used — logprobs are training-aligned (P+C-1 entries)
routing_matricesList[str] | NoneBase64-encoded per-token routing matrices for MoE Router Replay (R3)
To retrieve inference logprobs (needed for GRPO importance sampling), pass logprobs=True:
completions = sampler.sample_with_tokens(
    messages=[{"role": "user", "content": "Solve: 2+2="}],
    n=4,
    logprobs=True,
    top_logprobs=1,
)
for c in completions:
    print(c.inference_logprobs)  # List[float] or None

Sequence length filtering

sample_with_tokens supports max_seq_len for automatic filtering:
completions = sampler.sample_with_tokens(
    messages=input_messages,
    n=4,
    max_tokens=1024,
    max_seq_len=8192,  # filter out sequences exceeding this length
)
Two levels of filtering are applied:
  1. Prompt pre-filter: If the tokenized prompt already meets or exceeds max_seq_len, the method returns an empty list immediately — no inference call is made.
  2. Completion post-filter: After sampling, any completion whose full token sequence (prompt + completion) exceeds max_seq_len is silently dropped.

Operational guidance

  • Service mode supports both full-parameter and LoRA tuning. Set lora_rank=0 for full-parameter or a positive integer (e.g. 16, 64) for LoRA, and match create_training_client(lora_rank=...) accordingly.
  • Use checkpoint_type="base" for the first checkpoint, then "delta" for subsequent ones to reduce save/transfer time.
  • Keep checkpoint intervals predictable so evaluation comparisons are stable.
  • Store the exact prompt set used for each evaluation sweep for reproducibility.

Common pitfalls

  • Sampling from trainer internals instead of deployment endpoints can skew results — always evaluate through the serving path.
  • Missing checkpoint-to-deployment traceability makes rollback risky — log checkpoint names alongside metrics.
  • Stale deployments: Always verify the hotloaded checkpoint identity matches what you expect before sampling.