joaogante HF staff commited on
Commit
0b94c41
β€’
1 Parent(s): 6ba93dd
Files changed (2) hide show
  1. app.py +1 -152
  2. medusa_training.py +152 -0
app.py CHANGED
@@ -1,15 +1,7 @@
1
- import json
2
- import os
3
- import multiprocessing as mp
4
-
5
  from git import Repo
6
  import gradio as gr
7
- from huggingface_hub import HfApi
8
- from huggingface_hub.utils import RepositoryNotFoundError
9
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
10
- import torch
11
- import torch.distributed.run as distributed_run
12
 
 
13
 
14
  # Clone the medusa repo locally
15
  print("Cloning the medusa repo locally...")
@@ -18,149 +10,6 @@ print("Cloning the vicuna data locally...")
18
  Repo.clone_from("https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered", "data")
19
  print("Done")
20
 
21
- OUTPUT_DIR = "medusa_heads"
22
- MEDUSA_NUM_HEADS = 3
23
- MEDUSA_NUM_LAYERS = 1
24
- LR = 1e-3
25
-
26
- DATASET = "vicuna"
27
-
28
- # These can't be changed (e.g. they control the output path)
29
- FIXED_TRAINING_ARGS = \
30
- """medusa/medusa/train/train.py
31
- --model_name_or_path {model_id}
32
- --output_dir {output_dir}
33
- --run_name {model_id}-medusa-{dataset}
34
- --medusa_num_heads {medusa_num_heads}
35
- --medusa_num_layers {medusa_num_layers}
36
- --learning_rate {lr}
37
- --data_path data/ShareGPT_V4.3_unfiltered_cleaned_split.json"""
38
-
39
- # These can be freely changed
40
- DEFAULT_TRAINING_ARGS = \
41
- """--bf16 True
42
- --num_train_epochs 1
43
- --per_device_train_batch_size 64
44
- --per_device_eval_batch_size 64
45
- --gradient_accumulation_steps 4
46
- --evaluation_strategy no
47
- --save_strategy no
48
- --weight_decay 0.0
49
- --warmup_ratio 0.1
50
- --lr_scheduler_type cosine
51
- --logging_steps 10
52
- --tf32 True
53
- --model_max_length 2048
54
- --lazy_preprocess True
55
- --auto_find_batch_size True"""
56
-
57
-
58
- def train_medusa_heads(model_id: str, training_args: str):
59
- all_training_args = FIXED_TRAINING_ARGS.format(
60
- model_id=model_id,
61
- output_dir=OUTPUT_DIR,
62
- dataset=DATASET,
63
- medusa_num_heads=MEDUSA_NUM_HEADS,
64
- lr=LR,
65
- medusa_num_layers=MEDUSA_NUM_LAYERS
66
- ) + "\n" + training_args
67
- all_training_arg_list = []
68
- for arg in all_training_args.split("\n"):
69
- all_training_arg_list += arg.split(" ")
70
- print("Full argument list:", all_training_arg_list)
71
-
72
- parser = distributed_run.get_args_parser()
73
- args = parser.parse_args(all_training_arg_list)
74
- distributed_run.run(args)
75
-
76
-
77
- def run(model_id: str, training_args: str) -> str:
78
- print(f"\n\n\nNEW RUN: {model_id}")
79
- api = HfApi()
80
- model_name = model_id.split("/")[-1]
81
- repo_id = f"joaogante/{model_name}-medusa-{DATASET}"
82
-
83
- # Input validation
84
- if model_id == "":
85
- return """
86
- ### Invalid input 🐞
87
-
88
- Please fill a model_id.
89
- """
90
- if api.repo_exists(repo_id):
91
- return f"""
92
- ### Invalid input 🐞
93
-
94
- {repo_id} already exists, which means that {model_id} has already been used to create medusa heads.
95
- """
96
- print(f"Valid inputs βœ…\nValidating model_id: {model_id}")
97
-
98
- # Attempt to load the base model
99
- try:
100
- config = AutoConfig.from_pretrained(model_id)
101
- tokenizer = AutoTokenizer.from_pretrained(model_id)
102
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
103
- del config, tokenizer, model
104
- except Exception as e:
105
- return f"""
106
- ### {model_id} can't be loaded with AutoClasses 🐞
107
-
108
- {e}
109
- """
110
- print(f"{model_id} can be loaded βœ…\nCreating medusa heads (will take a few hours)")
111
-
112
- # Run the medusa heads creation
113
- try:
114
- proc = mp.Process(target=train_medusa_heads, args=(model_id, training_args))
115
- proc.start()
116
- proc.join()
117
- print("Medusa heads training process completed (it might have crashed!)")
118
- except Exception as e:
119
- print("Error ❌\n", e)
120
- return f"""
121
- ### Error 😒😒😒
122
-
123
- {e}
124
- """
125
-
126
- # Upload the medusa heads to the Hub
127
- try:
128
- # Folder path from https://github.com/FasterDecoding/Medusa/blob/main/medusa/train/train.py#L399
129
- folder_path = (
130
- f"{OUTPUT_DIR}_medusa_mlp_{model_name}_medusa_{MEDUSA_NUM_HEADS}_lr_{LR}_layers_{MEDUSA_NUM_LAYERS}"
131
- )
132
- if not any([x for x in os.listdir(folder_path) if len(x) >= 3 and x[-3:] == ".pt"]):
133
- raise Exception(
134
- "No model data in the expected model folder, the traning run probably failed. Check the logs for more "
135
- "information."
136
- )
137
-
138
- api.create_repo(
139
- repo_id=repo_id,
140
- exist_ok=True,
141
- )
142
- api.upload_folder(
143
- folder_path=folder_path,
144
- repo_id=repo_id,
145
- )
146
- print("Medusa heads upload success βœ…\n Uploaded to: ", repo_id)
147
- return f"""
148
- ### Success πŸ”₯
149
-
150
- Yay! Medusa heads were successfully created and uploaded to the following repo: {repo_id}
151
- """
152
- except Exception as e:
153
- print("Error ❌\n", e)
154
- try:
155
- api.delete_repo(repo_id)
156
- except RepositoryNotFoundError:
157
- pass
158
- return f"""
159
- ### Error 😒😒😒
160
-
161
- {e}
162
- """
163
-
164
 
