narcolepticchicken commited on
Commit
a47c93c
·
verified ·
1 Parent(s): 1d13104

Upload jobs/occ_cheap_baselines.py

Browse files
Files changed (1) hide show
  1. jobs/occ_cheap_baselines.py +279 -0
jobs/occ_cheap_baselines.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ OCC Cheap Baselines — Additional conditions for workshop-paper strength.
4
+ ======================================================================
5
+ Adds 6 baselines without new infrastructure:
6
+ 1. Confidence-gated debate (drop low-confidence turns)
7
+ 2. Disagreement-gated debate (extra turns only when agents disagree)
8
+ 3. Round-robin capped (equal turns but hard token cap)
9
+ 4. Single-agent best-of-N (same budget, no multi-agent)
10
+ 5. No-adversary 4-agent (all honest — does collapse still happen?)
11
+ 6. Reputation-only allocator (earn/lose score, no decay/transfer)
12
+
13
+ 30 topics, 2 seeds, Qwen3-Coder-30B-A3B-Instruct.
14
+ """
15
+
16
+ import json, os, sys, time, random, torch
17
+ from pathlib import Path
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer
19
+
20
+ MODEL = os.environ.get("MODEL", "Qwen/Qwen3-Coder-30B-A3B-Instruct")
21
+ SEEDS = [int(s) for s in os.environ.get("SEEDS", "42,123").split(",")]
22
+ OUT = Path(os.environ.get("OUT_DIR", "/app/results"))
23
+ OUT.mkdir(parents=True, exist_ok=True)
24
+ START = time.time()
25
+
26
+ def log(msg):
27
+ print(f"[+{time.time()-START:5.0f}s] {msg}", flush=True)
28
+
29
+ DEBATE_TOPICS = [
30
+ {"q": "Is Python faster than C for numerical computation?", "truth": "no"},
31
+ {"q": "Does water boil at 100C at all altitudes?", "truth": "no"},
32
+ {"q": "Can quantum computers break RSA-2048 today (2026)?", "truth": "no"},
33
+ {"q": "Is the Earth core hotter than the surface of the Sun?", "truth": "yes"},
34
+ {"q": "Does the Moon have an atmosphere?", "truth": "no"},
35
+ {"q": "Can sound travel through a vacuum?", "truth": "no"},
36
+ {"q": "Is JavaScript single-threaded by default in browsers?", "truth": "yes"},
37
+ {"q": "Does DNA replication occur in the nucleus of eukaryotic cells?", "truth": "yes"},
38
+ {"q": "Can a protein structure be determined with 100% certainty from X-ray?", "truth": "no"},
39
+ {"q": "Is gradient descent guaranteed to find global min for convex functions?", "truth": "yes"},
40
+ {"q": "Can GPT-4 reliably solve novel math proofs without supervision?", "truth": "no"},
41
+ {"q": "Is P vs NP solved as of 2026?", "truth": "no"},
42
+ {"q": "Do all metals expand when heated?", "truth": "no"},
43
+ {"q": "Is the speed of light constant in all reference frames?", "truth": "yes"},
44
+ {"q": "Can a program determine if an arbitrary program halts?", "truth": "no"},
45
+ {"q": "Is the Earth flat?", "truth": "no"},
46
+ {"q": "Does CO2 make up more than 1 percent of Earth atmosphere?", "truth": "no"},
47
+ {"q": "Can classical computers efficiently simulate quantum?", "truth": "no"},
48
+ {"q": "Is the golden ratio exactly (1+sqrt5)/2?", "truth": "yes"},
49
+ {"q": "Can 1-hidden-layer NN approximate any continuous function?", "truth": "yes"},
50
+ {"q": "Does entropy always increase in isolated systems?", "truth": "yes"},
51
+ {"q": "Is Python GIL removed in CPython 3.13+?", "truth": "yes"},
52
+ {"q": "Do sharks get cancer?", "truth": "yes"},
53
+ {"q": "Is Antarctica a country?", "truth": "no"},
54
+ {"q": "Can humans survive without gut bacteria?", "truth": "yes"},
55
+ {"q": "Do all birds fly?", "truth": "no"},
56
+ {"q": "Is lightning hotter than the Sun surface?", "truth": "yes"},
57
+ {"q": "Can finite-tape TM recognize all recursive languages?", "truth": "no"},
58
+ {"q": "Is the Riemann Hypothesis proved as of 2026?", "truth": "no"},
59
+ {"q": "Does gravitational lensing confirm GR?", "truth": "yes"},
60
+ ]
61
+
62
+ _model = None
63
+ _tok = None
64
+
65
+ def get_model():
66
+ global _model, _tok
67
+ if _model is None:
68
+ log(f"Loading {MODEL}...")
69
+ _tok = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
70
+ _tok.pad_token = _tok.eos_token
71
+ _model = AutoModelForCausalLM.from_pretrained(
72
+ MODEL, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto")
73
+ log(f"Loaded. Device: {_model.device}")
74
+ return _model, _tok
75
+
76
+ def generate(prompt, max_tokens=512, temperature=0.7):
77
+ model, tok = get_model()
78
+ inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
79
+ ilen = inputs.input_ids.shape[1]
80
+ with torch.no_grad():
81
+ out = model.generate(**inputs, max_new_tokens=max_tokens, do_sample=True,
82
+ temperature=temperature, top_p=0.9, pad_token_id=tok.eos_token_id)
83
+ ntok = out.shape[1] - ilen
84
+ return tok.decode(out[0][ilen:], skip_special_tokens=True), ntok
85
+
86
+ def extract_position(text):
87
+ t = text.strip()
88
+ fl = t.split("\n")[0].lower()
89
+ if fl.startswith("yes:") or fl.startswith("yes "): return "yes"
90
+ if fl.startswith("no:") or fl.startswith("no "): return "no"
91
+ for w in ["clearly yes", "definitely yes", "answer is yes"]:
92
+ if w in fl: return "yes"
93
+ for w in ["clearly no", "definitely no", "answer is no"]:
94
+ if w in fl: return "no"
95
+ tc = t.lower()
96
+ return "yes" if tc.count(" yes ") > tc.count(" no ") else "no" if tc.count(" no ") > tc.count(" yes ") else "unclear"
97
+
98
+ def extract_confidence(text):
99
+ tlow = text.lower()
100
+ if "high confidence" in tlow or "confidence: high" in tlow: return 1.0
101
+ if "medium confidence" in tlow or "confidence: medium" in tlow: return 0.6
102
+ if "low confidence" in tlow or "confidence: low" in tlow: return 0.3
103
+ return 0.5
104
+
105
+ def accuracy(results):
106
+ n = len(results); c = sum(1 for r in results if r["correct"]); return c/n, c, n
107
+
108
+ def run_confidence_gated():
109
+ results = []
110
+ for t in DEBATE_TOPICS:
111
+ positions = []
112
+ for rnd in range(3):
113
+ for a in range(4):
114
+ persona = "honest" if a < 3 else "adversarial"
115
+ text, _ = generate(
116
+ f"Question: {t['q']}\nYou are a {persona} debater. Round {rnd+1}/3. "
117
+ f"Start with YES: or NO:. State your confidence as HIGH, MEDIUM, or LOW.\n", 512)
118
+ if extract_confidence(text) >= 0.5:
119
+ positions.append(extract_position(text))
120
+ votes = [p for p in positions if p != "unclear"]
121
+ winner = max(set(votes), key=votes.count) if votes else "unclear"
122
+ results.append({"topic": t["q"], "truth": t["truth"], "winner": winner, "correct": winner == t["truth"]})
123
+ return results
124
+
125
+ def run_disagreement_gated():
126
+ results = []
127
+ for t in DEBATE_TOPICS:
128
+ r1_positions = []
129
+ for a in range(4):
130
+ persona = "honest" if a < 3 else "adversarial"
131
+ text, _ = generate(
132
+ f"Question: {t['q']}\nYou are a {persona} debater. Round 1/3. "
133
+ f"Start with YES: or NO: followed by brief explanation.\n", 512)
134
+ r1_positions.append(extract_position(text))
135
+ yes_c = sum(1 for p in r1_positions if p == "yes")
136
+ no_c = sum(1 for p in r1_positions if p == "no")
137
+ disagreed = yes_c > 0 and no_c > 0
138
+ all_positions = list(r1_positions)
139
+ if disagreed:
140
+ for rnd in [2, 3]:
141
+ for a in range(4):
142
+ persona = "honest" if a < 3 else "adversarial"
143
+ text, _ = generate(
144
+ f"Question: {t['q']}\nYou are a {persona} debater. Round {rnd}/3. "
145
+ f"Start with YES: or NO: followed by brief explanation.\n", 512)
146
+ all_positions.append(extract_position(text))
147
+ votes = [p for p in all_positions if p != "unclear"]
148
+ winner = max(set(votes), key=votes.count) if votes else "unclear"
149
+ results.append({"topic": t["q"], "truth": t["truth"], "winner": winner,
150
+ "correct": winner == t["truth"], "disagreed": disagreed})
151
+ return results
152
+
153
+ def run_capped_debate(max_tok=2000):
154
+ results = []
155
+ for t in DEBATE_TOPICS:
156
+ all_positions = []; tok_used = 0
157
+ for rnd in range(3):
158
+ for a in range(4):
159
+ if tok_used >= max_tok: break
160
+ persona = "honest" if a < 3 else "adversarial"
161
+ rem = max(50, max_tok - tok_used)
162
+ text, nt = generate(
163
+ f"Question: {t['q']}\nYou are a {persona} debater. Round {rnd+1}/3. "
164
+ f"Be very concise. Start with YES: or NO:.\n", max_tokens=min(128, rem))
165
+ all_positions.append(extract_position(text)); tok_used += nt
166
+ if tok_used >= max_tok: break
167
+ votes = [p for p in all_positions if p != "unclear"]
168
+ winner = max(set(votes), key=votes.count) if votes else "unclear"
169
+ results.append({"topic": t["q"], "truth": t["truth"], "winner": winner,
170
+ "correct": winner == t["truth"], "tokens_used": tok_used})
171
+ return results
172
+
173
+ def run_single_agent_best_of_n(n=12):
174
+ results = []
175
+ for t in DEBATE_TOPICS:
176
+ positions = []
177
+ for _ in range(n):
178
+ text, _ = generate(
179
+ f"Question: {t['q']}\nYou are a honest debater. "
180
+ f"Start with YES: or NO: followed by brief explanation.\n", 512)
181
+ positions.append(extract_position(text))
182
+ votes = [p for p in positions if p != "unclear"]
183
+ winner = max(set(votes), key=votes.count) if votes else "unclear"
184
+ results.append({"topic": t["q"], "truth": t["truth"], "winner": winner, "correct": winner == t["truth"]})
185
+ return results
186
+
187
+ def run_no_adversary_3round():
188
+ results = []
189
+ for t in DEBATE_TOPICS:
190
+ all_positions = []
191
+ for rnd in range(3):
192
+ for a in range(4):
193
+ text, _ = generate(
194
+ f"Question: {t['q']}\nYou are a honest debater. Round {rnd+1}/3. "
195
+ f"Start with YES: or NO: followed by brief explanation.\n", 512)
196
+ all_positions.append(extract_position(text))
197
+ votes = [p for p in all_positions if p != "unclear"]
198
+ winner = max(set(votes), key=votes.count) if votes else "unclear"
199
+ results.append({"topic": t["q"], "truth": t["truth"], "winner": winner, "correct": winner == t["truth"]})
200
+ return results
201
+
202
+ def run_reputation_only():
203
+ results = []
204
+ for t in DEBATE_TOPICS:
205
+ rep = [1.0, 1.0, 1.0, 1.0]; all_positions = []; rnd_positions = []
206
+ for rnd in range(3):
207
+ rp = []
208
+ for a in range(4):
209
+ persona = "honest" if a < 3 else "adversarial"
210
+ text, _ = generate(
211
+ f"Question: {t['q']}\nYou are a {persona} debater. Round {rnd+1}/3. "
212
+ f"Start with YES: or NO: followed by brief explanation.\n", 512)
213
+ pos = extract_position(text); rp.append(pos); all_positions.append(pos)
214
+ rnd_positions.append(rp)
215
+ votes = [p for p in rp if p != "unclear"]
216
+ if votes:
217
+ w = max(set(votes), key=votes.count)
218
+ for a in range(4): rep[a] += 0.1 if rp[a] == w else -0.1; rep[a] = max(0, rep[a])
219
+ yes_w = sum(rep[a] for a in range(4) if rnd_positions[-1][a] == "yes")
220
+ no_w = sum(rep[a] for a in range(4) if rnd_positions[-1][a] == "no")
221
+ winner = "yes" if yes_w > no_w else "no" if no_w > yes_w else "unclear"
222
+ results.append({"topic": t["q"], "truth": t["truth"], "winner": winner, "correct": winner == t["truth"]})
223
+ return results
224
+
225
+ CONDITIONS = [
226
+ ("confidence_gated", run_confidence_gated),
227
+ ("disagreement_gated", run_disagreement_gated),
228
+ ("capped_debate", run_capped_debate),
229
+ ("single_agent_best_of_n", run_single_agent_best_of_n),
230
+ ("no_adversary_3round", run_no_adversary_3round),
231
+ ("reputation_only", run_reputation_only),
232
+ ]
233
+
234
+ all_results = {"model": MODEL, "seeds": {}, "conditions": [c[0] for c in CONDITIONS]}
235
+
236
+ for seed in SEEDS:
237
+ torch.manual_seed(seed); random.seed(seed)
238
+ if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
239
+ log(f"\n{'='*60}\nSEED {seed}\n{'='*60}")
240
+ get_model()
241
+ seed_results = {}
242
+ for name, fn in CONDITIONS:
243
+ log(f"--- {name} ---"); t0 = time.time()
244
+ try:
245
+ res = fn(); acc, corr, total = accuracy(res)
246
+ extra = {}
247
+ if name == "disagreement_gated":
248
+ extra = {"disagreed": sum(1 for r in res if r.get("disagreed"))}
249
+ elif name == "capped_debate" and res:
250
+ extra = {"avg_tokens": round(sum(r.get("tokens_used", 0) for r in res)/len(res))}
251
+ seed_results[name] = {"accuracy": acc, "correct": corr, "total": total, **extra}
252
+ log(f" {corr}/{total} ({acc:.3f}) ({time.time()-t0:.0f}s)")
253
+ except Exception as e:
254
+ log(f" ERROR: {e}"); seed_results[name] = {"accuracy": None, "error": str(e)}
255
+ all_results["seeds"][str(seed)] = seed_results
256
+
257
+ summary = {}
258
+ for name, _ in CONDITIONS:
259
+ accs = [all_results["seeds"][str(s)][name].get("accuracy", 0) or 0 for s in SEEDS
260
+ if all_results["seeds"].get(str(s), {}).get(name, {}).get("accuracy") is not None]
261
+ if accs:
262
+ mn, mx = min(accs), max(accs); mean = sum(accs)/len(accs)
263
+ log(f" {name:<25} {mean:7.3f} [{mn:.3f}, {mx:.3f}]")
264
+ summary[name] = {"mean": mean, "min": mn, "max": mx}
265
+ all_results["summary"] = summary
266
+
267
+ path = OUT / "cheap_baselines_results.json"
268
+ path.write_text(json.dumps(all_results, indent=2))
269
+ log(f"\nSaved -> {path}")
270
+
271
+ try:
272
+ from huggingface_hub import HfApi
273
+ HfApi().upload_file(path_or_fileobj=str(path), path_in_repo="reports/cheap_baselines_results.json",
274
+ repo_id="narcolepticchicken/occ-stack", repo_type="model",
275
+ commit_message="Cheap baselines results")
276
+ log("Pushed to Hub")
277
+ except Exception as e:
278
+ log(f"Push failed: {e}")
279
+ log(f"Total elapsed: {time.time()-START:.0f}s")