import requests
import json
import os
from fireworks import LLM
# --- 1. SETUP: Define API keys, server URLs, and the model to use ---
# IMPORTANT: Make sure your FIREWORKS_API_KEY is set as an environment variable.
# You can get one from https://fireworks.ai
if "FIREWORKS_API_KEY" not in os.environ:
print("FATAL: FIREWORKS_API_KEY environment variable not set.")
# If not set, you can hardcode it here for testing, but this is not recommended:
# os.environ["FIREWORKS_API_KEY"] = "YOUR_API_KEY_HERE"
# The model we'll use to generate the SQL. This acts as our "base" model.
LLM_MODEL = "accounts/fireworks/models/llama-v3p1-8b-instruct"
llm = LLM(model=LLM_MODEL, deployment_type="auto", api_key=os.getenv("FIREWORKS_API_KEY"))
# The URL for your running MCP server.
MCP_SERVER_URL = None # PUT MCP SERVER URL HERE without the /mcp/ suffix at the end
# --- 2. LOAD THE EXAMPLE DATA ---
# This is the example data you provided.
DATASET_FILE_PATH = "data/final_rft_sql_train_data.jsonl"
ROW_INDEX_TO_TEST = 0 # 0 is the first row, 1 is the second row, etc.
EXAMPLE_DATA = None
try:
with open(DATASET_FILE_PATH, 'r') as f:
for i, line in enumerate(f):
if i == ROW_INDEX_TO_TEST:
EXAMPLE_DATA = json.loads(line)
break
if EXAMPLE_DATA is None:
with open(DATASET_FILE_PATH, 'r') as f:
line_count = sum(1 for line in f)
raise IndexError(f"row index {ROW_INDEX_TO_TEST} is out of bounds for file with {line_count} rows.")
print(f"Successfully loaded row {ROW_INDEX_TO_TEST} from '{DATASET_FILE_PATH}'.\n")
print(EXAMPLE_DATA)
print()
except Exception as e:
print(f"Warning: Could not load from file. Reason: {e}")
# If loading from file failed for any reason, use the hardcoded fallback data.
if EXAMPLE_DATA is None:
print("Using hardcoded fallback EXAMPLE_DATA.\n")
EXAMPLE_DATA = {
"messages": [
{"role": "system", "content": "\nYou are an expert SQL data analyst. Your task is to write a single, valid DuckDB SQL query to answer the user's question, based on the provided database schema. Do not provide any explanation or text other than the SQL query itself.\n\n**Database Schema:**\n| database | schema | name | column_names | column_types | temporary |\n|:----------------------|:---------|:----------|:---------------------------------------------------------------------------|:----------------------------------------------------------------------|:------------|\n| synthetic_openflights | main | airlines | ['airline_id' 'name' 'alias' 'iata' 'icao' 'callsign' 'country' 'active'] | ['BIGINT' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' | False |\n| | | | | 'VARCHAR'] | |\n| synthetic_openflights | main | airports | ['airport_id' 'name' 'city' 'country' 'iata' 'icao' 'latitude' 'longitude' | ['BIGINT' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'DOUBLE' | False |\n| | | | 'altitude' 'timezone' 'dst' 'tz_db' 'type' 'source'] | 'DOUBLE' 'BIGINT' 'DOUBLE' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR'] | |\n| synthetic_openflights | main | countries | ['name' 'iso_code' 'dafif_code'] | ['VARCHAR' 'VARCHAR' 'VARCHAR'] | False |\n| synthetic_openflights | main | planes | ['name' 'iata' 'icao'] | ['VARCHAR' 'VARCHAR' 'VARCHAR'] | False |\n| synthetic_openflights | main | routes | ['airline' 'airline_id' 'source_airport' 'source_airport_id' | ['VARCHAR' 'BIGINT' 'VARCHAR' 'BIGINT' 'VARCHAR' 'BIGINT' 'VARCHAR' | False |\n| | | | 'destination_airport' 'destination_airport_id' 'codeshare' 'stops' | 'BIGINT' 'VARCHAR'] | |\n| | | | 'equipment'] | | |\n"},
{"role": "user", "content": "Which countries have the most airlines, and how many airlines are there in each country, listed in descending order by the number of airlines and then alphabetically by country name?"},
{"role": "assistant", "content": "SELECT country, COUNT(*) AS airline_count FROM airlines GROUP BY country ORDER BY airline_count DESC, country ASC"}
],
"ground_truth": [{"country": "Canada", "airline_count": 10}, {"country": "Sweden", "airline_count": 10}, {"country": "Kenya", "airline_count": 9}, {"country": "United States", "airline_count": 9}, {"country": "Australia", "airline_count": 8}, {"country": "Spain", "airline_count": 6}, {"country": "Italy", "airline_count": 4}, {"country": "Switzerland", "airline_count": 4}, {"country": "Finland", "airline_count": 3}, {"country": "France", "airline_count": 3}, {"country": "Mexico", "airline_count": 3}, {"country": "Costa Rica", "airline_count": 2}, {"country": "Germany", "airline_count": 2}, {"country": "Iceland", "airline_count": 2}, {"country": "Ireland", "airline_count": 2}, {"country": "Japan", "airline_count": 2}, {"country": "Norway", "airline_count": 2}, {"country": "Singapore", "airline_count": 2}, {"country": "United Kingdom", "airline_count": 2}, {"country": "Argentina", "airline_count": 1}, {"country": "Brazil", "airline_count": 1}, {"country": "China", "airline_count": 1}, {"country": "Egypt", "airline_count": 1}, {"country": "Fiji", "airline_count": 1}, {"country": "Greece", "airline_count": 1}, {"country": "India", "airline_count": 1}, {"country": "Jordan", "airline_count": 1}, {"country": "Netherlands", "airline_count": 1}, {"country": "New Zealand", "airline_count": 1}, {"country": "Portugal", "airline_count": 1}, {"country": "Saudi Arabia", "airline_count": 1}, {"country": "South Africa", "airline_count": 1}, {"country": "Thailand", "airline_count": 1}, {"country": "United Arab Emirates", "airline_count": 1}]
}
# Extract the prompts and ground truth from the data
system_prompt = EXAMPLE_DATA["messages"][0]["content"]
user_prompt = EXAMPLE_DATA["messages"][1]["content"]
GROUND_TRUTH_ROWS = EXAMPLE_DATA["ground_truth"]
# --- 3. HELPER FUNCTION: To parse the server's ASCII table response ---
def parse_duckdb_ascii_table(table_string: str) -> list[dict]:
"""
Parses a DuckDB-style ASCII table string into a list of dictionaries.
This version robustly handles 'NULL' values and empty strings.
"""
lines = table_string.strip().split('\n')
content_lines = [line for line in lines if line.strip() and not line.startswith('+')]
if len(content_lines) < 2:
return []
header_raw = [h.strip() for h in content_lines[0].split('|')[1:-1]]
data_lines = content_lines[1:]
if len(data_lines) > 0:
try:
first_data_values = [v.strip() for v in data_lines[0].split('|')[1:-1]]
if len(first_data_values) == len(header_raw) and all(v.isupper() for v in first_data_values):
data_lines = data_lines[1:]
except IndexError:
pass
rows = []
for line in data_lines:
try:
values_raw = [v.strip() for v in line.split('|')[1:-1]]
if len(values_raw) == len(header_raw):
row_dict = {}
for i, header in enumerate(header_raw):
value_str = values_raw[i]
if value_str.upper() == 'NULL' or value_str == '':
row_dict[header] = None
continue
try:
if '.' in value_str:
row_dict[header] = float(value_str)
else:
row_dict[header] = int(value_str)
except (ValueError, TypeError):
row_dict[header] = value_str
rows.append(row_dict)
except IndexError:
continue
return rows
# --- 4. GENERATE SQL QUERY USING THE LLM ---
print("="*20)
print("LLM QUERY GENERATION")
print("="*20)
model_generated_sql = ""
try:
print(f"User prompt: {user_prompt}")
print(f"Ground truth: {GROUND_TRUTH_ROWS}")
print(f"Calling model '{LLM_MODEL}' to generate SQL query...")
messages_for_llm = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
response = llm.chat.completions.create(
model=LLM_MODEL,
messages=messages_for_llm,
temperature=0.0 # Set to 0 for deterministic output
)
model_generated_sql = response.choices[0].message.content.strip()
print("\nModel Generated SQL Query:")
print(model_generated_sql)
except Exception as e:
print(f"\nAN ERROR OCCURRED during LLM call: {e}")
# --- 5. EXECUTE GENERATED QUERY ON MCP SERVER ---
predicted_rows = []
if model_generated_sql:
try:
print("\n" + "="*20)
print("MCP SERVER EXECUTION")
print("="*20)
print(f"Sending query to MCP server...")
headers = {"Content-Type": "application/json", "Accept": "application/json, text/event-stream"}
payload = {
"id": "eval-query-1", "jsonrpc": "2.0", "method": "tools/call",
"params": {"session": {"id": "stateless-eval-session"}, "name": "query", "arguments": {"query": model_generated_sql}}
}
with requests.post(f"{MCP_SERVER_URL}/mcp/", headers=headers, json=payload, timeout=20, stream=True) as response:
response.raise_for_status()
response_data = None
for line in response.iter_lines():
if line and line.decode('utf-8').startswith('data:'):
json_part = line.decode('utf-8')[len('data:'):].strip()
if json_part:
response_data = json.loads(json_part)
break
if response_data is None: raise RuntimeError("No JSON data in event stream.")
if "error" in response_data: raise RuntimeError(f"SQL Error: {response_data['error'].get('message', 'Unknown')}")
ascii_table = response_data['result']['content'][0]['text']
predicted_rows = parse_duckdb_ascii_table(ascii_table)
print("\nParsed Result from Server:")
print(json.dumps(predicted_rows, indent=2))
except Exception as e:
print(f"\nAN ERROR OCCURRED during MCP call: {e}")
# --- 6. FINAL COMPARISON ---
print("\n" + "="*20)
print("COMPARISON")
print("="*20)
if not predicted_rows:
print("Skipping comparison: no rows returned from query or an error occurred.")
else:
gt_values = sorted([sorted(map(str, row.values())) for row in GROUND_TRUTH_ROWS])
predicted_values = sorted([sorted(map(str, row.values())) for row in predicted_rows])
if gt_values == predicted_values:
print("\nβ
GOOD RESULT: The base model generated SQL that produced the correct data.\n")
else:
print("\nβ BAD RESULT: The base model's SQL produced different data than expected.\n")
print("This is often the intended outcome when testing a base model, as it highlights what fine-tuning needs to correct.")