✈️ Natural-Language β†’ SQL with Reinforcement-Fine-Tuning (RFT)

🎯 What You’ll Build

Welcome! This tutorial will show you how to fine-tune a 7B parameter model to answer natural language (NL) questions by writing SQL to execute against your database, without using real production data in fine-tuning the process. Thus, you will end up with a model that can accurately translate natural language to SQL, which can be executed against a database, for workflows like:
πŸ‘€ User asks question β†’ πŸ€– Model generates SQL β†’ πŸ“Š Database returns results

πŸš€ Peformance Benefits of RFT

Before diving in to the tutorial, here’s a summary of the accuracy we achieved, using the OpenFlights dataset as a base, across various models:
ModelAccuracy on Test SetSizeSpeed
Qwen 2.5 7B (base)23.91%SmallFast
DeepSeek V327.17%LargeSlow
Kimi K2 Instruct28.26%LargeSlow
OpenAI GPT-4o23.91%LargeSlow
Anthropic Claude Sonnet 429.35%LargeSlow
Qwen3 Coder 480B34.78%LargeSlow
Our Fine-Tuned Qwen 2.5 7B56.52% ✨SmallFast
Note on methodology: to compare accuracy across the above models, we created a synthetic dataset that mirrors the OpenFlights schema, an initial set of synthetic queries written by Qwen3 Coder 480B, and a synthetic set of natural language questions (also written by Qwen3 Coder 480B) corresponding to those queries. The task above is for the LLM to translate each natural language question into SQL, and then execute the SQL query against the synthetic dataset. Accuracy is computed as the percent of queries that return the correct result (N = 92). β€œCorrect” is defined as the query returning the same result on the synthetic dataset as each initial synthetic query did. Thus, the relative performance between these models is a more meaningful metric than the absolute performance. More details on the data and evaluation process can be found throughout the tutorial below.
Key takeaway: Our fine-tuned 7B model outperforms models that are 50-200x larger, while being faster and cheaper to run.

πŸ’‘ Why Reinforcement Fine-Tuning?

The Problem with Supervised Fine-Tuning (SFT) for this use-case

SFT teaches models to mimic exact SQL syntax by showing them question-SQL pairs. But the key insight is that we care more about the result of the query than the exact SQL syntax. With SFT, the model is penalized if it generates a different SQL query from the training example, even though both are perfectly correct. This can lead to:
  • ❌ Overfitting to specific SQL patterns
  • ❌ Poorer generalization to new questions
  • ❌ Need for thousands of perfectly-matched examples

The RFT Solution

Reinforcement Fine-Tuning (RFT) takes a fundamentally different approach:
  • βœ… Rewards correct results, regardless of SQL syntax
  • βœ… Explores multiple solution paths during training
  • βœ… Works with just hundreds of examples instead of thousands

πŸ”„ The Process: From Schema to Expert Model

Here’s the complete pipeline you’ll implement:
  1. Start with just your schema (no real data needed!): Extract table structures and relationships
  2. Generate synthetic data: Use LLMs to create realistic fake data that maintains referential integrity
  3. Create SQL queries: Use historical logs or generate diverse query patterns
  4. Execute for ground truth: Run queries against synthetic data to get expected results
  5. Generate natural language questions: Convert SQL to questions users would actually ask
  6. Train with RFT: Model learns through trial and error, rewarded for correct results

Why this matters

Off-the-shelf LLM copilots often guess column names, ignore schema quirks, or hallucinate tables. Reinforcement Fine-Tuning (RFT) fixes this by teaching the model the shape of your data and the patterns in your queries, boosting exact-match accuracy.

What this tutorial will cover

You’ll practice …… and walk away with
βœ… Generate a synthetic DuckDB that mirrors your schemasynthetic_openflights.db (<20 MB) served via an MCP endpoint
βœ… Create a MECE query set & compute ground-truth rowsgenerated_queries.json & ground_truth_results.json
βœ… Build NL ↔ SQL result pairs for fine-tuning and evalfinal_rft_sql_train_data.jsonl & final_rft_sql_test_data.jsonl
βœ… Run an RFT job on Fireworks AIA tuned Qwen 2.5-7B checkpoint
βœ… Benchmark baseline vs. tuned model and a larger baseline> 30% exact-match improvement over Qwen 2.5-7B base model and > 20% over SoTA base models

Agenda

  1. πŸ› οΈ Development Environment Setup
  2. πŸ—„οΈ Simulate the β€œProduction” Database
  3. πŸ“‹ Acquire the Schema (No Real Data!)
  4. πŸ§ͺ Create the Synthetic Training Sandbox with an LLM
  5. βœ… Validate the Sandbox
  6. πŸ“ Generate Example SQL Queries
  7. ♻️ Query-Aware Augmentation of the Synthetic Sandbox
  8. 🎯 Execute Queries to Get Ground-Truth Answers
  9. πŸ’¬ Generate Natural Language Questions for Final RFT Training Data
  10. πŸ›°οΈ Deploy an MCP Server for the Synthetic Data
  11. ☁️ Set Up Google Cloud CLI & .gcloudignore
  12. πŸ“¦ Containerize & Deploy the MCP Server
  13. πŸ” Define an evaluation function for RFT
  14. πŸ§ͺ Test English -> SQL of a base model without fine-tuning
  15. πŸš€ Launch the Fine-Tuning Job & Deploy via the UI
  16. βš–οΈ Evaluate Model Performance
  17. ✨ Cleanup & Conclusion
Demo vs Real World 🌍
Look for these call-outs to see the difference between the self-contained demo steps in this notebook and the equivalent actions you’d perform on your own private schema, logs, and query store.

0. πŸ› οΈ Development Environment Setup

Complete these steps once in your terminal, outside this notebook.
  1. Get a Fireworks AI API Key
    • Go to fireworks.ai and sign up.
    • Create an API key from your settings page.
    • Create a file named .env in your project directory and add your key:
      FIREWORKS_API_KEY="YOUR_API_KEY_HERE"
      
  2. Install uv
    • uv is a fast Python package manager from Astral. Follow the official installation instructions at docs.astral.sh/uv/.
    • It’s significantly faster than pip and handles dependency resolution more reliably.
  3. Create a Virtual Environment and Install Packages
    • Once uv is installed, initialize a project.
    # Run this in your terminal
    uv init --python 3.12
    
    • Install all required packages using uv add.
    # Run this in your terminal
    uv add duckdb tabulate pandas pyarrow requests \
           pydantic python-dotenv \
           jsonlines fireworks-ai \
           mcp-server-motherduck ipykernel
    
    • Create and activate a virtual environment
    # Run this in your terminal
    uv sync
    source .venv/bin/activate
    
After running these commands, your environment is ready. You can proceed with the tutorial.

1. πŸ—„οΈ Simulate the β€œProduction” Database

First, we’ll create a database that represents your real, populated production database. We’ll download the public OpenFlights dataset and load it into a DuckDB file.

What is DuckDB?

DuckDB is an in-process SQL OLAP database management system. Think of it as β€œSQLite for analytics”. It’s perfect for this tutorial because:
  • It’s embedded (no server setup required)
  • It’s fast for analytical queries
  • It has excellent SQL compatibility
  • The entire database is just a single file
  • It has an existing MCP server we can use (mcp-server-motherduck)
Real World 🌍: You already have this! It’s your live production database (or a replica). You would skip this entire step.
import urllib.request
import pathlib
import pandas as pd
import duckdb

# --- Download the raw data files ---
DATA_DIR = pathlib.Path("data")
DATA_DIR.mkdir(exist_ok=True)
BASE_URL = "https://raw.githubusercontent.com/jpatokal/openflights/master/data/"
FILES_TO_DOWNLOAD = {
    "airports": "airports.dat",
    "airlines": "airlines.dat",
    "routes": "routes.dat",
    "countries": "countries.dat",
    "planes": "planes.dat"
}
# Define column names as the files don't have headers
COLUMN_NAMES = {
    "airports": ["airport_id", "name", "city", "country", "iata", "icao", "latitude", "longitude", "altitude", "timezone", "dst", "tz_db", "type", "source"],
    "airlines": ["airline_id", "name", "alias", "iata", "icao", "callsign", "country", "active"],
    "routes": ["airline", "airline_id", "source_airport", "source_airport_id", "destination_airport", "destination_airport_id", "codeshare", "stops", "equipment"],
    "countries": ["name", "iso_code", "dafif_code"],
    "planes": ["name", "iata", "icao"]
}

PROD_DB_PATH = "data/prod_openflights.db"

# --- Load the real data into our "production" DuckDB ---
with duckdb.connect(PROD_DB_PATH) as con:
    for name, filename in FILES_TO_DOWNLOAD.items():
        url = f"{BASE_URL}{filename}"
        path = DATA_DIR / filename
        if not path.exists():
            urllib.request.urlretrieve(url, path)
            print(f"βœ… Downloaded: {path}")

        # Load data using pandas to handle missing headers and null values
        df = pd.read_csv(path, header=None, names=COLUMN_NAMES[name], na_values=["\\N"])
        con.execute(f"CREATE OR REPLACE TABLE {name} AS SELECT * FROM df")

    print(f"\nβœ… 'Production' database simulated at: {PROD_DB_PATH}")
    print("Tables created:", con.sql("SHOW TABLES;").fetchall())
βœ… β€˜Production’ database simulated at: data/prod_openflights.db Tables created: [(β€˜airlines’,), (β€˜airports’,), (β€˜countries’,), (β€˜planes’,), (β€˜routes’,)]

2. πŸ“‹ Acquire the Schema (No Real Data!)

This is a critical step. We connect to our β€œproduction” database and extract only its schema (the table structure, column names, and data types). We do not touch or read any of the data rows. This schema is the only artifact we need from the production environment.

Why Schema-Only?

This approach is powerful because:
  • Privacy: No actual customer data leaves your production environment
  • Security: No risk of exposing sensitive data during fine-tuning
  • Efficiency: Schema information is tiny compared to actual data
The DESCRIBE command in DuckDB gives us comprehensive schema information without accessing any rows.
Real World 🌍: You would connect to your production database and run the DESCRIBE command shown below, thus obtaining the schema information for all its tables.
import duckdb

# Connect to the "production" database we just created
with duckdb.connect(PROD_DB_PATH, read_only=True) as con:
    # The DESCRIBE command gives us the schema information for all tables
    schema_df = con.sql("DESCRIBE;").df()

print("βœ… Schema successfully extracted from 'production' database:")
print(schema_df.to_markdown(index=False))

# We can also store this for later use in prompts
schema_for_prompt = schema_df.to_markdown(index=False)
βœ… Schema successfully extracted from β€˜production’ database:
databaseschemanamecolumn_namescolumn_typestemporary
prod_openflightsmainairlines[β€˜airline_id’ β€˜name’ β€˜alias’ β€˜iata’ β€˜icao’ β€˜callsign’ β€˜country’ β€˜active’][β€˜BIGINT’ β€˜VARCHAR’ β€˜VARCHAR’ β€˜VARCHAR’ β€˜VARCHAR’ β€˜VARCHAR’ β€˜VARCHAR’False
’VARCHAR’]
prod_openflightsmainairports[β€˜airport_id’ β€˜name’ β€˜city’ β€˜country’ β€˜iata’ β€˜icao’ β€˜latitude’ β€˜longitude’[β€˜BIGINT’ β€˜VARCHAR’ β€˜VARCHAR’ β€˜VARCHAR’ β€˜VARCHAR’ β€˜VARCHAR’ β€˜DOUBLE’False
’altitude’ β€˜timezone’ β€˜dst’ β€˜tz_db’ β€˜type’ β€˜source’]β€˜DOUBLE’ β€˜BIGINT’ β€˜DOUBLE’ β€˜VARCHAR’ β€˜VARCHAR’ β€˜VARCHAR’ β€˜VARCHAR’]
prod_openflightsmaincountries[β€˜name’ β€˜iso_code’ β€˜dafif_code’][β€˜VARCHAR’ β€˜VARCHAR’ β€˜VARCHAR’]False
prod_openflightsmainplanes[β€˜name’ β€˜iata’ β€˜icao’][β€˜VARCHAR’ β€˜VARCHAR’ β€˜VARCHAR’]False
prod_openflightsmainroutes[β€˜airline’ β€˜airline_id’ β€˜source_airport’ β€˜source_airport_id’[β€˜VARCHAR’ β€˜DOUBLE’ β€˜VARCHAR’ β€˜DOUBLE’ β€˜VARCHAR’ β€˜DOUBLE’ β€˜VARCHAR’False
’destination_airport’ β€˜destination_airport_id’ β€˜codeshare’ β€˜stops''BIGINT’ β€˜VARCHAR’]
β€˜equipment’]

3. πŸ§ͺ Create the Synthetic Training Sandbox with an LLM

Now that we have the schema, we will use a large language model to generate a complete, contextually-aware synthetic dataset.

Key Concepts in This Step:

Dynamic Pydantic Model Generation: We dynamically create Pydantic models based on your database schema. This ensures the LLM’s output is structured and parseable, adapting to any database schema automatically. Chunked Generation Strategy: Instead of asking for all data at once (which could overwhelm the LLM or hit token limits), we generate data in small chunks of 2 rows per API call. This approach:
  • Ensures high-quality, coherent data
  • Avoids token limit issues
Contextual Awareness: Each generation request includes previously generated data as context, preventing duplicates and ensuring variety. To fine-tune our model with RFT, we will only interact with this synthetic database.
Real World 🌍: This pattern is directly applicable. You would use the same approach with your production schema to generate synthetic data that maintains the structure and relationships of your real data without exposing any actual records.
import pandas as pd
import os
from pydantic import create_model, BaseModel
from fireworks import LLM
import duckdb
import json
from dotenv import load_dotenv
from typing import List, Optional, Any, Dict, Type
import datetime
import decimal
import uuid
import math
import time


TARGET_ROW_COUNT = 100  # The number of rows to generate for each table.

# --- 1. Dynamically Create Pydantic Models from the SQL Schema ---
def map_sql_type_to_python(sql_type: str) -> Type:
    """Maps SQL data types to Python types for Pydantic models."""
    sql_type_upper = str(sql_type).upper()
    if 'DECIMAL' in sql_type_upper: return decimal.Decimal
    if 'DOUBLE' in sql_type_upper or 'FLOAT' in sql_type_upper or 'REAL' in sql_type_upper: return float
    if 'BIGINT' in sql_type_upper or 'INT' in sql_type_upper: return int
    if 'VARCHAR' in sql_type_upper or 'TEXT' in sql_type_upper or 'STRING' in sql_type_upper: return str
    if 'TIMESTAMP' in sql_type_upper: return datetime.datetime
    if 'DATE' in sql_type_upper: return datetime.date
    if 'TIME' in sql_type_upper: return datetime.time
    if 'BOOLEAN' in sql_type_upper: return bool
    if 'BLOB' in sql_type_upper or 'BYTEA' in sql_type_upper: return bytes
    if 'UUID' in sql_type_upper: return uuid.UUID
    return object

pydantic_models: Dict[str, Type[BaseModel]] = {}
table_names = schema_df['name'].unique()

for table_name in table_names:
    table_schema = schema_df[schema_df['name'] == table_name].iloc[0]
    fields: Dict[str, Any] = {}
    col_names = table_schema['column_names']
    col_types = table_schema['column_types']
    for i, col_name in enumerate(col_names):
        python_type = map_sql_type_to_python(col_types[i])
        fields[col_name] = (Optional[python_type], None)
    model_name = table_name.capitalize() + "Model"
    pydantic_models[table_name] = create_model(model_name, **fields)

