What this is
forward_backward_custom lets you implement any objective function in Python. You provide the loss computation; the Tinker SDK handles the forward pass on remote GPUs, passes logprobs back to your function, then sends the computed gradients back for the backward pass.
How it works
- You call
training_client.forward_backward_custom(datums, loss_fn). - The trainer runs a forward pass on the GPU and returns per-token logprobs.
- The logprobs are converted to PyTorch tensors with
requires_grad=True. - Your
loss_fnis called with the datums and logprobs. - The SDK calls
loss.backward()to computed_loss/d_logprobgradients. - Gradients are sent back to the trainer GPU for the model backward pass.
Loss function signature
Key rules
logprobs_list[i]hasrequires_grad=True— your loss must be differentiable through it.- Use
torch.dot()to compute weighted sums — this correctly propagates gradients through the logprobs. - Return a scalar tensor as the loss, and a
dict[str, float]as metrics. - Access token weights via
data[i].loss_fn_inputs["weights"].data— these are0for prompt tokens and1for response tokens (set when building datums).
Building datums
Using tinker_cookbook (weight-based)
datum_from_tokens_weights from tinker_cookbook constructs datums with explicit token weights. It handles internal token shifting:
Using tinker.Datum directly (target-token-based)
For RL-style objectives where you need per-completion control (e.g. routing matrices, customloss_fn_inputs), construct datums directly with target_tokens. This is used internally by the cookbook’s RL recipe:
data[i].loss_fn_inputs["target_tokens"] in your loss function instead of "weights". The logprobs correspond to these target tokens.
Example: simple cross-entropy
Example: GRPO with KL penalty
Example: DPO margin loss
Applying the optimizer step
Afterforward_backward_custom, call optim_step to update weights:
forward_backward_custom multiple times before calling optim_step:
Common pitfalls
- Token-weight misalignment can silently break objective semantics — always verify that
len(logprobs)andlen(weights)are compatible (truncate tomin_len). - Ignoring per-step diagnostics makes instability hard to attribute — log metrics from every train step.
- Forgetting
.result()— all Tinker API calls return futures. Without.result(), errors are silently swallowed. - Non-differentiable loss: If your loss doesn’t depend on
logprobs_listentries through differentiable ops, gradients will be zero.