Skip to main content

What this is

The Training SDK supports two ways to compute loss:
  1. Built-in losses via forward_backward with a string identifier (e.g. "cross_entropy") — fastest, no extra forward pass needed.
  2. Custom losses via forward_backward_custom with an arbitrary Python function — flexible, supports any differentiable objective at the cost of an additional forward pass.

Built-in loss: cross_entropy

For supervised fine-tuning, use the built-in cross_entropy loss via forward_backward:
result = training_client.forward_backward(datums, "cross_entropy").result()
This computes standard next-token prediction loss on the server side — no extra forward pass or local loss computation needed.
Built-in cross_entropy requires datums with target_tokens in loss_fn_inputs. Datums built with datum_from_tokens_weights (weight-based) will fail with "missing required field 'target_tokens'". Use the target-token datum format shown in Building datums, or use forward_backward_custom with weight-based datums instead.
For a forward-only pass (e.g. to compute reference logprobs without updating weights):
result = training_client.forward(datums, "cross_entropy").result()
ref_logprobs = [result.loss_fn_outputs[i]["logprobs"].data for i in range(len(datums))]

Custom losses: forward_backward_custom

forward_backward_custom lets you implement any objective function in Python. You provide the loss computation; the SDK handles the forward pass on remote GPUs, passes logprobs back to your function, then sends the computed gradients back for the backward pass.

How it works

  1. You call training_client.forward_backward_custom(datums, loss_fn).
  2. The trainer runs a forward pass on the GPU and returns per-token logprobs.
  3. The logprobs are converted to PyTorch tensors with requires_grad=True.
  4. Your loss_fn is called with the datums and logprobs.
  5. The SDK calls loss.backward() to compute d_loss/d_logprob gradients.
  6. Gradients are sent back to the trainer GPU for the model backward pass.
Your loss function runs locally (on your machine), while the forward and backward passes run on remote GPUs.
forward_backward_custom does an extra forward pass compared to forward_backward, requiring ~1.5x FLOPs and up to ~3x wall time per step.

Loss function signature

def loss_fn(
    data: list[tinker.Datum],
    logprobs_list: list[torch.Tensor],
) -> tuple[torch.Tensor, dict[str, float]]:
    """
    Args:
        data: The same datums you passed to forward_backward_custom.
              Access token weights via data[i].loss_fn_inputs["weights"].data
        logprobs_list: Per-token log-probabilities from the forward pass.
              Each tensor has requires_grad=True. Shape: (seq_len,) per sequence.

    Returns:
        loss: A scalar tensor. Must be differentiable w.r.t. logprobs_list entries.
        metrics: A dict of float values for logging (not used for training).
    """

Key rules

  • logprobs_list[i] has requires_grad=True — your loss must be differentiable through it.
  • Use torch.dot() to compute weighted sums — this correctly propagates gradients through the logprobs.
  • Return a scalar tensor as the loss, and a dict[str, float] as metrics.
  • Access token weights via data[i].loss_fn_inputs["weights"].data — these are 0 for prompt tokens and 1 for response tokens.

Building datums

Using tinker_cookbook (weight-based)

datum_from_tokens_weights constructs datums with explicit token weights:
import torch
from tinker_cookbook.supervised.common import datum_from_tokens_weights

tokens = torch.tensor([101, 2054, 2003, ...], dtype=torch.long)
weights = torch.zeros(len(tokens), dtype=torch.float32)
weights[prompt_len:] = 1.0  # Only train on response tokens

datum = datum_from_tokens_weights(tokens, weights, max_length=8192)

Using tinker.Datum directly (target-token-based)

For RL-style objectives where you need per-completion control (e.g. routing matrices, custom loss_fn_inputs), construct datums directly:
import tinker

model_input_len = len(tokens) - 1
datum = tinker.Datum(
    model_input=tinker.ModelInput.from_ints(tokens[:-1]),
    loss_fn_inputs={
        "target_tokens": tinker.TensorData(
            data=tokens[1:], dtype="int64", shape=[model_input_len],
        ),
    },
)

Example: simple cross-entropy

def cross_entropy_loss(data, logprobs_list):
    total_loss = torch.tensor(0.0)
    for i, logprobs in enumerate(logprobs_list):
        weights = torch.tensor(data[i].loss_fn_inputs["weights"].data, dtype=torch.float32)
        min_len = min(len(logprobs), len(weights))
        weighted_sum = torch.dot(logprobs[:min_len].float(), weights[:min_len])
        total_loss = total_loss - weighted_sum  # Negative log-likelihood
    loss = total_loss / len(logprobs_list)
    return loss, {"cross_entropy": loss.item()}

result = training_client.forward_backward_custom(datums, cross_entropy_loss).result()

