How service-mode training works
With Fireworks service-mode training, you write a Python loop on your local machine that controls the training process, while the actual model computation runs on remote GPUs managed by Fireworks. They communicate over an HTTP API. This split means:- You don’t need GPUs locally — a laptop is enough to drive training.
- You have full control over the objective — your loss function runs in your Python process.
- The platform handles distributed training, checkpointing, and serving on the GPU side.
Datums: how you send data to the trainer
A Datum is the unit of training data sent to the remote GPU. It wraps tokenized input and metadata that your loss function needs. Think of it as one sequence in a batch.Token weights
Token weights tell the loss function which tokens matter for training:0.0= prompt token (don’t train on this)1.0= response token (train on this)
data[i].loss_fn_inputs["weights"].data.
Logprobs: what the GPU sends back
When you callforward_backward_custom, the GPU runs a forward pass and returns per-token log-probabilities (logprobs) — the model’s estimate of how likely each token is. These arrive as PyTorch tensors with requires_grad=True.
Your loss function receives these logprobs and computes a scalar loss from them. The SDK then calls loss.backward() to compute gradients, which are sent back to the GPU for the model backward pass.
forward_backward_custom: the core training primitive
This is the function that ties everything together. Here’s what happens step by step:- You call
training_client.forward_backward_custom(datums, loss_fn) - The datums are sent to the remote GPU
- The GPU runs a forward pass and returns per-token logprobs
- Your loss function runs locally with those logprobs (as autograd tensors)
- The SDK calls
loss.backward()to computed_loss/d_logprobs - Those gradients are sent back to the GPU for the model backward pass
Why .result()?
All training client API calls return futures (async handles). Call .result() to block until the operation completes and get the return value. Without .result(), errors are silently swallowed.
optim_step: applying weight updates
Afterforward_backward_custom accumulates gradients, call optim_step to apply the optimizer update:
forward_backward_custom multiple times before calling optim_step:
Checkpointing: base and delta
After training, you export checkpoints for serving:- Base checkpoint: Full model weights (~16 GB for an 8B model). Use for the first checkpoint.
- Delta checkpoint: Only the diff from the previous base (~10x smaller). Use for subsequent checkpoints.