""" GRPO-compatible reward hook for TRL. This module provides a reward function factory that wraps the OCC ImpactOracle into a TRL GRPOTrainer-compatible callable. Usage with TRL:: from grpo_hook import make_occ_reward_func from trl import GRPOTrainer reward_fn = make_occ_reward_func(mode="code", compute_budget=1e5) trainer = GRPOTrainer( model="Qwen/Qwen2.5-0.5B-Instruct", reward_funcs=reward_fn, train_dataset=ds, # must have a "prompt" column ) The reward function signature follows TRL conventions: def reward_fn(completions, **kwargs) -> list[float] """ import json from pathlib import Path from typing import Dict, List, Optional from oracle.oracle import ImpactOracle from ledger.ledger import CreditLedger from broker.broker import ResourceBroker from rl.reward import RewardHook, OfflinePolicyComparator def make_occ_reward_func( mode: str = "retrieval_qa", compute_budget: float = 1e5, qa_weights: Optional[Dict] = None, code_weights: Optional[Dict] = None, debate_weights: Optional[Dict] = None, ) -> callable: """ Factory for a TRL-compatible reward function. Returns a function with signature (completions, **kwargs) -> list[float]. """ oracle = ImpactOracle( compute_budget=compute_budget, qa_weights=qa_weights, code_weights=code_weights, debate_weights=debate_weights, ) hook = RewardHook(oracle=oracle, mode=mode) def _reward_fn(completions, **kwargs): """ TRL calls this with completions as list[str] (standard format) or list[list[dict]] (conversational format). We extract text and look for answer tags. """ texts = [] for comp in completions: if isinstance(comp, list) and len(comp) > 0 and isinstance(comp[0], dict): # Conversational format: [{"role":"assistant","content":"..."}] texts.append(comp[0].get("content", "")) elif isinstance(comp, str): texts.append(comp) else: texts.append(str(comp)) answers = [] confidences = [] compute_costs = [] for txt in texts: if "" in txt and "" in txt: start = txt.find("") + len("") end = txt.find("") ans = txt[start:end].strip() else: # Fallback: last token or empty parts = txt.strip().split() ans = parts[-1] if parts else "" answers.append(ans) confidences.append(0.7 if len(ans) > 0 else 0.3) compute_costs.append(len(txt.split())) gold_answers = kwargs.get("answers", [""] * len(texts)) if not gold_answers: gold_answers = [""] * len(texts) rewards = hook.compute_rewards( prompts=kwargs.get("prompts", [""] * len(texts)), completions=texts, answers=answers, gold_answers=gold_answers, confidences=confidences, compute_costs=compute_costs, agent_ids=kwargs.get("agent_ids", None), ) return rewards return _reward_fn def demo_offline(): """Offline comparison of two policies using the reward hook.""" hook = RewardHook(oracle=ImpactOracle(compute_budget=1e5), mode="retrieval_qa") comparator = OfflinePolicyComparator(reward_hook=hook) policy_a = [ {"reward": 0.5 + i * 0.02, "failure_tags": []} for i in range(10) ] policy_b = [ {"reward": 0.4 + i * 0.01, "failure_tags": []} for i in range(10) ] result = comparator.compare(policy_a, policy_b) print(json.dumps(result, indent=2, default=str)) return result if __name__ == "__main__": demo_offline()