Example: GRPO with KL penalty

def make_grpo_loss(rewards, ref_logprobs, kl_beta=0.001):
    advantages = compute_advantages(rewards)
    ref_tensors = [torch.tensor(lp, dtype=torch.float32) for lp in ref_logprobs]

    def loss_fn(data, logprobs_list):
        total_loss = torch.tensor(0.0)
        for i in range(len(logprobs_list)):
            weights = torch.tensor(data[i].loss_fn_inputs["weights"].data, dtype=torch.float32)
            pi = logprobs_list[i][:len(weights)]
            ref = ref_tensors[i][:len(weights)]

            pg_loss = -advantages[i] * torch.dot(pi.float(), weights)
            kl_term = torch.dot((pi - ref).float(), weights)
            total_loss = total_loss + pg_loss + kl_beta * kl_term

        return total_loss / len(logprobs_list), {"loss": (total_loss / len(logprobs_list)).item()}

    return loss_fn

Example: DPO margin loss

import torch.nn.functional as F

def make_dpo_loss(ref_chosen, ref_rejected, beta=0.1):
    ref_c = torch.tensor(ref_chosen, dtype=torch.float32)
    ref_r = torch.tensor(ref_rejected, dtype=torch.float32)

    def loss_fn(data, logprobs_list):
        pi_c, pi_r = logprobs_list[0], logprobs_list[1]
        w_c = torch.tensor(data[0].loss_fn_inputs["weights"].data, dtype=torch.float32)
        w_r = torch.tensor(data[1].loss_fn_inputs["weights"].data, dtype=torch.float32)

        margin = (torch.dot(pi_c.float(), w_c) - torch.dot(ref_c, w_c)) - \
                 (torch.dot(pi_r.float(), w_r) - torch.dot(ref_r, w_r))

        return -F.logsigmoid(beta * margin), {"margin": margin.item()}

    return loss_fn

Applying the optimizer step

After forward_backward_custom, call optim_step to update weights:
training_client.forward_backward_custom(datums, loss_fn).result()
training_client.optim_step(
    tinker.AdamParams(
        learning_rate=1e-5,
        beta1=0.9,
        beta2=0.999,
        eps=1e-8,
        weight_decay=0.01,
    )
).result()
For gradient accumulation, call forward_backward_custom multiple times before calling optim_step:
for micro_batch in micro_batches:
    training_client.forward_backward_custom(micro_batch, loss_fn).result()

# One optimizer step after accumulating gradients
training_client.optim_step(tinker.AdamParams(learning_rate=1e-5, ...)).result()

Gradient accumulation normalization

When you accumulate multiple micro-batches before optim_step, you have two places where normalization can happen:
  1. Inside your loss function
  2. Server-side inside optim_step(..., grad_accumulation_normalization=...)
Use only one normalization path. If your loss already returns a mean, leave server-side normalization unset. If your loss returns a raw sum, choose the matching server-side normalization mode:
training_client.forward_backward_custom(datums, loss_fn).result()
training_client.optim_step(
    tinker.AdamParams(learning_rate=1e-5, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01),
    grad_accumulation_normalization="num_loss_tokens",
).result()
ModeDivides byBest for
"num_loss_tokens"Total non-zero-grad tokens across accumulated micro-batchesRaw-sum token-level losses, such as RL / GRPO-style objectives
"num_sequences"Total sequences with at least one non-zero-grad tokenRaw-sum sequence-level objectives
NoneNothingLosses that already return per-token or per-sequence means, such as SFT, DPO, and ORPO

Choosing the right mode

  • If your loss function returns a raw sum over tokens, use "num_loss_tokens".
  • If your loss function returns a raw sum over sequences, use "num_sequences".
  • If your loss function already returns a mean, leave grad_accumulation_normalization unset.
Do not normalize in both places. If your loss function already divides by tokens or sequences, adding server-side normalization will double-normalize the gradients.

Recipe defaults

RecipeDefaultRationale
SFTNoneThe SFT loss is already normalized client-side.
GRPO / RL"num_loss_tokens"RL losses use server-side per-token normalization by default.
DPONoneThe DPO loss is already normalized client-side.
ORPONoneThe ORPO loss is already normalized client-side.

Common pitfalls

  • Token-weight misalignment can silently break objective semantics — always verify that len(logprobs) and len(weights) are compatible (truncate to min_len).
  • Ignoring per-step diagnostics makes instability hard to attribute — log metrics from every train step.
  • Forgetting .result() — all Tinker API calls return futures. Without .result(), errors are silently swallowed.
  • Non-differentiable loss: If your loss doesn’t depend on logprobs_list entries through differentiable ops, gradients will be zero.