RewardFunction Class Reference

The RewardFunction class is a core component of the Reward Kit, providing a unified interface for calling reward functions locally or remotely.

Overview

The RewardFunction class wraps a reward function (either a local function or a remote endpoint) and provides a consistent interface for evaluation. It supports:

  • Local functions (mode=“local”)
  • Remote endpoints (mode=“remote”)
  • Fireworks-hosted models (mode=“fireworks_hosted”)

Import

from reward_kit.reward_function import RewardFunction

Constructor

RewardFunction(
    func: Optional[Callable] = None,
    func_path: Optional[str] = None,
    mode: str = "local",
    endpoint: Optional[str] = None,
    name: Optional[str] = None,
    model_id: Optional[str] = None,
    **kwargs
)

Parameters

  • func (Optional[Callable]): The local function to use (for mode=“local”).

  • func_path (Optional[str]): A string path to a function (e.g., “module.submodule:function_name”).

  • mode (str): The mode of operation. Options:

    • "local": Run the function locally
    • "remote": Call a remote endpoint
    • "fireworks_hosted": Use a Fireworks-hosted model
  • endpoint (Optional[str]): The URL of the remote endpoint (for mode=“remote”).

  • name (Optional[str]): The name of the deployed evaluator (for mode=“remote”). If provided and endpoint is not, the endpoint will be constructed from the name.

  • model_id (Optional[str]): The ID of the Fireworks-hosted model (for mode=“fireworks_hosted”).

  • **kwargs: Additional keyword arguments to pass to the function when called.

Exceptions

  • ValueError: Raised if required parameters for the specified mode are missing or if an invalid mode is provided.

Methods

__call__

Call the reward function with the provided messages.

__call__(
    messages: List[Dict[str, str]],
    original_messages: Optional[List[Dict[str, str]]] = None,
    **kwargs
) -> RewardOutput

Parameters

  • messages (List[Dict[str, str]]): List of conversation messages, each with ‘role’ and ‘content’ keys.

  • original_messages (Optional[List[Dict[str, str]]]): Original conversation messages (for context). Defaults to all messages except the last one if not provided.

  • **kwargs: Additional keyword arguments to pass to the function.

Returns

  • RewardOutput: Object with score and metrics.

Exceptions

  • ValueError: Raised if no function or endpoint is provided for the selected mode.
  • TypeError: Raised if the function returns an invalid type.
  • requests.exceptions.RequestException: Raised if there is an error calling the remote endpoint.

get_trl_adapter

Create an adapter function for use with the TRL (Transformer Reinforcement Learning) library.

get_trl_adapter() -> Callable

Returns

  • Callable: A function that takes batch inputs and returns a batch of reward values, compatible with TRL.

Adapter Behavior

The returned adapter function:

  1. Handles batch inputs (list of message lists or list of strings)
  2. Returns a list of reward scores (one for each input)
  3. Handles exceptions gracefully, returning 0.0 for any errors

Examples

Local Mode

from reward_kit import RewardFunction, RewardOutput, MetricRewardOutput

# Define a reward function
def my_reward_fn(messages, **kwargs):
    response = messages[-1].get("content", "")
    score = min(len(response) / 100, 1.0)  # Simple score based on length
    return RewardOutput(
        score=score,
        metrics={"length": MetricRewardOutput(score=score, reason=f"Length: {len(response)}")}
    )

# Create a reward function in local mode
reward_fn = RewardFunction(func=my_reward_fn, mode="local")

# Call the reward function
result = reward_fn(messages=[
    {"role": "user", "content": "Hello"},
    {"role": "assistant", "content": "Hi there! How can I help you today?"}
])

print(f"Score: {result.score}")

Remote Mode

# Create a reward function in remote mode
remote_reward = RewardFunction(
    name="my-deployed-evaluator",
    mode="remote"
)

# Call the reward function
result = remote_reward(messages=[
    {"role": "user", "content": "What is machine learning?"},
    {"role": "assistant", "content": "Machine learning is a method of data analysis..."}
])

print(f"Score: {result.score}")

Fireworks Hosted Mode

# Create a reward function using a Fireworks-hosted model
hosted_reward = RewardFunction(
    model_id="accounts/fireworks/models/llama-v3-8b-instruct",
    mode="fireworks_hosted"
)

# Call the reward function
result = hosted_reward(messages=[
    {"role": "user", "content": "Explain quantum computing"},
    {"role": "assistant", "content": "Quantum computing uses quantum bits or qubits..."}
])

print(f"Score: {result.score}")

Using with TRL

from reward_kit import RewardFunction

# Create a reward function
reward_fn = RewardFunction(name="my-deployed-evaluator", mode="remote")

# Get a TRL-compatible adapter
trl_reward_fn = reward_fn.get_trl_adapter()

# Use in TRL (example)
batch_inputs = [
    [{"role": "user", "content": "Question 1"}, {"role": "assistant", "content": "Answer 1"}],
    [{"role": "user", "content": "Question 2"}, {"role": "assistant", "content": "Answer 2"}]
]

# Get reward scores for the batch
reward_scores = trl_reward_fn(batch_inputs)
print(reward_scores)  # [score1, score2]

Implementation Details

Mode-Specific Requirements

  • Local Mode: Requires either func or func_path.
  • Remote Mode: Requires either endpoint or name.
  • Fireworks Hosted Mode: Requires model_id.

Function Loading

When providing a func_path, the path can be specified in two formats:

  • module.path:function_name - Module with colon separator (preferred)
  • module.path.function_name - Module with function as last component

Authentication

For remote and Fireworks-hosted modes, the authentication token is retrieved from the FIREWORKS_API_KEY environment variable.