import os
import tinker
from concurrent.futures import ThreadPoolExecutor
from fireworks.training.sdk import TrainerJobManager, DeploymentManager, WeightSyncer
from training.utils import InfraConfig, DeployConfig, ReconnectableClient, create_trainer_job, setup_deployment
api_key = os.environ["FIREWORKS_API_KEY"]
account_id = os.environ.get("FIREWORKS_ACCOUNT_ID", "")
base_url = os.environ.get("FIREWORKS_BASE_URL", "https://api.fireworks.ai")
rlor_mgr = TrainerJobManager(api_key=api_key, account_id=account_id, base_url=base_url)
deploy_mgr = DeploymentManager(api_key=api_key, account_id=account_id, base_url=base_url)
base_model = "accounts/fireworks/models/qwen3-8b"
infra = InfraConfig(
training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200",
ref_training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200-forward",
)
with ThreadPoolExecutor(max_workers=2) as pool:
pol_fut = pool.submit(
create_trainer_job, rlor_mgr,
base_model=base_model, infra=infra, lora_rank=0,
display_name="dpo-policy", hot_load_deployment_id="dpo-serving",
)
ref_fut = pool.submit(
create_trainer_job, rlor_mgr,
base_model=base_model, infra=infra, lora_rank=0,
display_name="dpo-reference", forward_only=True,
)
policy_ep = pol_fut.result()
reference_ep = ref_fut.result()
policy_client = ReconnectableClient(rlor_mgr, policy_ep.job_id, base_model, lora_rank=0, fw_api_key=api_key)
reference_client = ReconnectableClient(rlor_mgr, reference_ep.job_id, base_model, lora_rank=0, fw_api_key=api_key)