occ-stack / grpo_hook.py
narcolepticchicken's picture
Upload grpo_hook.py
f4a0835 verified
"""
GRPO-compatible reward hook for TRL.
This module provides a reward function factory that wraps the OCC
ImpactOracle into a TRL GRPOTrainer-compatible callable.
Usage with TRL::
from grpo_hook import make_occ_reward_func
from trl import GRPOTrainer
reward_fn = make_occ_reward_func(mode="code", compute_budget=1e5)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
reward_funcs=reward_fn,
train_dataset=ds, # must have a "prompt" column
)
The reward function signature follows TRL conventions:
def reward_fn(completions, **kwargs) -> list[float]
"""
import json
from pathlib import Path
from typing import Dict, List, Optional
from oracle.oracle import ImpactOracle
from ledger.ledger import CreditLedger
from broker.broker import ResourceBroker
from rl.reward import RewardHook, OfflinePolicyComparator
def make_occ_reward_func(
mode: str = "retrieval_qa",
compute_budget: float = 1e5,
qa_weights: Optional[Dict] = None,
code_weights: Optional[Dict] = None,
debate_weights: Optional[Dict] = None,
) -> callable:
"""
Factory for a TRL-compatible reward function.
Returns a function with signature (completions, **kwargs) -> list[float].
"""
oracle = ImpactOracle(
compute_budget=compute_budget,
qa_weights=qa_weights,
code_weights=code_weights,
debate_weights=debate_weights,
)
hook = RewardHook(oracle=oracle, mode=mode)
def _reward_fn(completions, **kwargs):
"""
TRL calls this with completions as list[str] (standard format)
or list[list[dict]] (conversational format).
We extract text and look for answer tags.
"""
texts = []
for comp in completions:
if isinstance(comp, list) and len(comp) > 0 and isinstance(comp[0], dict):
# Conversational format: [{"role":"assistant","content":"..."}]
texts.append(comp[0].get("content", ""))
elif isinstance(comp, str):
texts.append(comp)
else:
texts.append(str(comp))
answers = []
confidences = []
compute_costs = []
for txt in texts:
if "<answer>" in txt and "</answer>" in txt:
start = txt.find("<answer>") + len("<answer>")
end = txt.find("</answer>")
ans = txt[start:end].strip()
else:
# Fallback: last token or empty
parts = txt.strip().split()
ans = parts[-1] if parts else ""
answers.append(ans)
confidences.append(0.7 if len(ans) > 0 else 0.3)
compute_costs.append(len(txt.split()))
gold_answers = kwargs.get("answers", [""] * len(texts))
if not gold_answers:
gold_answers = [""] * len(texts)
rewards = hook.compute_rewards(
prompts=kwargs.get("prompts", [""] * len(texts)),
completions=texts,
answers=answers,
gold_answers=gold_answers,
confidences=confidences,
compute_costs=compute_costs,
agent_ids=kwargs.get("agent_ids", None),
)
return rewards
return _reward_fn
def demo_offline():
"""Offline comparison of two policies using the reward hook."""
hook = RewardHook(oracle=ImpactOracle(compute_budget=1e5), mode="retrieval_qa")
comparator = OfflinePolicyComparator(reward_hook=hook)
policy_a = [
{"reward": 0.5 + i * 0.02, "failure_tags": []}
for i in range(10)
]
policy_b = [
{"reward": 0.4 + i * 0.01, "failure_tags": []}
for i in range(10)
]
result = comparator.compare(policy_a, policy_b)
print(json.dumps(result, indent=2, default=str))
return result
if __name__ == "__main__":
demo_offline()