narcolepticchicken commited on
Commit
522d111
·
verified ·
1 Parent(s): 02d02b3

Upload jobs/run_real_llm_standalone_v8.py

Browse files
Files changed (1) hide show
  1. jobs/run_real_llm_standalone_v8.py +336 -0
jobs/run_real_llm_standalone_v8.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Self-contained GPU job for real LLM code benchmark — V8.
3
+ CRITICAL FIX: Evalplus tests already contain check(candidate) calls.
4
+ Do NOT append check() — it causes TypeError (missing 'candidate' arg).
5
+ """
6
+ import json
7
+ import os
8
+ import re
9
+ import subprocess
10
+ import sys
11
+ import tempfile
12
+ import time
13
+ from dataclasses import dataclass, field
14
+ from enum import Enum
15
+ from pathlib import Path
16
+ from typing import Any, Dict, List, Optional
17
+
18
+ from datasets import load_dataset
19
+ from transformers import AutoModelForCausalLM, AutoTokenizer
20
+ import torch
21
+
22
+
23
+ # --- CORE (INLINE) ---
24
+
25
+ @dataclass
26
+ class OracleResult:
27
+ raw_score: float; cost_adjusted_score: float; confidence: float
28
+ evidence: Dict[str, Any]; reason: str
29
+ failure_tags: List[str] = field(default_factory=list)
30
+ reward_value: float = 0.0
31
+
32
+
33
+ class ImpactOracle:
34
+ def __init__(self, compute_penalty_rate=0.0001, gaming_penalty=2.0):
35
+ self.compute_penalty_rate = compute_penalty_rate
36
+ self.gaming_penalty = gaming_penalty
37
+
38
+ def score(self, mode, action, context, result, agent_id=""):
39
+ correctness = result.get("correctness", 0.0)
40
+ compute_cost = result.get("compute_cost", 0.0)
41
+ public_pass = result.get("public_pass", correctness)
42
+ hidden_pass = result.get("hidden_tests_pass", correctness)
43
+ failure_tags = []
44
+ if public_pass and not hidden_pass: failure_tags.append("gaming_hidden_tests")
45
+ raw = correctness * 1.0 - compute_cost * self.compute_penalty_rate
46
+ if "gaming_hidden_tests" in failure_tags: raw -= self.gaming_penalty
47
+ cost_adj = raw - compute_cost * self.compute_penalty_rate
48
+ return OracleResult(raw, cost_adj, result.get("confidence", correctness),
49
+ {"correctness": correctness}, f"corr={correctness:.2f}, cost={compute_cost}", failure_tags, cost_adj)
50
+
51
+
52
+ @dataclass
53
+ class LedgerEntry:
54
+ agent_id: str; task_id: str; action_id: str; earned_credit: float; spent_credit: float
55
+ decayed_credit: float; remaining_credit: float; reason: str; oracle_score: float
56
+ compute_cost: float; timestamp: float; capability_scope: str = "global"
57
+
58
+
59
+ class CreditLedger:
60
+ def __init__(self, decay_lambda=0.05):
61
+ self.entries = []; self.balances = {}; self.decay_lambda = decay_lambda
62
+
63
+ def earn(self, agent_id, task_id, action_id, amount, oracle_score, compute_cost, reason, capability_scope="global"):
64
+ now = time.time(); self._apply_decay(agent_id, now, capability_scope)
65
+ current = self._get(agent_id, capability_scope); new_bal = current + amount
66
+ self.entries.append(LedgerEntry(agent_id, task_id, action_id, amount, 0.0, 0.0, new_bal, reason, oracle_score, compute_cost, now, capability_scope))
67
+ self._set(agent_id, capability_scope, new_bal)
68
+
69
+ def spend(self, agent_id, task_id, action_id, amount, capability_scope="global", reason="spend"):
70
+ now = time.time(); self._apply_decay(agent_id, now, capability_scope)
71
+ current = self._get(agent_id, capability_scope)
72
+ if current < amount: return False
73
+ new_bal = current - amount
74
+ self.entries.append(LedgerEntry(agent_id, task_id, action_id, 0.0, amount, 0.0, new_bal, reason, 0.0, 0.0, now, capability_scope))
75
+ self._set(agent_id, capability_scope, new_bal)
76
+ return True
77
+
78
+ def balance(self, agent_id, capability_scope="global"):
79
+ now = time.time(); self._apply_decay(agent_id, now, capability_scope)
80
+ return self._get(agent_id, capability_scope)
81
+
82
+ def _get(self, agent_id, cap): return self.balances.get(agent_id, {}).get(cap, 0.0)
83
+ def _set(self, agent_id, cap, val):
84
+ if agent_id not in self.balances: self.balances[agent_id] = {}
85
+ self.balances[agent_id][cap] = val
86
+ def _apply_decay(self, agent_id, now, cap):
87
+ current = self._get(agent_id, cap)
88
+ if current <= 0: return
89
+ decayed = current * (1 - self.decay_lambda)
90
+ if decayed < current:
91
+ self.entries.append(LedgerEntry(agent_id, "decay", "decay", 0.0, 0.0, current - decayed, decayed, "credit_decay", 0.0, 0.0, now, cap))
92
+ self._set(agent_id, cap, decayed)
93
+
94
+
95
+ class Decision(Enum):
96
+ ALLOW = "allow"; DENY = "deny"; REQUIRE_APPROVAL = "require_approval"
97
+ DOWNGRADE = "downgrade"; ESCALATE = "escalate"; ASK_JUSTIFICATION = "ask_justification"
98
+
99
+
100
+ @dataclass
101
+ class ResourceDecision:
102
+ decision: Decision; reason: str; capability: str; downgrade_to: Optional[str] = None
103
+
104
+
105
+ class ResourceBroker:
106
+ RESOURCE_RISK = {"model_call": "medium", "retrieval_call": "low", "verifier_call": "medium",
107
+ "debate_turn": "low", "file_write": "high", "shell_execute": "high",
108
+ "memory_write": "medium", "human_escalation": "high", "larger_model": "medium"}
109
+ DEFAULT_THRESHOLDS = {"low": 0.5, "medium": 2.0, "high": 5.0}
110
+
111
+ def __init__(self, thresholds=None, urgency_boost=0.5):
112
+ self.thresholds = thresholds or self.DEFAULT_THRESHOLDS.copy()
113
+ self.urgency_boost = urgency_boost
114
+ self.denial_history = {}
115
+
116
+ def request(self, capability, agent_id, credit_balance, task_state=None, risk_score=0.0, gaming_flags=None):
117
+ task_state = task_state or {}; gaming_flags = gaming_flags or []
118
+ risk_class = self.RESOURCE_RISK.get(capability, "medium")
119
+ threshold = self.thresholds.get(risk_class, 2.0)
120
+ urgency = task_state.get("urgency", 0.0)
121
+ adjusted = max(0.1, threshold - urgency * self.urgency_boost)
122
+ if gaming_flags: return ResourceDecision(Decision.DENY, f"Gaming: {gaming_flags}", capability)
123
+ if risk_class == "high" and risk_score > 0.7: return ResourceDecision(Decision.REQUIRE_APPROVAL, f"High risk {risk_score:.2f}", capability)
124
+ if credit_balance >= adjusted: return ResourceDecision(Decision.ALLOW, f"Balance {credit_balance:.2f} >= {adjusted:.2f}", capability)
125
+ if credit_balance >= adjusted * 0.5:
126
+ if risk_class == "medium": return ResourceDecision(Decision.DOWNGRADE, f"Downgrading from {capability}", capability, "retrieval_call")
127
+ return ResourceDecision(Decision.ASK_JUSTIFICATION, f"Justification required", capability)
128
+ denials = self.denial_history.get(agent_id, 0)
129
+ if denials > 3: return ResourceDecision(Decision.ESCALATE, f"Denied {denials} times", capability)
130
+ self.denial_history[agent_id] = denials + 1
131
+ return ResourceDecision(Decision.DENY, f"Balance {credit_balance:.2f} < {adjusted:.2f}", capability)
132
+
133
+
134
+ # --- HELPERS ---
135
+
136
+ def extract_code_block(text: str) -> str:
137
+ """Extract code from markdown fenced code block."""
138
+ text = text.strip()
139
+ match = re.search(r'```(?:\w+)?\s*\n(.*?)\n```', text, re.DOTALL)
140
+ if match: return match.group(1).strip()
141
+ match2 = re.search(r'```(?:\w+)?\s*\n(.*)', text, re.DOTALL)
142
+ if match2:
143
+ candidate = match2.group(1).strip()
144
+ if candidate.endswith("```"): candidate = candidate[:-3].strip()
145
+ return candidate
146
+ return text
147
+
148
+
149
+ def contains_function_definition(code: str, entry_point: str) -> bool:
150
+ return bool(re.search(rf'\bdef\s+{re.escape(entry_point)}\b', code))
151
+
152
+
153
+ def run_test(code: str, test_code: str, timeout: int = 20):
154
+ # CRITICAL: evalplus tests already contain check() calls with arguments
155
+ # Do NOT append check()!
156
+ full = code + "\n\n" + test_code + "\n"
157
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
158
+ f.write(full)
159
+ tmp = f.name
160
+ try:
161
+ result = subprocess.run(['python', tmp], capture_output=True, text=True, timeout=timeout)
162
+ passed = result.returncode == 0
163
+ error = result.stderr[:400] if not passed else ""
164
+ except subprocess.TimeoutExpired:
165
+ passed = False; error = "Timeout"
166
+ except Exception as e:
167
+ passed = False; error = str(e)[:400]
168
+ finally:
169
+ os.unlink(tmp)
170
+ return passed, error
171
+
172
+
173
+ def wrap_prompt_chat(prompt: str, tokenizer) -> str:
174
+ system = "You are an expert Python programmer. Write ONLY the function definition. No markdown, no extra text."
175
+ messages = [
176
+ {"role": "system", "content": system},
177
+ {"role": "user", "content": prompt.strip()},
178
+ ]
179
+ if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
180
+ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
181
+ return f"system\n{system}\nuser\n{prompt.strip()}\nassistant\n"
182
+
183
+
184
+ # --- BENCHMARK ---
185
+
186
+ class RealLLMBenchmarkV8:
187
+ def __init__(self, model_name="Qwen/Qwen2.5-Coder-1.5B-Instruct", n_problems=20, seed=42):
188
+ self.model_name = model_name
189
+ self.n_problems = n_problems
190
+ self.seed = seed
191
+ self.oracle = ImpactOracle()
192
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
193
+ print(f"Using device: {self.device}, model: {self.model_name}")
194
+
195
+ def load_problems(self):
196
+ ds = load_dataset("evalplus/humanevalplus", split="test")
197
+ return [{"task_id": item["task_id"], "prompt": item["prompt"], "test": item["test"], "entry_point": item["entry_point"]}
198
+ for i, item in enumerate(ds) if i < self.n_problems]
199
+
200
+ def load_model(self):
201
+ print(f"Loading {self.model_name}...")
202
+ tok = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
203
+ model = AutoModelForCausalLM.from_pretrained(
204
+ self.model_name, trust_remote_code=True,
205
+ torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32,
206
+ device_map="auto" if self.device == "cuda" else None,
207
+ )
208
+ if self.device == "cpu": model = model.to("cpu").float()
209
+ print(f"Model loaded. Chat template: {bool(tok.chat_template)}")
210
+ return model, tok
211
+
212
+ def generate(self, model, tok, prompt_raw: str, max_new_tokens: int = 512):
213
+ chat_prompt = wrap_prompt_chat(prompt_raw, tok)
214
+ inputs = tok(chat_prompt, return_tensors="pt").to(model.device)
215
+ with torch.no_grad():
216
+ outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=tok.eos_token_id)
217
+ gen = tok.decode(outputs[0], skip_special_tokens=True)
218
+ prompt_decoded = tok.decode(inputs.input_ids[0], skip_special_tokens=True)
219
+ code = gen[len(prompt_decoded):].strip()
220
+ return code
221
+
222
+ def evaluate_one(self, problem, model, tok, max_new_tokens=512):
223
+ raw = self.generate(model, tok, problem["prompt"], max_new_tokens=max_new_tokens)
224
+ tokens = len(tok.encode(raw))
225
+ code = extract_code_block(raw)
226
+
227
+ # If generated code contains the function definition, use as-is
228
+ if contains_function_definition(code, problem["entry_point"]):
229
+ test_code = code
230
+ else:
231
+ test_code = problem["prompt"] + code
232
+
233
+ passed, error = run_test(test_code, problem["test"])
234
+
235
+ # Try alternative if failed
236
+ if not passed:
237
+ if contains_function_definition(code, problem["entry_point"]):
238
+ alt_code = problem["prompt"] + code
239
+ else:
240
+ alt_code = code
241
+ passed2, error2 = run_test(alt_code, problem["test"])
242
+ if passed2:
243
+ passed = True; error = ""
244
+ else:
245
+ error = error if len(error) < len(error2) else error2
246
+
247
+ return passed, tokens, raw, error, code
248
+
249
+ def run_baseline(self, problems, model, tok, max_new_tokens=512):
250
+ results = []; total_compute = 0
251
+ for i, problem in enumerate(problems):
252
+ passed, tokens, raw, error, code = self.evaluate_one(problem, model, tok, max_new_tokens)
253
+ total_compute += tokens
254
+ results.append({"task_id": problem["task_id"], "passed": passed, "tokens": tokens, "raw": raw[:300], "error": error[:200]})
255
+ print(f" {problem['task_id']}: passed={passed}, tokens={tokens}")
256
+ if not passed and i < 3:
257
+ print(f" error={error[:200]!r}")
258
+ print(f" [CODE first 200 chars]: {code[:200]!r}")
259
+ return {"accuracy": sum(1 for r in results if r["passed"]) / len(results), "total_compute": total_compute, "results": results}
260
+
261
+ def run_occ(self, problems, model, tok, max_new_tokens_first=256, max_new_tokens_retry=512):
262
+ ledger = CreditLedger(decay_lambda=0.02)
263
+ broker = ResourceBroker()
264
+ ledger.earn("code_agent", "seed", "seed", 25.0, 0.0, 0.0, "initial", "model_call")
265
+ results = []; total_compute = 0
266
+
267
+ for problem in problems:
268
+ budget_remaining = 3000; attempts = 0; passed = False
269
+ while budget_remaining > 100 and attempts < 3 and not passed:
270
+ attempts += 1
271
+ balance = ledger.balance("code_agent", "model_call")
272
+ dec = broker.request("model_call", "code_agent", balance,
273
+ task_state={"attempts": attempts, "budget_remaining": budget_remaining})
274
+ if dec.decision == Decision.DENY: break
275
+ max_tok = max_new_tokens_first if attempts == 1 else max_new_tokens_retry
276
+ code_raw = self.generate(model, tok, problem["prompt"], max_new_tokens=max_tok)
277
+ tokens = len(tok.encode(code_raw)); budget_remaining -= tokens; total_compute += tokens
278
+ code = extract_code_block(code_raw)
279
+ if contains_function_definition(code, problem["entry_point"]):
280
+ test_code = code
281
+ else:
282
+ test_code = problem["prompt"] + code
283
+ passed, error = run_test(test_code, problem["test"])
284
+ score = 1.0 if passed else 0.0
285
+ ora = self.oracle.score("code", {"attempt": attempts}, {},
286
+ {"correctness": score, "compute_cost": tokens, "public_pass": passed, "hidden_tests_pass": passed}, "code_agent")
287
+ if passed: ledger.earn("code_agent", problem["task_id"], f"att_{attempts}", 5.0, ora.raw_score, tokens, "pass", "model_call")
288
+ else: ledger.spend("code_agent", problem["task_id"], f"att_{attempts}", 1.0, "model_call", "fail")
289
+ if attempts >= 2 and not passed: break
290
+ results.append({"task_id": problem["task_id"], "passed": passed, "attempts": attempts})
291
+ print(f" {problem['task_id']}: passed={passed}, attempts={attempts}")
292
+ return {"accuracy": sum(1 for r in results if r["passed"]) / len(results), "total_compute": total_compute, "results": results}
293
+
294
+ def run_all(self):
295
+ problems = self.load_problems()
296
+ print(f"Loaded {len(problems)} problems")
297
+ model, tok = self.load_model()
298
+ print("\n--- Baseline ---")
299
+ baseline = self.run_baseline(problems, model, tok)
300
+ print(f"Baseline: acc={baseline['accuracy']:.3f}, compute={baseline['total_compute']}")
301
+ print("\n--- OCC ---")
302
+ occ = self.run_occ(problems, model, tok)
303
+ print(f"OCC: acc={occ['accuracy']:.3f}, compute={occ['total_compute']}")
304
+ return {
305
+ "baseline": baseline, "occ": occ,
306
+ "comparison": {
307
+ "baseline_accuracy": baseline["accuracy"], "occ_accuracy": occ["accuracy"],
308
+ "baseline_compute": baseline["total_compute"], "occ_compute": occ["total_compute"],
309
+ "compute_reduction": 1.0 - (occ["total_compute"] / max(baseline["total_compute"], 1)),
310
+ "accuracy_delta": occ["accuracy"] - baseline["accuracy"],
311
+ }
312
+ }
313
+
314
+
315
+ def main():
316
+ bench = RealLLMBenchmarkV8(n_problems=20, seed=42)
317
+ results = bench.run_all()
318
+ print("\n" + "=" * 60)
319
+ print("REAL LLM CODE BENCHMARK (V8)")
320
+ print("=" * 60)
321
+ comp = results["comparison"]
322
+ print(f"Baseline accuracy: {comp['baseline_accuracy']:.3f}")
323
+ print(f"OCC accuracy: {comp['occ_accuracy']:.3f}")
324
+ print(f"Baseline compute: {comp['baseline_compute']}")
325
+ print(f"OCC compute: {comp['occ_compute']}")
326
+ print(f"Compute reduction: {comp['compute_reduction']:.1%}")
327
+ print(f"Accuracy delta: {comp['accuracy_delta']:+.3f}")
328
+ out_dir = Path("/app/occ/reports")
329
+ out_dir.mkdir(parents=True, exist_ok=True)
330
+ with open(out_dir / "benchmark_code_real_llm_v8_results.json", "w") as f:
331
+ json.dump(results, f, indent=2, default=str)
332
+ print(f"\nSaved to {out_dir / 'benchmark_code_real_llm_v8_results.json'}")
333
+
334
+
335
+ if __name__ == "__main__":
336
+ main()