Nathan Habib commited on
Commit
d295afa
·
1 Parent(s): ef627e9

look for model type in request file

Browse files
src/auto_leaderboard/model_metadata_type.py CHANGED
@@ -1,5 +1,8 @@
1
  from dataclasses import dataclass
2
  from enum import Enum
 
 
 
3
  from typing import Dict, List
4
 
5
  from ..utils_display import AutoEvalColumn
@@ -9,6 +12,11 @@ class ModelInfo:
9
  name: str
10
  symbol: str # emoji
11
 
 
 
 
 
 
12
 
13
  class ModelType(Enum):
14
  PT = ModelInfo(name="pretrained", symbol="🟢")
@@ -526,21 +534,19 @@ def get_model_type(leaderboard_data: List[dict]):
526
  # Todo @clefourrier once requests are connected with results
527
  is_delta = False # (model_data["weight_type"] != "Original")
528
  # Stored information
529
- if model_data["model_name_for_query"] in TYPE_METADATA:
530
- model_data[AutoEvalColumn.model_type.name] = TYPE_METADATA[model_data["model_name_for_query"]].value.name
531
- model_data[AutoEvalColumn.model_type_symbol.name] = TYPE_METADATA[model_data["model_name_for_query"]].value.symbol + ("🔺" if is_delta else "")
532
- # Inferred from the name or the selected type
533
- elif model_data[AutoEvalColumn.model_type.name] == "pretrained" or any([i in model_data["model_name_for_query"] for i in ["pretrained"]]):
534
- model_data[AutoEvalColumn.model_type.name] = ModelType.PT.value.name
535
- model_data[AutoEvalColumn.model_type_symbol.name] = ModelType.PT.value.symbol + ("🔺" if is_delta else "")
536
- elif model_data[AutoEvalColumn.model_type.name] == "finetuned" or any([i in model_data["model_name_for_query"] for i in ["finetuned", "-ft-"]]):
537
- model_data[AutoEvalColumn.model_type.name] = ModelType.SFT.value.name
538
- model_data[AutoEvalColumn.model_type_symbol.name] = ModelType.SFT.value.symbol + ("🔺" if is_delta else "")
539
- elif model_data[AutoEvalColumn.model_type.name] == "with RL" or any([i in model_data["model_name_for_query"] for i in ["-rl-", "-rlhf-"]]):
540
- model_data[AutoEvalColumn.model_type.name] = ModelType.RL.value.name
541
- model_data[AutoEvalColumn.model_type_symbol.name] = ModelType.RL.value.symbol + ("🔺" if is_delta else "")
542
- else:
543
- model_data[AutoEvalColumn.model_type.name] = "N/A"
544
- model_data[AutoEvalColumn.model_type_symbol.name] = ("🔺" if is_delta else "")
545
-
546
 
 
1
  from dataclasses import dataclass
2
  from enum import Enum
3
+ import glob
4
+ import json
5
+ import os
6
  from typing import Dict, List
7
 
8
  from ..utils_display import AutoEvalColumn
 
12
  name: str
13
  symbol: str # emoji
14
 
15
+ model_type_symbols = {
16
+ "fine-tuned": "🔶",
17
+ "pretrained": "🟢",
18
+ "with RL": "🟦",
19
+ }
20
 
21
  class ModelType(Enum):
22
  PT = ModelInfo(name="pretrained", symbol="🟢")
 
534
  # Todo @clefourrier once requests are connected with results
535
  is_delta = False # (model_data["weight_type"] != "Original")
536
  # Stored information
537
+ request_file = os.path.join("eval-queue", model_data["model_name_for_query"] + "_eval_request_*" + ".json")
538
+ request_file = glob.glob(request_file)
539
+
540
+ try:
541
+ request_file = request_file[0]
542
+ with open(request_file, "r") as f:
543
+ request = json.load(f)
544
+ model_type = request["model_type"]
545
+ is_delta = request["weight_type"] != "Original"
546
+ model_data[AutoEvalColumn.model_type.name] = model_type
547
+ model_data[AutoEvalColumn.model_type_symbol.name] = model_type_symbols[model_type] + ("🔺" if is_delta else "")
548
+ except Exception:
549
+ if model_data["model_name_for_query"] in TYPE_METADATA:
550
+ model_data[AutoEvalColumn.model_type.name] = TYPE_METADATA[model_data["model_name_for_query"]].value.name
551
+ model_data[AutoEvalColumn.model_type_symbol.name] = TYPE_METADATA[model_data["model_name_for_query"]].value.symbol + ("🔺" if is_delta else "")
 
 
552