Overview
FiretitanServiceClient is the recommended direct SDK entry point. In the managed path, it creates or reattaches the FireTitan trainer, optional reference trainer, and optional inference deployment, then returns Tinker-compatible training and sampling clients.
For most direct SDK code, create it with FiretitanServiceClient.from_firetitan_config(...). The bare constructor is still useful when you already have a trainer endpoint URL, but that is a lower-level compatibility path.
from fireworks.training.sdk import FiretitanServiceClient, GradAccNormalization
FiretitanServiceClient
from_firetitan_config(...)
Create a lazy SDK-managed service. The trainer and deployment are provisioned on the first client call, usually create_training_client(...):
service = FiretitanServiceClient.from_firetitan_config(
api_key="<FIREWORKS_API_KEY>",
base_url="https://api.fireworks.ai",
base_model="accounts/fireworks/models/qwen3-8b",
tokenizer_model="Qwen/Qwen3-8B",
lora_rank=0,
training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200",
deployment_id="research-serving", # set create_deployment=False for trainer-only flows
learning_rate=1e-5,
replica_count=1, # deployment replicas
cleanup_trainer_on_close=True,
cleanup_deployment_on_close="scale_to_zero",
)
training_client = service.create_training_client(
base_model="accounts/fireworks/models/qwen3-8b",
lora_rank=0,
)
Core managed config fields:
| Field | Type | Default | Description |
|---|
api_key | str | None | FIREWORKS_API_KEY | Fireworks API key. |
base_url | str | None | https://api.fireworks.ai | Control-plane URL. |
inference_url | str | None | None | Optional inference gateway URL. |
base_model | str | — | Fireworks base model resource name. |
tokenizer_model | str | None | None | HuggingFace tokenizer name used by get_tokenizer() and sampler setup. |
lora_rank | int | 0 | 0 for full-parameter training; positive value for LoRA. |
training_shape_id | str | None | None | User-facing training shape ID. The SDK resolves the pinned version. |
reference_training_shape_id | str | None | None | Optional separate forward-only reference trainer shape. |
trainer_job_id | str | None | None | Reattach to an existing trainer instead of creating one. |
reference_trainer_job_id | str | None | None | Reattach to an existing reference trainer. |
create_deployment | bool | True | Whether to create or reattach an inference deployment. Set False for trainer-only SFT/DPO-style loops. |
deployment_id | str | None | None | Create or reattach an inference deployment for sampling and weight sync. |
deployment_shape | str | None | Linked shape | Optional deployment shape override. Usually inherited from the training shape. |
trainer_replica_count | int | None | None | Data-parallel HSDP replicas for the trainer. |
replica_count | int | 1 | Inference deployment replicas. |
cleanup_trainer_on_close | bool | False | Delete the SDK-managed policy trainer when service.close() runs. |
cleanup_reference_trainer_on_close | bool | True | Delete SDK-managed separate reference trainers when released/closed. |
cleanup_deployment_on_close | "scale_to_zero" | "delete" | None | None | Optional deployment cleanup action on close. |
The managed service exposes resolved metadata after provisioning:
print(service.trainer_job_id)
print(service.deployment_id)
print(service.max_context_length)
print(service.reference_trainer_job_id) # None when the reference is shared
Bare constructor
service = FiretitanServiceClient(
base_url=endpoint.base_url, # From TrainerJobManager.create_and_wait(...)
api_key="<FIREWORKS_API_KEY>",
)
base_url is the trainer endpoint URL from TrainerServiceEndpoint.base_url. Use this only when you intentionally manage trainer lifecycle yourself. New user code should use from_firetitan_config(...).
Creates a FiretitanTrainingClient for training operations:
training_client = service.create_training_client(
base_model="accounts/fireworks/models/qwen3-8b",
lora_rank=0, # Must match lora_rank from job creation
)
| Parameter | Type | Default | Description |
|---|
base_model | str | — | Must match the trainer job’s base_model |
lora_rank | int | 0 | Must match trainer creation config (0 for full-parameter) |
user_metadata | dict[str, str] | None | None | Optional run metadata |
A ValueError is raised if you attempt to create a second training client with the same (base_model, lora_rank) on the same FiretitanServiceClient instance. Create a new FiretitanServiceClient for a separate trainer.
Connecting to an existing trainer
If you already have a running trainer (e.g. from a previous session), connect directly by URL:
service = FiretitanServiceClient(
base_url="https://<existing-trainer-url>",
api_key="<FIREWORKS_API_KEY>",
)
training_client = service.create_training_client(
base_model="accounts/fireworks/models/qwen3-8b",
lora_rank=0,
)
Creates a base-only client on the same trainer session. Use this as a frozen reference for LoRA KL/reference logprobs without launching a separate forward-only trainer:
reference_client = service.create_base_training_client(base_model=base_model)
ref = reference_client.forward(datums, "cross_entropy").result()
Do not call forward_backward, forward_backward_custom, or optim_step on this client; it is for reference forward passes only.
Create a frozen reference client for KL/DPO baseline logprobs:
reference_client = service.create_reference_client(base_model, lora_rank=0)
ref = reference_client.forward(datums, "cross_entropy").result()
The SDK chooses the backing automatically. LoRA policies without an explicit reference shape reuse the policy trainer with the adapter disabled. Full-parameter policies, explicit reference_training_shape_id, or explicit reference_trainer_job_id use a separate forward-only reference trainer owned by the service.
create_sampling_client(model_path=None, ...)
Return a Tinker-shaped sampling client backed by the SDK-managed deployment. When model_path is provided, the SDK first syncs that sampler snapshot to the deployment:
saved = training_client.save_weights_for_sampler("step-100").result()
sampler = service.create_sampling_client(model_path=saved.path)
This is the replacement for calling a standalone weight-sync helper in user code. The SDK tracks the base/delta chain and builds the weight-sync metadata internally.
create_deployment_sampler(model_path=None, tokenizer=None, concurrency_controller=None)
Return the FireTitan-native DeploymentSampler directly. Use this when you need tokenized completions, inference logprobs, routing matrices, or adaptive concurrency:
sampler = service.create_deployment_sampler(
model_path=saved.path,
tokenizer=tokenizer,
concurrency_controller=controller,
)
hotload_sampler_snapshot(model_path)
Low-level method for syncing a previously saved sampler snapshot into the SDK-managed deployment without constructing a sampler:
service.hotload_sampler_snapshot(saved.path)
FiretitanTrainingClient
The training client returned by create_training_client(). Core training RPCs like forward(...), forward_backward_custom(...), optim_step(...), save_state(...), and load_state_with_optimizer(...) return futures. Fireworks convenience helpers like save_weights_for_sampler_ext(...), list_checkpoints(), and resolve_checkpoint_path(...) return concrete values directly.
forward(datums, loss_type)
Forward-only pass (no gradient computation). Useful for computing reference logprobs in GRPO/DPO:
result = training_client.forward(datums, "cross_entropy").result()
logprobs = result.loss_fn_outputs[0]["logprobs"].data
Built-in loss types like "cross_entropy" require datums with target_tokens in loss_fn_inputs. Datums built with datum_from_model_input_weights will fail. Use the target-token tinker.Datum example in Loss Functions for built-in losses, or use forward_backward_custom with the weight-based format in Building datums and the custom-loss pattern in Example: simple cross-entropy.
forward_backward_custom(datums, loss_fn)
Forward + backward with your custom loss function. See Loss Functions for details:
def my_loss(data, logprobs_list):
loss = compute_loss(data, logprobs_list)
return loss, {"loss": float(loss.item())}
result = training_client.forward_backward_custom(datums, my_loss).result()
print(result.metrics) # {"loss": 0.42}
For embedding-space objectives, pass output="embedding" and choose pooling="mean" or "last"; your loss function then receives pooled embedding tensors instead of logprobs:
result = training_client.forward_backward_custom(
datums,
embedding_loss,
output="embedding",
pooling="mean",
).result()
optim_step(adam_params, grad_accumulation_normalization=None)
Apply optimizer update after accumulating gradients:
import tinker
training_client.optim_step(
tinker.AdamParams(
learning_rate=1e-5,
beta1=0.9,
beta2=0.999,
eps=1e-8,
weight_decay=0.01,
)
).result()
Supports grad_accumulation_normalization for controlling how accumulated gradients are normalized. Pass GradAccNormalization.NUM_LOSS_TOKENS, GradAccNormalization.NUM_SEQUENCES, or GradAccNormalization.NONE rather than raw strings. See Loss Functions for when to use each mode.
save_weights_for_sampler(name, ttl_seconds=None, checkpoint_type=None)
Save serving-compatible sampler weights and return a future. This is the normal Tinker-shaped API:
saved = training_client.save_weights_for_sampler(
"step-100",
checkpoint_type="base", # optional: "base" or "delta"
).result()
print(saved.path) # Snapshot identity for create_sampling_client(model_path=...)
Full-parameter training saves a base checkpoint first and deltas after that by default. LoRA training always saves base checkpoints. The returned path is a public snapshot identity, not a raw storage URI.
save_weights_for_sampler_ext(name, checkpoint_type, ttl_seconds)
Fireworks-specific extension that returns a concrete SaveSamplerResult instead of a future:
result = training_client.save_weights_for_sampler_ext(
"step-100",
checkpoint_type="base", # "base" for full weights, "delta" for incremental
)
print(result.snapshot_name) # Session-qualified name for weight sync
| Parameter | Type | Default | Description |
|---|
name | str | — | Checkpoint name (auto-suffixed with session ID) |
checkpoint_type | str | None | None | "base" for full weights, "delta" for incremental |
ttl_seconds | int | None | None | Auto-delete checkpoint after this many seconds |
On full-parameter training, only checkpoint_type="base" produces a promotable blob; "delta" cannot be promoted. LoRA is always promotable. See Checkpoint kinds for the full promotability matrix.
save_weights_for_sampler_ext saves the snapshot only; it does not mutate a deployment. To serve the snapshot, pass result.snapshot_name to the managed service weight-sync path, or use create_sampling_client(model_path=...) / create_deployment_sampler(model_path=...), which sync and return a sampler.
save_state(name, ttl_seconds=None, timeout=None)
Save full train state (weights + optimizer) for resume:
training_client.save_state("train_state_step_100").result()
| Parameter | Type | Default | Description |
|---|
name | str | — | Checkpoint name |
ttl_seconds | int | None | None | Auto-delete checkpoint after this many seconds |
timeout | float | None | None | If set, block until the save completes or the timeout expires |
load_state_with_optimizer(name)
Restore full train state (weights + optimizer) from a checkpoint:
training_client.load_state_with_optimizer("train_state_step_100").result()
load_state(name)
Load model weights from a checkpoint without restoring optimizer state. The optimizer is reset so the next optim_step starts fresh:
training_client.load_state("train_state_step_100").result()
load_adapter(adapter_path)
Load Hugging Face PEFT adapter weights into the current LoRA session. This is a weights-only warm start; it does not restore optimizer state, scheduler state, or data cursor.
training_client.load_adapter("gs://my-bucket/adapters/run-42").result()
list_checkpoints()
List available DCP checkpoints from the trainer. Returns a list[str]:
checkpoint_names = training_client.list_checkpoints()
print(checkpoint_names) # e.g. ["step-2", "step-4"]
resolve_checkpoint_path(checkpoint_name, source_job_id)
Resolve a checkpoint path for cross-job resume:
checkpoint_ref = training_client.resolve_checkpoint_path(
"step-4",
source_job_id="previous-job-id",
)
training_client.load_state_with_optimizer(checkpoint_ref).result()
SaveSamplerResult
Returned by save_weights_for_sampler_ext:
| Field | Type | Description |
|---|
path | str | Snapshot name from trainer |
snapshot_name | str | Session-qualified name for weight sync operations |
GradAccNormalization
Enum for optim_step’s grad_accumulation_normalization parameter:
| Enum | Wire value | Description |
|---|
GradAccNormalization.NUM_LOSS_TOKENS | "num_loss_tokens" | Normalize by total loss tokens across accumulated micro-batches |
GradAccNormalization.NUM_SEQUENCES | "num_sequences" | Normalize by total sequences across accumulated micro-batches |
GradAccNormalization.NONE | "none" | No normalization (raw gradient sum) |