from reward_kit import reward_function
from reward_kit.rewards import accuracy, length, format_compliance
import numpy as np
@reward_function
def multi_metric_reward(response: str, expected_response: str, **kwargs) -> float:
"""
Advanced reward function combining accuracy, length, and format compliance.
"""
# Base accuracy score
acc_score = accuracy(response, expected_response)
# Length appropriateness (prefer responses between 50-200 chars)
len_score = length_appropriateness(response, min_len=50, max_len=200)
# Format compliance (if response should follow a pattern)
format_score = check_format_compliance(response)
# Weighted combination
weights = [0.6, 0.2, 0.2] # accuracy, length, format
scores = [acc_score, len_score, format_score]
return np.average(scores, weights=weights)
def length_appropriateness(response: str, min_len: int, max_len: int) -> float:
"""Helper function to score length appropriateness."""
length = len(response)
if min_len <= length <= max_len:
return 1.0
elif length < min_len:
return max(0.0, length / min_len)
else:
return max(0.0, 1.0 - (length - max_len) / max_len)
def check_format_compliance(response: str) -> float:
"""Helper function to check format compliance."""
# Example: Check if response follows expected structure
if response.startswith("Answer:") and response.endswith("."):
return 1.0
return 0.5
@reward_function
def context_aware_reward(response: str, expected_response: str, context: dict = None) -> float:
"""
Reward function that considers context information.
"""
base_score = accuracy(response, expected_response)
if context:
# Adjust score based on difficulty
difficulty = context.get('difficulty', 'medium')
if difficulty == 'hard' and base_score > 0.8:
base_score *= 1.2 # Bonus for hard questions
elif difficulty == 'easy' and base_score < 0.5:
base_score *= 0.8 # Penalty for easy questions
# Consider response time if available
response_time = context.get('response_time_seconds', 0)
if response_time > 0:
# Slight bonus for quick accurate responses
time_bonus = max(0, (10 - response_time) / 100)
base_score += time_bonus
return min(base_score, 1.0)
@reward_function
def code_quality_reward(response: str, expected_response: str, **kwargs) -> float:
"""
Evaluates code responses considering multiple quality factors.
"""
import ast
score = 0.0
# Check if code is syntactically valid
try:
ast.parse(response)
score += 0.3 # Syntax correctness
except SyntaxError:
return 0.0 # Invalid syntax gets zero score
# Check for best practices
if "def " in response: # Function definition
score += 0.2
if "# " in response or '"""' in response: # Comments/docstrings
score += 0.1
# Check for specific patterns
if "import " in response and "from " in response:
score += 0.1 # Good import practices
# Length consideration (not too short, not too long)
lines = response.split('\n')
if 5 <= len(lines) <= 50:
score += 0.1
# Functional correctness (if test cases available)
test_cases = kwargs.get('test_cases', [])
if test_cases:
correctness_score = evaluate_code_correctness(response, test_cases)
score += 0.2 * correctness_score
return min(score, 1.0)
def evaluate_code_correctness(code: str, test_cases: list) -> float:
"""Helper to evaluate code correctness against test cases."""
# This would implement actual code execution and testing
# For safety, this is a placeholder
return 0.8 # Placeholder score
from scipy.stats import pearsonr
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
@reward_function
def statistical_similarity_reward(response: str, expected_response: str, **kwargs) -> float:
"""
Uses statistical methods to evaluate response similarity.
"""
# Convert to numerical representations (e.g., using embeddings)
response_embedding = get_text_embedding(response)
expected_embedding = get_text_embedding(expected_response)
# Cosine similarity
cos_sim = cosine_similarity([response_embedding], [expected_embedding])[0][0]
# Pearson correlation (if applicable)
if len(response_embedding) == len(expected_embedding):
corr, _ = pearsonr(response_embedding, expected_embedding)
corr = max(0, corr) # Only positive correlations
else:
corr = 0
# Combine metrics
final_score = 0.7 * cos_sim + 0.3 * corr
return max(0.0, min(1.0, final_score))
def get_text_embedding(text: str) -> np.ndarray:
"""Placeholder for text embedding function."""
# In practice, use a real embedding model
return np.random.rand(100) # Placeholder
@reward_function
def hierarchical_reward(response: str, expected_response: str, **kwargs) -> float:
"""
Hierarchical evaluation with multiple levels of assessment.
"""
# Level 1: Basic format validation
if not basic_format_check(response):
return 0.0
# Level 2: Content relevance
relevance_score = content_relevance(response, expected_response)
if relevance_score < 0.3:
return relevance_score * 0.5 # Cap low relevance scores
# Level 3: Detailed accuracy
accuracy_score = detailed_accuracy(response, expected_response)
# Level 4: Style and presentation
style_score = evaluate_style(response)
# Weighted combination based on hierarchy
final_score = (
0.1 * 1.0 + # Format passed
0.3 * relevance_score +
0.5 * accuracy_score +
0.1 * style_score
)
return final_score
def basic_format_check(response: str) -> bool:
"""Basic format validation."""
return len(response.strip()) > 0 and len(response) < 10000
def content_relevance(response: str, expected: str) -> float:
"""Evaluate content relevance."""
# Placeholder for semantic similarity
common_words = set(response.lower().split()) & set(expected.lower().split())
return len(common_words) / max(len(set(expected.lower().split())), 1)
def detailed_accuracy(response: str, expected: str) -> float:
"""Detailed accuracy evaluation."""
return accuracy(response, expected)
def evaluate_style(response: str) -> float:
"""Evaluate writing style and presentation."""
score = 0.0
if response[0].isupper(): # Starts with capital
score += 0.3
if response.endswith('.'): # Ends with period
score += 0.3
if 10 <= len(response.split()) <= 100: # Appropriate length
score += 0.4
return score