165
  DESCRIPTION = """
166
  The steps to create [medusa](https://sites.google.com/view/medusa-llm) heads are the following:
 
 
 
 
 
1
  from git import Repo
2
  import gradio as gr
 
 
 
 
 
3
 
4
+ from medusa_training import run, DEFAULT_TRAINING_ARGS
5
 
6
  # Clone the medusa repo locally
7
  print("Cloning the medusa repo locally...")
 
10
  Repo.clone_from("https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered", "data")
11
  print("Done")
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  DESCRIPTION = """
15
  The steps to create [medusa](https://sites.google.com/view/medusa-llm) heads are the following:
medusa_training.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import multiprocessing as mp
4
+
5
+ from huggingface_hub import HfApi
6
+ from huggingface_hub.utils import RepositoryNotFoundError
7
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
8
+ import torch
9
+ import torch.distributed.run as distributed_run
10
+
11
+ OUTPUT_DIR = "medusa_heads"
12
+ MEDUSA_NUM_HEADS = 3
13
+ MEDUSA_NUM_LAYERS = 1
14
+ LR = 1e-3
15
+
16
+ DATASET = "vicuna"
17
+
18
+ # These can't be changed (e.g. they control the output path)
19
+ FIXED_TRAINING_ARGS = \
20
+ """medusa/medusa/train/train.py
21
+ --model_name_or_path {model_id}
22
+ --output_dir {output_dir}
23
+ --run_name {model_id}-medusa-{dataset}
24
+ --medusa_num_heads {medusa_num_heads}
25
+ --medusa_num_layers {medusa_num_layers}
26
+ --learning_rate {lr}
27
+ --data_path data/ShareGPT_V4.3_unfiltered_cleaned_split.json"""
28
+
29
+ # These can be freely changed
30
+ DEFAULT_TRAINING_ARGS = \
31
+ """--bf16 True
32
+ --num_train_epochs 1
33
+ --per_device_train_batch_size 64
34
+ --per_device_eval_batch_size 64
35
+ --gradient_accumulation_steps 4
36
+ --evaluation_strategy no
37
+ --save_strategy no
38
+ --weight_decay 0.0
39
+ --warmup_ratio 0.1
40
+ --lr_scheduler_type cosine
41
+ --logging_steps 10
42
+ --tf32 True
43
+ --model_max_length 2048
44
+ --lazy_preprocess True
45
+ --auto_find_batch_size True"""
46
+
47
+
48
+ def train_medusa_heads(model_id: str, training_args: str):
49
+ all_training_args = FIXED_TRAINING_ARGS.format(
50
+ model_id=model_id,
51
+ output_dir=OUTPUT_DIR,
52
+ dataset=DATASET,
53
+ medusa_num_heads=MEDUSA_NUM_HEADS,
54
+ lr=LR,
55
+ medusa_num_layers=MEDUSA_NUM_LAYERS
56
+ ) + "\n" + training_args
57
+ all_training_arg_list = []
58
+ for arg in all_training_args.split("\n"):
59
+ all_training_arg_list += arg.split(" ")
60
+ print("Full argument list:", all_training_arg_list)
61
+
62
+ parser = distributed_run.get_args_parser()
63
+ args = parser.parse_args(all_training_arg_list)
64
+ distributed_run.run(args)
65
+
66
+
67
+ def run(model_id: str, training_args: str) -> str:
68
+ print(f"\n\n\nNEW RUN: {model_id}")
69
+ api = HfApi()
70
+ model_name = model_id.split("/")[-1]
71
+ repo_id = f"joaogante/{model_name}-medusa-{DATASET}"
72
+
73
+ # Input validation
74
+ if model_id == "":
75
+ return """
76
+ ### Invalid input 🐞
77
+
78
+ Please fill a model_id.
79
+ """
80
+ if api.repo_exists(repo_id):
81
+ return f"""
82
+ ### Invalid input 🐞
83
+
84
+ {repo_id} already exists, which means that {model_id} has already been used to create medusa heads.
85
+ """
86
+ print(f"Valid inputs βœ…\nValidating model_id: {model_id}")
87
+
88
+ # Attempt to load the base model
89
+ try:
90
+ config = AutoConfig.from_pretrained(model_id)
91
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
92
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
93
+ del config, tokenizer, model
94
+ except Exception as e:
95
+ return f"""
96
+ ### {model_id} can't be loaded with AutoClasses 🐞
97
+
98
+ {e}
99
+ """
100
+ print(f"{model_id} can be loaded βœ…\nCreating medusa heads (will take a few hours)")
101
+
102
+ # Run the medusa heads creation
103
+ try:
104
+ proc = mp.Process(target=train_medusa_heads, args=(model_id, training_args))
105
+ proc.start()
106
+ proc.join()
107
+ print("Medusa heads training process completed (it might have crashed!)")
108
+ except Exception as e:
109
+ print("Error ❌\n", e)
110
+ return f"""
111
+ ### Error 😒😒😒
112
+
113
+ {e}
114
+ """
115
+
116
+ # Upload the medusa heads to the Hub
117
+ try:
118
+ # Folder path from https://github.com/FasterDecoding/Medusa/blob/main/medusa/train/train.py#L399
119
+ folder_path = (
120
+ f"{OUTPUT_DIR}_medusa_mlp_{model_name}_medusa_{MEDUSA_NUM_HEADS}_lr_{LR}_layers_{MEDUSA_NUM_LAYERS}"
121
+ )
122
+ if not any([x for x in os.listdir(folder_path) if len(x) >= 3 and x[-3:] == ".pt"]):
123
+ raise Exception(
124
+ "No model data in the expected model folder, the traning run probably failed. Check the logs for more "
125
+ "information."
126
+ )
127
+
128
+ api.create_repo(
129
+ repo_id=repo_id,
130
+ exist_ok=True,
131
+ )
132
+ api.upload_folder(
133
+ folder_path=folder_path,
134
+ repo_id=repo_id,
135
+ )
136
+ print("Medusa heads upload success βœ…\n Uploaded to: ", repo_id)
137
+ return f"""
138
+ ### Success πŸ”₯
139
+
140
+ Yay! Medusa heads were successfully created and uploaded to the following repo: {repo_id}
141
+ """
142
+ except Exception as e:
143
+ print("Error ❌\n", e)
144
+ try:
145
+ api.delete_repo(repo_id)
146
+ except RepositoryNotFoundError:
147
+ pass
148
+ return f"""
149
+ ### Error 😒😒😒
150
+
151
+ {e}
152
+ """