Reward function class
RewardFunction Class Reference
The RewardFunction
class is a core component of the Reward Kit, providing a unified interface for calling reward functions locally or remotely.
Overview
The RewardFunction
class wraps a reward function (either a local function or a remote endpoint) and provides a consistent interface for evaluation. It supports:
- Local functions (mode=“local”)
- Remote endpoints (mode=“remote”)
- Fireworks-hosted models (mode=“fireworks_hosted”)
Import
Constructor
Parameters
-
func
(Optional[Callable]
): The local function to use (for mode=“local”). -
func_path
(Optional[str]
): A string path to a function (e.g., “module.submodule:function_name”). -
mode
(str
): The mode of operation. Options:"local"
: Run the function locally"remote"
: Call a remote endpoint"fireworks_hosted"
: Use a Fireworks-hosted model
-
endpoint
(Optional[str]
): The URL of the remote endpoint (for mode=“remote”). -
name
(Optional[str]
): The name of the deployed evaluator (for mode=“remote”). If provided and endpoint is not, the endpoint will be constructed from the name. -
model_id
(Optional[str]
): The ID of the Fireworks-hosted model (for mode=“fireworks_hosted”). -
**kwargs
: Additional keyword arguments to pass to the function when called.
Exceptions
ValueError
: Raised if required parameters for the specified mode are missing or if an invalid mode is provided.
Methods
__call__
Call the reward function with the provided messages.
Parameters
-
messages
(List[Dict[str, str]]
): List of conversation messages, each with ‘role’ and ‘content’ keys. -
original_messages
(Optional[List[Dict[str, str]]]
): Original conversation messages (for context). Defaults to all messages except the last one if not provided. -
**kwargs
: Additional keyword arguments to pass to the function.
Returns
RewardOutput
: Object with score and metrics.
Exceptions
ValueError
: Raised if no function or endpoint is provided for the selected mode.TypeError
: Raised if the function returns an invalid type.requests.exceptions.RequestException
: Raised if there is an error calling the remote endpoint.
get_trl_adapter
Create an adapter function for use with the TRL (Transformer Reinforcement Learning) library.
Returns
Callable
: A function that takes batch inputs and returns a batch of reward values, compatible with TRL.
Adapter Behavior
The returned adapter function:
- Handles batch inputs (list of message lists or list of strings)
- Returns a list of reward scores (one for each input)
- Handles exceptions gracefully, returning 0.0 for any errors
Examples
Local Mode
Remote Mode
Fireworks Hosted Mode
Using with TRL
Implementation Details
Mode-Specific Requirements
- Local Mode: Requires either
func
orfunc_path
. - Remote Mode: Requires either
endpoint
orname
. - Fireworks Hosted Mode: Requires
model_id
.
Function Loading
When providing a func_path
, the path can be specified in two formats:
module.path:function_name
- Module with colon separator (preferred)module.path.function_name
- Module with function as last component
Authentication
For remote and Fireworks-hosted modes, the authentication token is retrieved from the FIREWORKS_API_KEY
environment variable.