import os
import tinker
import transformers
from concurrent.futures import ThreadPoolExecutor
from fireworks.training.sdk import (
TrainerJobManager, DeploymentManager, DeploymentSampler, WeightSyncer,
)
from training.utils import (
InfraConfig, DeployConfig, ReconnectableClient,
create_trainer_job, setup_deployment,
)
api_key = os.environ["FIREWORKS_API_KEY"]
account_id = os.environ.get("FIREWORKS_ACCOUNT_ID", "")
base_url = os.environ.get("FIREWORKS_BASE_URL", "https://api.fireworks.ai")
rlor_mgr = TrainerJobManager(api_key=api_key, account_id=account_id, base_url=base_url)
deploy_mgr = DeploymentManager(api_key=api_key, account_id=account_id, base_url=base_url)
base_model = "accounts/fireworks/models/qwen3-8b"
infra = InfraConfig(
training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200",
ref_training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200-forward",
)
deploy_cfg = DeployConfig(deployment_id="grpo-serving", tokenizer_model="Qwen/Qwen3-8B")
with ThreadPoolExecutor(max_workers=3) as pool:
dep_fut = pool.submit(setup_deployment, deploy_mgr, deploy_cfg, base_model, infra)
pol_fut = pool.submit(
create_trainer_job, rlor_mgr,
base_model=base_model, infra=infra, lora_rank=0,
display_name="grpo-policy", hot_load_deployment_id="grpo-serving",
)
ref_fut = pool.submit(
create_trainer_job, rlor_mgr,
base_model=base_model, infra=infra, lora_rank=0,
display_name="grpo-reference", forward_only=True,
)
dep_info = dep_fut.result()
policy_ep = pol_fut.result()
reference_ep = ref_fut.result()
policy = ReconnectableClient(rlor_mgr, policy_ep.job_id, base_model, lora_rank=0, fw_api_key=api_key)
reference = ReconnectableClient(rlor_mgr, reference_ep.job_id, base_model, lora_rank=0, fw_api_key=api_key)
tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True)
sampler = DeploymentSampler(
inference_url=deploy_mgr.inference_url,
model=dep_info.inference_model if dep_info else base_model,
api_key=api_key,
tokenizer=tokenizer,
)