import os
import tinker
from fireworks.training.sdk import (
FiretitanServiceClient,
TrainerJobManager,
TrainerJobConfig,
DeploymentManager,
DeploymentConfig,
WeightSyncer,
)
api_key = os.environ["FIREWORKS_API_KEY"]
account_id = os.environ.get("FIREWORKS_ACCOUNT_ID", "")
base_url = os.environ.get("FIREWORKS_BASE_URL", "https://api.fireworks.ai")
shape_id = "accounts/fireworks/trainingShapes/qwen3-8b-128k-h200"
rlor_mgr = TrainerJobManager(api_key=api_key, account_id=account_id, base_url=base_url)
deploy_mgr = DeploymentManager(api_key=api_key, account_id=account_id, base_url=base_url)
# Create deployment for sampling/hotload
deploy_mgr.create_or_get(DeploymentConfig(
deployment_id="research-serving",
base_model="accounts/fireworks/models/qwen3-8b",
min_replica_count=0,
max_replica_count=1,
))
deploy_mgr.wait_for_ready("research-serving")
# This is the only shape-specific value you choose
profile = rlor_mgr.resolve_training_profile(shape_id)
# Create trainer (polls until healthy)
endpoint = rlor_mgr.create_and_wait(TrainerJobConfig(
base_model="accounts/fireworks/models/qwen3-8b",
training_shape_ref=profile.training_shape_version,
lora_rank=0,
learning_rate=1e-5,
gradient_accumulation_steps=4,
hot_load_deployment_id="research-serving",
))
# Connect client (FiretitanServiceClient provides checkpoint_type + session ID)
service = FiretitanServiceClient(base_url=endpoint.base_url, api_key=api_key)
training_client = service.create_training_client(
base_model="accounts/fireworks/models/qwen3-8b", lora_rank=0,
)