dataset_fields: Dict[str, Any] = {
    table_name: (List[model], ...) for table_name, model in pydantic_models.items()
}
SyntheticDataset = create_model('SyntheticDataset', **dataset_fields)
print("βœ… Dynamically created Pydantic models for all tables.")


# --- 2. Define Total Row Counts and Chunking Strategy ---
TOTAL_ROW_COUNTS = {name: TARGET_ROW_COUNT for name in table_names}
ROWS_PER_API_CALL = 2 # Ask for data in small, safe chunks
print("\nβœ… Data Generation Plan:")
print(f" - Target rows per table: {list(TOTAL_ROW_COUNTS.values())[0]}")
print(f" - Will make API calls asking for {ROWS_PER_API_CALL} rows/call until target is met.")


# --- 3. Setup LLM and Loop to Generate Data in Chunks ---
SYNTHETIC_DB_PATH = "data/synthetic_openflights.db"
load_dotenv()
llm = LLM(model="accounts/fireworks/models/deepseek-v3", deployment_type="serverless", api_key=os.getenv("FIREWORKS_API_KEY"))

all_synthetic_data: Dict[str, List[Dict]] = {name: [] for name in table_names}
chunk_row_counts = {name: ROWS_PER_API_CALL for name in table_names}

base_generation_prompt = f"""
You are a highly intelligent AI data generator. Your task is to create a realistic, synthetic dataset based on the provided database schema.
The data you generate must be internally consistent. For example, an `airline_id` in a `routes` table must correspond to an existing `airline_id` in an `airlines` table within this same generated chunk.
This applies to any schema you might be working with, not just airline-related data.
You must generate a single JSON object that strictly adheres to the provided JSON schema.

The database schema is as follows:
{schema_for_prompt}
"""

call_count = 0
# Loop until all tables have at least the desired number of rows
while not all(len(rows) >= TOTAL_ROW_COUNTS[name] for name, rows in all_synthetic_data.items()):
    call_count += 1
    print(f"\nπŸ“ž --- Generating data chunk #{call_count} ---")
    
    # --- Create a summary of existing data to guide the LLM ---
    existing_data_summary = ""
    if any(len(rows) > 0 for rows in all_synthetic_data.values()):
        summary_parts = ["\nYou have already generated the following data. Do NOT generate rows that are substantially similar to these examples. Create new, unique data.\n"]
        for table_name, rows in all_synthetic_data.items():
            if rows:
                summary_parts.append(f"\n--- Existing data in '{table_name}' table ---")
                sample_rows = rows[-100:] if len(rows) > 100 else rows  # sample the last 100 rows
                df = pd.DataFrame(sample_rows)
                if len(df.columns) > 10:
                    df = df.iloc[:, :10]
                markdown_summary = df.to_markdown(index=False, tablefmt="grid")
                if markdown_summary:
                    summary_parts.append(markdown_summary)
        existing_data_summary = "\n".join(summary_parts)


    # --- Construct the final prompt for this iteration ---
    final_prompt = (
        base_generation_prompt +
        existing_data_summary +
        f"\n\nNow, generate a NEW JSON object with a key for each table. The number of new rows for each table should be:\n" +
        json.dumps(chunk_row_counts, indent=2)
    )

    response = llm.chat.completions.create(
        messages=[{"role": "user", "content": final_prompt}],
        response_format={"type": "json_schema", "json_schema": {"name": "SyntheticDataset", "schema": SyntheticDataset.model_json_schema()}},
        temperature=0.7
    )

    choice = response.choices[0]
    response_content = choice.message.content

    if choice.finish_reason == "length":
        print(f"⚠️ WARNING: Chunk #{call_count} was truncated. Skipping.")
        continue
    if not response_content:
        print(f"⚠️ WARNING: Received empty content for chunk #{call_count}. Skipping.")
        continue

    try:
        chunk_data = json.loads(response_content)
        print(f"βœ… Received and parsed chunk #{call_count}.")
        for table_name, rows in chunk_data.items():
            if table_name in all_synthetic_data and rows:
                all_synthetic_data[table_name].extend(rows)
        # Log progress
        for name, rows in all_synthetic_data.items():
             print(f"   - '{name}': {len(rows)} / {TOTAL_ROW_COUNTS[name]} rows")
    except json.JSONDecodeError as e:
        print(f"❌ ERROR: Failed to parse JSON for chunk #{call_count}. Reason: {e}. Skipping.")
    
    time.sleep(1)

# --- 4. Deduplicate and Write to DB ---
print("\n✨ Data generation complete. Aggregating, deduplicating, and saving to database...")

synthetic_data = all_synthetic_data
print("\n--- Deduplicating generated data ---")
for table_name, rows in synthetic_data.items():
    if not rows: continue
    initial_count = len(rows)
    df = pd.DataFrame(rows).drop_duplicates()
    final_count = len(df)
    synthetic_data[table_name] = df.to_dict('records')
    print(f" - Table '{table_name}': Removed {initial_count - final_count} duplicates ({initial_count} -> {final_count}).")

# Final trim to ensure exact counts
for table_name, total_rows_needed in TOTAL_ROW_COUNTS.items():
    if table_name in synthetic_data:
        synthetic_data[table_name] = synthetic_data[table_name][:total_rows_needed]

