Creating Your First Reward Function

This step-by-step tutorial will guide you through the process of creating, testing, and deploying your first reward function using the Reward Kit.

Prerequisites

Before starting this tutorial, make sure you have:

  1. Python 3.8+ installed on your system
  2. Reward Kit installed: pip install reward-kit
  3. Fireworks API credentials (for deployment)

Step 1: Set Up Your Project

First, let’s create a directory structure for our project:

mkdir -p my_reward_function/metrics/relevance
cd my_reward_function

Step 2: Create a Basic Reward Function

Let’s create a simple reward function that evaluates the relevance of a response to a user’s query.

Create a file at metrics/relevance/main.py:

"""
Relevance Metric: Evaluates how well a response addresses the user's query.
"""

from typing import List, Dict, Optional
from reward_kit import reward_function, EvaluateResult, MetricResult, Message

@reward_function
def evaluate(messages: List[Message], original_messages: List[Message] = list(), **kwargs) -> EvaluateResult:
    """
    Evaluate the relevance of a response to the user's query.
    
    Args:
        messages: List of conversation messages
        original_messages: Original messages (context)
        **kwargs: Additional parameters
        
    Returns:
        EvaluateResult with score and metrics
    """
    # Validate input
    if not messages or len(messages) < 2:
        return EvaluateResult(
            score=0.0,
            reason="Insufficient messages",
            metrics={}
        )
    
    # Find the user query (most recent user message)
    user_query = None
    for msg in reversed(messages[:-1]):
        if msg.role == "user":
            user_query = msg.content
            break
    
    if not user_query:
        return EvaluateResult(
            score=0.0,
            reason="No user query found",
            metrics={}
        )
    
    # Get the assistant's response (last message)
    response = messages[-1].content or ""
    if messages[-1].role != "assistant":
        return EvaluateResult(
            score=0.0,
            reason="Last message is not from assistant",
            metrics={}
        )
    
    # Calculate keyword overlap
    user_keywords = set(user_query.lower().split())
    response_keywords = set(response.lower().split())
    
    # Remove common stop words
    stop_words = {"a", "an", "the", "and", "or", "but", "is", "are", "on", "in", "at", "to", "for", "with", "by", "of"}
    user_keywords = user_keywords - stop_words
    response_keywords = response_keywords - stop_words
    
    if not user_keywords:
        return EvaluateResult(
            score=0.5,
            reason="No significant keywords in user query",
            metrics={}
        )
    
    # Calculate overlap
    common_keywords = user_keywords.intersection(response_keywords)
    overlap_ratio = len(common_keywords) / len(user_keywords)
    
    # Calculate basic relevance score
    relevance_score = min(overlap_ratio, 1.0)
    
    # Check for direct answer patterns
    direct_answer_patterns = [
        "the answer is",
        "to answer your question",
        "in response to your question",
        "regarding your question",
        "to address your query"
    ]
    
    has_direct_answer = any(pattern in response.lower() for pattern in direct_answer_patterns)
    if has_direct_answer:
        relevance_score = min(relevance_score + 0.2, 1.0)  # Boost for direct answers
        
    # Create metrics dictionary
    metrics = {
        "keyword_overlap": MetricResult(
            score=overlap_ratio,
            reason=f"Keyword overlap ratio: {overlap_ratio:.2f} ({len(common_keywords)}/{len(user_keywords)} keywords)"
        ),
        "direct_answer": MetricResult(
            score=1.0 if has_direct_answer else 0.0,
            reason="Response explicitly addresses the query" if has_direct_answer else "No explicit answer indicators"
        )
    }
    
    return EvaluateResult(
        score=relevance_score,
        reason=f"Relevance score: {relevance_score:.2f}",
        metrics=metrics
    )

Step 3: Create Sample Conversations

Let’s create some sample conversations to test our reward function.

Create a file named samples.jsonl:

{"messages": [{"role": "user", "content": "What is machine learning?"}, {"role": "assistant", "content": "Machine learning is a subset of artificial intelligence that enables computers to learn from data and improve from experience without being explicitly programmed. It focuses on developing algorithms that can access data, learn from it, and make predictions or decisions."}]}
{"messages": [{"role": "user", "content": "What is machine learning?"}, {"role": "assistant", "content": "Artificial intelligence has many applications in today's world, including image recognition, natural language processing, and recommendation systems."}]}
{"messages": [{"role": "user", "content": "How can I improve my cooking skills?"}, {"role": "assistant", "content": "To answer your question about improving cooking skills, I recommend: 1) Start with basic recipes and gradually increase complexity, 2) Learn fundamental techniques like knife skills, 3) Practice consistently, 4) Use quality ingredients, and 5) Take a cooking class or watch tutorial videos."}]}

Step 4: Create a Test Script

Let’s create a script to test our reward function locally.

Create a file named test_reward.py:

"""
Test script for the relevance reward function.
"""

import json
import sys
import os
from pathlib import Path

# Ensure the metrics directory is in the path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__))))

from metrics.relevance.main import evaluate

def load_samples(file_path):
    """Load sample conversations from a JSONL file."""
    samples = []
    with open(file_path, 'r') as f:
        for line in f:
            if line.strip():
                samples.append(json.loads(line))
    return samples

