narcolepticchicken commited on
Commit
fddd9d2
·
verified ·
1 Parent(s): 74b60bc

Upload rl/grpo_train_demo.py

Browse files
Files changed (1) hide show
  1. rl/grpo_train_demo.py +152 -1
rl/grpo_train_demo.py CHANGED
@@ -1 +1,152 @@
1
- See /app/occ/rl/grpo_train_demo.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GRPO Training Demonstrator
3
+ Uses Qwen2.5-0.5B-Instruct + DeepMath-103K dataset with cost-aware rewards.
4
+
5
+ This is a minimal demonstrator showing how OCC rewards can be used
6
+ with TRL's GRPOTrainer. If compute is available, it trains for a few
7
+ steps; otherwise it falls back to offline comparison.
8
+ """
9
+ import json
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ from datasets import load_dataset
14
+
15
+ sys.path.insert(0, str(Path(__file__).parent.parent))
16
+ from rl.reward import RewardHook
17
+ from oracle.oracle import ImpactOracle
18
+
19
+
20
+ def occ_reward_func(prompts, completions, **kwargs):
21
+ """OCC cost-aware reward function for GRPO."""
22
+ oracle = ImpactOracle(
23
+ qa_weights={
24
+ "correctness": 1.0,
25
+ "evidence_support": 0.5,
26
+ "calibration": 0.2,
27
+ "abstention_utility": 1.0,
28
+ "hallucination_penalty": 2.0,
29
+ "confident_wrong_penalty": 3.0,
30
+ }
31
+ )
32
+ reward_hook = RewardHook(oracle=oracle, mode="retrieval_qa")
33
+
34
+ answers = []
35
+ confidences = []
36
+ compute_costs = []
37
+
38
+ for comp in completions:
39
+ if "<answer>" in comp and "</answer>" in comp:
40
+ start = comp.find("<answer>") + len("<answer>")
41
+ end = comp.find("</answer>")
42
+ ans = comp[start:end].strip()
43
+ else:
44
+ ans = comp.strip().split()[-1] if comp.strip() else ""
45
+
46
+ answers.append(ans)
47
+ confidences.append(0.7 if len(ans) > 0 else 0.3)
48
+ compute_costs.append(len(comp.split()))
49
+
50
+ gold_answers = kwargs.get("answers", [""] * len(prompts))
51
+ if not gold_answers:
52
+ gold_answers = [""] * len(prompts)
53
+
54
+ rewards = reward_hook.compute_rewards(
55
+ prompts=prompts,
56
+ completions=completions,
57
+ answers=answers,
58
+ gold_answers=gold_answers,
59
+ confidences=confidences,
60
+ compute_costs=compute_costs,
61
+ )
62
+ return rewards
63
+
64
+
65
+ def run_offline_demonstrator():
66
+ """Run offline policy comparison without actual model training."""
67
+ print("=" * 60)
68
+ print("GRPO OFFLINE DEMONSTRATOR")
69
+ print("=" * 60)
70
+ print("\nAttempting to load DeepMath-103K dataset...")
71
+ try:
72
+ ds = load_dataset("trl-lib/DeepMath-103K", split="train")
73
+ sample = ds.select(range(5))
74
+ print(f"Dataset loaded: {len(ds)} examples")
75
+ print(f"Columns: {sample.features}")
76
+ for i, ex in enumerate(sample):
77
+ print(f"\nExample {i}:")
78
+ prompt = ex.get("prompt", "")[:100]
79
+ solution = ex.get("solution", "")[:100]
80
+ print(f" prompt: {prompt}...")
81
+ print(f" solution: {solution}...")
82
+ except Exception as e:
83
+ print(f"Could not load dataset: {e}")
84
+ return {"status": "dataset_load_failed", "error": str(e)}
85
+
86
+ print("\n--- Simulating policy trajectories ---")
87
+ policy_a_completions = ["The answer is 42. <answer>42</answer>" for _ in range(10)]
88
+ policy_b_completions = ["I think it might be 42 or maybe 41. <answer>42</answer>" for _ in range(10)]
89
+
90
+ prompts = ["Solve: 20 + 22 = ?"] * 10
91
+ rewards_a = occ_reward_func(prompts, policy_a_completions, answers=["42"] * 10)
92
+ rewards_b = occ_reward_func(prompts, policy_b_completions, answers=["42"] * 10)
93
+
94
+ print(f"Policy A (concise, confident): mean reward = {sum(rewards_a)/len(rewards_a):.3f}")
95
+ print(f"Policy B (verbose, uncertain): mean reward = {sum(rewards_b)/len(rewards_b):.3f}")
96
+
97
+ from rl.reward import OfflinePolicyComparator
98
+ comparator = OfflinePolicyComparator(RewardHook(oracle=ImpactOracle(), mode="retrieval_qa"))
99
+
100
+ traj_a = [{"reward": r, "failure_tags": []} for r in rewards_a]
101
+ traj_b = [{"reward": r, "failure_tags": []} for r in rewards_b]
102
+ comparison = comparator.compare(traj_a, traj_b)
103
+
104
+ print(f"\nWin rate (A vs B): {comparison['win_rate']:.1%}")
105
+ print(f"Mean reward improvement: {comparison['improvement']:+.3f}")
106
+
107
+ return {
108
+ "status": "offline_demo_complete",
109
+ "policy_a_mean_reward": sum(rewards_a) / len(rewards_a),
110
+ "policy_b_mean_reward": sum(rewards_b) / len(rewards_b),
111
+ "comparison": comparison,
112
+ }
113
+
114
+
115
+ def run_grpo_training(steps: int = 50):
116
+ """Run actual GRPO training if TRL is available."""
117
+ try:
118
+ from trl import GRPOTrainer
119
+ print("\n" + "=" * 60)
120
+ print("GRPO TRAINING DEMONSTRATION")
121
+ print("=" * 60)
122
+ print(f"Loading dataset and model for {steps} steps...")
123
+
124
+ ds = load_dataset("trl-lib/DeepMath-103K", split="train")
125
+
126
+ trainer = GRPOTrainer(
127
+ model="Qwen/Qwen2.5-0.5B-Instruct",
128
+ reward_funcs=occ_reward_func,
129
+ train_dataset=ds.select(range(100)),
130
+ )
131
+ trainer.train()
132
+ print("Training complete!")
133
+ return {"status": "training_complete", "steps": steps}
134
+ except ImportError:
135
+ print("TRL not installed. Falling back to offline demonstrator.")
136
+ return run_offline_demonstrator()
137
+ except Exception as e:
138
+ print(f"Training failed: {e}")
139
+ return {"status": "training_failed", "error": str(e)}
140
+
141
+
142
+ def main():
143
+ results = run_grpo_training(steps=10)
144
+
145
+ Path("/app/occ/reports").mkdir(parents=True, exist_ok=True)
146
+ with open("/app/occ/reports/grpo_results.json", "w") as f:
147
+ json.dump(results, f, indent=2, default=str)
148
+ print("\nSaved to reports/grpo_results.json")
149
+
150
+
151
+ if __name__ == "__main__":
152
+ main()