qMTEB / run_mteb.py
varun4's picture
added results
fba41a4
raw
history blame
3.03 kB
import argparse
import logging
import os
import time
from mteb import MTEB
from sentence_transformers import SentenceTransformer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("main")
os.environ["HF_DATASETS_OFFLINE"] = "1" # 1 for offline
os.environ["TRANSFORMERS_OFFLINE"] = "1" # 1 for offline
os.environ["TRANSFORMERS_CACHE"] = "./transformers_cache/"
os.environ["HF_DATASETS_CACHE"] = "./hf_datasets_cache/"
os.environ["HF_MODULES_CACHE"] = "./hf_modules_cache/"
os.environ["HF_METRICS_CACHE"] = "./hf_metrics_cache/"
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
TASK_LIST_CLUSTERING = [
"ArxivClusteringP2P",
"ArxivClusteringS2S",
"BiorxivClusteringP2P",
"BiorxivClusteringS2S",
"MedrxivClusteringP2P",
"MedrxivClusteringS2S",
"RedditClustering",
"RedditClusteringP2P",
"StackExchangeClustering",
"StackExchangeClusteringP2P",
"TwentyNewsgroupsClustering",
]
TASK_LIST_PAIR_CLASSIFICATION = [
"SprintDuplicateQuestions",
"TwitterSemEval2015",
"TwitterURLCorpus",
]
TASK_LIST = TASK_LIST_CLUSTERING + TASK_LIST_PAIR_CLASSIFICATION
def parse_args():
# Parse command line arguments
parser = argparse.ArgumentParser()
# parser.add_argument("--startid", type=int)
# parser.add_argument("--endid", type=int)
parser.add_argument("--modelpath", type=str, default="./models/")
parser.add_argument("--lang", type=str, default="en")
parser.add_argument("--taskname", type=str, default=None)
parser.add_argument("--batchsize", type=int, default=128)
parser.add_argument("--device", type=str, default="mps")
args = parser.parse_args()
return args
def main(args):
"""
ex: python run_mteb.py --modelpath ./models/all-MiniLM-L6-v2-unquantized
Optimum/onnx models need to contain string "opt" somewhere in model path or name.
"""
model = SentenceTransformer(args.modelpath, device=args.device)
model_name = args.modelpath.split("/")[-1].split("_")[-1]
if not model_name:
print(f"Model name is empty. Make sure not to end modelpath with a /")
return
print(f"Running on {model._target_device} with model {model_name}.")
for task in TASK_LIST:
print("Running task: ", task)
# this args. notation seems anti-pythonic
evaluation = MTEB(tasks=[task], task_langs=[args.lang])
retries = 5
for attempt in range(retries):
try:
evaluation.run(model, output_folder=f"results/{model_name}", batch_size=args.batchsize, eval_splits=["test"])
break
except ConnectionError:
if attempt < retries - 1:
print(f"Connection error occurred during task {task}. Waiting for 1 minute before retrying...")
time.sleep(60)
else:
print(f"Failed to execute task {task} after {retries} attempts due to connection errors.")
if __name__ == "__main__":
args = parse_args()
main(args)