with duckdb.connect(SYNTHETIC_DB_PATH) as con:
    for table_name, rows in synthetic_data.items():
        if rows:
            df = pd.DataFrame(rows)
            schema_cols = schema_df[schema_df['name'] == table_name].iloc[0]['column_names']
            for col in schema_cols:
                if col not in df.columns: df[col] = None
            df = df[schema_cols]
            con.execute(f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM df")
    
    print(f"\nβœ… Synthetic training sandbox created at: {SYNTHETIC_DB_PATH}")
    print("Tables created:", con.sql("SHOW TABLES;").fetchall())

4. βœ… Validate the Sandbox

Let’s run a few queries against our new synthetic database to ensure the LLM did a good job generating plausible, interconnected data. We expect to see non-empty, realistic-looking data that follows the schema constraints.
import duckdb
from tabulate import tabulate

SYNTHETIC_DB_PATH = "data/synthetic_openflights.db"

# Connect to the synthetic database
with duckdb.connect(SYNTHETIC_DB_PATH, read_only=True) as con:
    
    # Get the list of all tables created
    all_tables = [table[0] for table in con.sql("SHOW TABLES;").fetchall()]
    
    # Select the first 3 tables to display (or all if fewer than 3)
    tables_to_validate = all_tables[:3]

    print("--- Validating the first few tables in the synthetic sandbox ---\n")

    # Execute and print results for the selected tables
    for table_name in tables_to_validate:
        print(f"--- SELECT * FROM {table_name} LIMIT 3; ---")
        try:
            result_df = con.sql(f"SELECT * FROM {table_name} LIMIT 3;").df()
            if not result_df.empty:
                print(tabulate(result_df, headers='keys', tablefmt='psql'))
            else:
                print(f"(Table '{table_name}' is empty)")
        except Exception as e:
            print(f"Query failed for table '{table_name}': {e}")
        print("\n")
--- Validating the first few tables in the synthetic sandbox --- --- SELECT * FROM airlines LIMIT 3; ---
airline_idnamealiasiataicaocallsigncountryactive
58Nordic Eagle AirlinesNEANENEANORDIC EAGLEFinlandY
70Sapphire Sky AirlinesSSASSSSASAPPHIRESKYSouth AfricaY
86Polar AirPAPLPOLPOLARAIRMalaysiaY
--- SELECT * FROM airports LIMIT 3; ---
airport_idnamecitycountryiataicaolatitudelongitudealtitudetimezonedsttz_dbtypesource
17Rainbow Paris AirportParisFranceRPARPA48.85662.3522351EEurope/ParisairportOurAirports
32Orbit Paris AirportParisFranceORPORPA48.85662.3522351EEurope/ParisairportOurAirports
77Red Star Moscow AirportMoscowRussiaRSMRSMA55.755837.617315
--- SELECT * FROM countries LIMIT 3; ---
nameiso_codedafif_code
NorwayNONOR
ItalyITITA
Saudi ArabiaSASAU

5. πŸ“ Generate Example SQL Queries

With our synthetic database in place, the next step is to create a set of synthetic SQL queries. These SQL queries will be executed against our database of synthetic data to get the ground truth labels for RFT. Furthermore, these same SQL queries will be used as input to an LLM to generate queries in natural language. This will enable us to form our final RFT dataset, which pairs natural language queries with ground truth results from the database.

Query Generation Strategy:

  • Diversity: We want queries covering different SQL features (JOINs, GROUP BY, aggregates)
  • Complexity Range: From simple SELECT statements to complex multi-table joins
  • Deterministic Results: Queries include ORDER BY clauses where necessary to break ties and ensure consistent results
  • MECE Principle: Mutually Exclusive, Collectively Exhaustive - covering all major query patterns
Real World 🌍: You would use a historical log of real SQL queries that have been run against your production database; aim for ~500 unique SQL queries. These logs are the most valuable source of training data because they represent the actual way your users query your data.
import pandas as pd
import json
import time
from pydantic import BaseModel, Field
from typing import List
from fireworks import LLM
import os
import duckdb
from dotenv import load_dotenv

load_dotenv()

# --- 1. Define Generation Parameters and Pydantic Model ---
llm = LLM(model="accounts/fireworks/models/qwen3-coder-480b-a35b-instruct", deployment_type="serverless", api_key=os.getenv("FIREWORKS_API_KEY"))  # Use Qwen3-coder for SQL queries
TOTAL_QUERIES_TO_GENERATE = 1000  # Note, some of these queries will likely be duplicates or invalid, reducing the final number used for fine-tuning
QUERIES_PER_API_CALL = 30

class SqlQueryBatch(BaseModel):
    queries: List[str] = Field(description=f"A list of exactly {QUERIES_PER_API_CALL} unique and diverse SQL queries.")

print(f"🎯 Goal: Generate {TOTAL_QUERIES_TO_GENERATE} unique queries in batches of {QUERIES_PER_API_CALL}.")

# --- 2. Get Clean Schema From Synthetic DB ---
with duckdb.connect(SYNTHETIC_DB_PATH, read_only=True) as con:
    schema_df = con.sql("DESCRIBE;").df()
    schema_for_prompt = schema_df.to_markdown(index=False)

# --- 3. Setup Base Prompt and Generation Loop ---
base_query_generation_prompt = f"""
You are an expert SQL data analyst. Your task is to generate unique and diverse SQL queries based on the database schema provided.
The queries should be realistic and cover a range of complexities and SQL features (JOINS, GROUP BY, aggregates, etc.).
Ensure you break ties with ORDER BY clauses so that the same queries produce the same results when executed against the database.
Write only the raw SQL query text and nothing else (i.e., no markdown formatting); your output should be a directly executable valid SQL query.
Make sure your queries do not return duplicate rows (i.e., GROUP BY all columns that are not aggregate functions).
Ensure the generated SQL is valid for DuckDB.

**Database Schema:**
{schema_for_prompt}
""".strip()

all_generated_queries = []
# Loop until we have enough queries
while len(all_generated_queries) < TOTAL_QUERIES_TO_GENERATE:
    print(f"\nπŸ“ž --- Generating batch #{len(all_generated_queries) // QUERIES_PER_API_CALL + 1} ---")

    # Create a summary of queries generated so far to prevent duplicates
    existing_queries_summary = ""
    if all_generated_queries:
        summary_parts = ["\nYou have already generated the following queries:\n"]
        for i, q in enumerate(all_generated_queries):
            summary_parts.append(f"{i+1}. {q}")
        existing_queries_summary = "\n".join(summary_parts)

    # Construct the final prompt for this iteration
    final_prompt = (
        base_query_generation_prompt +
        existing_queries_summary +
        f"\n\nNow, generate {QUERIES_PER_API_CALL} new, unique SQL queries, which cover different analytic scenarios and are not already in the list above. Return your response as a single JSON object adhering to the specified schema."
    )

    response = llm.chat.completions.create(
        messages=[{"role": "user", "content": final_prompt}],
        response_format={"type": "json_schema", "json_schema": {"name": "SqlQueryBatch", "schema": SqlQueryBatch.model_json_schema()}},
        temperature=0.8
    )

    response_content = response.choices[0].message.content
    if response_content:
        try:
            new_queries = json.loads(response_content).get("queries", [])
            all_generated_queries.extend(new_queries)
            print(f"   - Received {len(new_queries)} new queries. Total now: {len(all_generated_queries)} / {TOTAL_QUERIES_TO_GENERATE}")
        except json.JSONDecodeError as e:
            print(f"❌ ERROR: Failed to parse generated queries in this batch: {e}")
    
    time.sleep(1) # Be nice to the API

# --- 4. Deduplicate, Trim, and Save --- 
print("\n✨ Generation complete. Deduplicating and saving...")
initial_count = len(all_generated_queries)
# Simple, fast deduplication preserving order
unique_queries = list(dict.fromkeys(all_generated_queries))
final_count = len(unique_queries)
print(f" - Removed {initial_count - final_count} duplicates ({initial_count} -> {final_count}).")

# Trim to the exact number we need
final_queries = unique_queries[:TOTAL_QUERIES_TO_GENERATE]

# Save the final list to a file
QUERIES_FILE_PATH = "data/generated_queries.json"
with open(QUERIES_FILE_PATH, 'w') as f:
    json.dump({"queries": final_queries}, f, indent=2)

print(f"\nβœ… Successfully saved {len(final_queries)} unique queries to `{QUERIES_FILE_PATH}`.")
print("\n--- Here are a few examples: ---")
for query in final_queries[:5]:
    print(f"- {query}")
# Check the proportion of generated SQL queries that return zero rows

import json
import duckdb

try:
    SYNTHETIC_DB_PATH
except NameError:
    SYNTHETIC_DB_PATH = "data/synthetic_openflights.db"

try:
    QUERIES_FILE_PATH
except NameError:
    QUERIES_FILE_PATH = "data/generated_queries.json"

with open(QUERIES_FILE_PATH, "r") as f:
    queries = json.load(f).get("queries", [])

total = len(queries)
zero_rows = 0
success = 0
failed = 0

with duckdb.connect(SYNTHETIC_DB_PATH, read_only=True) as con:
    for q in queries:
        try:
            cnt = con.sql(f"SELECT COUNT(*) AS c FROM ({q}) AS t").fetchone()[0]
            zero_rows += (cnt == 0)
            success += 1
        except Exception:
            failed += 1

print(f"Total queries: {total}")
print(f"Executed successfully: {success}")
print(f"Execution errors: {failed}")
print(f"Zero-row results: {zero_rows}")
if total > 0:
    print(f"Proportion zero-row (overall): {zero_rows/total:.3f}")
if success > 0:
    print(f"Proportion zero-row (successful only): {zero_rows/success:.3f}")
Total queries: 549 Executed successfully: 516 Execution errors: 33 Zero-row results: 211 Proportion zero-row (overall): 0.384 Proportion zero-row (successful only): 0.409

6. ♻️ Query-Aware Augmentation of the Synthetic Sandbox

To reduce empty results when executing our SQL queries, we augment the synthetic data to be β€œquery-aware.” We identify queries that return zero rows and generate minimal, natural-looking data that satisfies their conditions.

Why this matters

  • Higher coverage: More queries produce non-empty results, improving label quality for RFT
  • Minimal changes: Only adds the data needed to satisfy queries
  • Natural data: Generated rows look realistic and maintain referential integrity

How it works

  1. Execute all queries and identify those returning zero rows
  2. Process zero-result queries in batches of 10, grouped by involved tables
  3. Use the LLM to generate 1-2 new rows per table that satisfy the query conditions
  4. Insert rows, check which queries were fixed, and repeat until ≀10% return zero results
  5. Remove any duplicate rows
The process tracks which queries have been attempted to avoid redundant processing and can retry stubborn queries up to 2 times.
Real World 🌍: Run the cell below against your synthetic data and real queries. The code is domain-agnostic and will work with any SQL database schema.
import os, json, re, time
import pandas as pd
import duckdb
from typing import List, Optional, Any, Dict, Type, Set
from pydantic import BaseModel, create_model
from fireworks import LLM
import datetime, decimal, uuid
from collections import defaultdict
from dotenv import load_dotenv

# --- Config ---
SYNTHETIC_DB_PATH = "data/synthetic_openflights.db"
QUERIES_FILE_PATH = "data/generated_queries.json"

# --- Setup LLM ---
load_dotenv()
llm = LLM(
    model="accounts/fireworks/models/qwen3-coder-480b-a35b-instruct",
    deployment_type="serverless",
    api_key=os.getenv("FIREWORKS_API_KEY"),
)

# --- Schema helpers ---
def map_sql_type_to_python(sql_type: str) -> Type:
    s = str(sql_type).upper()
    if "DECIMAL" in s: return decimal.Decimal
    if any(k in s for k in ("DOUBLE","FLOAT","REAL")): return float
    if any(k in s for k in ("BIGINT","INT")): return int
    if any(k in s for k in ("VARCHAR","TEXT","STRING")): return str
    if "TIMESTAMP" in s: return datetime.datetime
    if "DATE" in s: return datetime.date
    if "TIME" in s: return datetime.time
    if "BOOLEAN" in s: return bool
    if any(k in s for k in ("BLOB","BYTEA")): return bytes
    if "UUID" in s: return uuid.UUID
    return object

with duckdb.connect(SYNTHETIC_DB_PATH, read_only=True) as con_ro:
    schema_df = con_ro.sql("DESCRIBE;").df()
schema_for_prompt = schema_df.to_markdown(index=False)

def extract_tables_from_query(sql: str) -> Set[str]:
    """Extract table names from SQL query."""
    # Remove comments
    sql = re.sub(r'--.*$', '', sql, flags=re.MULTILINE)
    sql = re.sub(r'/\*.*?\*/', '', sql, flags=re.DOTALL)
    
    tables = set()
    # Match FROM and JOIN clauses
    patterns = [
        r'(?:FROM|JOIN)\s+([a-zA-Z_][a-zA-Z0-9_]*)',
        r'(?:FROM|JOIN)\s+"([^"]+)"',
        r'(?:FROM|JOIN)\s+`([^`]+)`',
    ]
    
    for pattern in patterns:
        matches = re.findall(pattern, sql, re.IGNORECASE)
        tables.update(matches)
    
    # Filter out SQL keywords that might be captured
    sql_keywords = {'select', 'where', 'group', 'order', 'having', 'limit', 'as', 'on', 'and', 'or', 'not', 'in', 'exists'}
    tables = {t for t in tables if t.lower() not in sql_keywords}
    
    return tables

def table_columns(name: str) -> List[str]:
    row = schema_df[schema_df["name"] == name]
    if row.empty:
        return []
    return list(row.iloc[0]["column_names"])

def table_types(name: str) -> List[str]:
    row = schema_df[schema_df["name"] == name]
    if row.empty:
        return []
    return list(row.iloc[0]["column_types"])

def build_rows_payload_model(tables: List[str]) -> Type[BaseModel]:
    per_table_row_models: Dict[str, Type[BaseModel]] = {}
    for t in tables:
        cols = table_columns(t)
        types = table_types(t)
        if not cols:
            continue
        fields: Dict[str, Any] = {}
        for c, ty in zip(cols, types):
            fields[c] = (Optional[map_sql_type_to_python(ty)], None)
        per_table_row_models[t] = create_model(f"{t.capitalize()}Row", **fields)
    
    payload_fields: Dict[str, Any] = {}
    for t, row_model in per_table_row_models.items():
        payload_fields[t] = (List[row_model], [])
    
    if not payload_fields:
        return create_model("RowsPayloadFallback", rows=(List[dict], []))
    return create_model("RowsPayload", **payload_fields)

def count_rows(con, sql: str) -> int:
    try:
        return con.sql(f"SELECT COUNT(*) AS c FROM ({sql}) AS t").fetchone()[0]
    except Exception:
        return -1

def get_sample_data(con, table: str, limit: int = 3) -> List[dict]:
    """Get sample rows from a table for context."""
    try:
        df = con.sql(f'SELECT * FROM "{table}" LIMIT {limit}').df()
        return df.to_dict("records")
    except Exception:
        return []

# --- Load queries ---
with open(QUERIES_FILE_PATH, "r") as f:
    queries = json.load(f).get("queries", [])

total_q = len(queries)
print(f"Total queries: {total_q}")

# --- Parameters ---
TARGET_MAX_ZERO_PERCENT = 10
MAX_ITERATIONS = 20
BATCH_SIZE = 10  # Process up to 10 queries at once
MAX_ROWS_PER_TABLE_PER_BATCH = 2  # Generate at most 2 rows per table per batch

# Initial assessment
with duckdb.connect(SYNTHETIC_DB_PATH, read_only=True) as con_check:
    zero_indices = [i for i, q in enumerate(queries) if count_rows(con_check, q) == 0]
    initial_zero_count = len(zero_indices)
    print(f"Initial zero-result queries: {initial_zero_count}/{total_q} ({initial_zero_count/total_q*100:.1f}%)")

# Track which queries have been attempted
processed_query_indices = set()
retry_count = 0  # Track how many times we've cycled through all queries

iteration = 0
with duckdb.connect(SYNTHETIC_DB_PATH) as con:
    while iteration < MAX_ITERATIONS:
        iteration += 1
        
        # Get ALL current zero-result queries
        all_zero_indices = [i for i, q in enumerate(queries) if count_rows(con, q) == 0]
        zero_count = len(all_zero_indices)
        zero_percent = (zero_count / total_q * 100) if total_q else 0
        
        print(f"\n[Iteration {iteration}] Zero-result: {zero_count}/{total_q} ({zero_percent:.1f}%)")
        
        if zero_percent <= TARGET_MAX_ZERO_PERCENT or zero_count == 0:
            print(f"βœ… Target achieved! {zero_percent:.1f}% <= {TARGET_MAX_ZERO_PERCENT}%")
            break
        
        # Get unprocessed zero-result queries
        unprocessed_zero_indices = [i for i in all_zero_indices if i not in processed_query_indices]
        
        # If we've processed all zero-result queries, reset to try stubborn ones again
        if not unprocessed_zero_indices and all_zero_indices:
            retry_count += 1
            print(f"  All zero-result queries have been attempted. Starting retry cycle #{retry_count}")
            processed_query_indices.clear()
            unprocessed_zero_indices = all_zero_indices
            
            # If we've done multiple retry cycles with no progress, stop
            if retry_count > 2:
                print(f"  Stopping after {retry_count} retry cycles")
                break
        
        # Take a batch of unprocessed queries
        batch_indices = unprocessed_zero_indices[:BATCH_SIZE]
        if not batch_indices:
            print(f"[Iteration {iteration}] No queries to process.")
            break
            
        batch_queries = [queries[i] for i in batch_indices]
        processed_query_indices.update(batch_indices)
        
        print(f"  Processing batch: queries {batch_indices[:3]}{'...' if len(batch_indices) > 3 else ''} ({len(batch_indices)} total)")
        
        # Group queries by their involved tables for efficient processing
        query_tables_map = defaultdict(list)
        for idx, q in zip(batch_indices, batch_queries):
            tables = extract_tables_from_query(q)
            if tables:
                # Create a key from sorted table names
                key = tuple(sorted(tables))
                query_tables_map[key].append((idx, q))
        
        if not query_tables_map:
            print(f"  No tables found in batch. Moving to next batch.")
            continue
        
        total_fixed = 0
        
        # Process each group of queries that share the same tables
        for table_tuple, query_list in query_tables_map.items():
            tables = list(table_tuple)
            query_indices = [idx for idx, _ in query_list]
            query_texts = [q for _, q in query_list]
            
            print(f"    Processing {len(query_list)} queries involving tables: {tables}")
            
            # Build Pydantic model
            RowsPayload = build_rows_payload_model(tables)
            rows_schema = RowsPayload.model_json_schema()
            
            # Limit rows per table
            props = rows_schema.get("properties", {})
            for t in list(props.keys()):
                spec = props.get(t, {})
                if isinstance(spec, dict) and spec.get("type") == "array":
                    spec["maxItems"] = min(MAX_ROWS_PER_TABLE_PER_BATCH, len(query_list))
            
            # Get sample data for context
            existing_samples = {}
            for t in tables:
                samples = get_sample_data(con, t, limit=5)
                if samples:
                    existing_samples[t] = samples
            
            # Analyze query patterns to understand what's needed
            query_analysis = []
            for q in query_texts[:3]:  # Analyze first 3 queries for brevity
                query_analysis.append(f"- {q}")
            
            system_prompt = """You are an expert at generating natural, realistic database records.
Generate minimal new rows that will make the provided SQL queries return results.
The data should be diverse, realistic, and consistent with the domain implied by the schema."""
            
            user_prompt = f"""
Given this DuckDB schema and SQL queries that return zero rows, generate the minimum number
of natural, realistic rows needed to make these queries return results.

**Database Schema:**
{schema_for_prompt}

**Tables to populate:** {json.dumps(tables)}

**Example existing data (for style/format reference):**
{json.dumps(existing_samples, indent=2, default=str) if existing_samples else "No existing samples available"}

**Queries that need to return results ({len(query_texts)} total):**
{chr(10).join(query_analysis)}
{f"... and {len(query_texts) - 3} more similar queries" if len(query_texts) > 3 else ""}

**Requirements:**
1. Generate at most {MAX_ROWS_PER_TABLE_PER_BATCH} rows per table
2. Make the data satisfy the WHERE conditions and JOINs in the queries
3. Use natural, realistic values appropriate for the domain
4. Maintain referential integrity between tables
5. Avoid duplicating existing IDs - generate new unique IDs
6. Create diverse data that satisfies multiple queries if possible

Return ONLY the JSON object with the new rows. No explanations.
""".strip()
            
            try:
                resp = llm.chat.completions.create(
                    messages=[
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": user_prompt}
                    ],
                    response_format={
                        "type": "json_schema",
                        "json_schema": {"name": "RowsPayload", "schema": rows_schema}
                    },
                    temperature=0.5,  # Some randomness for natural data
                )
                
                content = resp.choices[0].message.content
                if not content:
                    print(f"      Empty response from LLM")
                    continue
                
                payload = json.loads(content)
                
                # Insert generated rows
                rows_inserted = 0
                for t in tables:
                    rows = payload.get(t, [])
                    if not rows or not isinstance(rows, list):
                        continue
                    
                    # Convert to DataFrame
                    df = pd.DataFrame(rows[:MAX_ROWS_PER_TABLE_PER_BATCH])
                    cols = table_columns(t)
                    if not cols:
                        continue
                    
                    # Align columns
                    for c in cols:
                        if c not in df.columns:
                            df[c] = None
                    df = df[cols]
                    
                    # Insert new rows (avoiding exact duplicates)
                    try:
                        con.register("new_rows_df", df)
                        con.execute(f'INSERT INTO "{t}" SELECT * FROM new_rows_df EXCEPT SELECT * FROM "{t}"')
                        con.unregister("new_rows_df")
                        rows_inserted += len(df)
                    except Exception as e:
                        print(f"      Warning: Failed to insert into {t}: {e}")
                
                # Check how many queries were fixed
                fixed_in_group = 0
                for idx in query_indices:
                    if count_rows(con, queries[idx]) > 0:
                        fixed_in_group += 1
                
                print(f"      Inserted {rows_inserted} rows, fixed {fixed_in_group}/{len(query_list)} queries")
                total_fixed += fixed_in_group
                
            except Exception as e:
                print(f"      Error processing group: {e}")
                continue
            
            time.sleep(0.5)  # Rate limiting
        
        print(f"  [Iteration {iteration}] Total fixed in this iteration: {total_fixed} queries")
        
        # If we're in a retry cycle and made no progress, stop
        if retry_count > 0 and total_fixed == 0:
            print(f"  No progress made in retry cycle. Stopping.")
            break

# Final cleanup - remove exact duplicates
print("\n--- Final deduplication ---")
with duckdb.connect(SYNTHETIC_DB_PATH) as con:
    tables = con.sql("""
        SELECT table_name 
        FROM information_schema.tables 
        WHERE table_type = 'BASE TABLE' 
        AND table_schema = 'main'
    """).fetchall()
    
    for (table_name,) in tables:
        before = con.sql(f'SELECT COUNT(*) FROM "{table_name}"').fetchone()[0]
        con.execute(f'CREATE OR REPLACE TABLE "{table_name}" AS SELECT DISTINCT * FROM "{table_name}"')
        after = con.sql(f'SELECT COUNT(*) FROM "{table_name}"').fetchone()[0]
        if before != after:
            print(f"  {table_name}: removed {before - after} duplicate rows")

# Final report
with duckdb.connect(SYNTHETIC_DB_PATH, read_only=True) as con_read:
    final_zero_indices = [i for i, q in enumerate(queries) if count_rows(con_read, q) == 0]
    final_zero_count = len(final_zero_indices)
    final_zero_percent = (final_zero_count / total_q * 100) if total_q else 0
    
    print(f"\n{'='*60}")
    print(f"AUGMENTATION COMPLETE")
    print(f"{'='*60}")
    print(f"Initial zero-result: {initial_zero_count}/{total_q} ({initial_zero_count/total_q*100:.1f}%)")
    print(f"Final zero-result:   {final_zero_count}/{total_q} ({final_zero_percent:.1f}%)")
    print(f"Improvement:         {initial_zero_count - final_zero_count} queries now return results")
    
    if final_zero_percent <= TARGET_MAX_ZERO_PERCENT:
        print(f"\nβœ… SUCCESS: Achieved target of ≀{TARGET_MAX_ZERO_PERCENT}% zero-result queries")
    else:
        print(f"\n⚠️  Partial success. Consider running again or adjusting parameters.")
        
    # Show a few examples of remaining zero-result queries if any
    if final_zero_count > 0 and final_zero_count <= 5:
        print(f"\nRemaining zero-result queries:")
        for idx in final_zero_indices[:5]:
            print(f"  [{idx}]: {queries[idx][:100]}...")

7. 🎯 Execute Queries to Get Ground-Truth Answers

Now we will act as the β€œsystem” and run the queries we just generated against our synthetic sandbox. The output of each query is the ground-truth result. During Reinforcement Fine-Tuning, our model will be rewarded if the SQL it writes produces this exact same result.

Why RFT is a good choice for a text-to-SQL use-case

In RFT, the model explores the space of possible SQL queries during fine-tuning; the reward signal comes from comparing the result of executing the model’s output SQL queries against the ground truth expected results. This is fundamentally different from SFT, where the model learns to mimic the exact SQL syntax. With RFT:
  • Multiple SQL queries can be β€œcorrect” if they produce the same result
  • The model learns to reason about the problem rather than memorize solutions
  • Edge cases and query optimization patterns can emerge naturally
Real World 🌍: You would run your real historical queries against the synthetic database we previously created. The correctness of the data is not a concern here, as our aim is to see what a correct query would have generated, so we can compare it to our LLM’s generations during the RFT process.
import duckdb
import json
import pandas as pd

# --- 1. Define File Paths ---
SYNTHETIC_DB_PATH = "data/synthetic_openflights.db"
QUERIES_FILE_PATH = "data/generated_queries.json"
GROUND_TRUTH_FILE_PATH = "data/ground_truth_results.jsonl"

# Thresholds to drop overly large results
MAX_RESULT_ROWS = 1000          # skip if result has more than this many rows
MAX_RESULT_BYTES = 100_000      # ~0.1MB; skip if JSON payload exceeds this

# --- 2. Load Generated Queries ---
with open(QUERIES_FILE_PATH, 'r') as f:
    queries_data = json.load(f)
    queries_to_execute = queries_data.get("queries", [])

print(f"Loaded {len(queries_to_execute)} queries to execute.")

# --- 3. Execute Queries and Store Results ---
ground_truth_results = []
successful_executions = 0
failed_executions = 0
oversized_skipped = 0

print("Executing queries against the synthetic database...")
with duckdb.connect(SYNTHETIC_DB_PATH, read_only=True) as con:
    for query in queries_to_execute:
        try:
            result_df = con.sql(query).df()

            # Replace any NaN/NaT values with None, which serializes to JSON `null`
            result_df = result_df.astype(object).where(pd.notna(result_df), None)
            result_records = result_df.to_dict('records')

            # Skip examples that are "much too wide"
            if len(result_records) > MAX_RESULT_ROWS:
                oversized_skipped += 1
                continue
            # Size-based guard (in bytes)
            payload_bytes = len(json.dumps(result_records, ensure_ascii=False).encode("utf-8"))
            if payload_bytes > MAX_RESULT_BYTES:
                oversized_skipped += 1
                continue

            ground_truth_results.append({
                "query": query,
                "result": result_records
            })
            successful_executions += 1
        except Exception as e:
            print(f"⚠️  Skipping query due to execution error: {query}\n   Error: {e}\n")
            failed_executions += 1

print(f"\nExecution complete. Success: {successful_executions}, Skipped (oversized): {oversized_skipped}, Failed: {failed_executions}.")

# --- 4. Save the Ground-Truth Data ---
with open(GROUND_TRUTH_FILE_PATH, 'w') as f:
    for entry in ground_truth_results:
        f.write(json.dumps(entry) + '\n')

print(f"\nβœ… Successfully saved {len(ground_truth_results)} ground-truth results to `{GROUND_TRUTH_FILE_PATH}`.")

# --- 5. Print an Example ---
if ground_truth_results:
    print("\n--- Example ground_truth_results dataset entry ---")
    print(json.dumps(ground_truth_results[0], indent=2))

8. πŸ’¬ Generate Natural Language Questions for Final RFT Training Data

We now have pairs of (SQL Query, Ground-Truth Result). The final piece missing from our training data is the user’s input: a question in natural language. This is because our final goal is to use RFT to tune an LLM to map from a natural language question to a SQL query, having the reward signal be the actual result of the query, rather than just the query itself. This is important because there are many ways to write the same SQL query that yield the same, correct result.

Thus, the complete training loop will look like this:

  1. User asks: β€œWhich countries have the most airlines?”
  2. Model generates: SQL query
  3. System executes: Query against database
  4. Reward calculation: Does result match ground truth?
  5. Model update: Reinforce successful strategies
Thus, we will use an LLM once again to translate our β€œhistorical” SQL queries into plausible questions a business user might ask, corresponding to that query. This will yield our final training dataset in the format: (Natural Language Question, SQL Query, Ground-Truth Result). Note that the SQL queries themselves will not be used as part of the RFT job itself, but are useful for debugging our evaluation function (more details in a later section).
Real World 🌍: You might not need this step! If you have logs that already link user questions to the queries they ran (e.g., from a BI tool’s search bar), you can use those directly. If not, this LLM-based translation is a powerful technique to bootstrap your training data.
import json
import time
import jsonlines
from typing import List
import random
from fireworks import LLM
import os

# --- 1. Define File Paths and Parameters ---
llm = LLM(model="accounts/fireworks/models/qwen3-coder-480b-a35b-instruct", deployment_type="serverless", api_key=os.getenv("FIREWORKS_API_KEY"))
GROUND_TRUTH_FILE_PATH = "data/ground_truth_results.jsonl"
FINAL_TRAINING_DATA_PATH = "data/final_rft_sql_train_data.jsonl"
FINAL_TEST_DATA_PATH = "data/final_rft_sql_test_data.jsonl"

# --- 2. Load Ground-Truth Data ---
query_result_pairs = []
with jsonlines.open(GROUND_TRUTH_FILE_PATH) as reader:
    for obj in reader:
        query_result_pairs.append(obj)

print(f"Loaded {len(query_result_pairs)} query-result pairs.")

# --- 3. Use LLM to Generate Natural Language Questions ---
nl_generation_prompt_template = """
You are an expert data analyst who is great at translating SQL queries into plain English.
Based on the database schema and the provided SQL query, what is a natural language question a business user would ask to get this information?
Ensure that the question is precise enough to accurately map to the corresponding SQL query.

**Database Schema:**
{schema_for_prompt}

**SQL Query:**
{query}

Provide only the user's question, without any preamble or explanation.
""".strip()

# The system prompt that will be included in the final training data for the RFT job.
# It gives the model its instructions at inference time.
rft_system_prompt = f"""
You are an expert SQL data analyst.
Your task is to write a single, valid DuckDB SQL query to answer the user's question, based on the provided database schema.
Write only the raw SQL query text and nothing else (i.e., no markdown formatting); your output should be a directly executable valid SQL query.
Make sure your queries do not return duplicate rows (i.e., GROUP BY all columns that are not aggregate functions).
Ensure the generated SQL is valid for DuckDB.

**Database Schema:**
{schema_for_prompt}
""".strip()

final_generated_data = []
print(f"Generating natural language questions and formatting for RFT for {len(query_result_pairs)} queries...")

for i, pair in enumerate(query_result_pairs):
    print(f" - Processing query {i+1}/{len(query_result_pairs)}...")
    query = pair['query']
    ground_truth = pair['result']
    nl_generation_prompt = nl_generation_prompt_template.format(schema_for_prompt=schema_for_prompt, query=query)
    
    response = llm.chat.completions.create(
        messages=[{"role": "user", "content": nl_generation_prompt}],
        temperature=0.5
    )
    
    nl_question = response.choices[0].message.content
    if nl_question:  # Only include the entry if the LLM generated a question
        # Assemble the final data structure
        rft_entry = {
            "messages": [
                {"role": "system", "content": rft_system_prompt},
                {"role": "user", "content": nl_question.strip()},
                {"role": "assistant", "content": query}
            ],
            "ground_truth": ground_truth  # The ground-truth result for the evaluator
        }
        final_generated_data.append(rft_entry)
    
    time.sleep(0.5) # Be nice to the API

# --- 4. Shuffle and Split the Dataset ---
print(f"\nGenerated {len(final_generated_data)} total examples.")

# Filter out entries where 'ground_truth' is an empty list
original_count = len(final_generated_data)
final_generated_data = [entry for entry in final_generated_data if entry.get("ground_truth")]
print(f"Filtered out {original_count - len(final_generated_data)} examples with empty ground truth.")

print(f"Now splitting the remaining {len(final_generated_data)} examples into train and test sets.")

random.seed(42)
random.shuffle(final_generated_data)

split_index = int(len(final_generated_data) * 0.8)
train_data = final_generated_data[:split_index]
test_data = final_generated_data[split_index:]

print(f"Train set size: {len(train_data)}")
print(f"Test set size: {len(test_data)}")

# --- 5. Save the Final RFT-Ready Datasets ---
with jsonlines.open(FINAL_TRAINING_DATA_PATH, mode='w') as writer:
    writer.write_all(train_data)
print(f"\nβœ… Successfully saved training dataset to `{FINAL_TRAINING_DATA_PATH}`.")

with jsonlines.open(FINAL_TEST_DATA_PATH, mode='w') as writer:
    writer.write_all(test_data)
print(f"βœ… Successfully saved test dataset to `{FINAL_TEST_DATA_PATH}`.")

# --- 6. Print an Example ---
if train_data:
    print("\n--- Example RFT training entry ---")
    print(json.dumps(train_data[0], indent=2))

9. πŸ›°οΈ Deploy an MCP Server for the Synthetic Data

Now, we’ll start a remote server that speaks the Model Context Protocol (MCP). This server will wrap our synthetic DuckDB database, providing a standardized way for any external toolβ€”in our case, the Fireworks RFT evaluatorβ€”to interact with it.

What is MCP?

The Model Context Protocol is an open standard that standardizes how applications provide context to LLMs. Think of MCP like a USB-C port for AI applications. Just as USB-C provides a standardized way to connect your devices to various peripherals, MCP provides a standardized way to connect AI models to various data sources and tools. Key benefits:
  • Flexibility: Works with any data source or tool
  • Standardization: One protocol for all integrations instead of custom APIs for each tool; MCP servers for many applications are readily available
Real World 🌍: This pattern is directly applicable. You would run a similar MCP server to provide a secure, read-only interface to a production database replica or a data warehouse, allowing the fine-tuning process to happen without granting direct database credentials to the training environment.
  1. a) Create a server script in this project’s root directory (run_mcp_server.py). This Python script starts our database server. It is configured to be read-only.
    import os, contextlib, uvicorn
    from starlette.applications import Starlette
    from starlette.routing import Mount
    from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
    from mcp_server_motherduck import build_application

    DB = "data/synthetic_openflights.db"          # ← path from previous steps
    PORT = int(os.environ.get("PORT", 8080))        # Cloud Run injects $PORT

    # 1️⃣ Build the core SQL-aware MCP server (read-only for safety).
    server, _ = build_application(db_path=DB, read_only=True)

    # 2️⃣ Wrap it so HTTP clients can talk to it (ASGI handler).
    sess = StreamableHTTPSessionManager(app=server, event_store=None, stateless=True)

    async def handler(scope, receive, send):
        await sess.handle_request(scope, receive, send)

    @contextlib.asynccontextmanager
    async def lifespan(app):
        async with sess.run():
            yield                                        # keep sessions alive

    # 3️⃣ Starlette turns that handler into a full ASGI app Uvicorn can serve.
    app = Starlette(routes=[Mount("/mcp", app=handler)], lifespan=lifespan)

    if __name__ == "__main__":
        print(f"πŸ”₯ MCP endpoint β†’ http://0.0.0.0:{PORT}/mcp")
        uvicorn.run(app, host="0.0.0.0", port=PORT)

