import torch
import tinker
import transformers
from tinker_cookbook.supervised.common import datum_from_tokens_weights
service = FiretitanServiceClient(base_url=endpoint.base_url, api_key=api_key)
training_client = service.create_training_client(
base_model=base_model, lora_rank=0,
)
processor = transformers.AutoProcessor.from_pretrained(
"Qwen/Qwen3-VL-8B-Instruct", trust_remote_code=True,
)
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is in this image?"},
{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,/9j/..."}},
],
},
{
"role": "assistant",
"content": "The image shows a sunset over the ocean.",
},
]
text = processor.apply_chat_template(conversation, tokenize=False)
full_tokens = processor.tokenizer.encode(text)
prompt_text = processor.apply_chat_template(conversation[:1], tokenize=False)
prompt_len = len(processor.tokenizer.encode(prompt_text))
weights = torch.zeros(len(full_tokens), dtype=torch.float32)
weights[prompt_len:] = 1.0
datum = datum_from_tokens_weights(
torch.tensor(full_tokens, dtype=torch.long),
weights,
max_length=4096,
)
def sft_loss(data, logprobs_list):
total_loss = torch.tensor(0.0)
n_tokens = 0
for i, logprobs in enumerate(logprobs_list):
w = torch.tensor(data[i].loss_fn_inputs["weights"].data, dtype=torch.float32)
min_len = min(len(logprobs), len(w))
total_loss = total_loss - torch.dot(logprobs[:min_len].float(), w[:min_len])
n_tokens += w[:min_len].sum().item()
return total_loss / max(n_tokens, 1), {"sft_loss": (total_loss / max(n_tokens, 1)).item()}
for step in range(100):
training_client.forward_backward_custom([datum], sft_loss).result()
training_client.optim_step(
tinker.AdamParams(learning_rate=1e-5, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01)
).result()