Skip to main content

Overview

FiretitanServiceClient is the recommended direct SDK entry point. In the managed path, it creates or reattaches the FireTitan trainer, optional reference trainer, and optional inference deployment, then returns Tinker-compatible training and sampling clients. For most direct SDK code, create it with FiretitanServiceClient.from_firetitan_config(...). The bare constructor is still useful when you already have a trainer endpoint URL, but that is a lower-level compatibility path.
from fireworks.training.sdk import FiretitanServiceClient, GradAccNormalization

FiretitanServiceClient

from_firetitan_config(...)

Create a lazy SDK-managed service. The trainer and deployment are provisioned on the first client call, usually create_training_client(...):
service = FiretitanServiceClient.from_firetitan_config(
    api_key="<FIREWORKS_API_KEY>",
    base_url="https://api.fireworks.ai",
    base_model="accounts/fireworks/models/qwen3-8b",
    tokenizer_model="Qwen/Qwen3-8B",
    lora_rank=0,
    training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200",
    deployment_id="research-serving",   # set create_deployment=False for trainer-only flows
    learning_rate=1e-5,
    replica_count=1,                     # deployment replicas
    cleanup_trainer_on_close=True,
    cleanup_deployment_on_close="scale_to_zero",
)

training_client = service.create_training_client(
    base_model="accounts/fireworks/models/qwen3-8b",
    lora_rank=0,
)
Core managed config fields:
FieldTypeDefaultDescription
api_keystr | NoneFIREWORKS_API_KEYFireworks API key.
base_urlstr | Nonehttps://api.fireworks.aiControl-plane URL.
inference_urlstr | NoneNoneOptional inference gateway URL.
base_modelstrFireworks base model resource name.
tokenizer_modelstr | NoneNoneHuggingFace tokenizer name used by get_tokenizer() and sampler setup.
lora_rankint00 for full-parameter training; positive value for LoRA.
training_shape_idstr | NoneNoneUser-facing training shape ID. The SDK resolves the pinned version.
reference_training_shape_idstr | NoneNoneOptional separate forward-only reference trainer shape.
trainer_job_idstr | NoneNoneReattach to an existing trainer instead of creating one.
reference_trainer_job_idstr | NoneNoneReattach to an existing reference trainer.
create_deploymentboolTrueWhether to create or reattach an inference deployment. Set False for trainer-only SFT/DPO-style loops.
deployment_idstr | NoneNoneCreate or reattach an inference deployment for sampling and weight sync.
deployment_shapestr | NoneLinked shapeOptional deployment shape override. Usually inherited from the training shape.
trainer_replica_countint | NoneNoneData-parallel HSDP replicas for the trainer.
replica_countint1Inference deployment replicas.
cleanup_trainer_on_closeboolFalseDelete the SDK-managed policy trainer when service.close() runs.
cleanup_reference_trainer_on_closeboolTrueDelete SDK-managed separate reference trainers when released/closed.
cleanup_deployment_on_close"scale_to_zero" | "delete" | NoneNoneOptional deployment cleanup action on close.
The managed service exposes resolved metadata after provisioning:
print(service.trainer_job_id)
print(service.deployment_id)
print(service.max_context_length)
print(service.reference_trainer_job_id)  # None when the reference is shared

Bare constructor

service = FiretitanServiceClient(
    base_url=endpoint.base_url,  # From TrainerJobManager.create_and_wait(...)
    api_key="<FIREWORKS_API_KEY>",
)
base_url is the trainer endpoint URL from TrainerServiceEndpoint.base_url. Use this only when you intentionally manage trainer lifecycle yourself. New user code should use from_firetitan_config(...).

create_training_client(base_model, lora_rank, user_metadata)

Creates a FiretitanTrainingClient for training operations:
training_client = service.create_training_client(
    base_model="accounts/fireworks/models/qwen3-8b",
    lora_rank=0,  # Must match lora_rank from job creation
)
ParameterTypeDefaultDescription
base_modelstrMust match the trainer job’s base_model
lora_rankint0Must match trainer creation config (0 for full-parameter)
user_metadatadict[str, str] | NoneNoneOptional run metadata
A ValueError is raised if you attempt to create a second training client with the same (base_model, lora_rank) on the same FiretitanServiceClient instance. Create a new FiretitanServiceClient for a separate trainer.

