Skip to main content

Overview

DeploymentSampler handles client-side tokenization via a HuggingFace tokenizer and returns structured SampledCompletion objects with token IDs, logprobs, and completion metadata. Use it in training scripts that need token-level outputs (e.g. GRPO, DPO).
from fireworks.training.sdk import DeploymentSampler

Constructor

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True)
sampler = DeploymentSampler(
    inference_url="https://api.fireworks.ai",
    model=f"accounts/{account_id}/deployments/{deployment_id}",
    api_key="<FIREWORKS_API_KEY>",
    tokenizer=tokenizer,
)
ParameterTypeDescription
inference_urlstrGateway URL for inference completions
modelstrDeployment model path (accounts/<id>/deployments/<id>)
api_keystrFireworks API key
tokenizerPreTrainedTokenizerHuggingFace tokenizer matching the base model

sample_with_tokens(...)

Sample completions and return structured results with token IDs. This method is async, so call it with await or wrap it with asyncio.run(...) from synchronous code:
import asyncio

async def main():
    completions = await sampler.sample_with_tokens(
        messages=[{"role": "user", "content": "Solve: 2+2="}],
        n=4,
        max_tokens=1024,
        temperature=0.7,
    )
    for c in completions:
        print(c.full_tokens)       # prompt + completion token IDs
        print(c.prompt_len)        # number of prompt tokens
        print(c.completion_len)    # number of completion tokens
        print(c.text)              # decoded completion text
        print(c.finish_reason)     # "stop", "length", etc.

asyncio.run(main())

Retrieving inference logprobs

For GRPO importance sampling, pass logprobs=True:
import asyncio

async def main():
    completions = await sampler.sample_with_tokens(
        messages=[{"role": "user", "content": "Solve: 2+2="}],
        n=4,
        logprobs=True,
        top_logprobs=1,
    )
    for c in completions:
        print(c.inference_logprobs)  # List[float] or None

asyncio.run(main())

Sequence length filtering

sample_with_tokens supports max_seq_len for automatic filtering:
import asyncio

completions = asyncio.run(
    sampler.sample_with_tokens(
        messages=input_messages,
        n=4,
        max_tokens=1024,
        max_seq_len=8192,  # filter out sequences exceeding this length
    )
)
Two levels of filtering are applied:
  1. Prompt pre-filter: If the tokenized prompt already meets or exceeds max_seq_len, the method returns an empty list immediately — no inference call is made.
  2. Completion post-filter: After sampling, any completion whose full token sequence (prompt + completion) exceeds max_seq_len is silently dropped.

SampledCompletion

Each completion returned by sample_with_tokens:
FieldTypeDescription
textstrDecoded completion text
full_tokensList[int]Prompt + completion token IDs
prompt_lenintNumber of prompt tokens
finish_reasonstr"stop", "length", etc.
completion_lenintNumber of completion tokens
inference_logprobsList[float] | NonePer-token logprobs (when logprobs=True is passed)
logprobs_echoedboolTrue when echo=True was used — logprobs are training-aligned (P+C-1 entries)
routing_matricesList[str] | NoneBase64-encoded per-token routing matrices for MoE Router Replay (R3)