What this is
The Training SDK supports two ways to compute loss:
- Built-in losses via
forward_backward with a string identifier (e.g. "cross_entropy") — fastest, no extra forward pass needed.
- Custom losses via
forward_backward_custom with an arbitrary Python function — flexible, supports any differentiable objective at the cost of an additional forward pass.
Built-in loss: cross_entropy
For supervised fine-tuning, use the built-in cross_entropy loss via forward_backward:
result = training_client.forward_backward(datums, "cross_entropy").result()
This computes standard next-token prediction loss on the server side — no extra forward pass or local loss computation needed.
Built-in cross_entropy requires datums with target_tokens in loss_fn_inputs. Datums built with datum_from_tokens_weights (weight-based) will fail with "missing required field 'target_tokens'". Use the target-token datum format shown in Building datums, or use forward_backward_custom with weight-based datums instead.
For a forward-only pass (e.g. to compute reference logprobs without updating weights):
result = training_client.forward(datums, "cross_entropy").result()
ref_logprobs = [result.loss_fn_outputs[i]["logprobs"].data for i in range(len(datums))]
Custom losses: forward_backward_custom
forward_backward_custom lets you implement any objective function in Python. You provide the loss computation; the SDK handles the forward pass on remote GPUs, passes logprobs back to your function, then sends the computed gradients back for the backward pass.
How it works
- You call
training_client.forward_backward_custom(datums, loss_fn).
- The trainer runs a forward pass on the GPU and returns per-token logprobs.
- The logprobs are converted to PyTorch tensors with
requires_grad=True.
- Your
loss_fn is called with the datums and logprobs.
- The SDK calls
loss.backward() to compute d_loss/d_logprob gradients.
- Gradients are sent back to the trainer GPU for the model backward pass.
Your loss function runs locally (on your machine), while the forward and backward passes run on remote GPUs.
forward_backward_custom does an extra forward pass compared to forward_backward, requiring ~1.5x FLOPs and up to ~3x wall time per step.
Loss function signature
def loss_fn(
data: list[tinker.Datum],
logprobs_list: list[torch.Tensor],
) -> tuple[torch.Tensor, dict[str, float]]:
"""
Args:
data: The same datums you passed to forward_backward_custom.
Access token weights via data[i].loss_fn_inputs["weights"].data
logprobs_list: Per-token log-probabilities from the forward pass.
Each tensor has requires_grad=True. Shape: (seq_len,) per sequence.
Returns:
loss: A scalar tensor. Must be differentiable w.r.t. logprobs_list entries.
metrics: A dict of float values for logging (not used for training).
"""
Key rules
logprobs_list[i] has requires_grad=True — your loss must be differentiable through it.
- Use
torch.dot() to compute weighted sums — this correctly propagates gradients through the logprobs.
- Return a scalar tensor as the loss, and a
dict[str, float] as metrics.
- Access token weights via
data[i].loss_fn_inputs["weights"].data — these are 0 for prompt tokens and 1 for response tokens.
Building datums
Using tinker_cookbook (weight-based)
datum_from_tokens_weights constructs datums with explicit token weights:
import torch
from tinker_cookbook.supervised.common import datum_from_tokens_weights
tokens = torch.tensor([101, 2054, 2003, ...], dtype=torch.long)
weights = torch.zeros(len(tokens), dtype=torch.float32)
weights[prompt_len:] = 1.0 # Only train on response tokens
datum = datum_from_tokens_weights(tokens, weights, max_length=8192)
Using tinker.Datum directly (target-token-based)
For RL-style objectives where you need per-completion control (e.g. routing matrices, custom loss_fn_inputs), construct datums directly:
import tinker
model_input_len = len(tokens) - 1
datum = tinker.Datum(
model_input=tinker.ModelInput.from_ints(tokens[:-1]),
loss_fn_inputs={
"target_tokens": tinker.TensorData(
data=tokens[1:], dtype="int64", shape=[model_input_len],
),
},
)
Example: simple cross-entropy
def cross_entropy_loss(data, logprobs_list):
total_loss = torch.tensor(0.0)
for i, logprobs in enumerate(logprobs_list):
weights = torch.tensor(data[i].loss_fn_inputs["weights"].data, dtype=torch.float32)
min_len = min(len(logprobs), len(weights))
weighted_sum = torch.dot(logprobs[:min_len].float(), weights[:min_len])
total_loss = total_loss - weighted_sum # Negative log-likelihood
loss = total_loss / len(logprobs_list)
return loss, {"cross_entropy": loss.item()}
result = training_client.forward_backward_custom(datums, cross_entropy_loss).result()
Example: GRPO with KL penalty
def make_grpo_loss(rewards, ref_logprobs, kl_beta=0.001):
advantages = compute_advantages(rewards)
ref_tensors = [torch.tensor(lp, dtype=torch.float32) for lp in ref_logprobs]
def loss_fn(data, logprobs_list):
total_loss = torch.tensor(0.0)
for i in range(len(logprobs_list)):
weights = torch.tensor(data[i].loss_fn_inputs["weights"].data, dtype=torch.float32)
pi = logprobs_list[i][:len(weights)]
ref = ref_tensors[i][:len(weights)]
pg_loss = -advantages[i] * torch.dot(pi.float(), weights)
kl_term = torch.dot((pi - ref).float(), weights)
total_loss = total_loss + pg_loss + kl_beta * kl_term
return total_loss / len(logprobs_list), {"loss": (total_loss / len(logprobs_list)).item()}
return loss_fn
Example: DPO margin loss
import torch.nn.functional as F
def make_dpo_loss(ref_chosen, ref_rejected, beta=0.1):
ref_c = torch.tensor(ref_chosen, dtype=torch.float32)
ref_r = torch.tensor(ref_rejected, dtype=torch.float32)
def loss_fn(data, logprobs_list):
pi_c, pi_r = logprobs_list[0], logprobs_list[1]
w_c = torch.tensor(data[0].loss_fn_inputs["weights"].data, dtype=torch.float32)
w_r = torch.tensor(data[1].loss_fn_inputs["weights"].data, dtype=torch.float32)
margin = (torch.dot(pi_c.float(), w_c) - torch.dot(ref_c, w_c)) - \
(torch.dot(pi_r.float(), w_r) - torch.dot(ref_r, w_r))
return -F.logsigmoid(beta * margin), {"margin": margin.item()}
return loss_fn
Applying the optimizer step
After forward_backward_custom, call optim_step to update weights:
training_client.forward_backward_custom(datums, loss_fn).result()
training_client.optim_step(
tinker.AdamParams(
learning_rate=1e-5,
beta1=0.9,
beta2=0.999,
eps=1e-8,
weight_decay=0.01,
)
).result()
For gradient accumulation, call forward_backward_custom multiple times before calling optim_step:
for micro_batch in micro_batches:
training_client.forward_backward_custom(micro_batch, loss_fn).result()
# One optimizer step after accumulating gradients
training_client.optim_step(tinker.AdamParams(learning_rate=1e-5, ...)).result()
Gradient accumulation normalization
When you accumulate multiple micro-batches before optim_step, you have two places where normalization can happen:
- Inside your loss function
- Server-side inside
optim_step(..., grad_accumulation_normalization=...)
Use only one normalization path. If your loss already returns a mean, leave server-side normalization unset. If your loss returns a raw sum, choose the matching server-side normalization mode:
training_client.forward_backward_custom(datums, loss_fn).result()
training_client.optim_step(
tinker.AdamParams(learning_rate=1e-5, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01),
grad_accumulation_normalization="num_loss_tokens",
).result()
| Mode | Divides by | Best for |
|---|
"num_loss_tokens" | Total non-zero-grad tokens across accumulated micro-batches | Raw-sum token-level losses, such as RL / GRPO-style objectives |
"num_sequences" | Total sequences with at least one non-zero-grad token | Raw-sum sequence-level objectives |
None | Nothing | Losses that already return per-token or per-sequence means, such as SFT, DPO, and ORPO |
Choosing the right mode
- If your loss function returns a raw sum over tokens, use
"num_loss_tokens".
- If your loss function returns a raw sum over sequences, use
"num_sequences".
- If your loss function already returns a mean, leave
grad_accumulation_normalization unset.
Do not normalize in both places. If your loss function already divides by tokens or sequences, adding server-side normalization will double-normalize the gradients.
Recipe defaults
| Recipe | Default | Rationale |
|---|
| SFT | None | The SFT loss is already normalized client-side. |
| GRPO / RL | "num_loss_tokens" | RL losses use server-side per-token normalization by default. |
| DPO | None | The DPO loss is already normalized client-side. |
| ORPO | None | The ORPO loss is already normalized client-side. |
Common pitfalls
- Token-weight misalignment can silently break objective semantics — always verify that
len(logprobs) and len(weights) are compatible (truncate to min_len).
- Ignoring per-step diagnostics makes instability hard to attribute — log metrics from every train step.
- Forgetting
.result() — all Tinker API calls return futures. Without .result(), errors are silently swallowed.
- Non-differentiable loss: If your loss doesn’t depend on
logprobs_list entries through differentiable ops, gradients will be zero.