10. ☁️ Set Up Google Cloud CLI & .gcloudignore

We’ll first set up the Google Cloud CLI and authenticate. Google Cloud Run provides an easy way to deploy containerized applications without managing infrastructure.
Real World 🌍
You would follow along here in the same way. Cloud Run is ideal for MCP servers because it auto-scales based on demand (down to zero when not in use, thus charging only for actual usage).
  1. a) Install the SDK (macOS/Linux):
    curl -sSL https://sdk.cloud.google.com | bash
    exec -l $SHELL  # reload shell so 'gcloud' is available
    
  2. b) Log in (creates local access token):
    gcloud auth login
    
  3. c) Set your active project desired gcloud project:
    gcloud config set project < YOUR_PROJECT_ID >  # set up project in gcloud console before running this if not already done
    

11. πŸ“¦ Containerize & Deploy the MCP Server

We’ll build a Docker image and push it straight to Cloud Run.
Remember to replace YOUR_PROJECT_ID with the project you actually want to bill.
Real World 🌍
You would follow along in the same way here.
  1. a) Create mcp_requirements.txt containing the following:
mcp
mcp-server-motherduck
duckdb
uvicorn
starlette
  1. b) Create a Dockerfile (no extension) containing the following
base
FROM python:3.11-slim
WORKDIR /app

COPY mcp_requirements.txt .
RUN pip install --no-cache-dir -r mcp_requirements.txt

COPY run_mcp_server.py .
COPY data/synthetic_openflights.db ./data/

EXPOSE 8080

CMD ["python", "run_mcp_server.py"]
  1. c) Create a .gcloudignore file in your root dir (to only deploy files needed for MCP server) containing:
# .gcloudignore

# 1. Ignore EVERYTHING in the directory by default.
*

# 2. Now, create exceptions for ONLY the files needed by the Dockerfile.
# The "!" character means "do not ignore this file".

# The Dockerfile itself is needed for the build process.
!Dockerfile

# The files explicitly copied by your Dockerfile:
!mcp_requirements.txt
!run_mcp_server.py

