Skip to main content

What this is

This is the default lifecycle for research loops: bootstrap a trainer and deployment, run iterative updates, export checkpoints, sync weights to the deployment, then sample through it for realistic evaluation.

Workflow

  1. Create resources: a deployment (DeploymentManager) and a service-mode trainer (TrainerJobManager).
  2. Connect a training client from your Python loop.
  3. Run train steps: forward_backward_custom + optim_step in a loop.
  4. Save checkpoints at regular intervals using base/delta pattern.
  5. Weight-sync the checkpoint to your serving deployment.
  6. Sample and evaluate through the deployment endpoint.
  7. Record metrics and decide whether to continue or branch experiments.

End-to-end example

The only training-shape input you choose below is the shape ID. The SDK resolves the versioned reference for you before launch.

1. Bootstrap

import os
import tinker
from fireworks.training.sdk import (
    FiretitanServiceClient,
    TrainerJobManager,
    TrainerJobConfig,
    DeploymentManager,
    DeploymentConfig,
    WeightSyncer,
)

api_key = os.environ["FIREWORKS_API_KEY"]
account_id = os.environ.get("FIREWORKS_ACCOUNT_ID", "")
base_url = os.environ.get("FIREWORKS_BASE_URL", "https://api.fireworks.ai")
shape_id = "accounts/fireworks/trainingShapes/qwen3-8b-128k-h200"

rlor_mgr = TrainerJobManager(api_key=api_key, account_id=account_id, base_url=base_url)
deploy_mgr = DeploymentManager(api_key=api_key, account_id=account_id, base_url=base_url)

# Create deployment for sampling/hotload
deploy_mgr.create_or_get(DeploymentConfig(
    deployment_id="research-serving",
    base_model="accounts/fireworks/models/qwen3-8b",
    min_replica_count=0,
    max_replica_count=1,
))
deploy_mgr.wait_for_ready("research-serving")

# This is the only shape-specific value you choose
profile = rlor_mgr.resolve_training_profile(shape_id)

# Create trainer (polls until healthy)
endpoint = rlor_mgr.create_and_wait(TrainerJobConfig(
    base_model="accounts/fireworks/models/qwen3-8b",
    training_shape_ref=profile.training_shape_version,
    lora_rank=0,
    learning_rate=1e-5,
    gradient_accumulation_steps=4,
    hot_load_deployment_id="research-serving",
))

# Connect client (FiretitanServiceClient provides checkpoint_type + session ID)
service = FiretitanServiceClient(base_url=endpoint.base_url, api_key=api_key)
training_client = service.create_training_client(
    base_model="accounts/fireworks/models/qwen3-8b", lora_rank=0,
)

2. Train step with custom objective

def objective(data, logprobs_list):
    loss = compute_objective(data=data, logprobs_list=logprobs_list)
    return loss, {"loss": float(loss.item())}

for step in range(total_steps):
    batch = build_batch(step)
    training_client.forward_backward_custom(batch, objective).result()
    training_client.optim_step(
        tinker.AdamParams(learning_rate=1e-5, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01)
    ).result()

3. Checkpoint, weight sync, and evaluate

import asyncio

from transformers import AutoTokenizer
from fireworks.training.sdk import DeploymentSampler

# Set up WeightSyncer for automatic delta-chain management
tracker = WeightSyncer(
    policy_client=training_client,
    deploy_mgr=deploy_mgr,
    deployment_id="research-serving",
    base_model="accounts/fireworks/models/qwen3-8b",
    hotload_timeout=600,
    first_checkpoint_type="base",
)

if step % eval_interval == 0:
    # WeightSyncer auto-selects base (first) or delta (subsequent)
    tracker.save_and_hotload(f"step_{step:05d}")

    # Sample via deployment for evaluation
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True)
    sampler = DeploymentSampler(
        inference_url=deploy_mgr.inference_url,
        model=f"accounts/{account_id}/deployments/research-serving",
        api_key=api_key,
        tokenizer=tokenizer,
    )
    completions = asyncio.run(
        sampler.sample_with_tokens(messages=eval_prompts, n=1)
    )
    score = evaluate_responses(completions)
    print({"step": step, "eval_score": score})

Operational guidance

  • Service mode supports both full-parameter and LoRA tuning. Set lora_rank=0 for full-parameter or a positive integer (e.g. 16, 64) for LoRA, and match create_training_client(lora_rank=...) accordingly.
  • Use checkpoint_type="base" for the first checkpoint, then "delta" for subsequent ones to reduce save/transfer time.
  • DeploymentSampler.sample_with_tokens() is async — use await in async code or asyncio.run(...) from synchronous scripts.
  • Keep checkpoint intervals predictable so evaluation comparisons are stable.
  • Store the exact prompt set used for each evaluation sweep for reproducibility.

Common pitfalls

  • Sampling from trainer internals instead of deployment endpoints can skew results — always evaluate through the serving path.
  • Missing checkpoint-to-deployment traceability makes rollback risky — log checkpoint names alongside metrics.
  • Stale deployments: Always verify the weight-synced checkpoint identity matches what you expect before sampling.