| """ |
| GRPO-compatible reward hook for TRL. |
| |
| This module provides a reward function factory that wraps the OCC |
| ImpactOracle into a TRL GRPOTrainer-compatible callable. |
| |
| Usage with TRL:: |
| |
| from grpo_hook import make_occ_reward_func |
| from trl import GRPOTrainer |
| |
| reward_fn = make_occ_reward_func(mode="code", compute_budget=1e5) |
| trainer = GRPOTrainer( |
| model="Qwen/Qwen2.5-0.5B-Instruct", |
| reward_funcs=reward_fn, |
| train_dataset=ds, # must have a "prompt" column |
| ) |
| |
| The reward function signature follows TRL conventions: |
| def reward_fn(completions, **kwargs) -> list[float] |
| """ |
|
|
| import json |
| from pathlib import Path |
| from typing import Dict, List, Optional |
|
|
| from oracle.oracle import ImpactOracle |
| from ledger.ledger import CreditLedger |
| from broker.broker import ResourceBroker |
| from rl.reward import RewardHook, OfflinePolicyComparator |
|
|
|
|
| def make_occ_reward_func( |
| mode: str = "retrieval_qa", |
| compute_budget: float = 1e5, |
| qa_weights: Optional[Dict] = None, |
| code_weights: Optional[Dict] = None, |
| debate_weights: Optional[Dict] = None, |
| ) -> callable: |
| """ |
| Factory for a TRL-compatible reward function. |
| |
| Returns a function with signature (completions, **kwargs) -> list[float]. |
| """ |
| oracle = ImpactOracle( |
| compute_budget=compute_budget, |
| qa_weights=qa_weights, |
| code_weights=code_weights, |
| debate_weights=debate_weights, |
| ) |
| hook = RewardHook(oracle=oracle, mode=mode) |
|
|
| def _reward_fn(completions, **kwargs): |
| """ |
| TRL calls this with completions as list[str] (standard format) |
| or list[list[dict]] (conversational format). |
| We extract text and look for answer tags. |
| """ |
| texts = [] |
| for comp in completions: |
| if isinstance(comp, list) and len(comp) > 0 and isinstance(comp[0], dict): |
| |
| texts.append(comp[0].get("content", "")) |
| elif isinstance(comp, str): |
| texts.append(comp) |
| else: |
| texts.append(str(comp)) |
|
|
| answers = [] |
| confidences = [] |
| compute_costs = [] |
|
|
| for txt in texts: |
| if "<answer>" in txt and "</answer>" in txt: |
| start = txt.find("<answer>") + len("<answer>") |
| end = txt.find("</answer>") |
| ans = txt[start:end].strip() |
| else: |
| |
| parts = txt.strip().split() |
| ans = parts[-1] if parts else "" |
| answers.append(ans) |
| confidences.append(0.7 if len(ans) > 0 else 0.3) |
| compute_costs.append(len(txt.split())) |
|
|
| gold_answers = kwargs.get("answers", [""] * len(texts)) |
| if not gold_answers: |
| gold_answers = [""] * len(texts) |
|
|
| rewards = hook.compute_rewards( |
| prompts=kwargs.get("prompts", [""] * len(texts)), |
| completions=texts, |
| answers=answers, |
| gold_answers=gold_answers, |
| confidences=confidences, |
| compute_costs=compute_costs, |
| agent_ids=kwargs.get("agent_ids", None), |
| ) |
| return rewards |
|
|
| return _reward_fn |
|
|
|
|
| def demo_offline(): |
| """Offline comparison of two policies using the reward hook.""" |
| hook = RewardHook(oracle=ImpactOracle(compute_budget=1e5), mode="retrieval_qa") |
| comparator = OfflinePolicyComparator(reward_hook=hook) |
|
|
| policy_a = [ |
| {"reward": 0.5 + i * 0.02, "failure_tags": []} |
| for i in range(10) |
| ] |
| policy_b = [ |
| {"reward": 0.4 + i * 0.01, "failure_tags": []} |
| for i in range(10) |
| ] |
|
|
| result = comparator.compare(policy_a, policy_b) |
| print(json.dumps(result, indent=2, default=str)) |
| return result |
|
|
|
|
| if __name__ == "__main__": |
| demo_offline() |
|
|