def main():
    # Load sample conversations
    samples_path = Path("samples.jsonl")
    if not samples_path.exists():
        print(f"Error: Sample file {samples_path} not found.")
        return
    
    samples = load_samples(samples_path)
    print(f"Loaded {len(samples)} sample conversations.")
    
    # Test each sample
    for i, sample in enumerate(samples):
        print(f"\nSample {i+1}:")
        messages = sample.get("messages", [])
        
        # Find user and assistant messages
        user_msg = next((msg for msg in messages if msg.get("role") == "user"), None)
        assistant_msg = next((msg for msg in messages if msg.get("role") == "assistant"), None)
        
        if user_msg and assistant_msg:
            print(f"User: {user_msg.get('content', '')[:50]}...")
            print(f"Assistant: {assistant_msg.get('content', '')[:50]}...")
            
            # Evaluate the sample
            result = evaluate(messages=messages)
            
            # Print the results
            print(f"Overall Score: {result.score:.2f}")
            print(f"Reason: {result.reason}")
            print("Metrics:")
            for name, metric in result.metrics.items():
                print(f"  {name}: {metric.score:.2f} - {metric.reason}")
        else:
            print("Invalid sample: Missing user or assistant message.")

if __name__ == "__main__":
    main()

Step 5: Run Local Tests

Run your test script to see how your reward function performs:

python test_reward.py

You should see output similar to:

Loaded 3 sample conversations.

Sample 1:
User: What is machine learning?...
Assistant: Machine learning is a subset of artificial intel...
Overall Score: 0.67
Reason: Relevance score: 0.67
Metrics:
  keyword_overlap: 0.67 - Keyword overlap ratio: 0.67 (2/3 keywords)
  direct_answer: 0.00 - No explicit answer indicators

Sample 2:
User: What is machine learning?...
Assistant: Artificial intelligence has many applications in...
Overall Score: 0.33
Reason: Relevance score: 0.33
Metrics:
  keyword_overlap: 0.33 - Keyword overlap ratio: 0.33 (1/3 keywords)
  direct_answer: 0.00 - No explicit answer indicators

Sample 3:
User: How can I improve my cooking skills?...
Assistant: To answer your question about improving cooking ...
Overall Score: 1.00
Reason: Relevance score: 1.00
Metrics:
  keyword_overlap: 0.80 - Keyword overlap ratio: 0.80 (4/5 keywords)
  direct_answer: 1.00 - Response explicitly addresses the query

Step 6: Preview Using the CLI

Now let’s use the Reward Kit CLI to preview our evaluation with the sample data:

# Make sure your authentication is set up
export FIREWORKS_API_KEY=your_api_key

# Run the preview
reward-kit preview \
  --metrics-folders "relevance=./metrics/relevance" \
  --samples ./samples.jsonl

You should see preview results from the Fireworks API.

Step 7: Deploy Your Reward Function

Once you’re satisfied with your reward function, deploy it to make it available for training workflows:

# Deploy using the CLI
reward-kit deploy \
  --id relevance-evaluator \
  --metrics-folders "relevance=./metrics/relevance" \
  --display-name "Response Relevance Evaluator" \
  --description "Evaluates how well a response addresses the user's query" \
  --force

You should see output confirming that your evaluator was successfully deployed.

Step 8: Create a Deployment Script

For more control over deployment, create a deployment script:

# deploy_relevance.py
import os
import sys
from pathlib import Path
from reward_kit.evaluation import create_evaluation

def main():
    """Deploy the relevance evaluator to Fireworks."""
    print("Deploying relevance evaluator...")
    
    try:
        # Deploy the evaluator
        metrics_path = Path("metrics/relevance").absolute()
        
        if not metrics_path.exists():
            print(f"Error: Metrics folder {metrics_path} not found.")
            return
        
        result = create_evaluation(
            evaluator_id="relevance-evaluator",
            metric_folders=[f"relevance={metrics_path}"],
            display_name="Response Relevance Evaluator",
            description="Evaluates how well a response addresses the user's query",
            force=True  # Update if it already exists
        )
        
        print(f"Successfully deployed evaluator: {result['name']}")
        print(f"Use this ID for training jobs: {result['name'].split('/')[-1]}")
        
    except Exception as e:
        print(f"Error deploying evaluator: {str(e)}")
        print("Make sure your API credentials are set up correctly.")

if __name__ == "__main__":
    main()

Run the deployment script:

python deploy_relevance.py

Step 9: Use Your Reward Function in Training

Finally, you can use your deployed reward function in an RL training job:

firectl create rl-job \
  --reward-endpoint "https://api.fireworks.ai/v1/evaluations/relevance-evaluator" \
  --model-id "accounts/fireworks/models/llama-v3-8b-instruct" \
  --dataset-id "my-training-dataset"

Improving Your Reward Function

Now that you have a basic reward function, consider these improvements:

  1. Better Keyword Matching: Use techniques like TF-IDF or word embeddings
  2. Context Understanding: Consider the full conversation context
  3. Question Understanding: Detect question types and verify answer formats
  4. Domain-Specific Knowledge: Add domain knowledge for specialized topics
  5. Multi-Component Scoring: Add metrics for informativeness, accuracy, etc.

Next Steps

You’ve successfully created your first reward function! To continue your journey:

  1. Learn about Advanced Reward Functions
  2. Explore Core Data Types for more flexibility
  3. Try integrating Multiple Metrics into a single evaluator