# 3. To include a specific file in a subdirectory, use this
#    three-line pattern to un-ignore the directory, re-ignore its
#    contents, and then un-ignore the specific file.
!data/
data/*
!data/synthetic_openflights.db
  1. d) Deploy your MCP server as a Cloud Run app by running (from your project root):
gcloud run deploy mcp-sql-rft-server \
  --source . \
  --region < YOUR_GCP_REGION > \
  --project < YOUR_GCP_PROJECT_ID > \
  --allow-unauthenticated \
  --port 8080
Note this will create a default Docker repository called cloud-run-source-deploy; press Y to continue when prompted.
  1. e) Test that your MCP server is working as expected by running the following from your terminal:
  2. e) i. To get your MCP server’s URL:
gcloud run services describe mcp-sql-rft-server \
--project < YOUR_GCP_PROJECT_ID > \
--region < YOUR_GCP_REGION > \
--format="value(status.url)"
  1. e) ii. (optional) To check the names of the MCP server’s available tools:
curl -X POST "< YOUR_MCP_SERVER_URL_FROM_STEP_i >/mcp/" \
-H "Content-Type: application/json" \
-H "Accept: application/json, text/event-stream" \
-d '{
    "id": "list-tools-1",
    "jsonrpc": "2.0",
    "method": "tools/list",
    "params": {
        "session": {"id": "test-from-my-laptop"}
    }
}'
Note that the above is a generally useful way to check an MCP server’s tools. In this case, the tool of interest is the β€œquery” tool.
  1. e) iii. To send a test request to the MCP server:
curl -X POST "< YOUR_MCP_SERVER_URL_FROM_STEP_i >/mcp/" \
-H "Content-Type: application/json" \
-H "Accept: application/json, text/event-stream" \
-d '{
    "id": "query-1",
    "jsonrpc": "2.0",
    "method": "tools/call",
    "params": {
        "session": {"id": "test-from-my-laptop"},
        "name": "query",
        "arguments": {
            "query": "SELECT COUNT(*) FROM airlines;"
        }
    }
}'

12. πŸ” Define an evaluation function for RFT

Here, we define an evaluate function for RFT, which will interface with our MCP server. Note that you will not directly execute the function here, but will use it as part of the Fireworks Evaluations UI.

Understanding the Evaluation Function:

The evaluation function is the heart of RFT. It:
  1. Receives the model’s generated SQL query
  2. Executes it against the real database (via MCP)
  3. Compares the result with ground truth
  4. Returns a reward score (0 or 1)
This binary reward signal drives the reinforcement learning process. The model learns through trial and error which SQL patterns lead to correct results. Key design decisions:
  • Exact match comparison: We normalize values and sort rows to handle different but equivalent result orderings
  • Robust error handling: SQL syntax errors or execution failures return a score of 0
  • Detailed reasoning: The function returns explanatory messages for debugging
Ensure that you set MCP_SERVER_URL to be your actual MCP server URL from step 11. e) i.
Real World 🌍
You would follow along in the same way here. The evaluation function could also be further customized, with, for example:
  • Partial credit for near-correct answers
  • Performance-based rewards (faster queries get higher scores)
import requests
import json
import math

MCP_SERVER_URL = None  # <--- PUT MCP SERVER URL HERE without the /mcp/ suffix at the end

def evaluate(messages: list[dict], ground_truth: list[dict], **kwargs) -> dict:
    """
    Evaluates the model's generated SQL query by executing it against a live
    MCP server and comparing the result with the ground_truth.
    """
    
    def parse_duckdb_ascii_table(table_string: str) -> list[dict]:
        """
        Parses a DuckDB-style ASCII table string into a list of dictionaries.
        This version robustly handles 'NULL' values and empty strings.
        """
        lines = table_string.strip().split('\n')
        content_lines = [line for line in lines if line.strip() and not line.startswith('+')]
        if len(content_lines) < 2:
            return []
        
        header_raw = [h.strip() for h in content_lines[0].split('|')[1:-1]]
        data_lines = content_lines[1:]
        
        if len(data_lines) > 0:
            try:
                first_data_values = [v.strip() for v in data_lines[0].split('|')[1:-1]]
                if len(first_data_values) == len(header_raw) and all(v.isupper() for v in first_data_values):
                    data_lines = data_lines[1:]
            except IndexError:
                pass

        rows = []
        for line in data_lines:
            try:
                values_raw = [v.strip() for v in line.split('|')[1:-1]]
                if len(values_raw) == len(header_raw):
                    row_dict = {}
                    for i, header in enumerate(header_raw):
                        value_str = values_raw[i]
                        if value_str.upper() == 'NULL' or value_str == '':
                            row_dict[header] = None
                            continue
                        
                        try:
                            if '.' in value_str:
                                row_dict[header] = float(value_str)
                            else:
                                row_dict[header] = int(value_str)
                        except (ValueError, TypeError):
                            row_dict[header] = value_str
                    rows.append(row_dict)
            except IndexError:
                continue
        return rows

    # --- 1. Get MCP Server URL from Environment Variables ---
    mcp_server_url = MCP_SERVER_URL
    if not mcp_server_url:
        return {"score": 0, "is_score_valid": False, "reason": "FATAL: MCP_SERVER_URL environment variable is not set."}

    # --- 2. Get the SQL query from the model's response ---
    sql_query = messages[-1]['content'].strip()
    if not sql_query:
        return {"score": 0, "reason": "Model returned an empty response."}

    # --- 3. Execute the Query against the MCP Server ---
    headers = {
        "Content-Type": "application/json",
        "Accept": "application/json, text/event-stream"
    }
    payload = {
        "id": "eval-query-1", "jsonrpc": "2.0", "method": "tools/call",
        "params": {"session": {"id": "stateless-eval-session"}, "name": "query", "arguments": {"query": sql_query}}
    }
    try:
        with requests.post(f"{mcp_server_url}/mcp/", headers=headers, json=payload, timeout=15, stream=True) as response:
            response.raise_for_status()
            response_data = None
            for line in response.iter_lines():
                if line:
                    decoded_line = line.decode('utf-8')
                    if decoded_line.startswith('data:'):
                        json_part = decoded_line[len('data:'):].strip()
                        if json_part:
                            response_data = json.loads(json_part)
                            break
            if response_data is None:
                return {"score": 0, "reason": "Could not find JSON data in event stream response from MCP server."}

        if "error" in response_data:
            return {"score": 0, "reason": f"SQL execution failed. Error: {response_data['error'].get('message', 'Unknown')}"}

        ascii_table = response_data['result']['content'][0]['text']
        predicted_rows = parse_duckdb_ascii_table(ascii_table)

    except requests.exceptions.RequestException as e:
        return {"score": 0, "reason": f"Network error calling MCP server: {e}"}
    except json.JSONDecodeError as e:
        return {"score": 0, "reason": f"JSON decode error from server response: {e}"}
    except (KeyError, IndexError):
        return {"score": 0, "reason": f"Failed to parse predicted result from MCP server response structure. Data found: {json.dumps(response_data)}"}
    except Exception as e:
        return {"score": 0, "reason": f"An unexpected error occurred during query execution: {e}"}

    # --- 4. Process Ground Truth ---
    if not isinstance(ground_truth, list):
        return {"score": 0, "is_score_valid": False, "reason": f"FATAL: ground_truth is not a list as expected. Got type: {type(ground_truth)}"}
    ground_truth_rows = ground_truth


    # --- 5. Comparison Logic ---
    def normalize_and_stringify(v):
        """
        Normalizes numbers and None before string conversion.
        """
        if v is None:
            return str(v)
        
        if isinstance(v, float) and not math.isinf(v) and not math.isnan(v) and v == int(v):
            v = int(v)
        return str(v)

    try:
        gt_values = sorted([sorted(map(normalize_and_stringify, row.values())) for row in ground_truth_rows])
        predicted_values = sorted([sorted(map(normalize_and_stringify, row.values())) for row in predicted_rows])

        if gt_values == predicted_values:
            score = 1
            reason = "Success: The SQL query produced the exact expected result."
        else:
            score = 0
            gt_json = json.dumps(ground_truth_rows)
            pred_json = json.dumps(predicted_rows)
            reason = f"Incorrect result. Expected (from ground_truth): {gt_json}. Got (from query): {pred_json}."
    
    except Exception as e:
        return {"score": 0, "reason": f"Error during result comparison: {e}"}

    return {"score": score, "reason": reason}

13. πŸ§ͺ Test English -> SQL of a base model without fine-tuning

Here, we test a base model’s ability to generate SQL from a natural language question on a single example from our training data. This is a quick sanity check that:
  1. Verifies your MCP server is working: Ensures the server is accessible and can execute queries
  2. Tests the full pipeline: Confirms that the flow from natural language β†’ SQL generation β†’ execution β†’ result parsing works end-to-end
  3. Shows a concrete example: Demonstrates what happens when an off-the-shelf model tries to answer a question about your specific database
The test process:
  1. Load one example from your training data (by default, the first row)
  2. Feed the natural language question to a base model (e.g., Llama 3.1 8B)
  3. Execute whatever SQL the model generates against your MCP server
  4. Compare the result to the ground truth
  5. Print whether it succeeded or failed
What to expect:
  • The base model might get it right! Simple queries often work.
  • Or, you’ll see some kind of failure: wrong column names, missing aliases, incorrect syntax, etc.
  • Either outcome is fine; this is just a quick test to see the model in action before fine-tuning.
To try different examples, change ROW_INDEX_TO_TEST to test other rows from your dataset. Ensure that you set MCP_SERVER_URL to be your actual MCP server URL from step 11. e) i.
Real World 🌍
You can follow along in the same way here. This single-example test is just a quick way to verify everything is wired up correctly before launching the more expensive fine-tuning job.
import requests
import json
import os
from fireworks import LLM

# --- 1. SETUP: Define API keys, server URLs, and the model to use ---

# IMPORTANT: Make sure your FIREWORKS_API_KEY is set as an environment variable.
# You can get one from https://fireworks.ai
if "FIREWORKS_API_KEY" not in os.environ:
    print("FATAL: FIREWORKS_API_KEY environment variable not set.")
    # If not set, you can hardcode it here for testing, but this is not recommended:
    # os.environ["FIREWORKS_API_KEY"] = "YOUR_API_KEY_HERE"

# The model we'll use to generate the SQL. This acts as our "base" model.
LLM_MODEL = "accounts/fireworks/models/llama-v3p1-8b-instruct"
llm = LLM(model=LLM_MODEL, deployment_type="auto", api_key=os.getenv("FIREWORKS_API_KEY"))

# The URL for your running MCP server.
MCP_SERVER_URL = None  # PUT MCP SERVER URL HERE without the /mcp/ suffix at the end


# --- 2. LOAD THE EXAMPLE DATA ---

# This is the example data you provided.
DATASET_FILE_PATH = "data/final_rft_sql_train_data.jsonl"
ROW_INDEX_TO_TEST = 0  # 0 is the first row, 1 is the second row, etc.

EXAMPLE_DATA = None
try:
    with open(DATASET_FILE_PATH, 'r') as f:
        for i, line in enumerate(f):
            if i == ROW_INDEX_TO_TEST:
                EXAMPLE_DATA = json.loads(line)
                break
    
    if EXAMPLE_DATA is None:
        with open(DATASET_FILE_PATH, 'r') as f:
            line_count = sum(1 for line in f)
        raise IndexError(f"row index {ROW_INDEX_TO_TEST} is out of bounds for file with {line_count} rows.")

    print(f"Successfully loaded row {ROW_INDEX_TO_TEST} from '{DATASET_FILE_PATH}'.\n")
    print(EXAMPLE_DATA)
    print()

except Exception as e:
    print(f"Warning: Could not load from file. Reason: {e}")

# If loading from file failed for any reason, use the hardcoded fallback data.
if EXAMPLE_DATA is None:
    print("Using hardcoded fallback EXAMPLE_DATA.\n")
    EXAMPLE_DATA = {
        "messages": [
            {"role": "system", "content": "\nYou are an expert SQL data analyst. Your task is to write a single, valid DuckDB SQL query to answer the user's question, based on the provided database schema. Do not provide any explanation or text other than the SQL query itself.\n\n**Database Schema:**\n| database              | schema   | name      | column_names                                                               | column_types                                                          | temporary   |\n|:----------------------|:---------|:----------|:---------------------------------------------------------------------------|:----------------------------------------------------------------------|:------------|\n| synthetic_openflights | main     | airlines  | ['airline_id' 'name' 'alias' 'iata' 'icao' 'callsign' 'country' 'active']  | ['BIGINT' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' | False       |\n|                       |          |           |                                                                            |  'VARCHAR']                                                           |             |\n| synthetic_openflights | main     | airports  | ['airport_id' 'name' 'city' 'country' 'iata' 'icao' 'latitude' 'longitude' | ['BIGINT' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'DOUBLE'  | False       |\n|                       |          |           |  'altitude' 'timezone' 'dst' 'tz_db' 'type' 'source']                      |  'DOUBLE' 'BIGINT' 'DOUBLE' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR']  |             |\n| synthetic_openflights | main     | countries | ['name' 'iso_code' 'dafif_code']                                           | ['VARCHAR' 'VARCHAR' 'VARCHAR']                                       | False       |\n| synthetic_openflights | main     | planes    | ['name' 'iata' 'icao']                                                     | ['VARCHAR' 'VARCHAR' 'VARCHAR']                                       | False       |\n| synthetic_openflights | main     | routes    | ['airline' 'airline_id' 'source_airport' 'source_airport_id'               | ['VARCHAR' 'BIGINT' 'VARCHAR' 'BIGINT' 'VARCHAR' 'BIGINT' 'VARCHAR'   | False       |\n|                       |          |           |  'destination_airport' 'destination_airport_id' 'codeshare' 'stops'        |  'BIGINT' 'VARCHAR']                                                  |             |\n|                       |          |           |  'equipment']                                                              |                                                                       |             |\n"},
            {"role": "user", "content": "Which countries have the most airlines, and how many airlines are there in each country, listed in descending order by the number of airlines and then alphabetically by country name?"},
            {"role": "assistant", "content": "SELECT country, COUNT(*) AS airline_count FROM airlines GROUP BY country ORDER BY airline_count DESC, country ASC"}
        ],
        "ground_truth": [{"country": "Canada", "airline_count": 10}, {"country": "Sweden", "airline_count": 10}, {"country": "Kenya", "airline_count": 9}, {"country": "United States", "airline_count": 9}, {"country": "Australia", "airline_count": 8}, {"country": "Spain", "airline_count": 6}, {"country": "Italy", "airline_count": 4}, {"country": "Switzerland", "airline_count": 4}, {"country": "Finland", "airline_count": 3}, {"country": "France", "airline_count": 3}, {"country": "Mexico", "airline_count": 3}, {"country": "Costa Rica", "airline_count": 2}, {"country": "Germany", "airline_count": 2}, {"country": "Iceland", "airline_count": 2}, {"country": "Ireland", "airline_count": 2}, {"country": "Japan", "airline_count": 2}, {"country": "Norway", "airline_count": 2}, {"country": "Singapore", "airline_count": 2}, {"country": "United Kingdom", "airline_count": 2}, {"country": "Argentina", "airline_count": 1}, {"country": "Brazil", "airline_count": 1}, {"country": "China", "airline_count": 1}, {"country": "Egypt", "airline_count": 1}, {"country": "Fiji", "airline_count": 1}, {"country": "Greece", "airline_count": 1}, {"country": "India", "airline_count": 1}, {"country": "Jordan", "airline_count": 1}, {"country": "Netherlands", "airline_count": 1}, {"country": "New Zealand", "airline_count": 1}, {"country": "Portugal", "airline_count": 1}, {"country": "Saudi Arabia", "airline_count": 1}, {"country": "South Africa", "airline_count": 1}, {"country": "Thailand", "airline_count": 1}, {"country": "United Arab Emirates", "airline_count": 1}]
    }

# Extract the prompts and ground truth from the data
system_prompt = EXAMPLE_DATA["messages"][0]["content"]
user_prompt = EXAMPLE_DATA["messages"][1]["content"]
GROUND_TRUTH_ROWS = EXAMPLE_DATA["ground_truth"]

# --- 3. HELPER FUNCTION: To parse the server's ASCII table response ---

def parse_duckdb_ascii_table(table_string: str) -> list[dict]:
    """
    Parses a DuckDB-style ASCII table string into a list of dictionaries.
    This version robustly handles 'NULL' values and empty strings.
    """
    lines = table_string.strip().split('\n')
    content_lines = [line for line in lines if line.strip() and not line.startswith('+')]
    if len(content_lines) < 2:
        return []
    
    header_raw = [h.strip() for h in content_lines[0].split('|')[1:-1]]
    data_lines = content_lines[1:]
    
    if len(data_lines) > 0:
        try:
            first_data_values = [v.strip() for v in data_lines[0].split('|')[1:-1]]
            if len(first_data_values) == len(header_raw) and all(v.isupper() for v in first_data_values):
                data_lines = data_lines[1:]
        except IndexError:
            pass

    rows = []
    for line in data_lines:
        try:
            values_raw = [v.strip() for v in line.split('|')[1:-1]]
            if len(values_raw) == len(header_raw):
                row_dict = {}
                for i, header in enumerate(header_raw):
                    value_str = values_raw[i]
                    if value_str.upper() == 'NULL' or value_str == '':
                        row_dict[header] = None
                        continue
                    
                    try:
                        if '.' in value_str:
                            row_dict[header] = float(value_str)
                        else:
                            row_dict[header] = int(value_str)
                    except (ValueError, TypeError):
                        row_dict[header] = value_str
                rows.append(row_dict)
        except IndexError:
            continue
    return rows

# --- 4. GENERATE SQL QUERY USING THE LLM ---

print("="*20)
print("LLM QUERY GENERATION")
print("="*20)

model_generated_sql = ""
try:
    print(f"User prompt: {user_prompt}")
    print(f"Ground truth: {GROUND_TRUTH_ROWS}")
    print(f"Calling model '{LLM_MODEL}' to generate SQL query...")
    
    messages_for_llm = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]
    
    response = llm.chat.completions.create(
        model=LLM_MODEL,
        messages=messages_for_llm,
        temperature=0.0  # Set to 0 for deterministic output
    )
    
    model_generated_sql = response.choices[0].message.content.strip()
    print("\nModel Generated SQL Query:")
    print(model_generated_sql)
    
except Exception as e:
    print(f"\nAN ERROR OCCURRED during LLM call: {e}")


# --- 5. EXECUTE GENERATED QUERY ON MCP SERVER ---

predicted_rows = []
if model_generated_sql:
    try:
        print("\n" + "="*20)
        print("MCP SERVER EXECUTION")
        print("="*20)
        print(f"Sending query to MCP server...")
        
        headers = {"Content-Type": "application/json", "Accept": "application/json, text/event-stream"}
        payload = {
            "id": "eval-query-1", "jsonrpc": "2.0", "method": "tools/call",
            "params": {"session": {"id": "stateless-eval-session"}, "name": "query", "arguments": {"query": model_generated_sql}}
        }

        with requests.post(f"{MCP_SERVER_URL}/mcp/", headers=headers, json=payload, timeout=20, stream=True) as response:
            response.raise_for_status()
            response_data = None
            for line in response.iter_lines():
                if line and line.decode('utf-8').startswith('data:'):
                    json_part = line.decode('utf-8')[len('data:'):].strip()
                    if json_part:
                        response_data = json.loads(json_part)
                        break
            
            if response_data is None: raise RuntimeError("No JSON data in event stream.")
            if "error" in response_data: raise RuntimeError(f"SQL Error: {response_data['error'].get('message', 'Unknown')}")

            ascii_table = response_data['result']['content'][0]['text']
            predicted_rows = parse_duckdb_ascii_table(ascii_table)
            print("\nParsed Result from Server:")
            print(json.dumps(predicted_rows, indent=2))

    except Exception as e:
        print(f"\nAN ERROR OCCURRED during MCP call: {e}")

# --- 6. FINAL COMPARISON ---
print("\n" + "="*20)
print("COMPARISON")
print("="*20)

if not predicted_rows:
    print("Skipping comparison: no rows returned from query or an error occurred.")
else:
    gt_values = sorted([sorted(map(str, row.values())) for row in GROUND_TRUTH_ROWS])
    predicted_values = sorted([sorted(map(str, row.values())) for row in predicted_rows])

    if gt_values == predicted_values:
        print("\nβœ… GOOD RESULT: The base model generated SQL that produced the correct data.\n")
    else:
        print("\n❌ BAD RESULT: The base model's SQL produced different data than expected.\n")
        print("This is often the intended outcome when testing a base model, as it highlights what fine-tuning needs to correct.")
    Successfully loaded row 0 from 'data/final_rft_sql_train_data.jsonl'.
    
    {'messages': [{'role': 'system', 'content': "You are an expert SQL data analyst.\nYour task is to write a single, valid DuckDB SQL query to answer the user's question, based on the provided database schema.\nWrite only the raw SQL query text and nothing else (i.e., no markdown formatting); your output should be a directly executable valid SQL query.\nMake sure your queries do not return duplicate rows (i.e., GROUP BY all columns that are not aggregate functions).\nEnsure the generated SQL is valid for DuckDB.\n\n**Database Schema:**\n| database              | schema   | name      | column_names                                                               | column_types                                                          | temporary   |\n|:----------------------|:---------|:----------|:---------------------------------------------------------------------------|:----------------------------------------------------------------------|:------------|\n| synthetic_openflights | main     | airlines  | ['airline_id' 'name' 'alias' 'iata' 'icao' 'callsign' 'country' 'active']  | ['BIGINT' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' | False       |\n|                       |          |           |                                                                            |  'VARCHAR']                                                           |             |\n| synthetic_openflights | main     | airports  | ['airport_id' 'name' 'city' 'country' 'iata' 'icao' 'latitude' 'longitude' | ['BIGINT' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'DOUBLE'  | False       |\n|                       |          |           |  'altitude' 'timezone' 'dst' 'tz_db' 'type' 'source']                      |  'DOUBLE' 'BIGINT' 'DOUBLE' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR']  |             |\n| synthetic_openflights | main     | countries | ['name' 'iso_code' 'dafif_code']                                           | ['VARCHAR' 'VARCHAR' 'VARCHAR']                                       | False       |\n| synthetic_openflights | main     | planes    | ['name' 'iata' 'icao']                                                     | ['VARCHAR' 'VARCHAR' 'VARCHAR']                                       | False       |\n| synthetic_openflights | main     | routes    | ['airline' 'airline_id' 'source_airport' 'source_airport_id'               | ['VARCHAR' 'BIGINT' 'VARCHAR' 'BIGINT' 'VARCHAR' 'BIGINT' 'VARCHAR'   | False       |\n|                       |          |           |  'destination_airport' 'destination_airport_id' 'codeshare' 'stops'        |  'BIGINT' 'VARCHAR']                                                  |             |\n|                       |          |           |  'equipment']                                                              |                                                                       |             |"}, {'role': 'user', 'content': 'Which active airlines in India have the most routes, and how many routes does each operate?'}, {'role': 'assistant', 'content': "SELECT a.name AS airline_name, COUNT(r.airline_id) AS route_count FROM airlines a JOIN routes r ON a.airline_id = r.airline_id WHERE a.active = 'Y' AND a.country = 'India' GROUP BY a.name ORDER BY route_count DESC, airline_name"}], 'ground_truth': [{'airline_name': 'Alliance Air', 'route_count': 50}, {'airline_name': 'IndiGo Airlines', 'route_count': 50}, {'airline_name': 'Air India', 'route_count': 29}, {'airline_name': 'Crimson Air', 'route_count': 1}, {'airline_name': 'SkyLark Airways', 'route_count': 1}]}
    
    ====================
    LLM QUERY GENERATION
    ====================
    User prompt: Which active airlines in India have the most routes, and how many routes does each operate?
    Ground truth: [{'airline_name': 'Alliance Air', 'route_count': 50}, {'airline_name': 'IndiGo Airlines', 'route_count': 50}, {'airline_name': 'Air India', 'route_count': 29}, {'airline_name': 'Crimson Air', 'route_count': 1}, {'airline_name': 'SkyLark Airways', 'route_count': 1}]
    Calling model 'accounts/fireworks/models/llama-v3p1-8b-instruct' to generate SQL query...
    
    Model Generated SQL Query:
    SELECT T1.name, COUNT(T2.airline_id) FROM airlines AS T1 INNER JOIN routes AS T2 ON T1.airline_id = T2.airline_id WHERE T1.country = 'India' AND T1.active = 1 GROUP BY T1.name ORDER BY COUNT(T2.airline_id) DESC LIMIT 1
    
    ====================
    MCP SERVER EXECUTION
    ====================
    Sending query to MCP server...
    
    Parsed Result from Server:
    [
      {}
    ]
    
    ====================
    COMPARISON
    ====================
    
    ❌ BAD RESULT: The base model's SQL produced different data than expected.
    
    This is often the intended outcome when testing a base model, as it highlights what fine-tuning needs to correct.

14. πŸš€ Launch the Fine-Tuning Job & Deploy via the UI

Now we’ll use the Fireworks AI web interface to take our prepared dataset and fine-tune a model. This process uses your custom evaluate function to teach a base model how to generate SQL correctly.

RFT vs Traditional Fine-Tuning:

Traditional supervised fine-tuning (SFT) would:
  • Require thousands of examples
  • Teach the model to mimic exact SQL syntax
  • Often overfit to specific query patterns
Reinforcement fine-tuning (RFT) instead:
  • Works with just hundreds of examples
  • Rewards correct results regardless of SQL syntax
  • Discovers novel solutions through exploration
  • Generalizes better to unseen queries
Real World 🌍
This is the core of the RFT process. You’re teaching a general-purpose model a very specific and valuable new skill using a powerful, UI-driven workflow. You may follow along as described below
As described in the Fireworks RFT documentation, the process involves uploading your data, creating an evaluator, running the job, and deploying. 14. a) Upload Your Dataset
  1. Navigate to the Datasets tab in your https://app.fireworks.ai dashboard.
  2. Click β€œCreate Dataset”.
  3. Upload your training file: data/final_rft_sql_train_data.jsonl.
  4. Give it a memorable name, like rft-sql-train-data-v1, and save it.
14. b) Create the Evaluator
  1. Navigate to the Evaluations tab in the dashboard.
  2. Click β€œCreate Evaluator”. This will open the web IDE.
  3. In the editor on the left, replace the template code with your full evaluate function from step 12 above. This function already contains the logic to connect to your MCP server and compare the results. You just need to add your MCP server URL to the MCP_SERVER_URL line.
  4. Save the evaluator with a name like rft-sql-mcp-evaluator-v1.
14. c) Launch the Fine-Tuning Job
  1. Navigate to the Fine-Tuning tab.
  2. Click β€œFine-Tune a Model” and select Reinforcement.
  3. Configure the job:
    • Model Selection: Select a model, for example qwen2p5-7b (may appear as Qwen2.5 7B).
    • Dataset: Select the rft-sql-train-data-v1 you uploaded.
    • Evaluator: Select the rft-sql-mcp-evaluator-v1 you just created.
    • Rollout: You can leave these as the default values.
    • Optional Settings: You can leave the Model Output Name blank and get the default name, or enter a name of your choosing.
  4. You can leave most other hyperparameters as their defaults, though fine-tuning for 32 epochs (i.e., setting Epochs to 32) is recommended due to the complexity of the task.
  5. Click β€œCreate Job”.
14. d) Monitor and Deploy
  1. You can monitor the progress of your job in the Fine-Tuning tab. In this example, we trained for 32 epochs and got the following plot: RFT Training Progress
  2. Once the job status is Completed, you can deploy your model. To deploy, click β€œDeploy” on the top right of your fine-tuning job’s page and then β€œDeploy LoRA Model”. Please note:
    • The Model under β€œSelect base model*” should be the one from your Reinforcement Fine-Tuning job (this should be populated automatically)
    • Speculative decoding is an advanced technique that can improve latency, but is not needed for this use-case
    • Feel free to make the other selections (Performance, Scaling, and Metadata) as needed; enabling autoscaling is recommended to reduce costs
  3. Find this new model and click the Deploy button to create an API endpoint.
14. e) Test Your New Model! Once deployed, copy your new model’s ID and paste it into the LLM_MODEL variable in the testing cell (step #13) to make sure it works as expected, along with your MCP server URL (i.e., LLM_MODEL = "accounts/<your-account-id>/models/<your-model-id>" and MCP_SERVER_URL = "<your-mcp-server-url>").

15. βš–οΈ Evaluate Model Performance

Now for the moment of truth. We will systematically compare the performance of the original base model against our newly fine-tuned model, as well as a much larger base model, to quantify the improvement and general accuracy. We’ll run both models against every entry in our test dataset (final_rft_sql_test_data.jsonl). For each entry, we will:
  1. Provide the same system and user prompt to both the base model and the fine-tuned model.
  2. Capture the SQL query generated by each.
  3. Execute each query against our live MCP server.
  4. Compare the query result to the ground_truth from our dataset.
  5. Keep a running score for each model.
This process will give us a clear, data-driven view of how much more accurate our model became after reinforcement fine-tuning.
Real World 🌍 This is a critical step in any MLOps loop. Evaluating a model on a consistent test set is the only way to prove that your efforts have resulted in a tangible improvement. In production, you’d also want to:
  • Track latency and cost metrics
  • Monitor for drift over time
  • A/B test against your current solution
  • Collect user feedback on query quality
import requests
import json
import os
import time
from fireworks import LLM
from tqdm.auto import tqdm
from dotenv import load_dotenv

load_dotenv()

# --- 1. SETUP: Define the models to compare, server URL, and dataset path ---

# IMPORTANT: Make sure your FIREWORKS_API_KEY is set as an environment variable.
if "FIREWORKS_API_KEY" not in os.environ:
    print("FATAL: FIREWORKS_API_KEY environment variable not set.")

# The base model you used for the fine-tuning job.
BASE_MODEL_ID = "accounts/fireworks/models/qwen2p5-7b"  # <--- Replace if you used a different base model
LARGE_BASE_MODEL_ID = "accounts/fireworks/models/qwen3-coder-480b-a35b-instruct"

# IMPORTANT: Replace this with the model ID of your new fine-tuned model.
#FINE_TUNED_MODEL_ID = "accounts/<your-account-id>/models/<your-base-model-id>"  # <--- Replace with your fine-tuned model ID
FINE_TUNED_MODEL_ID = "accounts/<your-account-id>/models/<your-base-model-id>"

MCP_SERVER_URL = None  # <--- PUT MCP SERVER URL HERE without the /mcp/ suffix at the end
DATASET_FILE_PATH = "data/final_rft_sql_test_data.jsonl"

# --- 2. Create LLM Objects ---
base_model_llm = None
large_base_model_llm = None
fine_tuned_model_llm = None
try:
    base_model_llm = LLM(model=BASE_MODEL_ID, deployment_type="auto", api_key=os.getenv("FIREWORKS_API_KEY"))
    large_base_model_llm = LLM(model=LARGE_BASE_MODEL_ID, deployment_type="auto", api_key=os.getenv("FIREWORKS_API_KEY"))
    fine_tuned_model_llm = LLM(model=FINE_TUNED_MODEL_ID, deployment_type="auto", api_key=os.getenv("FIREWORKS_API_KEY"))
    print("LLM objects for all three models created successfully.")
except Exception as e:
    print(f"FATAL: Could not create LLM objects. Error: {e}")

# --- 3. Load Dataset ---
dataset = []
if all([base_model_llm, large_base_model_llm, fine_tuned_model_llm]):
    try:
        with open(DATASET_FILE_PATH, 'r') as f:
            dataset = [json.loads(line) for line in f]
        print(f"Loaded {len(dataset)} evaluation examples from '{DATASET_FILE_PATH}'.")
    except Exception as e:
        print(f"FATAL: Could not load dataset. Error: {e}")
        dataset = []

# --- 4. HELPER AND EVALUATION FUNCTIONS ---

def parse_duckdb_ascii_table(table_string: str) -> list[dict]:
    """
    Parses a DuckDB-style ASCII table string into a list of dictionaries.
    This version robustly handles 'NULL' values and empty strings.
    """
    lines = table_string.strip().split('\n')
    content_lines = [line for line in lines if line.strip() and not line.startswith('+')]
    if len(content_lines) < 2:
        return []
    
    header_raw = [h.strip() for h in content_lines[0].split('|')[1:-1]]
    data_lines = content_lines[1:]
    
    if len(data_lines) > 0:
        try:
            first_data_values = [v.strip() for v in data_lines[0].split('|')[1:-1]]
            if len(first_data_values) == len(header_raw) and all(v.isupper() for v in first_data_values):
                data_lines = data_lines[1:]
        except IndexError:
            pass

    rows = []
    for line in data_lines:
        try:
            values_raw = [v.strip() for v in line.split('|')[1:-1]]
            if len(values_raw) == len(header_raw):
                row_dict = {}
                for i, header in enumerate(header_raw):
                    value_str = values_raw[i]
                    if value_str.upper() == 'NULL' or value_str == '':
                        row_dict[header] = None
                        continue
                    
                    try:
                        if '.' in value_str:
                            row_dict[header] = float(value_str)
                        else:
                            row_dict[header] = int(value_str)
                    except (ValueError, TypeError):
                        row_dict[header] = value_str
                rows.append(row_dict)
        except IndexError:
            continue
    return rows

def are_results_equal(predicted_rows: list[dict], ground_truth_rows: list[dict]) -> bool:
    """
    Compares datasets by converting all values to strings and sorting them,
    which ignores row order, column order, and data types (e.g., int vs float).
    """
    try:
        gt_values = sorted([sorted(map(str, row.values())) for row in ground_truth_rows])
        predicted_values = sorted([sorted(map(str, row.values())) for row in predicted_rows])
        return gt_values == predicted_values
    except Exception:
        return False

def get_sql_and_evaluate(llm_obj, system_prompt: str, user_prompt: str, ground_truth: list[dict]) -> int:
    """
    Calls a pre-configured LLM object to get a SQL query, executes it, and compares to ground truth.
    Returns 1 for a correct result, 0 for an incorrect one.
    """
    try:
        # Step 1: Get SQL from the model
        messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
        response = llm_obj.chat.completions.create(messages=messages, temperature=0.0)
        sql_query = response.choices[0].message.content.strip()

        # Step 2: Execute SQL on MCP server
        headers = {"Content-Type": "application/json", "Accept": "application/json, text/event-stream"}
        payload = {"id": "eval-query-1", "jsonrpc": "2.0", "method": "tools/call", "params": {"session": {"id": "full-eval-session"}, "name": "query", "arguments": {"query": sql_query}}}

        response_data = None
        with requests.post(f"{MCP_SERVER_URL}/mcp/", headers=headers, json=payload, timeout=30, stream=True) as mcp_response:
            mcp_response.raise_for_status()
            for line in mcp_response.iter_lines():
                if line and line.decode('utf-8').startswith('data:'):
                    json_part = line.decode('utf-8')[len('data:'):].strip()
                    if json_part:
                        response_data = json.loads(json_part)
                        break

        if response_data is None or "error" in response_data:
            return 0

        # Step 3: Parse and compare results
        ascii_table = response_data['result']['content'][0]['text']
        predicted_rows = parse_duckdb_ascii_table(ascii_table)

        return 1 if are_results_equal(predicted_rows, ground_truth) else 0
    except Exception as e:
        print(f"--> Error during evaluation for model {llm_obj.model}: {e}")
        return 0

# --- 5. RUN THE FULL EVALUATION ---

base_model_score = 0
large_base_model_score = 0
fine_tuned_model_score = 0

if dataset:
    print("\nStarting evaluation...")
    for item in tqdm(dataset, desc="Evaluating models"):
        system_prompt = item["messages"][0]["content"]
        user_prompt = item["messages"][1]["content"]
        ground_truth = item["ground_truth"]

        # Evaluate base model
        base_model_score += get_sql_and_evaluate(base_model_llm, system_prompt, user_prompt, ground_truth)
        time.sleep(1)  # Be nice to the API

        # Evaluate large base model
        large_base_model_score += get_sql_and_evaluate(large_base_model_llm, system_prompt, user_prompt, ground_truth)
        time.sleep(1)

        # Evaluate fine-tuned model
        fine_tuned_model_score += get_sql_and_evaluate(fine_tuned_model_llm, system_prompt, user_prompt, ground_truth)
        time.sleep(1)

# --- 6. REPORT RESULTS ---

if dataset:
    total = len(dataset)
    base_accuracy = (base_model_score / total) * 100
    large_base_accuracy = (large_base_model_score / total) * 100
    tuned_accuracy = (fine_tuned_model_score / total) * 100

    print("\n" + "="*25)
    print("  EVALUATION COMPLETE")
    print("="*25)
    print(f"Total Examples: {total}\n")
    print("--- BASE MODEL ---")
    print(f"Model ID: {BASE_MODEL_ID}")
    print(f"Correct: {base_model_score}/{total}")
    print(f"Accuracy: {base_accuracy:.2f}%\n")

    print("--- LARGE BASE MODEL ---")
    print(f"Model ID: {LARGE_BASE_MODEL_ID}")
    print(f"Correct: {large_base_model_score}/{total}")
    print(f"Accuracy: {large_base_accuracy:.2f}%\n")

    print("--- FINE-TUNED MODEL ---")
    print(f"Model ID: {FINE_TUNED_MODEL_ID}")
    print(f"Correct: {fine_tuned_model_score}/{total}")
    print(f"Accuracy: {tuned_accuracy:.2f}%\n")
    
    print("="*25)
    print("  PERFORMANCE LIFT")
    print("="*25)
    print(f"Fine-Tuned vs. Base: {tuned_accuracy - base_accuracy:+.2f}%")
    print(f"Fine-Tuned vs. Large Base: {tuned_accuracy - large_base_accuracy:+.2f}%")

else:
    print("Evaluation skipped because the dataset or LLM objects could not be loaded.")
    LLM objects for all three models created successfully.
    Loaded 92 evaluation examples from 'data/final_rft_sql_test_data.jsonl'.
    
    Starting evaluation...


    Evaluating models:   0%|          | 0/92 [00:00<?, ?it/s]
      result = func(*args, **kwargs)
    Evaluating models: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 92/92 [14:41<00:00,  9.59s/it]  

    
    =========================
      EVALUATION COMPLETE
    =========================
    Total Examples: 92
    
    --- BASE MODEL ---
    Model ID: accounts/fireworks/models/qwen2p5-7b
    Correct: 22/92
    Accuracy: 23.91%
    
    --- LARGE BASE MODEL ---
    Model ID: accounts/fireworks/models/qwen3-coder-480b-a35b-instruct
    Correct: 32/92
    Accuracy: 34.78%
    
    --- FINE-TUNED MODEL ---
    Model ID: accounts/<account-id>/models/<base-model-id>
    Correct: 52/92
    Accuracy: 56.52%
    
    =========================
      PERFORMANCE LIFT
    =========================
    Fine-Tuned vs. Base: +32.61%
    Fine-Tuned vs. Large Base: +21.74%

16. ✨ Cleanup & Conclusion

Congratulations! You’ve successfully completed the entire Reinforcement Fine-Tuning loop. You started with just a database schema and ended with a highly specialized, performant, and data-aware AI model.

Cleanup

Cloud resources and model deployments can incur costs, so it’s good practice to clean up any resources you no longer need.
  • Check your Deployments: Navigate to the Deployments tab in your Fireworks AI dashboard. Here you can monitor and manage all your deployed models.
  • Delete Unneeded Models: Feel free to delete any deployments you no longer need. For example, you might have deployed the base or large-base models during the evaluation step to compare against your fine-tuned model. These can now be safely removed to save costs.
  • Delete Cloud Run service and container image: Feel free to delete your MCP server Cloud Run service and container image to avoid stray storage costs.
You can, of course, continue using your new fine-tuned SQL generation model for any application you see fit!

Conclusions

The evaluation results from the previous step highlight the power of this approach.
  • Performance on par with massive models: Our fine-tuned 7B parameter model performs better than a much larger model like qwen3-coder-480b-a35b-instruct on this specific dataset. This is because it has been fine-tuned to understand the data schema via real query generation and execution.
  • Efficiency Gains: A 7B model is significantly faster and cheaper to run than a 480B one, offering production-grade performance at a fraction of the cost and latency.
  • High-Level Capability on Complex Tasks: The queries in this dataset are relatively complex, which is reflected in the final accuracy score of around 57%. This is a strong result, demonstrating that for a specialized domain, a smaller model can be tuned to achieve a level of performance that exceeds larger models like qwen3-coder-480b-a35b-instruct. Specifically, the final accuracy scores we measured for this dataset were:
    • Qwen2.5 7B (base): 23.91% accuracy (22/92 correct on the held-out test set)
    • Qwen3 Coder 480B A35B Instruct (base): 34.78% accuracy (32/92 correct on the held-out test set)
    • Qwen2.5 7B (RFT tuned): 56.52% accuracy (52/92 correct on the held-out test set)

Throughout this tutorial, we demonstrated a complete, end-to-end workflow for creating a fine-tuned text-to-SQL model. We began with the absolute minimum requirement, a database schema, and used a series of LLM-driven steps to generate a safe, synthetic data sandbox. From there, we generated a rich dataset of queries and answers, which we used to fine-tune a model using the Fireworks RFT platform. The final result is a small, efficient model that can accurately query data it has never seen, a task that was previously only possible with vastly larger and more expensive models. This pattern of schema β†’ synthetic data β†’ RFT is a secure, effective, and repeatable methodology for teaching language models to become expert users of your private data and custom APIs, without ever exposing the underlying sensitive information.

Appendix

Testing more open models on Fireworks

import requests
import json
import os
import time
import datetime
from fireworks import LLM
from tqdm.auto import tqdm
from dotenv import load_dotenv

load_dotenv()

# --- 1. SETUP: Define the models to compare, server URL, and dataset path ---

# IMPORTANT: Make sure your FIREWORKS_API_KEY is set as an environment variable.
if "FIREWORKS_API_KEY" not in os.environ:
    print("FATAL: FIREWORKS_API_KEY environment variable not set.")

# IMPORTANT: Replace these with the model IDs of the models you want to evaluate.
MODEL_1_ID = "accounts/fireworks/models/qwen3-coder-480b-a35b-instruct"  # <--- Replace with your model ID
MODEL_2_ID = "accounts/fireworks/models/kimi-k2-instruct" # <--- Replace with your model ID
MODEL_3_ID = "accounts/fireworks/models/deepseek-v3" # <--- Replace with your model ID

MCP_SERVER_URL = None  # <--- PUT MCP SERVER URL HERE without the /mcp/ suffix at the end
DATASET_FILE_PATH = "data/final_rft_sql_test_data.jsonl"

# --- 2. Create LLM Objects ---
model_1_llm = None
model_2_llm = None
model_3_llm = None
try:
    model_1_llm = LLM(model=MODEL_1_ID, deployment_type="auto", api_key=os.getenv("FIREWORKS_API_KEY"))
    model_2_llm = LLM(model=MODEL_2_ID, deployment_type="auto", api_key=os.getenv("FIREWORKS_API_KEY"))
    model_3_llm = LLM(model=MODEL_3_ID, deployment_type="auto", api_key=os.getenv("FIREWORKS_API_KEY"))
    print("LLM objects for all three models created successfully.")
except Exception as e:
    print(f"FATAL: Could not create LLM objects. Error: {e}")

# --- 3. Load Dataset ---
dataset = []
if all([model_1_llm, model_2_llm, model_3_llm]):
    try:
        with open(DATASET_FILE_PATH, 'r') as f:
            dataset = [json.loads(line) for line in f]
        print(f"Loaded {len(dataset)} evaluation examples from '{DATASET_FILE_PATH}'.")
    except Exception as e:
        print(f"FATAL: Could not load dataset. Error: {e}")
        dataset = []

# --- 4. SETUP for Manual Logging ---
MANUAL_LOG_FILE = "evaluation_manual_open.log"

# Initialize the log file by clearing it and writing a header
with open(MANUAL_LOG_FILE, 'w', encoding='utf-8') as f:
    f.write(f"--- Evaluation Log ---\n")
    f.write(f"Log started at {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")

def write_to_log(message: str):
    """Appends a timestamped message to the manual log file."""
    timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
    try:
        with open(MANUAL_LOG_FILE, 'a', encoding='utf-8') as f:
            f.write(f"{timestamp} - {message}\n")
    except Exception as e:
        # If logging fails, print an error to the console
        print(f"CRITICAL: Failed to write to log file '{MANUAL_LOG_FILE}'. Error: {e}")

print(f"Manual logging configured. Log file is '{MANUAL_LOG_FILE}'.")


# --- 5. HELPER AND EVALUATION FUNCTIONS ---

def parse_duckdb_ascii_table(table_string: str) -> list[dict]:
    """
    Parses a DuckDB-style ASCII table string into a list of dictionaries.
    This version robustly handles 'NULL' values and empty strings.
    """
    lines = table_string.strip().split('\n')
    content_lines = [line for line in lines if line.strip() and not line.startswith('+')]
    if len(content_lines) < 2:
        return []
    
    header_raw = [h.strip() for h in content_lines[0].split('|')[1:-1]]
    data_lines = content_lines[1:]
    
    if len(data_lines) > 0:
        try:
            first_data_values = [v.strip() for v in data_lines[0].split('|')[1:-1]]
            if len(first_data_values) == len(header_raw) and all(v.isupper() for v in first_data_values):
                data_lines = data_lines[1:]
        except IndexError:
            pass

    rows = []
    for line in data_lines:
        try:
            values_raw = [v.strip() for v in line.split('|')[1:-1]]
            if len(values_raw) == len(header_raw):
                row_dict = {}
                for i, header in enumerate(header_raw):
                    value_str = values_raw[i]
                    if value_str.upper() == 'NULL' or value_str == '':
                        row_dict[header] = None
                        continue
                    
                    try:
                        if '.' in value_str:
                            row_dict[header] = float(value_str)
                        else:
                            row_dict[header] = int(value_str)
                    except (ValueError, TypeError):
                        row_dict[header] = value_str
                rows.append(row_dict)
        except IndexError:
            continue
    return rows

def are_results_equal(predicted_rows: list[dict], ground_truth_rows: list[dict]) -> bool:
    """
    Compares datasets by converting all values to strings and sorting them,
    which ignores row order, column order, and data types (e.g., int vs float).
    """
    try:
        gt_values = sorted([sorted(map(str, row.values())) for row in ground_truth_rows])
        predicted_values = sorted([sorted(map(str, row.values())) for row in predicted_rows])
        return gt_values == predicted_values
    except Exception:
        return False


def get_sql_and_evaluate(llm_obj, system_prompt: str, user_prompt: str, ground_truth: list[dict]) -> int:
    """
    Calls a pre-configured LLM object to get a SQL query, executes it, logs the process,
    and compares to ground truth. Returns 1 for a correct result, 0 for an incorrect one.
    """
    model_id = llm_obj.model
    write_to_log(f"--- Fireworks Evaluation Start: Model {model_id} ---")
    write_to_log(f"System Prompt: {system_prompt}")
    write_to_log(f"User Prompt: {user_prompt}")
    try:
        # Step 1: Get SQL from the model
        messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
        response = llm_obj.chat.completions.create(messages=messages, temperature=0.0)
        raw_output = response.choices[0].message.content.strip()
        sql_query = raw_output
        write_to_log(f"Raw LLM Output: {raw_output}")
        write_to_log(f"Extracted SQL: {sql_query}")

        # Step 2: Execute SQL on MCP server
        headers = {"Content-Type": "application/json", "Accept": "application/json, text/event-stream"}
        payload = {"id": "eval-query-1", "jsonrpc": "2.0", "method": "tools/call", "params": {"session": {"id": "full-eval-session"}, "name": "query", "arguments": {"query": sql_query}}}

        response_data = None
        with requests.post(f"{MCP_SERVER_URL}/mcp/", headers=headers, json=payload, timeout=30, stream=True) as mcp_response:
            mcp_response.raise_for_status()
            for line in mcp_response.iter_lines():
                if line and line.decode('utf-8').startswith('data:'):
                    json_part = line.decode('utf-8')[len('data:'):].strip()
                    if json_part:
                        response_data = json.loads(json_part)
                        break

        if response_data is None or "error" in response_data:
            error_message = response_data.get('error', 'No response data') if response_data else 'No response data'
            write_to_log(f"ERROR: Failed to execute query. Response: {error_message}")
            write_to_log("Result: Incorrect (Query Execution Failed)")
            write_to_log(f"--- Fireworks Evaluation End: Model {model_id} ---\n")
            return 0

        # Step 3: Parse and compare results
        ascii_table = response_data['result']['content'][0]['text']
        predicted_rows = parse_duckdb_ascii_table(ascii_table)
        write_to_log(f"Execution Result (predicted): {predicted_rows}")
        write_to_log(f"Ground Truth: {ground_truth}")

        is_correct = are_results_equal(predicted_rows, ground_truth)
        write_to_log(f"Result: {'Correct' if is_correct else 'Incorrect'}")
        write_to_log(f"--- Fireworks Evaluation End: Model {model_id} ---\n")
        return 1 if is_correct else 0
    except Exception as e:
        error_msg = f"ERROR: Exception during evaluation for model {model_id}: {e}"
        print(f"--> {error_msg}")
        write_to_log(error_msg)
        write_to_log(f"--- Fireworks Evaluation End: Model {model_id} ---\n")
        return 0

# --- 6. RUN THE FULL EVALUATION ---

model_1_score = 0
model_2_score = 0
model_3_score = 0

if dataset:
    msg = "\nStarting evaluation..."
    print(msg)
    write_to_log(msg.strip())
    for item in tqdm(dataset, desc="Evaluating models"):
        system_prompt = item["messages"][0]["content"]
        user_prompt = item["messages"][1]["content"]
        ground_truth = item["ground_truth"]

        # Evaluate model 1
        model_1_score += get_sql_and_evaluate(model_1_llm, system_prompt, user_prompt, ground_truth)
        time.sleep(1)  # Be nice to the API

        # Evaluate model 2
        model_2_score += get_sql_and_evaluate(model_2_llm, system_prompt, user_prompt, ground_truth)
        time.sleep(1)

        # Evaluate model 3
        model_3_score += get_sql_and_evaluate(model_3_llm, system_prompt, user_prompt, ground_truth)
        time.sleep(1)

# --- 7. REPORT RESULTS ---

if dataset:
    total = len(dataset)
    model_1_accuracy = (model_1_score / total) * 100
    model_2_accuracy = (model_2_score / total) * 100
    model_3_accuracy = (model_3_score / total) * 100

    print("\n" + "="*25)
    print("  EVALUATION COMPLETE")
    print("="*25)
    print(f"Total Examples: {total}\n")
    
    print("--- MODEL 1 ---")
    print(f"Model ID: {MODEL_1_ID}")
    print(f"Correct: {model_1_score}/{total}")
    print(f"Accuracy: {model_1_accuracy:.2f}%\n")

    print("--- MODEL 2 ---")
    print(f"Model ID: {MODEL_2_ID}")
    print(f"Correct: {model_2_score}/{total}")
    print(f"Accuracy: {model_2_accuracy:.2f}%\n")

    print("--- MODEL 3 ---")
    print(f"Model ID: {MODEL_3_ID}")
    print(f"Correct: {model_3_score}/{total}")
    print(f"Accuracy: {model_3_accuracy:.2f}%\n")

    # Write final summary to log file
    write_to_log("="*25)
    write_to_log("  EVALUATION COMPLETE")
    write_to_log("="*25)
    write_to_log(f"Total Examples: {total}")
    write_to_log(f"--- MODEL 1 --- Model ID: {MODEL_1_ID} | Correct: {model_1_score}/{total} | Accuracy: {model_1_accuracy:.2f}%")
    write_to_log(f"--- MODEL 2 --- Model ID: {MODEL_2_ID} | Correct: {model_2_score}/{total} | Accuracy: {model_2_accuracy:.2f}%")
    write_to_log(f"--- MODEL 3 --- Model ID: {MODEL_3_ID} | Correct: {model_3_score}/{total} | Accuracy: {model_3_accuracy:.2f}%")

else:
    msg = "Evaluation skipped because the dataset or LLM objects could not be loaded"
    print(msg)
    write_to_log(f"WARNING: {msg}")
    LLM objects for all three models created successfully.
    Loaded 92 evaluation examples from 'data/final_rft_sql_test_data.jsonl'.
    Manual logging configured. Log file is 'evaluation_manual_open.log'.
    
    Starting evaluation...


    Evaluating models: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 92/92 [16:59<00:00, 11.08s/it]

    
    =========================
      EVALUATION COMPLETE
    =========================
    Total Examples: 92
    
    --- MODEL 1 ---
    Model ID: accounts/fireworks/models/qwen3-coder-480b-a35b-instruct
    Correct: 32/92
    Accuracy: 34.78%
    
    --- MODEL 2 ---
    Model ID: accounts/fireworks/models/kimi-k2-instruct
    Correct: 26/92
    Accuracy: 28.26%
    
    --- MODEL 3 ---
    Model ID: accounts/fireworks/models/deepseek-v3
    Correct: 25/92
    Accuracy: 27.17%

Testing proprietary models

# Cell to evaluate OpenAI and Anthropic models (with manual logging) - GPT-4o and Claude Sonnet 4

import openai
import anthropic
import os
import requests
import json
import time
import re
from tqdm.auto import tqdm
from dotenv import load_dotenv

load_dotenv()

# --- 1. SETUP for Environment and Data ---

MCP_SERVER_URL = None  # <--- PUT MCP SERVER URL HERE without the /mcp/ suffix at the end
DATASET_FILE_PATH = "data/final_rft_sql_test_data.jsonl"

# --- Load Dataset ---
dataset = []
try:
    with open(DATASET_FILE_PATH, 'r') as f:
        dataset = [json.loads(line) for line in f]
    print(f"Loaded {len(dataset)} evaluation examples from '{DATASET_FILE_PATH}'.")
except Exception as e:
    print(f"FATAL: Could not load dataset. Error: {e}")
    dataset = []

# --- 2. SETUP for OpenAI and Anthropic ---

# IMPORTANT: Make sure your API keys are set as environment variables.
if "OPENAI_API_KEY" not in os.environ:
    print("FATAL: OPENAI_API_KEY environment variable not set.")
if "ANTHROPIC_API_KEY" not in os.environ:
    print("FATAL: ANTHROPIC_API_KEY environment variable not set.")

# Define Model IDs
GPT_MODEL_ID = "gpt-4o"
CLAUDE_MODEL_ID = "claude-sonnet-4-20250514"

# --- Create API Clients ---
openai_client = None
anthropic_client = None
try:
    openai_client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
    anthropic_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
    print("OpenAI and Anthropic clients created successfully.")
except Exception as e:
    print(f"FATAL: Could not create clients. Error: {e}")

# --- 3. HELPER FUNCTIONS ---

def clean_sql_output(sql_text: str) -> str:
    """
    Cleans SQL output by removing markdown code blocks and extracting pure SQL.
    Handles various formats like ```sql, ```SQL, or just ```
    """
    # Strip leading/trailing whitespace
    sql_text = sql_text.strip()
    
    # Pattern to match markdown code blocks with optional language specification
    # This will match ```sql, ```SQL, ```postgres, or just ```
    pattern = r'^```(?:sql|SQL|postgres|POSTGRES)?\s*\n?(.*?)```$'
    match = re.match(pattern, sql_text, re.DOTALL)
    
    if match:
        # Extract the SQL from within the code block
        sql_text = match.group(1).strip()
    else:
        # If no full code block match, just remove any leading ```sql or trailing ```
        # Remove leading markdown
        sql_text = re.sub(r'^```(?:sql|SQL|postgres|POSTGRES)?\s*\n?', '', sql_text)
        # Remove trailing markdown
        sql_text = re.sub(r'\n?```$', '', sql_text)
    
    return sql_text.strip()

def parse_duckdb_ascii_table(table_string: str) -> list[dict]:
    """
    Parses a DuckDB-style ASCII table string into a list of dictionaries.
    This version robustly handles 'NULL' values and empty strings.
    """
    lines = table_string.strip().split('\n')
    content_lines = [line for line in lines if line.strip() and not line.startswith('+')]
    if len(content_lines) < 2:
        return []
    
    header_raw = [h.strip() for h in content_lines[0].split('|')[1:-1]]
    data_lines = content_lines[1:]
    
    if len(data_lines) > 0:
        try:
            first_data_values = [v.strip() for v in data_lines[0].split('|')[1:-1]]
            if len(first_data_values) == len(header_raw) and all(v.isupper() for v in first_data_values):
                data_lines = data_lines[1:]
        except IndexError:
            pass

    rows = []
    for line in data_lines:
        try:
            values_raw = [v.strip() for v in line.split('|')[1:-1]]
            if len(values_raw) == len(header_raw):
                row_dict = {}
                for i, header in enumerate(header_raw):
                    value_str = values_raw[i]
                    if value_str.upper() == 'NULL' or value_str == '':
                        row_dict[header] = None
                        continue
                    
                    try:
                        if '.' in value_str:
                            row_dict[header] = float(value_str)
                        else:
                            row_dict[header] = int(value_str)
                    except (ValueError, TypeError):
                        row_dict[header] = value_str
                rows.append(row_dict)
        except IndexError:
            continue
    return rows

def are_results_equal(predicted_rows: list[dict], ground_truth_rows: list[dict]) -> bool:
    """
    Compares datasets by converting all values to strings and sorting them,
    which ignores row order, column order, and data types (e.g., int vs float).
    """
    try:
        gt_values = sorted([sorted(map(str, row.values())) for row in ground_truth_rows])
        predicted_values = sorted([sorted(map(str, row.values())) for row in predicted_rows])
        return gt_values == predicted_values
    except Exception:
        return False

# --- 4. EVALUATION FUNCTIONS for OpenAI and Anthropic ---

def get_sql_and_evaluate_openai(client, model_id: str, system_prompt: str, user_prompt: str, ground_truth: list[dict]) -> int:
    """
    Calls OpenAI API, executes SQL, and compares to ground truth.
    """
    try:
        messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
        response = client.chat.completions.create(model=model_id, messages=messages, temperature=0.0)
        raw_output = response.choices[0].message.content.strip()
        
        # Clean the SQL output
        sql_query = clean_sql_output(raw_output)

        headers = {"Content-Type": "application/json", "Accept": "application/json, text/event-stream"}
        payload = {"id": "eval-query-1", "jsonrpc": "2.0", "method": "tools/call", "params": {"session": {"id": "full-eval-session"}, "name": "query", "arguments": {"query": sql_query}}}
        response_data = None
        with requests.post(f"{MCP_SERVER_URL}/mcp/", headers=headers, json=payload, timeout=30, stream=True) as mcp_response:
            mcp_response.raise_for_status()
            for line in mcp_response.iter_lines():
                if line and line.decode('utf-8').startswith('data:'):
                    json_part = line.decode('utf-8')[len('data:'):].strip()
                    if json_part:
                        response_data = json.loads(json_part)
                        break
        
        if response_data is None or "error" in response_data:
            return 0

        ascii_table = response_data['result']['content'][0]['text']
        predicted_rows = parse_duckdb_ascii_table(ascii_table)
        
        is_correct = are_results_equal(predicted_rows, ground_truth)
        return 1 if is_correct else 0
    except Exception as e:
        print(f"--> Error during evaluation for model {model_id}: {e}")
        return 0

def get_sql_and_evaluate_anthropic(client, model_id: str, system_prompt: str, user_prompt: str, ground_truth: list[dict]) -> int:
    """
    Calls Anthropic API, executes SQL, and compares to ground truth.
    """
    try:
        response = client.messages.create(model=model_id, system=system_prompt, messages=[{"role": "user", "content": user_prompt}], temperature=0.0, max_tokens=2048)
        raw_output = response.content[0].text.strip()
        
        # Clean the SQL output
        sql_query = clean_sql_output(raw_output)

        headers = {"Content-Type": "application/json", "Accept": "application/json, text/event-stream"}
        payload = {"id": "eval-query-1", "jsonrpc": "2.0", "method": "tools/call", "params": {"session": {"id": "full-eval-session"}, "name": "query", "arguments": {"query": sql_query}}}
        response_data = None
        with requests.post(f"{MCP_SERVER_URL}/mcp/", headers=headers, json=payload, timeout=30, stream=True) as mcp_response:
            mcp_response.raise_for_status()
            for line in mcp_response.iter_lines():
                if line and line.decode('utf-8').startswith('data:'):
                    json_part = line.decode('utf-8')[len('data:'):].strip()
                    if json_part:
                        response_data = json.loads(json_part)
                        break

        if response_data is None or "error" in response_data:
            return 0

        ascii_table = response_data['result']['content'][0]['text']
        predicted_rows = parse_duckdb_ascii_table(ascii_table)
        
        is_correct = are_results_equal(predicted_rows, ground_truth)
        return 1 if is_correct else 0
    except Exception as e:
        print(f"--> Error during evaluation for model {model_id}: {e}")
        return 0

# --- 5. RUN THE EVALUATION for OpenAI and Anthropic ---

gpt_model_score = 0
claude_model_score = 0

# --- Pre-evaluation Check ---
print("\n--- Pre-evaluation Check ---")
dataset_exists = 'dataset' in locals()
clients_exist = all([openai_client, anthropic_client])
print(f"1. 'dataset' variable exists: {dataset_exists}")
if dataset_exists:
    print(f"2. 'dataset' is not empty: {bool(dataset)}")
else:
    print("2. 'dataset' is not empty: N/A (Does not exist)")
print(f"3. OpenAI client created: {bool(openai_client)}")
print(f"4. Anthropic client created: {bool(anthropic_client)}")
print("--------------------------\n")

if dataset_exists and dataset and clients_exist:
    print("\nStarting evaluation for OpenAI and Anthropic models...")
    for item in tqdm(dataset, desc="Evaluating OpenAI/Anthropic"):
        system_prompt = item["messages"][0]["content"]
        user_prompt = item["messages"][1]["content"]
        ground_truth = item["ground_truth"]

        gpt_model_score += get_sql_and_evaluate_openai(openai_client, GPT_MODEL_ID, system_prompt, user_prompt, ground_truth)
        time.sleep(0.5)

        claude_model_score += get_sql_and_evaluate_anthropic(anthropic_client, CLAUDE_MODEL_ID, system_prompt, user_prompt, ground_truth)
        time.sleep(0.5)

    # --- 6. REPORT RESULTS for OpenAI and Anthropic ---
    total = len(dataset)
    gpt_accuracy = (gpt_model_score / total) * 100
    claude_accuracy = (claude_model_score / total) * 100

    print("\n" + "="*25)
    print("  EVALUATION COMPLETE")
    print("="*25)
    print(f"Total Examples: {total}\n")
    print("--- GPT MODEL (OpenAI) ---")
    print(f"Model ID: {GPT_MODEL_ID}")
    print(f"Correct: {gpt_model_score}/{total}")
    print(f"Accuracy: {gpt_accuracy:.2f}%\n")
    print("--- CLAUDE MODEL (Anthropic) ---")
    print(f"Model ID: {CLAUDE_MODEL_ID}")
    print(f"Correct: {claude_model_score}/{total}")
    print(f"Accuracy: {claude_accuracy:.2f}%\n")

else:
    print("\nEvaluation for OpenAI and Anthropic models skipped.")
    print("One of the pre-evaluation checks failed. Please check the output above.")
    Loaded 92 evaluation examples from 'data/final_rft_sql_test_data.jsonl'.
    Manual logging configured. Log file is 'evaluation_manual_proprietary.log'.
    OpenAI and Anthropic clients created successfully.
    
    --- Pre-evaluation Check ---
    1. 'dataset' variable exists: True
    2. 'dataset' is not empty: True
    3. OpenAI client created: True
    4. Anthropic client created: True
    --------------------------
    
    
    Starting evaluation for OpenAI and Anthropic models...


    Evaluating OpenAI/Anthropic: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 92/92 [08:06<00:00,  5.28s/it]

    
    =========================
      ADDITIONAL MODELS EVALUATION COMPLETE
    =========================
    Total Examples: 92
    
    --- GPT MODEL (OpenAI) ---
    Model ID: gpt-4o
    Correct: 22/92
    Accuracy: 23.91%
    
    --- CLAUDE MODEL (Anthropic) ---
    Model ID: claude-sonnet-4-20250514
    Correct: 27/92
    Accuracy: 29.35%
Note that Claude Sonnet 4 and GPT-4o sometimes output SQL queries wrapped in markdown formatting like ```sql <query_here>```, so we added a helper function to clean the output before executing the SQL query in those cases.