File size: 4,457 Bytes
08ae6c5
 
 
 
f72e694
9e82c5f
08ae6c5
1d6da9d
7798457
08ae6c5
eb2a0ba
45e5a75
67fde66
9e82c5f
08ae6c5
 
 
72bd0af
 
 
 
08ae6c5
 
 
 
9e82c5f
 
 
1fb69db
0462f2c
 
 
837a251
9e82c5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d6da9d
9e82c5f
 
 
 
 
 
 
 
 
 
 
 
 
 
67fde66
9e82c5f
 
 
 
 
 
1d6da9d
9e82c5f
08ae6c5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import json
import os
import logging
from datetime import datetime
from argparse import Namespace
import traceback

from lighteval.main_accelerate import main, EnvConfig, create_model_config, load_model
from src.envs import RESULTS_REPO, CACHE_PATH, TOKEN, OWNER
from src.backend.manage_requests import EvalRequest
from lighteval.logging.evaluation_tracker import EnhancedJSONEncoder
from lighteval.models.model_loader import ModelInfo
from huggingface_hub.errors import InferenceEndpointTimeoutError
from huggingface_hub import HfApi

logging.getLogger("openai").setLevel(logging.WARNING)

class DefaultNamespace(Namespace):
    def __getattr__(self, name):
        return self.__dict__.get(name, None)

def run_evaluation(eval_request: EvalRequest, task_names: str, batch_size: int, local_dir: str, accelerator: str, region: str, vendor: str, instance_size: str, instance_type: str, limit=None):
    if limit:
        print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")

    api = HfApi(token=TOKEN)

    completed = False
    img_versions = ['2.1.0', '2.0.2', '1.4.5']
    if 'gemma-2' in eval_request.model:
        instance_size = 'x2'
        instance_type = 'nvidia-a100'
    for img_version in img_versions:
        args = DefaultNamespace(**{
                "model_config": dict(model=dict(
                    type="endpoint",
                    base_params=dict(
                        endpoint_name=f'{eval_request.model.split("/")[1].replace(".", "-").replace("_", "-").lower()}-lighteval'[-32:].strip('-'),
                        model=eval_request.model,
                        revision=eval_request.revision,
                        dtype=eval_request.precision,
                        reuse_existing=False
                    ),
                    instance=dict(
                        accelerator=accelerator,
                        region=region,
                        vendor=vendor,
                        instance_size=instance_size,
                        instance_type=instance_type,
                        framework='pytorch',
                        endpoint_type='protected',
                        namespace=OWNER,
                        image_url='ghcr.io/huggingface/text-generation-inference:' + img_version
                    ),
                    generation=dict(
                        add_special_tokens=True
                    )
                )),
                "max_samples": limit,
                "job_id": str(datetime.now()),
                "push_results_to_hub": True,
                "save_details": False,
                "push_details_to_hub": False,
                "public_run": False,
                "cache_dir": CACHE_PATH,
                "results_org": OWNER,
                "output_dir": local_dir,
                "override_batch_size": batch_size,
                "custom_tasks": "custom_tasks.py",
                "tasks": task_names,
                "dataset_loading_processes": 24,
                "num_fewshot_seeds": 0
        })


        try:
            # in case of timeout, try it again with reuse_existing
            for i in range(3):
                try:
                    results = main(args)
                    completed = True # success! 

                    dumped = json.dumps(results, cls=EnhancedJSONEncoder, indent=2)
                    print(dumped)

                    # if we are i>0, then raise an error so that we call clean up 
                    if i > 0: raise Exception() 
                    break # no need to loop twice if we completed
                except InferenceEndpointTimeoutError:
                    if i < 3: 
                        print('Timed out, trying again...')
                        args.model_config['model']['base_params']['reuse_existing'] = True
                    # loop around and try again, for timeout

        except Exception as ex: # if eval failed, we force a cleanup
            traceback.print_exception(ex)
            try:
                api.delete_inference_endpoint(
                        name=args.model_config['model']['base_params']['endpoint_name'],
                        namespace=args.model_config['model']['instance']['namespace']
                )
            except Exception as ex:
                traceback.print_exception(ex)

        if completed: break # no need to try with a different image version

    return results