Skip to main content

What this is

forward_backward_custom lets you implement any objective function in Python. You provide the loss computation; the Tinker SDK handles the forward pass on remote GPUs, passes logprobs back to your function, then sends the computed gradients back for the backward pass.

How it works

  1. You call training_client.forward_backward_custom(datums, loss_fn).
  2. The trainer runs a forward pass on the GPU and returns per-token logprobs.
  3. The logprobs are converted to PyTorch tensors with requires_grad=True.
  4. Your loss_fn is called with the datums and logprobs.
  5. The SDK calls loss.backward() to compute d_loss/d_logprob gradients.
  6. Gradients are sent back to the trainer GPU for the model backward pass.
Your loss function runs locally (on your machine), while the forward and backward passes run on remote GPUs.

Loss function signature

def loss_fn(
    data: list[tinker.Datum],
    logprobs_list: list[torch.Tensor],
) -> tuple[torch.Tensor, dict[str, float]]:
    """
    Args:
        data: The same datums you passed to forward_backward_custom.
              Access token weights via data[i].loss_fn_inputs["weights"].data
        logprobs_list: Per-token log-probabilities from the forward pass.
              Each tensor has requires_grad=True. Shape: (seq_len,) per sequence.

    Returns:
        loss: A scalar tensor. Must be differentiable w.r.t. logprobs_list entries.
        metrics: A dict of float values for logging (not used for training).
    """

Key rules

  • logprobs_list[i] has requires_grad=True — your loss must be differentiable through it.
  • Use torch.dot() to compute weighted sums — this correctly propagates gradients through the logprobs.
  • Return a scalar tensor as the loss, and a dict[str, float] as metrics.
  • Access token weights via data[i].loss_fn_inputs["weights"].data — these are 0 for prompt tokens and 1 for response tokens (set when building datums).

Building datums

Using tinker_cookbook (weight-based)

datum_from_tokens_weights from tinker_cookbook constructs datums with explicit token weights. It handles internal token shifting:
import torch
from tinker_cookbook.supervised.common import datum_from_tokens_weights

tokens = torch.tensor([101, 2054, 2003, ...], dtype=torch.long)
weights = torch.zeros(len(tokens), dtype=torch.float32)
weights[prompt_len:] = 1.0  # Only train on response tokens

datum = datum_from_tokens_weights(tokens, weights, max_length=8192)

Using tinker.Datum directly (target-token-based)

For RL-style objectives where you need per-completion control (e.g. routing matrices, custom loss_fn_inputs), construct datums directly with target_tokens. This is used internally by the cookbook’s RL recipe:
import tinker

model_input_len = len(tokens) - 1
datum = tinker.Datum(
    model_input=tinker.ModelInput.from_ints(tokens[:-1]),
    loss_fn_inputs={
        "target_tokens": tinker.TensorData(
            data=tokens[1:], dtype="int64", shape=[model_input_len],
        ),
    },
)
With target-token datums, access data[i].loss_fn_inputs["target_tokens"] in your loss function instead of "weights". The logprobs correspond to these target tokens.

Example: simple cross-entropy

def cross_entropy_loss(data, logprobs_list):
    total_loss = torch.tensor(0.0)
    for i, logprobs in enumerate(logprobs_list):
        weights = torch.tensor(data[i].loss_fn_inputs["weights"].data, dtype=torch.float32)
        min_len = min(len(logprobs), len(weights))
        weighted_sum = torch.dot(logprobs[:min_len].float(), weights[:min_len])
        total_loss = total_loss - weighted_sum  # Negative log-likelihood
    loss = total_loss / len(logprobs_list)
    return loss, {"cross_entropy": loss.item()}

result = training_client.forward_backward_custom(datums, cross_entropy_loss).result()

Example: GRPO with KL penalty

def make_grpo_loss(rewards, ref_logprobs, kl_beta=0.001):
    advantages = compute_advantages(rewards)
    ref_tensors = [torch.tensor(lp, dtype=torch.float32) for lp in ref_logprobs]

    def loss_fn(data, logprobs_list):
        total_loss = torch.tensor(0.0)
        for i in range(len(logprobs_list)):
            weights = torch.tensor(data[i].loss_fn_inputs["weights"].data, dtype=torch.float32)
            pi = logprobs_list[i][:len(weights)]
            ref = ref_tensors[i][:len(weights)]

            pg_loss = -advantages[i] * torch.dot(pi.float(), weights)
            kl_term = torch.dot((pi - ref).float(), weights)
            total_loss = total_loss + pg_loss + kl_beta * kl_term

        return total_loss / len(logprobs_list), {"loss": (total_loss / len(logprobs_list)).item()}

    return loss_fn

Example: DPO margin loss

import torch.nn.functional as F

def make_dpo_loss(ref_chosen, ref_rejected, beta=0.1):
    ref_c = torch.tensor(ref_chosen, dtype=torch.float32)
    ref_r = torch.tensor(ref_rejected, dtype=torch.float32)

    def loss_fn(data, logprobs_list):
        pi_c, pi_r = logprobs_list[0], logprobs_list[1]
        w_c = torch.tensor(data[0].loss_fn_inputs["weights"].data, dtype=torch.float32)
        w_r = torch.tensor(data[1].loss_fn_inputs["weights"].data, dtype=torch.float32)

        margin = (torch.dot(pi_c.float(), w_c) - torch.dot(ref_c, w_c)) - \
                 (torch.dot(pi_r.float(), w_r) - torch.dot(ref_r, w_r))

        return -F.logsigmoid(beta * margin), {"margin": margin.item()}

    return loss_fn

Applying the optimizer step

After forward_backward_custom, call optim_step to update weights:
training_client.forward_backward_custom(datums, loss_fn).result()
training_client.optim_step(
    tinker.AdamParams(
        learning_rate=1e-5,
        beta1=0.9,
        beta2=0.999,
        eps=1e-8,
        weight_decay=0.01,
    )
).result()
For gradient accumulation, call forward_backward_custom multiple times before calling optim_step:
for micro_batch in micro_batches:
    training_client.forward_backward_custom(micro_batch, loss_fn).result()

# One optimizer step after accumulating gradients
training_client.optim_step(tinker.AdamParams(learning_rate=1e-5, ...)).result()

Common pitfalls

  • Token-weight misalignment can silently break objective semantics — always verify that len(logprobs) and len(weights) are compatible (truncate to min_len).
  • Ignoring per-step diagnostics makes instability hard to attribute — log metrics from every train step.
  • Forgetting .result() — all Tinker API calls return futures. Without .result(), errors are silently swallowed.
  • Non-differentiable loss: If your loss doesn’t depend on logprobs_list entries through differentiable ops, gradients will be zero.