File size: 2,945 Bytes
a1f5641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
import logging
import os
import time

from mteb import MTEB
from sentence_transformers import SentenceTransformer

logging.basicConfig(level=logging.INFO)

logger = logging.getLogger("main")

os.environ["HF_DATASETS_OFFLINE"] = "1"  # 1 for offline
os.environ["TRANSFORMERS_OFFLINE"] = "1"  # 1 for offline
os.environ["TRANSFORMERS_CACHE"] = "./transformers_cache/"
os.environ["HF_DATASETS_CACHE"] = "./hf_datasets_cache/"
os.environ["HF_MODULES_CACHE"] = "./hf_modules_cache/"
os.environ["HF_METRICS_CACHE"] = "./hf_metrics_cache/"
# os.environ["TOKENIZERS_PARALLELISM"] = "false"



TASK_LIST_CLUSTERING = [
    "ArxivClusteringP2P",
    "ArxivClusteringS2S",
    "BiorxivClusteringP2P",
    "BiorxivClusteringS2S",
    "MedrxivClusteringP2P",
    "MedrxivClusteringS2S",
    "RedditClustering",
    "RedditClusteringP2P",
    "StackExchangeClustering",
    "StackExchangeClusteringP2P",
    "TwentyNewsgroupsClustering",
]

TASK_LIST_PAIR_CLASSIFICATION = [
    "SprintDuplicateQuestions",
    "TwitterSemEval2015",
    "TwitterURLCorpus",
]

TASK_LIST = TASK_LIST_CLUSTERING + TASK_LIST_PAIR_CLASSIFICATION


def parse_args():
    # Parse command line arguments
    parser = argparse.ArgumentParser()
    # parser.add_argument("--startid", type=int)
    # parser.add_argument("--endid", type=int)

    parser.add_argument("--modelpath", type=str, default="./models/")
    parser.add_argument("--lang", type=str, default="en")
    parser.add_argument("--taskname", type=str, default=None)
    parser.add_argument("--batchsize", type=int, default=128)
    parser.add_argument("--device", type=str, default="mps")  # sorry :>
    args = parser.parse_args()
    return args


def main(args):
    """
    ex: python run_array.py --modelpath ./models/all-MiniLM-L6-v2
    """
    model = SentenceTransformer(args.modelpath, device=args.device)
    model_name = args.modelpath.split("/")[-1].split("_")[-1]
    if not model_name:
        print(f"Model name is empty. Make sure not to end modelpath with a /")
        return

    print(f"Running on {model._target_device} with model {model_name}.")

    for task in TASK_LIST:
        print("Running task: ", task)
        # this args. notation seems anti-pythonic
        evaluation = MTEB(tasks=[task], task_langs=[args.lang])
        retries = 5
        for attempt in range(retries):
            try:
                evaluation.run(model, output_folder=f"results/{model_name}", batch_size=args.batchsize, eval_splits=["test"])
                break
            except ConnectionError:
                if attempt < retries - 1:
                    print(f"Connection error occurred during task {task}. Waiting for 1 minute before retrying...")
                    time.sleep(60)
                else:
                    print(f"Failed to execute task {task} after {retries} attempts due to connection errors.")


if __name__ == "__main__":
    args = parse_args()
    main(args)