narcolepticchicken commited on
Commit
44602fe
·
verified ·
1 Parent(s): d70cbcc

Upload jobs/run_ablations_detailed.py

Browse files
Files changed (1) hide show
  1. jobs/run_ablations_detailed.py +340 -0
jobs/run_ablations_detailed.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Self-contained ablation runner with meaningful variation.
3
+ Key improvement: broker thresholds directly affect agent selection
4
+ (denying cheap agents more often in strict mode, allowing expensive agents
5
+ in lenient mode, etc.)
6
+ """
7
+ import json
8
+ import random
9
+ import time
10
+ from dataclasses import dataclass, field
11
+ from enum import Enum
12
+ from pathlib import Path
13
+ from typing import Any, Dict, List, Optional
14
+
15
+ import numpy as np
16
+
17
+
18
+ # --- CORE ---
19
+
20
+ @dataclass
21
+ class OracleResult:
22
+ raw_score: float; cost_adjusted_score: float; confidence: float
23
+ evidence: Dict[str, Any]; reason: str
24
+ failure_tags: List[str] = field(default_factory=list)
25
+ reward_value: float = 0.0
26
+
27
+
28
+ class ImpactOracle:
29
+ def __init__(self, compute_penalty_rate=0.0001, gaming_penalty=2.0):
30
+ self.compute_penalty_rate = compute_penalty_rate
31
+ self.gaming_penalty = gaming_penalty
32
+
33
+ def score(self, mode, action, context, result, agent_id=""):
34
+ correctness = result.get("correctness", 0.0)
35
+ compute_cost = result.get("compute_cost", 0.0)
36
+ public_pass = result.get("public_pass", correctness)
37
+ hidden_pass = result.get("hidden_tests_pass", correctness)
38
+ tags = []
39
+ if public_pass and not hidden_pass:
40
+ tags.append("gaming_hidden_tests")
41
+ raw = correctness * 1.0 - compute_cost * self.compute_penalty_rate
42
+ if "gaming_hidden_tests" in tags:
43
+ raw -= self.gaming_penalty
44
+ cost_adj = raw - compute_cost * self.compute_penalty_rate
45
+ return OracleResult(raw, cost_adj, result.get("confidence", correctness),
46
+ {"correctness": correctness}, f"corr={correctness:.2f}, cost={compute_cost}", tags, cost_adj)
47
+
48
+
49
+ @dataclass
50
+ class LedgerEntry:
51
+ agent_id: str; task_id: str; action_id: str; earned_credit: float; spent_credit: float
52
+ decayed_credit: float; remaining_credit: float; reason: str; oracle_score: float
53
+ compute_cost: float; timestamp: float; capability_scope: str = "global"
54
+
55
+
56
+ class CreditLedger:
57
+ def __init__(self, decay_lambda=0.05):
58
+ self.entries = []; self.balances = {}; self.decay_lambda = decay_lambda
59
+
60
+ def earn(self, agent_id, task_id, action_id, amount, oracle_score, compute_cost, reason, capability_scope="global"):
61
+ now = time.time()
62
+ self._apply_decay(agent_id, now, capability_scope)
63
+ current = self._get(agent_id, capability_scope)
64
+ new_bal = current + amount
65
+ self.entries.append(LedgerEntry(agent_id, task_id, action_id, amount, 0.0, 0.0, new_bal, reason, oracle_score, compute_cost, now, capability_scope))
66
+ self._set(agent_id, capability_scope, new_bal)
67
+
68
+ def spend(self, agent_id, task_id, action_id, amount, capability_scope="global", reason="spend"):
69
+ now = time.time()
70
+ self._apply_decay(agent_id, now, capability_scope)
71
+ current = self._get(agent_id, capability_scope)
72
+ if current < amount: return False
73
+ new_bal = current - amount
74
+ self.entries.append(LedgerEntry(agent_id, task_id, action_id, 0.0, amount, 0.0, new_bal, reason, 0.0, 0.0, now, capability_scope))
75
+ self._set(agent_id, capability_scope, new_bal)
76
+ return True
77
+
78
+ def balance(self, agent_id, capability_scope="global"):
79
+ now = time.time()
80
+ self._apply_decay(agent_id, now, capability_scope)
81
+ return self._get(agent_id, capability_scope)
82
+
83
+ def _get(self, agent_id, cap): return self.balances.get(agent_id, {}).get(cap, 0.0)
84
+ def _set(self, agent_id, cap, val):
85
+ if agent_id not in self.balances: self.balances[agent_id] = {}
86
+ self.balances[agent_id][cap] = val
87
+ def _apply_decay(self, agent_id, now, cap):
88
+ current = self._get(agent_id, cap)
89
+ if current <= 0: return
90
+ decayed = current * (1 - self.decay_lambda)
91
+ if decayed < current:
92
+ self.entries.append(LedgerEntry(agent_id, "decay", "decay", 0.0, 0.0, current - decayed, decayed, "credit_decay", 0.0, 0.0, now, cap))
93
+ self._set(agent_id, cap, decayed)
94
+
95
+
96
+ class Decision(Enum):
97
+ ALLOW = "allow"; DENY = "deny"; REQUIRE_APPROVAL = "require_approval"
98
+ DOWNGRADE = "downgrade"; ESCALATE = "escalate"; ASK_JUSTIFICATION = "ask_justification"
99
+
100
+
101
+ @dataclass
102
+ class ResourceDecision:
103
+ decision: Decision; reason: str; capability: str; downgrade_to: Optional[str] = None
104
+
105
+
106
+ class ResourceBroker:
107
+ RESOURCE_RISK = {"model_call": "medium", "retrieval_call": "low", "verifier_call": "medium",
108
+ "debate_turn": "low", "file_write": "high", "shell_execute": "high",
109
+ "memory_write": "medium", "human_escalation": "high", "larger_model": "medium"}
110
+ DEFAULT_THRESHOLDS = {"low": 0.5, "medium": 2.0, "high": 5.0}
111
+
112
+ def __init__(self, thresholds=None, urgency_boost=0.5):
113
+ self.thresholds = thresholds or self.DEFAULT_THRESHOLDS.copy()
114
+ self.urgency_boost = urgency_boost
115
+ self.denial_history = {}
116
+
117
+ def request(self, capability, agent_id, credit_balance, task_state=None, risk_score=0.0, gaming_flags=None):
118
+ task_state = task_state or {}; gaming_flags = gaming_flags or []
119
+ risk_class = self.RESOURCE_RISK.get(capability, "medium")
120
+ threshold = self.thresholds.get(risk_class, 2.0)
121
+ urgency = task_state.get("urgency", 0.0)
122
+ adjusted = max(0.1, threshold - urgency * self.urgency_boost)
123
+ if gaming_flags: return ResourceDecision(Decision.DENY, f"Gaming: {gaming_flags}", capability)
124
+ if risk_class == "high" and risk_score > 0.7: return ResourceDecision(Decision.REQUIRE_APPROVAL, f"High risk {risk_score:.2f}", capability)
125
+ if credit_balance >= adjusted: return ResourceDecision(Decision.ALLOW, f"Balance {credit_balance:.2f} >= {adjusted:.2f}", capability)
126
+ if credit_balance >= adjusted * 0.5:
127
+ if risk_class == "medium": return ResourceDecision(Decision.DOWNGRADE, f"Downgrading from {capability}", capability, "retrieval_call")
128
+ return ResourceDecision(Decision.ASK_JUSTIFICATION, f"Justification required", capability)
129
+ denials = self.denial_history.get(agent_id, 0)
130
+ if denials > 3: return ResourceDecision(Decision.ESCALATE, f"Denied {denials} times", capability)
131
+ self.denial_history[agent_id] = denials + 1
132
+ return ResourceDecision(Decision.DENY, f"Balance {credit_balance:.2f} < {adjusted:.2f}", capability)
133
+
134
+
135
+ # --- CODE BENCHMARK WITH THRESHOLD-DEPENDENT SELECTION ---
136
+
137
+ @dataclass
138
+ class CodeProblem:
139
+ task_id: str; difficulty: float; hidden_test_difficulty: float
140
+
141
+
142
+ class SimCodeAgent:
143
+ def __init__(self, agent_id, pass_easy, pass_hard, hidden_falloff, cost):
144
+ self.agent_id = agent_id
145
+ self.pass_easy = pass_easy
146
+ self.pass_hard = pass_hard
147
+ self.hidden_falloff = hidden_falloff
148
+ self.cost = cost
149
+ self.attempts = 0
150
+
151
+ def solve(self, problem):
152
+ self.attempts += 1
153
+ base = self.pass_easy * (1 - problem.difficulty) + self.pass_hard * problem.difficulty
154
+ public = random.random() < base
155
+ hidden_acc = max(0.0, base - self.hidden_falloff * problem.hidden_test_difficulty)
156
+ hidden = random.random() < hidden_acc
157
+ return {"public_pass": public, "hidden_pass": hidden, "compute_cost": self.cost}
158
+
159
+
160
+ def gen_problems(n, seed):
161
+ random.seed(seed); np.random.seed(seed)
162
+ return [CodeProblem(f"task_{i}", random.random(), random.random()) for i in range(n)]
163
+
164
+
165
+ def run_code_occ(problems, agents, oracle, ledger, broker, max_attempts=3):
166
+ total = 0; results = []
167
+ # Seed: agents earn initial credits proportional to quality
168
+ for a in agents:
169
+ q = (a.pass_easy + a.pass_hard) / 2
170
+ ledger.earn(a.agent_id, "seed", "seed", q * 20, 0.0, 0.0, "initial", "model_call")
171
+
172
+ for p in problems:
173
+ solved = False; cost = 0; used = []
174
+ ranked = sorted(agents, key=lambda a: a.cost / max(0.1, (a.pass_easy + a.pass_hard) / 2))
175
+
176
+ for agent in ranked:
177
+ if solved or len(used) >= max_attempts:
178
+ break
179
+
180
+ # KEY: Broker check before each attempt
181
+ balance = ledger.balance(agent.agent_id, "model_call")
182
+ dec = broker.request("model_call", agent.agent_id, balance,
183
+ task_state={"urgency": 0.3 if not solved else 0.0, "attempts": len(used)})
184
+
185
+ if dec.decision == Decision.DENY:
186
+ # Can't use this agent — try next
187
+ used.append(f"{agent.agent_id}_DENIED")
188
+ continue
189
+
190
+ r = agent.solve(p); cost += r["compute_cost"]; total += r["compute_cost"]; used.append(agent.agent_id)
191
+ solved = r["public_pass"]; hidden = r["hidden_pass"]
192
+
193
+ ora = oracle.score("code", {"attempt": len([u for u in used if not u.endswith("_DENIED")])}, {},
194
+ {"correctness": 1.0 if solved else 0.0, "pass_at_k": 1.0 if hidden else 0.0,
195
+ "compute_cost": cost, "public_pass": solved, "hidden_tests_pass": hidden},
196
+ agent_id=agent.agent_id)
197
+ if ora.raw_score > 0:
198
+ ledger.earn(agent.agent_id, p.task_id, "solve", ora.raw_score * 5, ora.raw_score, cost, "pass", "model_call")
199
+ else:
200
+ ledger.spend(agent.agent_id, p.task_id, "solve", 1.0, "model_call", "fail")
201
+ if hidden: break
202
+ results.append({"solved": solved, "cost": cost, "agents": used})
203
+
204
+ acc = sum(1 for r in results if r["solved"]) / len(results)
205
+ return {"accuracy": acc, "total_compute": total, "mean_compute": total / len(problems),
206
+ "mean_agents": sum(len(r["agents"]) for r in results) / len(results),
207
+ "denied_count": sum(1 for r in results for u in r["agents"] if u.endswith("_DENIED"))}
208
+
209
+
210
+ def run_qa_occ(dataset, oracle, ledger, broker, agent_acc=0.85):
211
+ total_compute = 0; correct = 0
212
+ ledger.earn("qa_agent", "seed", "seed", 20, 0.0, 0.0, "initial", "retrieval_call")
213
+ for item in dataset:
214
+ balance = ledger.balance("qa_agent", "retrieval_call")
215
+ dec = broker.request("retrieval_call", "qa_agent", balance, task_state={"urgency": 0.5})
216
+ if dec.decision == Decision.DENY:
217
+ continue
218
+ tokens = 200 if dec.decision == Decision.ALLOW else 100
219
+ total_compute += tokens
220
+ should_answer = item["type"] != "unanswerable"
221
+ ans = item["answer"] if (should_answer and random.random() < agent_acc) else None
222
+ conf = 0.9 if ans else 0.3
223
+ ora = oracle.score("retrieval_qa", {"abstained": ans is None}, item,
224
+ {"answer": ans, "confidence": conf, "evidence": {}, "compute_cost": tokens}, "qa_agent")
225
+ if ora.raw_score > 0:
226
+ ledger.earn("qa_agent", item["id"], "ans", ora.raw_score * 3, ora.raw_score, tokens, "correct", "retrieval_call")
227
+ correct += 1
228
+ else:
229
+ ledger.spend("qa_agent", item["id"], "ans", 0.5, "retrieval_call", "wrong")
230
+ return {"accuracy": correct / len(dataset), "total_compute": total_compute, "mean_compute": total_compute / len(dataset)}
231
+
232
+
233
+ def create_qa_dataset(seed=42, n=50):
234
+ random.seed(seed)
235
+ evidence_pool = ["alpha", "beta", "gamma", "delta"]
236
+ return [{"id": f"q_{i}", "question": f"Q{i}", "type": random.choice(["answerable", "unanswerable", "misleading", "incomplete", "conflicting"]),
237
+ "answer": random.choice(["paris", "42", "yes", "no", "tokyo"]), "evidence": random.sample(evidence_pool, k=random.randint(1, 3)),
238
+ "is_unanswerable": False} for i in range(n)]
239
+
240
+
241
+ # --- ABLATIONS ---
242
+
243
+ ABLATIONS = [
244
+ ("default", "Full OCC", 0.02, 2.0, 0.0001, True, {}),
245
+ ("no_decay", "No credit decay", 0.0, 2.0, 0.0001, True, {}),
246
+ ("fast_decay", "Aggressive decay", 0.1, 2.0, 0.0001, True, {}),
247
+ ("no_gaming_penalty", "No gaming penalties", 0.02, 0.0, 0.0001, True, {}),
248
+ ("high_gaming_penalty", "Severe gaming penalties", 0.02, 5.0, 0.0001, True, {}),
249
+ ("lenient_broker", "Lenient broker (thresholds x0.5)", 0.02, 2.0, 0.0001, True, {"low": 0.25, "medium": 1.0, "high": 2.5}),
250
+ ("strict_broker", "Strict broker (thresholds x2.0)", 0.02, 2.0, 0.0001, True, {"low": 1.0, "medium": 4.0, "high": 10.0}),
251
+ ("high_compute_cost", "High compute penalty (x10)", 0.02, 2.0, 0.001, True, {}),
252
+ ("low_compute_cost", "Low compute penalty (x0.1)", 0.02, 2.0, 0.00001, True, {}),
253
+ ("anti_gaming_off", "Anti-gaming disabled", 0.02, 2.0, 0.0001, False, {}),
254
+ ]
255
+
256
+
257
+ def run_all():
258
+ print("=" * 60)
259
+ print("OCC ABLATION RUNNER (DETAILED)")
260
+ print("=" * 60)
261
+ seed = 42; n_problems = 200; n_qa = 100
262
+ results = {"ablations": {}, "anti_gaming": {}}
263
+
264
+ for name, desc, decay, game_pen, comp_pen, anti_on, broker_thresh in ABLATIONS:
265
+ print(f"\n--- ABLATION: {name} ---")
266
+ oracle = ImpactOracle(compute_penalty_rate=comp_pen, gaming_penalty=game_pen if anti_on else 0.0)
267
+ ledger = CreditLedger(decay_lambda=decay)
268
+ broker = ResourceBroker(thresholds=broker_thresh if broker_thresh else None)
269
+
270
+ problems = gen_problems(n_problems, seed)
271
+ cheap = SimCodeAgent("cheap", 0.65, 0.15, 0.20, 60)
272
+ medium = SimCodeAgent("medium", 0.85, 0.35, 0.15, 150)
273
+ expensive = SimCodeAgent("expensive", 0.95, 0.65, 0.10, 350)
274
+ code_res = run_code_occ(problems, [cheap, medium, expensive], oracle, ledger, broker, max_attempts=3)
275
+ print(f" Code: acc={code_res['accuracy']:.3f}, compute={code_res['total_compute']:.0f}, denied={code_res['denied_count']}")
276
+
277
+ qa_data = create_qa_dataset(seed=seed, n=n_qa)
278
+ qa_res = run_qa_occ(qa_data, oracle, ledger, broker, agent_acc=0.85)
279
+ print(f" QA: acc={qa_res['accuracy']:.3f}, compute={qa_res['total_compute']:.0f}")
280
+
281
+ results["ablations"][name] = {"description": desc, "code": code_res, "qa": qa_res}
282
+
283
+ # Anti-gaming
284
+ print("\n--- ANTI-GAMING TESTS ---")
285
+ oracle = ImpactOracle(gaming_penalty=2.0)
286
+ normal_res = []; gamer_res = []
287
+ for _ in range(50):
288
+ public = random.random() < 0.9
289
+ hidden = random.random() < 0.5
290
+ ora_normal = oracle.score("code", {}, {}, {"correctness": 1.0 if public else 0.0, "pass_at_k": 1.0 if hidden else 0.0, "compute_cost": 150, "public_pass": public, "hidden_tests_pass": hidden})
291
+ normal_res.append(ora_normal.raw_score)
292
+ ora_gamer = oracle.score("code", {}, {}, {"correctness": 1.0, "pass_at_k": 0.0, "compute_cost": 100, "public_pass": True, "hidden_tests_pass": False})
293
+ gamer_res.append(ora_gamer.raw_score)
294
+ results["anti_gaming"]["hidden_test_gaming"] = {
295
+ "normal_mean_raw": sum(normal_res) / len(normal_res),
296
+ "gamer_mean_raw": sum(gamer_res) / len(gamer_res),
297
+ "gamer_penalized_rate": sum(1 for r in gamer_res if r < 0) / len(gamer_res),
298
+ }
299
+ print(f" Hidden-test: normal={results['anti_gaming']['hidden_test_gaming']['normal_mean_raw']:.2f}, gamer={results['anti_gaming']['hidden_test_gaming']['gamer_mean_raw']:.2f}")
300
+
301
+ ledger = CreditLedger()
302
+ ledger.earn("alice", "seed", "seed", 10, 0, 0, "initial")
303
+ ok = ledger.transfer("alice", "bob", 5.0)
304
+ results["anti_gaming"]["collusion"] = {"transfer_allowed": ok, "alice_balance": ledger.balance("alice"), "blocked": not ok}
305
+ print(f" Collusion: transfer_allowed={ok}, blocked={not ok}")
306
+
307
+ oracle = ImpactOracle()
308
+ abstention_rewards = []
309
+ for _ in range(10):
310
+ res = oracle.score("retrieval_qa", {"abstained": True}, {"is_unanswerable": False, "gold_answer": "yes"},
311
+ {"answer": None, "confidence": 0.9, "evidence": {}, "compute_cost": 50})
312
+ abstention_rewards.append(res.reward_value)
313
+ results["anti_gaming"]["abstention"] = {"mean_reward": sum(abstention_rewards) / len(abstention_rewards), "negative": sum(abstention_rewards) < 0}
314
+ print(f" Abstention: mean_reward={results['anti_gaming']['abstention']['mean_reward']:.2f}")
315
+
316
+ oracle = ImpactOracle()
317
+ spam_res = oracle.score("retrieval_qa", {}, {"gold_answer": "paris"},
318
+ {"answer": "london", "confidence": 0.1, "evidence": {}, "compute_cost": 5000})
319
+ results["anti_gaming"]["spam"] = {"reward": spam_res.reward_value, "tags": spam_res.failure_tags}
320
+ print(f" Spam: reward={spam_res.reward_value:.2f}, tags={spam_res.failure_tags}")
321
+
322
+ # Save
323
+ out = Path("/app/occ/reports")
324
+ out.mkdir(parents=True, exist_ok=True)
325
+ with open(out / "ablations_detailed.json", "w") as f:
326
+ json.dump(results, f, indent=2, default=str)
327
+
328
+ print("\n" + "=" * 60)
329
+ print("ABLATION SUMMARY")
330
+ print("=" * 60)
331
+ print(f"{'Name':<22} {'Code Acc':>9} {'Code Comp':>10} {'Denied':>8} {'QA Acc':>9} {'QA Comp':>10}")
332
+ for name, data in results["ablations"].items():
333
+ print(f"{name:<22} {data['code']['accuracy']:>9.3f} {data['code']['total_compute']:>10.0f} "
334
+ f"{data['code']['denied_count']:>8} {data['qa']['accuracy']:>9.3f} {data['qa']['total_compute']:>10.0f}")
335
+ print(f"\nSaved to {out / 'ablations_detailed.json'}")
336
+ return results
337
+
338
+
339
+ if __name__ == "__main__":
340
+ run_all()