File size: 5,396 Bytes
25db7e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import logging
import os
from collections import Counter, defaultdict
import multiprocessing
from datetime import datetime
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Dict, List, Tuple
import gc

from fastapi import FastAPI, HTTPException
from fastapi.responses import RedirectResponse

from api.code_execution import untrusted_check

Result = Tuple[str, List[bool]]

def create_app() -> FastAPI:

    level = os.environ.get("LOG_LEVEL", default=logging.INFO)
    logging.basicConfig(level=level)
    logger = logging.getLogger(__name__)

    app = FastAPI()

    @app.get("/")
    def root():
        return RedirectResponse("/docs")

    @app.get("/health", status_code=204)
    def health():
        return

    @app.post("/evaluate/")
    async def evaluate(
        samples: List[dict],
        calibrate: bool = True,
        parallel: int = -1,
        min_time_limit: float = 1,
        max_as_limit: int = 30 * 1024,
        max_data_limit: int = 30 * 1024,
        max_stack_limit: int = 10,
        no_gt: bool = True,
    ) -> dict:
        """
        Evaluate the correctness of the solutions in the given samples data.
        """
        if parallel < 1:
            n_workers = max(1, multiprocessing.cpu_count() // 2)
        else:
            n_workers = parallel

        if not no_gt:
            expected_time = get_groundtruth()
        else:
            expected_time = {}

        results = {
            "date": datetime.now().strftime("%Y-%m-%d %H:%M"),
            "eval": {},
        }

        with ProcessPoolExecutor(max_workers=n_workers) as executor:
            futures = []
            completion_id = Counter()
            n_samples = 0
            eval_results = defaultdict(list)  # task_id ->
            remainings = set()

            for i, sample in enumerate(samples):
                # TODO: investigate why HTTPException detail is not passed to client.

                for key in ["task_id", "res_id", "test", "solution", "entry_point"]:
                    if key not in sample:
                        raise HTTPException(status_code=400, detail=f"'{key}' not in sample {i}!")

                if not isinstance(sample["solution"], str):
                    raise HTTPException(status_code=400, detail="Solution must be a string!")

                sample["_identifier"] = (
                    sample["task_id"] + f" (line {i+1} )"
                )

                task_id = sample["task_id"]
                
                solution = sample["solution"]

                if calibrate:
                    solution = sample["code_prompt"] + "\n    pass\n" + solution
                remainings.add(sample["_identifier"])
                args = (
                    completion_id[task_id],
                    sample["res_id"],
                    task_id,
                    solution,
                    sample["test"],
                    sample["entry_point"],
                    max_as_limit,
                    max_data_limit,
                    max_stack_limit,
                    sample["_identifier"],
                    min_time_limit,
                    expected_time.get(task_id) if expected_time.get(task_id) else 20
                )
                futures.append(executor.submit(check_correctness, *args))
                completion_id[task_id] += 1
                n_samples += 1

            assert n_samples == len(remainings), "Missing problems in unfinished"
            #assert len(completion_id) == len(problems), "Missing problems in samples"

            for future in as_completed(futures):
                result = future.result()
                remainings.remove(result["_identifier"])
                eval_results[result["task_id"]].append(result)
                del future, result
                gc.collect()
        
        # sort the results for each problem by completion_id
        for task_id, task_results in eval_results.items():
            task_results.sort(key=lambda x: x["completion_id"])
            results["eval"][task_id] = []
            for res in task_results:
                stat, details = res["base"]
                results["eval"][task_id].append(
                    {
                        "res_id": res["res_id"],
                        "task_id": task_id,
                        "solution": res["solution"],
                        "status": stat,
                        "details": details,
                    }
                )
        return results

    return app

def check_correctness(
    completion_id: int,
    res_id: int,
    task_id: str,
    solution: str,
    test: str,
    entry_point: str,
    max_as_limit: float,
    max_data_limit: float,
    max_stack_limit: float,
    identifier=None,
    min_time_limit: float = 0.1,
    gt_time_limit: float = 2.0,
) -> Dict[str, Result]:  
    ret = {
        "completion_id": completion_id,
        "res_id": res_id,
        "task_id": task_id,
        "_identifier": identifier,
        "solution": solution,
    }
    ret["base"] = untrusted_check(
        solution,
        test,
        entry_point,
        max_as_limit,
        max_data_limit,
        max_stack_limit,
        min_time_limit,
        gt_time_limit,
    )
    return ret


def get_groundtruth():
    raise HTTPException(status_code=405, detail="Groundtruth execution is not implemented yet!")