Skip to main content

Overview

FiretitanServiceClient connects to a trainer endpoint and creates a FiretitanTrainingClient for training operations. It extends Tinker’s ServiceClient with Fireworks-specific features like checkpoint_type support and session-scoped snapshot naming.
from fireworks.training.sdk import FiretitanServiceClient

FiretitanServiceClient

Constructor

service = FiretitanServiceClient(
    base_url=endpoint.base_url,  # From TrainerJobManager.create_and_wait(...)
    api_key="<FIREWORKS_API_KEY>",
)
base_url is the trainer endpoint URL from TrainerServiceEndpoint.base_url.

create_training_client(base_model, lora_rank, user_metadata)

Creates a FiretitanTrainingClient for training operations:
training_client = service.create_training_client(
    base_model="accounts/fireworks/models/qwen3-8b",
    lora_rank=0,  # Must match lora_rank from job creation
)
ParameterTypeDefaultDescription
base_modelstrMust match the trainer job’s base_model
lora_rankint0Must match trainer creation config (0 for full-parameter)
user_metadatadict[str, str] | NoneNoneOptional run metadata
A ValueError is raised if you attempt to create a second training client with the same (base_model, lora_rank) on the same FiretitanServiceClient instance. Create a new FiretitanServiceClient for a separate trainer.

Connecting to an existing trainer

If you already have a running trainer (e.g. from a previous session), connect directly by URL:
service = FiretitanServiceClient(
    base_url="https://<existing-trainer-url>",
    api_key="<FIREWORKS_API_KEY>",
)
training_client = service.create_training_client(
    base_model="accounts/fireworks/models/qwen3-8b",
    lora_rank=0,
)

FiretitanTrainingClient

The training client returned by create_training_client(). Core training RPCs like forward(...), forward_backward_custom(...), optim_step(...), save_state(...), and load_state_with_optimizer(...) return futures. Fireworks convenience helpers like save_weights_for_sampler_ext(...), list_checkpoints(), and resolve_checkpoint_path(...) return concrete values directly.

forward(datums, loss_type)

Forward-only pass (no gradient computation). Useful for computing reference logprobs in GRPO/DPO:
result = training_client.forward(datums, "cross_entropy").result()
logprobs = result.loss_fn_outputs[0]["logprobs"].data
Built-in loss types like "cross_entropy" require datums with target_tokens in loss_fn_inputs. Datums built with datum_from_tokens_weights will fail. See Loss Functions — Building datums for the correct format, or use forward_backward_custom with weight-based datums instead.

forward_backward_custom(datums, loss_fn)

Forward + backward with your custom loss function. See Loss Functions for details:
def my_loss(data, logprobs_list):
    loss = compute_loss(data, logprobs_list)
    return loss, {"loss": float(loss.item())}

result = training_client.forward_backward_custom(datums, my_loss).result()
print(result.metrics)  # {"loss": 0.42}

optim_step(params)

Apply optimizer update after accumulating gradients:
import tinker

training_client.optim_step(
    tinker.AdamParams(
        learning_rate=1e-5,
        beta1=0.9,
        beta2=0.999,
        eps=1e-8,
        weight_decay=0.01,
    )
).result()
Supports grad_accumulation_normalization parameter for controlling how accumulated gradients are normalized. See Cookbook Reference for normalization modes.

save_weights_for_sampler_ext(name, checkpoint_type, ttl_seconds)

Export serving-compatible checkpoint with session-scoped naming. Returns a SaveSamplerResult:
result = training_client.save_weights_for_sampler_ext(
    "step-100",
    checkpoint_type="base",  # "base" for full weights, "delta" for incremental
)
print(result.snapshot_name)  # Session-qualified name for hotloading
ParameterTypeDefaultDescription
namestrCheckpoint name (auto-suffixed with session ID)
checkpoint_typestr | NoneNone"base" for full weights, "delta" for incremental
ttl_secondsint | NoneNoneAuto-delete checkpoint after this many seconds

save_state(name, ttl_seconds=None)

Save full train state (weights + optimizer) for resume:
training_client.save_state("train_state_step_100").result()
ParameterTypeDefaultDescription
namestrCheckpoint name
ttl_secondsint | NoneNoneAuto-delete checkpoint after this many seconds

load_state_with_optimizer(name)

Restore train state from a checkpoint:
training_client.load_state_with_optimizer("train_state_step_100").result()

list_checkpoints()

List available DCP checkpoints from the trainer. Returns a list[str]:
checkpoint_names = training_client.list_checkpoints()
print(checkpoint_names)  # e.g. ["step-2", "step-4"]

resolve_checkpoint_path(name, source_job_id)

Resolve a checkpoint path for cross-job resume:
checkpoint_ref = training_client.resolve_checkpoint_path(
    "step-4",
    source_job_id="previous-job-id",
)
training_client.load_state_with_optimizer(checkpoint_ref).result()

SaveSamplerResult

Returned by save_weights_for_sampler_ext:
FieldTypeDescription
pathstrSnapshot name from trainer
snapshot_namestrSession-qualified name for weight sync operations

GradAccNormalization

Enum for optim_step’s grad_accumulation_normalization parameter:
ValueDescription
"num_loss_tokens"Normalize by total loss tokens across accumulated micro-batches
"num_sequences"Normalize by total sequences across accumulated micro-batches
"none"No normalization (raw gradient sum)