TRL documentation

Judges

You are viewing v0.12.0 version. A newer version v0.12.1 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Judges

TRL Judges is an experimental API which is subject to change at any time.

TRL provides judges to easily compare two completions.

Make sure to have installed the required dependencies by running:

pip install trl[llm_judge]

Using the provided judges

TRL provides several judges out of the box. For example, you can use the HfPairwiseJudge to compare two completions using a pre-trained model from the Hugging Face model hub:

from trl import HfPairwiseJudge

judge = HfPairwiseJudge()
judge.judge(
    prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"],
    completions=[["Paris", "Lyon"], ["Saturn", "Jupiter"]],
)  # Outputs: [0, 1]

Define your own judge

To define your own judge, we provide several base classes that you can subclass. For rank-based judges, you need to subclass BaseRankJudge and implement the BaseRankJudge.judge() method. For pairwise judges, you need to subclass BasePairJudge and implement the BasePairJudge.judge method. If you want to define a judge that doesn’t fit into these categories, you need to subclass BaseJudge and implement the BaseJudge.judge() method.

As an example, let’s define a pairwise judge that prefers shorter completions:

from trl import BasePairwiseJudge

class PrefersShorterJudge(BasePairwiseJudge):
    def judge(self, prompts, completions, shuffle_order=False):
        return [0 if len(completion[0]) > len(completion[1]) else 1 for completion in completions]

You can then use this judge as follows:

judge = PrefersShorterJudge()
judge.judge(
    prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"],
    completions=[["Paris", "The capital of France is Paris."], ["Jupiter is the biggest planet in the solar system.", "Jupiter"]],
)  # Outputs: [0, 1]

BaseJudge

class trl.BaseJudge

< >

( )

Base class for judges. The subclasses of this class should implement the judge method.

BaseRankJudge

class trl.BaseRankJudge

< >

( )

Base class for LLM ranking judges.

Example:

class MyRankJudge(BaseRankJudge):
    def judge(self, prompts, completions, shuffle_order=True):
        return ...  # Your ranking logic here

judge = MyRankJudge()
judge.judge(
    prompts=["The capital of France is", "The capital of Germany is"],
    completions=[[" Paris", " Marseille", "Lyon"], [" Munich", " Berlin"]]
)  # [[0, 1, 2], [1, 0]]

judge

< >

( prompts: List completions: List shuffle_order: bool = True )

Parameters

  • prompts (List[str]) — List of prompts.
  • completions (List[List[str]]) — List of completions list, where each element is a list of completions for the corresponding prompt.
  • shuffle_order (bool) — Whether to shuffle the order of the completions to avoid positional bias.

Judge the completion for the given prompts and return the ranks of each completion.

BasePairwiseJudge

class trl.BasePairwiseJudge

< >

( )

Base class for pairwise judges.

judge

< >

( prompts: List completions: List shuffle_order: bool = True )

Parameters

  • prompts (List[str]) — List of prompts.
  • completions (List[List[str]]) — List of completions pairs, where each element is a pair of completions for the corresponding prompt.
  • shuffle_order (bool) — Whether to shuffle the order of the completions to avoid positional bias.

Judge the completion pairs for the given prompts.

Note: If the judge returns -1 for any prompt, it indicates that the inner process used to compute the preference has failed. For instance, this could occur if the underlying language model returned an invalid answer. In such cases, the caller should handle these invalid indices appropriately, possibly by implementing fallback logic or error handling.

RandomRankJudge

class trl.RandomRankJudge

< >

( )

Random rank, for testing purposes.

RandomPairwiseJudge

class trl.RandomPairwiseJudge

< >

( )

Random pairwise judge, for testing purposes.

PairRMJudge

class trl.PairRMJudge

< >

( )

Parameters

  • blender (llm_blender.Blender) — An instance of the Blender class from llm-blender.

LLM judge based on the PairRM model from AllenAI.

This judge uses the PairRM model to rank pairs of completions for given prompts. It’s designed for pairwise comparison of language model outputs.

The PairRM model is loaded using the llm-blender library and runs on the default Accelerator device.

Example:

pairrm_judge = PairRMJudge() prompts = [“Translate ‘hello’ to French”, “What’s the capital of Japan?”] completions = [[“Bonjour”, “Salut”], [“Kyoto”, “Tokyo”]] results = pairrm_judge.judge(prompts, completions) print(results) # [0, 1] (indicating the first completion is preferred for the first prompt and the second)

Note: This class requires the llm-blender library to be installed. Install it with: pip install llm-blender

judge

< >

( prompts: List completions: List shuffle_order: bool = True return_scores: bool = False temperature: float = 1.0 ) List[Union[int, float]]

Parameters

  • prompts (List[str]) — List of prompts to judge.
  • completions (List[List[str]]) — List of completion pairs for each prompt.
  • shuffle_order (bool, optional) — Whether to shuffle the order of completions to avoid positional bias. Defaults to True.
  • return_scores (bool, optional) — If True, return probability scores instead of ranks (i.e. a soft-judge). Defaults to False.
  • temperature (float, optional) — Temperature for scaling logits if return_scores is True. Defaults to 1.0.

Returns

List[Union[int, float]]

List of ranks (0 or 1) or scores for each prompt, indicating which completion is preferred or its score.

Raises

ValueError

  • ValueError — If the number of completions per prompt is not exactly 2.

Judge the completion pairs for the given prompts using the PairRM model.

Note:

  • Ranks are 0-indexed (0 means the first completion is preferred).
  • If return_scores is True, returns softmax probabilities for the first completion.

HfPairwiseJudge

class trl.HfPairwiseJudge

< >

( model = 'meta-llama/Meta-Llama-3-70B-Instruct' token: Optional = None system_prompt: Optional = None )

Parameters

  • model (str, optional) — The model to use for the judge. Defaults to “meta-llama/Meta-Llama-3-70B-Instruct”.
  • token (str, optional) — The Hugging Face API token to use for the InferenceClient.
  • system_prompt (str, optional) — The system prompt to be used for the judge. If not provided, a default prompt is used. Note that the system prompt should contain the following placeholders: {prompt}, {response0}, and {response1}. Also, the inference is called with max_tokens=1, consequently the system prompt should ask for a single token response.

Pairwise judge based on the Hugging Face API with chat completion.

This judge is relevant for assessing the quality chat models, where the completion is a response to a given prompt.

OpenAIPairwiseJudge

class trl.OpenAIPairwiseJudge

< >

( model = 'gpt-4-turbo-preview' system_prompt: Optional = None max_requests: Optional = 1000 )

Parameters

  • model (str, optional) — The model to use for the judge. Defaults to "gpt-4-turbo-preview".
  • system_prompt (str, optional) — The system prompt to be used for the judge. If not provided, a default prompt is used. Note that the system prompt should contain the following placeholders: {prompt}, {response0}, and {response1}. Also, the inference is called with max_tokens=1, consequently the system prompt should ask for a single token response.
  • max_requests (int, optional) — The maximum number of requests to make to the OpenAI API. Defaults to 1000. If set to None, there is no limit.

Judge based on the OpenAI API.

This judge is relevant for assessing the quality chat models, where the completion is a response to a given prompt.

< > Update on GitHub