Skip to main content

How service-mode training works

With Fireworks service-mode training, you write a Python loop on your local machine that controls the training process, while the actual model computation runs on remote GPUs managed by Fireworks. They communicate over an HTTP API. This split means:
  • You don’t need GPUs locally — a laptop is enough to drive training.
  • You have full control over the objective — your loss function runs in your Python process.
  • The platform handles distributed training, checkpointing, and serving on the GPU side.

Datums: how you send data to the trainer

A Datum is the unit of training data sent to the remote GPU. It wraps tokenized input and metadata that your loss function needs. Think of it as one sequence in a batch.
import torch
from tinker_cookbook.supervised.common import datum_from_tokens_weights

tokens = tokenizer.encode("What is 2+2? The answer is 4.")
prompt_len = len(tokenizer.encode("What is 2+2? "))

weights = torch.zeros(len(tokens), dtype=torch.float32)
weights[prompt_len:] = 1.0  # Train on response tokens only

datum = datum_from_tokens_weights(
    torch.tensor(tokens, dtype=torch.long),
    weights,
    max_length=4096,
)

Token weights

Token weights tell the loss function which tokens matter for training:
  • 0.0 = prompt token (don’t train on this)
  • 1.0 = response token (train on this)
This is how you teach the model to generate good responses without memorizing prompts. You access them in your loss function via data[i].loss_fn_inputs["weights"].data.

Logprobs: what the GPU sends back

When you call forward_backward_custom, the GPU runs a forward pass and returns per-token log-probabilities (logprobs) — the model’s estimate of how likely each token is. These arrive as PyTorch tensors with requires_grad=True. Your loss function receives these logprobs and computes a scalar loss from them. The SDK then calls loss.backward() to compute gradients, which are sent back to the GPU for the model backward pass.
# logprobs_list[i] is a 1-D tensor of shape (seq_len,)
# Each value is log P(token_j | tokens_before_j)
# requires_grad=True so your loss can backpropagate through them

forward_backward_custom: the core training primitive

This is the function that ties everything together. Here’s what happens step by step:
  1. You call training_client.forward_backward_custom(datums, loss_fn)
  2. The datums are sent to the remote GPU
  3. The GPU runs a forward pass and returns per-token logprobs
  4. Your loss function runs locally with those logprobs (as autograd tensors)
  5. The SDK calls loss.backward() to compute d_loss/d_logprobs
  6. Those gradients are sent back to the GPU for the model backward pass
def my_loss_fn(data, logprobs_list):
    """
    Args:
        data: list of Datum objects (same ones you passed in)
        logprobs_list: list of torch.Tensor with requires_grad=True

    Returns:
        loss: scalar tensor (must be differentiable w.r.t. logprobs)
        metrics: dict of floats for logging
    """
    loss = compute_something(logprobs_list)
    return loss, {"loss": loss.item()}

result = training_client.forward_backward_custom(datums, my_loss_fn).result()

Why .result()?

All training client API calls return futures (async handles). Call .result() to block until the operation completes and get the return value. Without .result(), errors are silently swallowed.
# This queues the work but doesn't wait:
future = training_client.forward_backward_custom(datums, loss_fn)

# This blocks until it completes:
result = future.result()

optim_step: applying weight updates

After forward_backward_custom accumulates gradients, call optim_step to apply the optimizer update:
import tinker

training_client.optim_step(
    tinker.AdamParams(
        learning_rate=1e-5,
        beta1=0.9,
        beta2=0.999,
        eps=1e-8,
        weight_decay=0.01,
    )
).result()
For gradient accumulation, call forward_backward_custom multiple times before calling optim_step:
for micro_batch in micro_batches:
    training_client.forward_backward_custom(micro_batch, loss_fn).result()

# Gradients accumulated — now update weights
training_client.optim_step(tinker.AdamParams(learning_rate=1e-5, ...)).result()

Checkpointing: base and delta

After training, you export checkpoints for serving:
  • Base checkpoint: Full model weights (~16 GB for an 8B model). Use for the first checkpoint.
  • Delta checkpoint: Only the diff from the previous base (~10x smaller). Use for subsequent checkpoints.
# First checkpoint — must be base
result = training_client.save_weights_for_sampler_ext("step-0001", checkpoint_type="base")

# Later checkpoints — delta is faster
result = training_client.save_weights_for_sampler_ext("step-0010", checkpoint_type="delta")

# result.snapshot_name is the session-qualified name for hotloading

Hotloading: updating a live deployment

Hotloading pushes a checkpoint onto a running inference deployment without restarting it. This lets you evaluate the model under serving conditions during training.

Putting it all together

A complete training loop follows this pattern:
1. Provision a trainer job (GPU allocation)
2. Connect a training client
3. For each step:
   a. Build datums from your data
   b. Call forward_backward_custom with your loss function
   c. Call optim_step to update weights
   d. Periodically: save checkpoint → hotload → evaluate
4. Clean up resources
See the Quickstart for a minimal working example, or the GRPO and DPO guides for full algorithm implementations.