Connecting to an existing trainer

If you already have a running trainer (e.g. from a previous session), connect directly by URL:
service = FiretitanServiceClient(
    base_url="https://<existing-trainer-url>",
    api_key="<FIREWORKS_API_KEY>",
)
training_client = service.create_training_client(
    base_model="accounts/fireworks/models/qwen3-8b",
    lora_rank=0,
)

create_base_training_client(base_model, user_metadata=None)

Creates a base-only client on the same trainer session. Use this as a frozen reference for LoRA KL/reference logprobs without launching a separate forward-only trainer:
reference_client = service.create_base_training_client(base_model=base_model)
ref = reference_client.forward(datums, "cross_entropy").result()
Do not call forward_backward, forward_backward_custom, or optim_step on this client; it is for reference forward passes only.

create_reference_client(base_model, lora_rank=0, user_metadata=None)

Create a frozen reference client for KL/DPO baseline logprobs:
reference_client = service.create_reference_client(base_model, lora_rank=0)
ref = reference_client.forward(datums, "cross_entropy").result()
The SDK chooses the backing automatically. LoRA policies without an explicit reference shape reuse the policy trainer with the adapter disabled. Full-parameter policies, explicit reference_training_shape_id, or explicit reference_trainer_job_id use a separate forward-only reference trainer owned by the service.

create_sampling_client(model_path=None, ...)

Return a Tinker-shaped sampling client backed by the SDK-managed deployment. When model_path is provided, the SDK first syncs that sampler snapshot to the deployment:
saved = training_client.save_weights_for_sampler("step-100").result()
sampler = service.create_sampling_client(model_path=saved.path)
This is the replacement for calling a standalone weight-sync helper in user code. The SDK tracks the base/delta chain and builds the weight-sync metadata internally.

create_deployment_sampler(model_path=None, tokenizer=None, concurrency_controller=None)

Return the FireTitan-native DeploymentSampler directly. Use this when you need tokenized completions, inference logprobs, routing matrices, or adaptive concurrency:
sampler = service.create_deployment_sampler(
    model_path=saved.path,
    tokenizer=tokenizer,
    concurrency_controller=controller,
)

hotload_sampler_snapshot(model_path)

Low-level method for syncing a previously saved sampler snapshot into the SDK-managed deployment without constructing a sampler:
service.hotload_sampler_snapshot(saved.path)

FiretitanTrainingClient

The training client returned by create_training_client(). Core training RPCs like forward(...), forward_backward_custom(...), optim_step(...), save_state(...), and load_state_with_optimizer(...) return futures. Fireworks convenience helpers like save_weights_for_sampler_ext(...), list_checkpoints(), and resolve_checkpoint_path(...) return concrete values directly.

forward(datums, loss_type)

Forward-only pass (no gradient computation). Useful for computing reference logprobs in GRPO/DPO:
result = training_client.forward(datums, "cross_entropy").result()
logprobs = result.loss_fn_outputs[0]["logprobs"].data
Built-in loss types like "cross_entropy" require datums with target_tokens in loss_fn_inputs. Datums built with datum_from_model_input_weights will fail. Use the target-token tinker.Datum example in Loss Functions for built-in losses, or use forward_backward_custom with the weight-based format in Building datums and the custom-loss pattern in Example: simple cross-entropy.

forward_backward_custom(datums, loss_fn)

Forward + backward with your custom loss function. See Loss Functions for details:
def my_loss(data, logprobs_list):
    loss = compute_loss(data, logprobs_list)
    return loss, {"loss": float(loss.item())}

result = training_client.forward_backward_custom(datums, my_loss).result()
print(result.metrics)  # {"loss": 0.42}
For embedding-space objectives, pass output="embedding" and choose pooling="mean" or "last"; your loss function then receives pooled embedding tensors instead of logprobs:
result = training_client.forward_backward_custom(
    datums,
    embedding_loss,
    output="embedding",
    pooling="mean",
).result()

optim_step(adam_params, grad_accumulation_normalization=None)

Apply optimizer update after accumulating gradients:
import tinker

