jjyang77
move scripts from testing space
25db7e9
raw
history blame
5.4 kB
import logging
import os
from collections import Counter, defaultdict
import multiprocessing
from datetime import datetime
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Dict, List, Tuple
import gc
from fastapi import FastAPI, HTTPException
from fastapi.responses import RedirectResponse
from api.code_execution import untrusted_check
Result = Tuple[str, List[bool]]
def create_app() -> FastAPI:
level = os.environ.get("LOG_LEVEL", default=logging.INFO)
logging.basicConfig(level=level)
logger = logging.getLogger(__name__)
app = FastAPI()
@app.get("/")
def root():
return RedirectResponse("/docs")
@app.get("/health", status_code=204)
def health():
return
@app.post("/evaluate/")
async def evaluate(
samples: List[dict],
calibrate: bool = True,
parallel: int = -1,
min_time_limit: float = 1,
max_as_limit: int = 30 * 1024,
max_data_limit: int = 30 * 1024,
max_stack_limit: int = 10,
no_gt: bool = True,
) -> dict:
"""
Evaluate the correctness of the solutions in the given samples data.
"""
if parallel < 1:
n_workers = max(1, multiprocessing.cpu_count() // 2)
else:
n_workers = parallel
if not no_gt:
expected_time = get_groundtruth()
else:
expected_time = {}
results = {
"date": datetime.now().strftime("%Y-%m-%d %H:%M"),
"eval": {},
}
with ProcessPoolExecutor(max_workers=n_workers) as executor:
futures = []
completion_id = Counter()
n_samples = 0
eval_results = defaultdict(list) # task_id ->
remainings = set()
for i, sample in enumerate(samples):
# TODO: investigate why HTTPException detail is not passed to client.
for key in ["task_id", "res_id", "test", "solution", "entry_point"]:
if key not in sample:
raise HTTPException(status_code=400, detail=f"'{key}' not in sample {i}!")
if not isinstance(sample["solution"], str):
raise HTTPException(status_code=400, detail="Solution must be a string!")
sample["_identifier"] = (
sample["task_id"] + f" (line {i+1} )"
)
task_id = sample["task_id"]
solution = sample["solution"]
if calibrate:
solution = sample["code_prompt"] + "\n pass\n" + solution
remainings.add(sample["_identifier"])
args = (
completion_id[task_id],
sample["res_id"],
task_id,
solution,
sample["test"],
sample["entry_point"],
max_as_limit,
max_data_limit,
max_stack_limit,
sample["_identifier"],
min_time_limit,
expected_time.get(task_id) if expected_time.get(task_id) else 20
)
futures.append(executor.submit(check_correctness, *args))
completion_id[task_id] += 1
n_samples += 1
assert n_samples == len(remainings), "Missing problems in unfinished"
#assert len(completion_id) == len(problems), "Missing problems in samples"
for future in as_completed(futures):
result = future.result()
remainings.remove(result["_identifier"])
eval_results[result["task_id"]].append(result)
del future, result
gc.collect()
# sort the results for each problem by completion_id
for task_id, task_results in eval_results.items():
task_results.sort(key=lambda x: x["completion_id"])
results["eval"][task_id] = []
for res in task_results:
stat, details = res["base"]
results["eval"][task_id].append(
{
"res_id": res["res_id"],
"task_id": task_id,
"solution": res["solution"],
"status": stat,
"details": details,
}
)
return results
return app
def check_correctness(
completion_id: int,
res_id: int,
task_id: str,
solution: str,
test: str,
entry_point: str,
max_as_limit: float,
max_data_limit: float,
max_stack_limit: float,
identifier=None,
min_time_limit: float = 0.1,
gt_time_limit: float = 2.0,
) -> Dict[str, Result]:
ret = {
"completion_id": completion_id,
"res_id": res_id,
"task_id": task_id,
"_identifier": identifier,
"solution": solution,
}
ret["base"] = untrusted_check(
solution,
test,
entry_point,
max_as_limit,
max_data_limit,
max_stack_limit,
min_time_limit,
gt_time_limit,
)
return ret
def get_groundtruth():
raise HTTPException(status_code=405, detail="Groundtruth execution is not implemented yet!")