Terry Zhuo
update
3204d18
raw
history blame
9.62 kB
import gradio as gr
import json
import multiprocessing
import os
import pickle
import threading
import time
from collections import Counter, defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed, wait, FIRST_COMPLETED
from datetime import datetime
from typing import Any, Dict, List, Tuple
from warnings import warn
import numpy as np
from termcolor import cprint
from tqdm import tqdm
from bigcodebench.data import get_bigcodebench, get_bigcodebench_hash, load_solutions
from bigcodebench.data.utils import CACHE_DIR
from bigcodebench.eval import PASS, compatible_eval_result, estimate_pass_at_k, untrusted_check
from bigcodebench.gen.util import trusted_check
Result = Tuple[str, List[bool]]
def get_groundtruth(n_workers, problems, hashcode, check_gt_only, max_as_limit, max_data_limit, max_stack_limit, min_time_limit):
cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl")
if os.path.exists(cache_file):
if check_gt_only:
os.remove(cache_file)
else:
print(f"Load from ground-truth from {cache_file}")
with open(cache_file, "rb") as f:
return pickle.load(f)
os.makedirs(CACHE_DIR, exist_ok=True)
print("\nAsserting the groundtruth...")
tbegin = time.time()
with ProcessPoolExecutor(max_workers=n_workers) as executor:
futures = []
n_samples = 0
expected_time = dict()
for problem in problems.values():
args = (
problem["complete_prompt"] + "\n" + problem["canonical_solution"],
problem["test"],
problem["task_id"],
max_as_limit,
max_data_limit,
max_stack_limit,
min_time_limit,
)
futures.append(executor.submit(trusted_check, *args))
n_samples += 1
for future in tqdm(as_completed(futures), total=n_samples):
result = future.result()
expected_time[result["task_id"]] = result["time"]
print(f"Expected outputs computed in {time.time() - tbegin:.2f}s")
if any(expected_time.values()):
with open(cache_file, "wb") as f:
pickle.dump(expected_time, f)
return expected_time
def check_correctness(
completion_id: int,
problem: Dict[str, Any],
solution: str,
max_as_limit: float,
max_data_limit: float,
max_stack_limit: float,
identifier=None,
min_time_limit: float = 0.1,
gt_time_limit: float = 2.0,
) -> Dict[str, Result]:
ret = {
"completion_id": completion_id,
"task_id": problem["task_id"],
"_identifier": identifier,
"solution": solution,
}
ret["base"] = untrusted_check(
solution,
problem["test"],
problem["entry_point"],
max_as_limit,
max_data_limit,
max_stack_limit,
min_time_limit,
gt_time_limit,
)
return ret
def evaluate(
split: str,
subset: str,
samples: str,
pass_k: str="1,5,10",
parallel: int = None,
min_time_limit: float = 1,
max_as_limit: int = 30 * 1024,
max_data_limit: int = 30 * 1024,
max_stack_limit: int = 10,
check_gt_only: bool = False,
no_gt: bool = False,
):
pass_k = [int(k.strip()) for k in pass_k.split(',') if k.strip().isdigit()]
if parallel is None:
n_workers = max(1, multiprocessing.cpu_count() // 2)
else:
n_workers = parallel
if check_gt_only:
samples = "__dummy__.jsonl"
extra = subset + "_" if subset != "full" else ""
problems = get_bigcodebench(subset=subset)
dataset_hash = get_bigcodebench_hash(subset=subset)
if not no_gt:
expected_time = get_groundtruth(n_workers, problems, dataset_hash, check_gt_only, max_as_limit, max_data_limit, max_stack_limit, min_time_limit)
else:
expected_time = {task_id: None for task_id in problems}
gt_pass_rate = np.mean([1 if v is not None else 0 for k, v in expected_time.items() if k in problems])
failed_tasks = [k for k, v in expected_time.items() if v is None and k in problems]
if check_gt_only:
if gt_pass_rate > 0.99:
cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}", "green")
else:
cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}\nPlease be cautious!", "red")
if len(failed_tasks) > 0:
cprint(f"Failed tasks: {failed_tasks}", "red")
return {"gt_pass_rate":float(gt_pass_rate), "failed_tasks": failed_tasks}
results = {
"date": datetime.now().strftime("%Y-%m-%d %H:%M"),
"eval": {},
}
with ProcessPoolExecutor(max_workers=n_workers) as executor:
futures = []
completion_id = Counter()
n_samples = 0
eval_results = defaultdict(list) # task_id ->
remainings = set()
print("Reading samples...")
for sample in tqdm(load_solutions(samples)):
task_id = sample["task_id"]
if task_id not in problems:
warn(
f"Task {task_id} is found in the samples but not found in the dataset"
)
continue
solution = (
sample["solution"]
if "solution" in sample
else problems[task_id]["complete_prompt"] + sample["completion"]
)
if "sanitized-calibrated" in samples:
solution = problems[task_id]["code_prompt"] + "\n pass\n" + solution
remainings.add(sample["_identifier"])
args = (
completion_id[task_id],
problems[task_id],
solution,
max_as_limit,
max_data_limit,
max_stack_limit,
sample["_identifier"],
min_time_limit,
expected_time[task_id] if expected_time[task_id] else 20
)
futures.append(executor.submit(check_correctness, *args))
completion_id[task_id] += 1
n_samples += 1
assert n_samples == len(remainings), "Missing problems in unfinished"
assert len(completion_id) == len(problems), "Missing problems in samples"
def stucking_checker():
not_done = futures
while len(not_done) > 0:
done, not_done = wait(not_done, timeout=240, return_when=FIRST_COMPLETED)
if len(done) == 0:
warn("No samples have finished testing in the last 240s")
warn(f"{len(remainings)} samples to be tested: {remainings}")
threading.Thread(target=stucking_checker).start()
for future in tqdm(as_completed(futures), total=n_samples):
result = future.result()
remainings.remove(result["_identifier"])
eval_results[result["task_id"]].append(result)
# sort the results for each problem by completion_id
for task_id, task_results in eval_results.items():
task_results.sort(key=lambda x: x["completion_id"])
results["eval"][task_id] = []
for res in task_results:
stat, details = res["base"]
results["eval"][task_id].append(
{
"task_id": task_id,
"solution": res["solution"],
"status": stat,
"details": details,
}
)
# Calculate pass@k.
total = np.array([len(r) for k, r in results["eval"].items() if k in problems])
base_correct = []
for key, res in results["eval"].items():
if key not in problems:
continue
bc = sum([r["status"] == PASS for r in res])
base_correct.append(bc)
base_correct = np.array(base_correct)
pass_at_k = {
f"pass@{k}": estimate_pass_at_k(total, base_correct, k).mean()
for k in pass_k
if total.min() >= k
}
pass_at_k["model"] = os.path.basename(samples).split("--bigcodebench-")[0]
pass_at_k["split"] = split
pass_at_k["subset"] = subset
pass_at_k["calibrated"] = "sanitized-calibrated" in samples
pass_at_k["gt_pass_rate"] = gt_pass_rate
pass_at_k["failed_tasks"] = failed_tasks
return results, pass_at_k
def run_gradio():
interface = gr.Interface(
fn=evaluate,
inputs=[
gr.Dropdown(["complete", "instruct"], label="Split"),
gr.Dropdown(["full", "hard"], label="Subset"),
gr.File(label="Samples Path (.jsonl)"),
gr.Textbox(label="Pass k Values (comma-separated)", value="1,5,10"),
gr.Slider(1, multiprocessing.cpu_count(), step=1, label="Parallel Workers"),
gr.Slider(0.1, 10, step=0.1, label="Min Time Limit", value=1),
gr.Slider(1, 100 * 1024, step=1024, label="Max AS Limit", value=30 * 1024),
gr.Slider(1, 100 * 1024, step=1024, label="Max Data Limit", value=30 * 1024),
gr.Slider(1, 100, step=1, label="Max Stack Limit", value=10),
gr.Checkbox(label="Check GT Only"),
gr.Checkbox(label="No GT"),
],
outputs=[
gr.JSON(label="Results"),
gr.JSON(label="Eval Results"),
],
# concurrency_limit=None
)
interface.queue(default_concurrency_limit=None)
interface.launch(show_error=True)
if __name__ == "__main__":
run_gradio()
# evaluate("complete", "hard", "meta-llama--Llama-3.2-3B-Instruct--bigcodebench-instruct--vllm-0-1.jsonl")