training_client.optim_step(
    tinker.AdamParams(
        learning_rate=1e-5,
        beta1=0.9,
        beta2=0.999,
        eps=1e-8,
        weight_decay=0.01,
    )
).result()
Supports grad_accumulation_normalization for controlling how accumulated gradients are normalized. Pass GradAccNormalization.NUM_LOSS_TOKENS, GradAccNormalization.NUM_SEQUENCES, or GradAccNormalization.NONE rather than raw strings. See Loss Functions for when to use each mode.

save_weights_for_sampler(name, ttl_seconds=None, checkpoint_type=None)

Save serving-compatible sampler weights and return a future. This is the normal Tinker-shaped API:
saved = training_client.save_weights_for_sampler(
    "step-100",
    checkpoint_type="base",  # optional: "base" or "delta"
).result()
print(saved.path)  # Snapshot identity for create_sampling_client(model_path=...)
Full-parameter training saves a base checkpoint first and deltas after that by default. LoRA training always saves base checkpoints. The returned path is a public snapshot identity, not a raw storage URI.

save_weights_for_sampler_ext(name, checkpoint_type, ttl_seconds)

Fireworks-specific extension that returns a concrete SaveSamplerResult instead of a future:
result = training_client.save_weights_for_sampler_ext(
    "step-100",
    checkpoint_type="base",  # "base" for full weights, "delta" for incremental
)
print(result.snapshot_name)  # Session-qualified name for weight sync
ParameterTypeDefaultDescription
namestrCheckpoint name (auto-suffixed with session ID)
checkpoint_typestr | NoneNone"base" for full weights, "delta" for incremental
ttl_secondsint | NoneNoneAuto-delete checkpoint after this many seconds
On full-parameter training, only checkpoint_type="base" produces a promotable blob; "delta" cannot be promoted. LoRA is always promotable. See Checkpoint kinds for the full promotability matrix.
save_weights_for_sampler_ext saves the snapshot only; it does not mutate a deployment. To serve the snapshot, pass result.snapshot_name to the managed service weight-sync path, or use create_sampling_client(model_path=...) / create_deployment_sampler(model_path=...), which sync and return a sampler.

save_state(name, ttl_seconds=None, timeout=None)

Save full train state (weights + optimizer) for resume:
training_client.save_state("train_state_step_100").result()
ParameterTypeDefaultDescription
namestrCheckpoint name
ttl_secondsint | NoneNoneAuto-delete checkpoint after this many seconds
timeoutfloat | NoneNoneIf set, block until the save completes or the timeout expires

load_state_with_optimizer(name)

Restore full train state (weights + optimizer) from a checkpoint:
training_client.load_state_with_optimizer("train_state_step_100").result()

load_state(name)

Load model weights from a checkpoint without restoring optimizer state. The optimizer is reset so the next optim_step starts fresh:
training_client.load_state("train_state_step_100").result()

load_adapter(adapter_path)

Load Hugging Face PEFT adapter weights into the current LoRA session. This is a weights-only warm start; it does not restore optimizer state, scheduler state, or data cursor.
training_client.load_adapter("gs://my-bucket/adapters/run-42").result()

list_checkpoints()

List available DCP checkpoints from the trainer. Returns a list[str]:
checkpoint_names = training_client.list_checkpoints()
print(checkpoint_names)  # e.g. ["step-2", "step-4"]

resolve_checkpoint_path(checkpoint_name, source_job_id)

Resolve a checkpoint path for cross-job resume:
checkpoint_ref = training_client.resolve_checkpoint_path(
    "step-4",
    source_job_id="previous-job-id",
)
training_client.load_state_with_optimizer(checkpoint_ref).result()

SaveSamplerResult

Returned by save_weights_for_sampler_ext:
FieldTypeDescription
pathstrSnapshot name from trainer
snapshot_namestrSession-qualified name for weight sync operations

GradAccNormalization

Enum for optim_step’s grad_accumulation_normalization parameter:
EnumWire valueDescription
GradAccNormalization.NUM_LOSS_TOKENS"num_loss_tokens"Normalize by total loss tokens across accumulated micro-batches
GradAccNormalization.NUM_SEQUENCES"num_sequences"Normalize by total sequences across accumulated micro-batches
GradAccNormalization.NONE"none"No normalization (raw gradient sum)