Overview
FiretitanServiceClient connects to a trainer endpoint and creates a FiretitanTrainingClient for training operations. It extends Tinker’s ServiceClient with Fireworks-specific features like checkpoint_type support and session-scoped snapshot naming.
from fireworks.training.sdk import FiretitanServiceClient
FiretitanServiceClient
Constructor
service = FiretitanServiceClient(
base_url=endpoint.base_url, # From TrainerJobManager.create_and_wait(...)
api_key="<FIREWORKS_API_KEY>",
)
base_url is the trainer endpoint URL from TrainerServiceEndpoint.base_url.
Creates a FiretitanTrainingClient for training operations:
training_client = service.create_training_client(
base_model="accounts/fireworks/models/qwen3-8b",
lora_rank=0, # Must match lora_rank from job creation
)
| Parameter | Type | Default | Description |
|---|
base_model | str | — | Must match the trainer job’s base_model |
lora_rank | int | 0 | Must match trainer creation config (0 for full-parameter) |
user_metadata | dict[str, str] | None | None | Optional run metadata |
A ValueError is raised if you attempt to create a second training client with the same (base_model, lora_rank) on the same FiretitanServiceClient instance. Create a new FiretitanServiceClient for a separate trainer.
Connecting to an existing trainer
If you already have a running trainer (e.g. from a previous session), connect directly by URL:
service = FiretitanServiceClient(
base_url="https://<existing-trainer-url>",
api_key="<FIREWORKS_API_KEY>",
)
training_client = service.create_training_client(
base_model="accounts/fireworks/models/qwen3-8b",
lora_rank=0,
)
FiretitanTrainingClient
The training client returned by create_training_client(). Core training RPCs like forward(...), forward_backward_custom(...), optim_step(...), save_state(...), and load_state_with_optimizer(...) return futures. Fireworks convenience helpers like save_weights_for_sampler_ext(...), list_checkpoints(), and resolve_checkpoint_path(...) return concrete values directly.
forward(datums, loss_type)
Forward-only pass (no gradient computation). Useful for computing reference logprobs in GRPO/DPO:
result = training_client.forward(datums, "cross_entropy").result()
logprobs = result.loss_fn_outputs[0]["logprobs"].data
Built-in loss types like "cross_entropy" require datums with target_tokens in loss_fn_inputs. Datums built with datum_from_tokens_weights will fail. See Loss Functions — Building datums for the correct format, or use forward_backward_custom with weight-based datums instead.
forward_backward_custom(datums, loss_fn)
Forward + backward with your custom loss function. See Loss Functions for details:
def my_loss(data, logprobs_list):
loss = compute_loss(data, logprobs_list)
return loss, {"loss": float(loss.item())}
result = training_client.forward_backward_custom(datums, my_loss).result()
print(result.metrics) # {"loss": 0.42}
optim_step(params)
Apply optimizer update after accumulating gradients:
import tinker
training_client.optim_step(
tinker.AdamParams(
learning_rate=1e-5,
beta1=0.9,
beta2=0.999,
eps=1e-8,
weight_decay=0.01,
)
).result()
Supports grad_accumulation_normalization parameter for controlling how accumulated gradients are normalized. See Cookbook Reference for normalization modes.
save_weights_for_sampler_ext(name, checkpoint_type, ttl_seconds)
Export serving-compatible checkpoint with session-scoped naming. Returns a SaveSamplerResult:
result = training_client.save_weights_for_sampler_ext(
"step-100",
checkpoint_type="base", # "base" for full weights, "delta" for incremental
)
print(result.snapshot_name) # Session-qualified name for hotloading
| Parameter | Type | Default | Description |
|---|
name | str | — | Checkpoint name (auto-suffixed with session ID) |
checkpoint_type | str | None | None | "base" for full weights, "delta" for incremental |
ttl_seconds | int | None | None | Auto-delete checkpoint after this many seconds |
save_state(name, ttl_seconds=None)
Save full train state (weights + optimizer) for resume:
training_client.save_state("train_state_step_100").result()
| Parameter | Type | Default | Description |
|---|
name | str | — | Checkpoint name |
ttl_seconds | int | None | None | Auto-delete checkpoint after this many seconds |
load_state_with_optimizer(name)
Restore train state from a checkpoint:
training_client.load_state_with_optimizer("train_state_step_100").result()
list_checkpoints()
List available DCP checkpoints from the trainer. Returns a list[str]:
checkpoint_names = training_client.list_checkpoints()
print(checkpoint_names) # e.g. ["step-2", "step-4"]
resolve_checkpoint_path(name, source_job_id)
Resolve a checkpoint path for cross-job resume:
checkpoint_ref = training_client.resolve_checkpoint_path(
"step-4",
source_job_id="previous-job-id",
)
training_client.load_state_with_optimizer(checkpoint_ref).result()
SaveSamplerResult
Returned by save_weights_for_sampler_ext:
| Field | Type | Description |
|---|
path | str | Snapshot name from trainer |
snapshot_name | str | Session-qualified name for weight sync operations |
GradAccNormalization
Enum for optim_step’s grad_accumulation_normalization parameter:
| Value | Description |
|---|
"num_loss_tokens" | Normalize by total loss tokens across accumulated micro-batches |
"num_sequences" | Normalize by total sequences across accumulated micro-batches |
"none" | No normalization (raw gradient sum) |