Spaces:
Paused
Paused
datasets refactor
Browse files- .gitignore +169 -0
- app.py +15 -13
- medusa_heads_medusa_TinyLlama-1.1B-Chat-v1.0/config.json +6 -0
- requirements.txt +0 -1
- src/calibration_datasets.py +603 -0
- src/medusa_training_script.py +269 -0
- medusa_training.py β src/train_workflow.py +17 -24
.gitignore
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Initially taken from Github's Python gitignore file
|
2 |
+
|
3 |
+
# Byte-compiled / optimized / DLL files
|
4 |
+
__pycache__/
|
5 |
+
*.py[cod]
|
6 |
+
*$py.class
|
7 |
+
|
8 |
+
# C extensions
|
9 |
+
*.so
|
10 |
+
|
11 |
+
# tests and logs
|
12 |
+
tests/fixtures/cached_*_text.txt
|
13 |
+
logs/
|
14 |
+
lightning_logs/
|
15 |
+
lang_code_data/
|
16 |
+
|
17 |
+
# Distribution / packaging
|
18 |
+
.Python
|
19 |
+
build/
|
20 |
+
develop-eggs/
|
21 |
+
dist/
|
22 |
+
downloads/
|
23 |
+
eggs/
|
24 |
+
.eggs/
|
25 |
+
lib/
|
26 |
+
lib64/
|
27 |
+
parts/
|
28 |
+
sdist/
|
29 |
+
var/
|
30 |
+
wheels/
|
31 |
+
*.egg-info/
|
32 |
+
.installed.cfg
|
33 |
+
*.egg
|
34 |
+
MANIFEST
|
35 |
+
|
36 |
+
# PyInstaller
|
37 |
+
# Usually these files are written by a python script from a template
|
38 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
39 |
+
*.manifest
|
40 |
+
*.spec
|
41 |
+
|
42 |
+
# Installer logs
|
43 |
+
pip-log.txt
|
44 |
+
pip-delete-this-directory.txt
|
45 |
+
|
46 |
+
# Unit test / coverage reports
|
47 |
+
htmlcov/
|
48 |
+
.tox/
|
49 |
+
.nox/
|
50 |
+
.coverage
|
51 |
+
.coverage.*
|
52 |
+
.cache
|
53 |
+
nosetests.xml
|
54 |
+
coverage.xml
|
55 |
+
*.cover
|
56 |
+
.hypothesis/
|
57 |
+
.pytest_cache/
|
58 |
+
|
59 |
+
# Translations
|
60 |
+
*.mo
|
61 |
+
*.pot
|
62 |
+
|
63 |
+
# Django stuff:
|
64 |
+
*.log
|
65 |
+
local_settings.py
|
66 |
+
db.sqlite3
|
67 |
+
|
68 |
+
# Flask stuff:
|
69 |
+
instance/
|
70 |
+
.webassets-cache
|
71 |
+
|
72 |
+
# Scrapy stuff:
|
73 |
+
.scrapy
|
74 |
+
|
75 |
+
# Sphinx documentation
|
76 |
+
docs/_build/
|
77 |
+
|
78 |
+
# PyBuilder
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# pyenv
|
89 |
+
.python-version
|
90 |
+
|
91 |
+
# celery beat schedule file
|
92 |
+
celerybeat-schedule
|
93 |
+
|
94 |
+
# SageMath parsed files
|
95 |
+
*.sage.py
|
96 |
+
|
97 |
+
# Environments
|
98 |
+
.env
|
99 |
+
.venv
|
100 |
+
env/
|
101 |
+
venv/
|
102 |
+
ENV/
|
103 |
+
env.bak/
|
104 |
+
venv.bak/
|
105 |
+
|
106 |
+
# Spyder project settings
|
107 |
+
.spyderproject
|
108 |
+
.spyproject
|
109 |
+
|
110 |
+
# Rope project settings
|
111 |
+
.ropeproject
|
112 |
+
|
113 |
+
# mkdocs documentation
|
114 |
+
/site
|
115 |
+
|
116 |
+
# mypy
|
117 |
+
.mypy_cache/
|
118 |
+
.dmypy.json
|
119 |
+
dmypy.json
|
120 |
+
|
121 |
+
# Pyre type checker
|
122 |
+
.pyre/
|
123 |
+
|
124 |
+
# vscode
|
125 |
+
.vs
|
126 |
+
.vscode
|
127 |
+
|
128 |
+
# Pycharm
|
129 |
+
.idea
|
130 |
+
|
131 |
+
# TF code
|
132 |
+
tensorflow_code
|
133 |
+
|
134 |
+
# Models
|
135 |
+
proc_data
|
136 |
+
|
137 |
+
# examples
|
138 |
+
runs
|
139 |
+
/runs_old
|
140 |
+
/wandb
|
141 |
+
/examples/runs
|
142 |
+
/examples/**/*.args
|
143 |
+
/examples/rag/sweep
|
144 |
+
|
145 |
+
# data
|
146 |
+
/data
|
147 |
+
serialization_dir
|
148 |
+
|
149 |
+
# emacs
|
150 |
+
*.*~
|
151 |
+
debug.env
|
152 |
+
|
153 |
+
# vim
|
154 |
+
.*.swp
|
155 |
+
|
156 |
+
#ctags
|
157 |
+
tags
|
158 |
+
|
159 |
+
# pre-commit
|
160 |
+
.pre-commit*
|
161 |
+
|
162 |
+
# .lock
|
163 |
+
*.lock
|
164 |
+
|
165 |
+
# DS_Store (MacOS)
|
166 |
+
.DS_Store
|
167 |
+
|
168 |
+
# ruff
|
169 |
+
.ruff_cache
|
app.py
CHANGED
@@ -1,22 +1,20 @@
|
|
1 |
-
|
2 |
-
|
|
|
3 |
|
4 |
-
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
Repo.clone_from("https://github.com/FasterDecoding/Medusa.git", "medusa")
|
9 |
-
print("Cloning the vicuna data locally...")
|
10 |
-
Repo.clone_from("https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered", "data")
|
11 |
-
print("Done")
|
12 |
|
13 |
|
14 |
DESCRIPTION = """
|
15 |
The steps to create [medusa](https://sites.google.com/view/medusa-llm) heads are the following:
|
16 |
|
17 |
1. Input a public model id from the Hub
|
18 |
-
2.
|
19 |
-
3.
|
|
|
20 |
"""
|
21 |
|
22 |
title="Create LLM medusa heads in a new repo π"
|
@@ -28,8 +26,12 @@ with gr.Blocks(title=title) as demo:
|
|
28 |
with gr.Row() as r:
|
29 |
with gr.Column() as c:
|
30 |
model_id = gr.Text(max_lines=1, label="model_id")
|
|
|
|
|
|
|
|
|
31 |
with gr.Accordion("Training arguments (advanced)", open=False):
|
32 |
-
training_args = gr.Textbox(DEFAULT_TRAINING_ARGS, interactive=True, lines=
|
33 |
with gr.Row() as c:
|
34 |
clean = gr.ClearButton()
|
35 |
submit = gr.Button("Submit", variant="primary")
|
@@ -37,6 +39,6 @@ with gr.Blocks(title=title) as demo:
|
|
37 |
with gr.Column() as d:
|
38 |
status_box = gr.Markdown()
|
39 |
|
40 |
-
submit.click(run, inputs=[model_id, training_args], outputs=status_box, concurrency_limit=1)
|
41 |
|
42 |
demo.queue(max_size=10).launch(show_api=True)
|
|
|
1 |
+
"""
|
2 |
+
Holds the gradio app itself
|
3 |
+
"""
|
4 |
|
5 |
+
import gradio as gr
|
6 |
|
7 |
+
from src.train_workflow import run, DEFAULT_TRAINING_ARGS
|
8 |
+
from src.calibration_datasets import CalibrationDataset
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
DESCRIPTION = """
|
12 |
The steps to create [medusa](https://sites.google.com/view/medusa-llm) heads are the following:
|
13 |
|
14 |
1. Input a public model id from the Hub
|
15 |
+
2. Select a dataset to train the medusa heads on. The dataset should be representative of the downstream use case.
|
16 |
+
3. Click "Submit"
|
17 |
+
4. That's it! You'll get feedback if it works or not, and if it worked, you'll get the name of the new repo π₯
|
18 |
"""
|
19 |
|
20 |
title="Create LLM medusa heads in a new repo π"
|
|
|
26 |
with gr.Row() as r:
|
27 |
with gr.Column() as c:
|
28 |
model_id = gr.Text(max_lines=1, label="model_id")
|
29 |
+
dataset_names = [
|
30 |
+
cls.dataset for cls in CalibrationDataset.__subclasses__()
|
31 |
+
]
|
32 |
+
dataset = gr.Dropdown(dataset_names, label="dataset")
|
33 |
with gr.Accordion("Training arguments (advanced)", open=False):
|
34 |
+
training_args = gr.Textbox(DEFAULT_TRAINING_ARGS, interactive=True, lines=20, label="training_args")
|
35 |
with gr.Row() as c:
|
36 |
clean = gr.ClearButton()
|
37 |
submit = gr.Button("Submit", variant="primary")
|
|
|
39 |
with gr.Column() as d:
|
40 |
status_box = gr.Markdown()
|
41 |
|
42 |
+
submit.click(run, inputs=[model_id, training_args, dataset], outputs=status_box, concurrency_limit=1)
|
43 |
|
44 |
demo.queue(max_size=10).launch(show_api=True)
|
medusa_heads_medusa_TinyLlama-1.1B-Chat-v1.0/config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"base_model_name_or_path": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
3 |
+
"medusa_num_heads": 3,
|
4 |
+
"medusa_num_layers": 1,
|
5 |
+
"transformers_version": "4.37.0.dev0"
|
6 |
+
}
|
requirements.txt
CHANGED
@@ -1,2 +1 @@
|
|
1 |
medusa-llm[train]
|
2 |
-
gitpython
|
|
|
1 |
medusa-llm[train]
|
|
src/calibration_datasets.py
ADDED
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Prepares the datasets for calibration. Original code gently shared by TheBloke"""
|
2 |
+
|
3 |
+
from abc import ABC
|
4 |
+
import time
|
5 |
+
from typing import Dict, List, Optional
|
6 |
+
from datasets import load_dataset, Dataset
|
7 |
+
from transformers import PreTrainedTokenizerBase
|
8 |
+
|
9 |
+
|
10 |
+
class CalibrationDataset(ABC):
|
11 |
+
tokenizer: Optional[PreTrainedTokenizerBase] = None
|
12 |
+
num_samples: int = 128
|
13 |
+
seqlen: int = 4096
|
14 |
+
dataset_config: dict
|
15 |
+
dataset: str
|
16 |
+
dataset_name: str
|
17 |
+
dataset_limit: int = int(1e7)
|
18 |
+
|
19 |
+
# Defines the field to extract from the HF dataset
|
20 |
+
# If specified, just this field will be returned, and no transformation will be done.
|
21 |
+
dataset_field: Optional[str] = None
|
22 |
+
|
23 |
+
# Define the default parameters for a dataset which requires a transformation
|
24 |
+
# Only used if dataset_field is None.
|
25 |
+
# The fields to extract from the original dataset
|
26 |
+
transform_fields: List[str] = []
|
27 |
+
|
28 |
+
# A format string describing how the fields should be joined
|
29 |
+
# Can use {field1}, {field2}, etc. as placeholders for the field names
|
30 |
+
# Or can use actual names, eg "{input} {output}"
|
31 |
+
transform_join: str = "{field1} {field2}"
|
32 |
+
|
33 |
+
# Optional override for the dataset URL
|
34 |
+
# By default this is automatically derived from the dataset name and config
|
35 |
+
dataset_url: Optional[str] = None
|
36 |
+
|
37 |
+
data: Optional[Dataset] = None
|
38 |
+
samples: List[str] = []
|
39 |
+
tokenized_samples: List[Dict[str, str]] = {}
|
40 |
+
|
41 |
+
randomize: bool = False
|
42 |
+
randomize_seed: int = 42
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
num_samples: int = 128,
|
47 |
+
seqlen: int = 4096,
|
48 |
+
tokenizer: Optional[PreTrainedTokenizerBase] = None
|
49 |
+
):
|
50 |
+
self.num_samples = num_samples
|
51 |
+
self.seqlen = seqlen
|
52 |
+
self.tokenizer = tokenizer
|
53 |
+
|
54 |
+
@classmethod
|
55 |
+
def get_dataset(cls, dataset_name, **kwargs):
|
56 |
+
for subclass in cls.__subclasses__():
|
57 |
+
if hasattr(subclass, "dataset") and subclass.dataset == dataset_name:
|
58 |
+
return subclass(**kwargs)
|
59 |
+
|
60 |
+
raise ValueError(f"No dataset class found for name: {dataset_name}")
|
61 |
+
|
62 |
+
def tokenize_dataset(self, samples: Optional[List[str]] = None) -> List[Dict[str, int]]:
|
63 |
+
"""
|
64 |
+
Tokenize the dataset and return a list of tokens of `seqlen` length
|
65 |
+
|
66 |
+
First tokenize the List[str] of samples, as a batch.
|
67 |
+
|
68 |
+
Then flatten the batch, and split it into `num_samples` rows of `seqlen` length.
|
69 |
+
"""
|
70 |
+
if not self.tokenizer:
|
71 |
+
raise ValueError("No tokenizer provided to tokenize_dataset()")
|
72 |
+
else:
|
73 |
+
if not samples:
|
74 |
+
if not self.samples:
|
75 |
+
self.get_samples()
|
76 |
+
samples = self.samples
|
77 |
+
|
78 |
+
print(f"Tokenizing {self.dataset_name} of length {len(samples)}")
|
79 |
+
|
80 |
+
start_time = time.time()
|
81 |
+
# Tokenize the list of samples. We don't use return_tensors="pt",
|
82 |
+
# as that requires the samples to be the same length, or padding to be used.
|
83 |
+
tokenized = self.tokenizer(samples)
|
84 |
+
|
85 |
+
# Output of tokenizer will be:
|
86 |
+
# {"input_ids": [[1,2,3], [4,5], [6,7]], "attention_mask": [[1,1,1], [1,1], [1,1]]}
|
87 |
+
# Flatten that so as to concatenate the samples into a single input_mask and attention_mask
|
88 |
+
flattened = {
|
89 |
+
key: [
|
90 |
+
item for sublist in value
|
91 |
+
for item in sublist
|
92 |
+
]
|
93 |
+
for key, value in tokenized.items()
|
94 |
+
}
|
95 |
+
print(
|
96 |
+
f"Tokenized length: {len(flattened['input_ids'])} tokens."
|
97 |
+
)
|
98 |
+
|
99 |
+
# Slice our single input_mask list into num_samples samples of seqlen length
|
100 |
+
tokenized_samples = []
|
101 |
+
for i in range(0, self.num_samples * self.seqlen, self.seqlen):
|
102 |
+
if i + self.seqlen >= len(flattened["input_ids"]):
|
103 |
+
break
|
104 |
+
sample = {
|
105 |
+
"input_ids": flattened["input_ids"][i:i + self.seqlen],
|
106 |
+
"attention_mask": flattened["attention_mask"][i:i + self.seqlen]
|
107 |
+
}
|
108 |
+
tokenized_samples.append(sample)
|
109 |
+
|
110 |
+
print(
|
111 |
+
f"Return {len(tokenized_samples)} samples of {self.seqlen} length. "
|
112 |
+
f"Time taken: {time.time() - start_time:.2f}s."
|
113 |
+
)
|
114 |
+
self.tokenized_samples = tokenized_samples
|
115 |
+
return self.tokenized_samples
|
116 |
+
|
117 |
+
def get_hf_dataset(
|
118 |
+
self,
|
119 |
+
path: str,
|
120 |
+
limit: Optional[int] = None,
|
121 |
+
**kwargs
|
122 |
+
) -> Dataset:
|
123 |
+
"""Load the Hugging Face dataset at `path`, using the provided kwargs."""
|
124 |
+
|
125 |
+
print(f"Loading HF dataset {path} with params: {kwargs}")
|
126 |
+
data: Dataset = load_dataset(path=path, **kwargs)
|
127 |
+
|
128 |
+
limit = limit and min(limit, len(data)) or len(data)
|
129 |
+
return data.select(range(limit))
|
130 |
+
|
131 |
+
@staticmethod
|
132 |
+
def list_with_nls(samples: List[str]) -> List[str]:
|
133 |
+
"""
|
134 |
+
Return a List[str] with each sample ending in a newline.
|
135 |
+
|
136 |
+
Also filters the list by stripping, then removing any empty samples.
|
137 |
+
"""
|
138 |
+
return [
|
139 |
+
x.rstrip() + '\n'
|
140 |
+
for x in samples
|
141 |
+
if x and len(x.strip()) > 0
|
142 |
+
]
|
143 |
+
|
144 |
+
def get_samples(self) -> List[str]:
|
145 |
+
"""
|
146 |
+
Return a list of samples for the dataset.
|
147 |
+
|
148 |
+
If the subclass implements `dataset_field`, this is used to filter the HF Dataset.
|
149 |
+
|
150 |
+
Otherwise, the subclass must implement `process_samples()`, for custom filtering.
|
151 |
+
|
152 |
+
Samples are returned as a List[str], each ending in a newline.
|
153 |
+
"""
|
154 |
+
# Load HF dataset. Subclasses provide HF dataset details in `dataset_config`
|
155 |
+
if not self.data:
|
156 |
+
self.data = self.get_hf_dataset(**self.dataset_config, limit=self.dataset_limit)
|
157 |
+
|
158 |
+
if not self.samples:
|
159 |
+
if hasattr(self, "dataset_field") and self.dataset_field:
|
160 |
+
samples = self.data[self.dataset_field]
|
161 |
+
else:
|
162 |
+
try:
|
163 |
+
samples = self.process_samples()
|
164 |
+
except NotImplementedError:
|
165 |
+
raise ValueError(
|
166 |
+
f"No dataset field specified for class {self.__class__}, "
|
167 |
+
f"and process_samples() method not defined."
|
168 |
+
)
|
169 |
+
if self.randomize:
|
170 |
+
import random
|
171 |
+
random.seed(self.randomize_seed)
|
172 |
+
random.shuffle(samples)
|
173 |
+
self.samples = self.list_with_nls(samples)
|
174 |
+
return self.samples
|
175 |
+
|
176 |
+
def process_samples(self) -> List[str]:
|
177 |
+
if not self.transform_fields or not isinstance(self.transform_fields, list):
|
178 |
+
raise ValueError("transform_fields must be a List[str], defined in the subclass")
|
179 |
+
|
180 |
+
if not self.transform_join or not isinstance(self.transform_join, str):
|
181 |
+
raise ValueError("transform_fields must be a str defined in the subclass")
|
182 |
+
|
183 |
+
def transform_sample(sample):
|
184 |
+
field_values = {field: sample[field] for field in self.transform_fields}
|
185 |
+
# We support both:
|
186 |
+
# generic numbered fields: "{field1} {field2}"
|
187 |
+
# and named fields: "{input} {output}"
|
188 |
+
# Creating a combined dictionary to handle both specific field names and generic placeholders
|
189 |
+
combined_dict = {**field_values, **{f'field{i+1}': field for i, field in enumerate(field_values.values())}}
|
190 |
+
output = self.transform_join.format_map(combined_dict)
|
191 |
+
return {"output": output}
|
192 |
+
|
193 |
+
return self.data.map(transform_sample)["output"]
|
194 |
+
|
195 |
+
def generate_checksum(self) -> str:
|
196 |
+
# Create a sha256sum checksum of the joined samples
|
197 |
+
# Can be used to confirm that code updates haven't changed the output
|
198 |
+
import hashlib
|
199 |
+
samples = self.get_samples()
|
200 |
+
combined_samples = ''.join(samples)
|
201 |
+
checksum = hashlib.sha256(combined_samples.encode()).hexdigest()
|
202 |
+
return checksum
|
203 |
+
|
204 |
+
@classmethod
|
205 |
+
def get_dataset_url(cls) -> str:
|
206 |
+
"""Return the Hugging Face dataset URL for this dataset."""
|
207 |
+
if hasattr(cls, "dataset_url") and cls.dataset_url:
|
208 |
+
return cls.dataset_url
|
209 |
+
else:
|
210 |
+
return "https://huggingface.co/datasets/{}/viewer/{}".format(
|
211 |
+
cls.dataset_config["path"],
|
212 |
+
cls.dataset_config.get("name", "")
|
213 |
+
)
|
214 |
+
|
215 |
+
|
216 |
+
class WikitextDataset(CalibrationDataset):
|
217 |
+
dataset = "wikitext"
|
218 |
+
dataset_config = {
|
219 |
+
"path": "wikitext",
|
220 |
+
"name": "wikitext-2-raw-v1",
|
221 |
+
"split": "train"
|
222 |
+
}
|
223 |
+
dataset_name = "Wikitext2 Full"
|
224 |
+
|
225 |
+
def process_samples(self) -> List[str]:
|
226 |
+
return [
|
227 |
+
"\n" if len(item) == 0 else item
|
228 |
+
for item in self.data["text"]
|
229 |
+
]
|
230 |
+
|
231 |
+
|
232 |
+
class C4Dataset(CalibrationDataset):
|
233 |
+
dataset = "c4"
|
234 |
+
dataset_field = "text"
|
235 |
+
dataset_config = {
|
236 |
+
"path": "allenai/c4",
|
237 |
+
"data_files": {
|
238 |
+
"train": "en/c4-train.00000-of-01024.json.gz"
|
239 |
+
},
|
240 |
+
"split": "train"
|
241 |
+
}
|
242 |
+
dataset_name = "C4"
|
243 |
+
|
244 |
+
|
245 |
+
class ThaiDataset(CalibrationDataset):
|
246 |
+
dataset = "thai"
|
247 |
+
dataset_field = "text"
|
248 |
+
dataset_config = {
|
249 |
+
"path": "pbwt/all-thai",
|
250 |
+
"data_files": {
|
251 |
+
"train": "data/train-00000-of-00047-985fbaed08d034cf.parquet"
|
252 |
+
},
|
253 |
+
"split": "train"
|
254 |
+
}
|
255 |
+
dataset_name = "All Thai"
|
256 |
+
|
257 |
+
|
258 |
+
class MovieScriptDataset(CalibrationDataset):
|
259 |
+
dataset = "movie-scripts"
|
260 |
+
dataset_field = "full_script"
|
261 |
+
dataset_config = {
|
262 |
+
"path": "jondurbin/cinematika-v0.1",
|
263 |
+
"data_files": { "train": "full_script.parquet" },
|
264 |
+
"split": "train"
|
265 |
+
}
|
266 |
+
dataset_name = "Cinematika Full Scripts"
|
267 |
+
|
268 |
+
|
269 |
+
class JapaneseEnglishDataset(CalibrationDataset):
|
270 |
+
dataset = "japanese-english"
|
271 |
+
dataset_config = {
|
272 |
+
"path": "augmxnt/shisa-en-ja-dpo-v1",
|
273 |
+
"split": "train"
|
274 |
+
}
|
275 |
+
dataset_name = "Shisa English Japanese DPO"
|
276 |
+
randomize = True
|
277 |
+
|
278 |
+
def process_samples(self) -> List[str]:
|
279 |
+
def transform_samples(sample):
|
280 |
+
prompt = sample["prompt"]
|
281 |
+
chosen = sample["chosen"]
|
282 |
+
# prompt example: "[INST] <<SYS>>\nYou are a helpful, unbiased, uncensored assistant.\n<</SYS>>\n\nWhat are cardigans made of? Leather or wood? [/INST]"
|
283 |
+
|
284 |
+
try:
|
285 |
+
part1 = prompt.split('\n<</SYS>>\n\n')[1]
|
286 |
+
extracted_text = part1.split(' [/INST]')[0]
|
287 |
+
except Exception as e:
|
288 |
+
print(f"Error extracting text from prompt '{prompt}': {e}")
|
289 |
+
raise
|
290 |
+
|
291 |
+
prompt = extracted_text
|
292 |
+
|
293 |
+
return {"output": f"{prompt} {chosen}"}
|
294 |
+
|
295 |
+
return self.data.map(transform_samples)["output"]
|
296 |
+
|
297 |
+
|
298 |
+
class PortugueseDataset(CalibrationDataset):
|
299 |
+
dataset = "portuguese"
|
300 |
+
dataset_config = {
|
301 |
+
"path": "adalbertojunior/portuguese_orca",
|
302 |
+
"split": "train"
|
303 |
+
}
|
304 |
+
dataset_name = "Portuguese Orca"
|
305 |
+
transform_fields = [ "question", "response" ]
|
306 |
+
|
307 |
+
|
308 |
+
class MathsDataset(CalibrationDataset):
|
309 |
+
dataset = "maths"
|
310 |
+
dataset_config = {
|
311 |
+
"path": "andersonbcdefg/math",
|
312 |
+
"split": "train"
|
313 |
+
}
|
314 |
+
dataset_name = "CamelAI Math"
|
315 |
+
transform_fields = [ "message_1", "message_2" ]
|
316 |
+
|
317 |
+
|
318 |
+
class MedicalDataset(CalibrationDataset):
|
319 |
+
dataset = "medical"
|
320 |
+
dataset_config = {
|
321 |
+
"path": "medalpaca/medical_meadow_wikidoc",
|
322 |
+
"split": "train"
|
323 |
+
}
|
324 |
+
dataset_name = "Medical Medaow WikiDoc"
|
325 |
+
transform_fields = [ "input", "output" ]
|
326 |
+
|
327 |
+
|
328 |
+
class OpenInstructDataset(CalibrationDataset):
|
329 |
+
dataset = "open-instruct"
|
330 |
+
dataset_config = {
|
331 |
+
"path": "VMware/open-instruct",
|
332 |
+
"split": "train"
|
333 |
+
}
|
334 |
+
dataset_name = "VMware Open Instruct"
|
335 |
+
transform_fields = [ "instruction", "response" ]
|
336 |
+
|
337 |
+
|
338 |
+
class KoreanDataset(CalibrationDataset):
|
339 |
+
dataset = "korean"
|
340 |
+
dataset_config = {
|
341 |
+
"path": "beomi/KoAlpaca-v1.1a",
|
342 |
+
"split": "train"
|
343 |
+
}
|
344 |
+
dataset_name = "Korean Alpaca"
|
345 |
+
transform_fields = [ "instruction", "output" ]
|
346 |
+
|
347 |
+
|
348 |
+
class CodeDataset(CalibrationDataset):
|
349 |
+
dataset = "code"
|
350 |
+
dataset_field = "output"
|
351 |
+
dataset_config = {
|
352 |
+
"path": "nickrosh/Evol-Instruct-Code-80k-v1",
|
353 |
+
"split": "train"
|
354 |
+
}
|
355 |
+
dataset_name = "Evol Instruct Code"
|
356 |
+
|
357 |
+
|
358 |
+
class MultiLanguageDataset(CalibrationDataset):
|
359 |
+
dataset = "multi-language"
|
360 |
+
dataset_field = "text"
|
361 |
+
dataset_config = {
|
362 |
+
"path": "papluca/language-identification",
|
363 |
+
"split": "train"
|
364 |
+
}
|
365 |
+
dataset_name = "Language Identification"
|
366 |
+
|
367 |
+
|
368 |
+
class RussianDataset(CalibrationDataset):
|
369 |
+
dataset = "russian"
|
370 |
+
dataset_config = {
|
371 |
+
"path": "Den4ikAI/russian_instructions_2",
|
372 |
+
"split": "train"
|
373 |
+
}
|
374 |
+
dataset_name = "Russian Instructions 2"
|
375 |
+
transform_fields = [ "question", "answer" ]
|
376 |
+
|
377 |
+
|
378 |
+
class DutchDataset(CalibrationDataset):
|
379 |
+
dataset = "dutch"
|
380 |
+
dataset_config = {
|
381 |
+
"path": "BramVanroy/dolly-15k-dutch",
|
382 |
+
"split": "train"
|
383 |
+
}
|
384 |
+
dataset_name = "Dolly 15K Dutch"
|
385 |
+
transform_fields = [ "instruction", "context", "response" ]
|
386 |
+
transform_join = "{field1} {field2} {field3}"
|
387 |
+
|
388 |
+
|
389 |
+
class VietnameseChineseDataset(CalibrationDataset):
|
390 |
+
dataset = "vietnamesechinese"
|
391 |
+
dataset_config = {
|
392 |
+
"path": "nRuaif/Vietnamese_x_Alpaca",
|
393 |
+
"split": "train"
|
394 |
+
}
|
395 |
+
dataset_name = "Vietnamese and Chinese"
|
396 |
+
|
397 |
+
def get_dataset_url(self) -> None:
|
398 |
+
return None
|
399 |
+
|
400 |
+
def process_samples(self) -> List[str]:
|
401 |
+
samples = self.data["output"]
|
402 |
+
chinese_samples = CalibrationDataset.get_dataset("chinese").get_samples()
|
403 |
+
|
404 |
+
joined_list = samples + chinese_samples
|
405 |
+
|
406 |
+
import random
|
407 |
+
random.shuffle(joined_list)
|
408 |
+
|
409 |
+
return joined_list[:self.dataset_limit]
|
410 |
+
|
411 |
+
|
412 |
+
class VietnameseDataset(CalibrationDataset):
|
413 |
+
dataset = "vietnamese"
|
414 |
+
dataset_field = "output"
|
415 |
+
dataset_config = {
|
416 |
+
"path": "nRuaif/Vietnamese_x_Alpaca",
|
417 |
+
"split": "train"
|
418 |
+
}
|
419 |
+
dataset_name = "Alpaca Vietnamese"
|
420 |
+
|
421 |
+
|
422 |
+
class ChineseDataset(CalibrationDataset):
|
423 |
+
dataset = "chinese"
|
424 |
+
dataset_config = {
|
425 |
+
"path": "TigerResearch/tigerbot-alpaca-zh-0.5m",
|
426 |
+
"split": "train"
|
427 |
+
}
|
428 |
+
dataset_name = "Tiger Alpaca ZH"
|
429 |
+
transform_fields = [ "instruction", "input", "output" ]
|
430 |
+
transform_join = "{field1} {field2} {field3}"
|
431 |
+
|
432 |
+
|
433 |
+
class LatinEnglishDataset(CalibrationDataset):
|
434 |
+
dataset = "latin-english"
|
435 |
+
dataset_config = {
|
436 |
+
"path": "grosenthal/latin_english_parallel",
|
437 |
+
"split": "train"
|
438 |
+
}
|
439 |
+
dataset_name = "Latin English Parallel"
|
440 |
+
transform_fields = [ "la", "en" ]
|
441 |
+
transform_join = "{field1}\n{field2}"
|
442 |
+
|
443 |
+
|
444 |
+
class PolishDataset(CalibrationDataset):
|
445 |
+
dataset = "polish"
|
446 |
+
dataset_field = "content"
|
447 |
+
dataset_config = {
|
448 |
+
"path": "WiktorS/polish-news",
|
449 |
+
"split": "train"
|
450 |
+
}
|
451 |
+
dataset_name = "Polish News"
|
452 |
+
|
453 |
+
|
454 |
+
class JapaneseDataset(CalibrationDataset):
|
455 |
+
dataset = "japanese"
|
456 |
+
dataset_field = "output"
|
457 |
+
dataset_config = {
|
458 |
+
"path": "fujiki/japanese_alpaca_data",
|
459 |
+
"split": "train"
|
460 |
+
}
|
461 |
+
dataset_name = "Alpaca Japanese"
|
462 |
+
|
463 |
+
|
464 |
+
class SpanishDataset(CalibrationDataset):
|
465 |
+
dataset = "spanish"
|
466 |
+
dataset_field = "output"
|
467 |
+
dataset_config = {
|
468 |
+
"path": "bertin-project/alpaca-spanish",
|
469 |
+
"split": "train"
|
470 |
+
}
|
471 |
+
dataset_name = "Alpaca Spanish"
|
472 |
+
|
473 |
+
|
474 |
+
class GermanDataset(CalibrationDataset):
|
475 |
+
dataset = "german"
|
476 |
+
dataset_config = {
|
477 |
+
"path": "deepset/germanquad",
|
478 |
+
"split": "train"
|
479 |
+
}
|
480 |
+
dataset_name = "German Quad"
|
481 |
+
|
482 |
+
def process_samples(self) -> List[str]:
|
483 |
+
def transform_samples(sample):
|
484 |
+
split_context = sample["context"].split("===")
|
485 |
+
if len(split_context) >= 3:
|
486 |
+
trans_context = split_context[2]
|
487 |
+
else:
|
488 |
+
trans_context = sample["context"]
|
489 |
+
return {"output": trans_context.strip()}
|
490 |
+
|
491 |
+
return self.data.map(transform_samples)["output"]
|
492 |
+
|
493 |
+
|
494 |
+
class FrenchDataset(CalibrationDataset):
|
495 |
+
dataset = "french"
|
496 |
+
dataset_field = "text"
|
497 |
+
dataset_config = {
|
498 |
+
"path": "Kant1/French_Wikipedia_articles",
|
499 |
+
"data_files": { "wiki_00.txt" },
|
500 |
+
"split": "train"
|
501 |
+
}
|
502 |
+
dataset_name = "French Wikipedia Articles"
|
503 |
+
|
504 |
+
|
505 |
+
def validate_dataset(dataset_name: str, **kwargs):
|
506 |
+
for cls in CalibrationDataset.__subclasses__():
|
507 |
+
if hasattr(cls, "dataset") and cls.dataset == dataset_name:
|
508 |
+
return True
|
509 |
+
return False
|
510 |
+
|
511 |
+
# FIXME: a temp function put in for AutoAWQ, pending full refactor where it won't be necessary
|
512 |
+
def get_dataset_url(dataset_name: str):
|
513 |
+
for cls in CalibrationDataset.__subclasses__():
|
514 |
+
if hasattr(cls, "dataset") and cls.dataset == dataset_name:
|
515 |
+
return cls.get_dataset_url()
|
516 |
+
raise ValueError(f"No dataset class found for name: {dataset_name}")
|
517 |
+
|
518 |
+
def get_dataset_name(dataset_name: str):
|
519 |
+
for cls in CalibrationDataset.__subclasses__():
|
520 |
+
if hasattr(cls, "dataset") and cls.dataset == dataset_name:
|
521 |
+
return cls.dataset_name
|
522 |
+
raise ValueError(f"No dataset class found for name: {dataset_name}")
|
523 |
+
|
524 |
+
def test_datasets(datasets: Optional[List[str]] = None, checksum_only=False):
|
525 |
+
import sys
|
526 |
+
from transformers import AutoTokenizer
|
527 |
+
try:
|
528 |
+
failed = []
|
529 |
+
for cls in CalibrationDataset.__subclasses__():
|
530 |
+
if not hasattr(cls, "dataset") or not cls.dataset:
|
531 |
+
failed.append(cls.__name__)
|
532 |
+
if failed:
|
533 |
+
print(f"The following classes have no 'dataset' attribute: {failed}")
|
534 |
+
sys.exit(-1)
|
535 |
+
else:
|
536 |
+
print()(f"All classes have 'dataset' attribute.")
|
537 |
+
|
538 |
+
print(f"Enumerating CalibrationDataset classes")
|
539 |
+
classes = CalibrationDataset.__subclasses__()
|
540 |
+
dataset_names = [
|
541 |
+
cls.dataset
|
542 |
+
for cls in classes
|
543 |
+
if cls.dataset and (not datasets or cls.dataset in datasets)
|
544 |
+
]
|
545 |
+
|
546 |
+
print(f"Found {len(classes)} total dataset classes: {[c.dataset for c in classes]}")
|
547 |
+
if datasets:
|
548 |
+
print(f"Will test {len(dataset_names)} datasets: {dataset_names}")
|
549 |
+
|
550 |
+
print(f"Starting test: loading Llama-2 tokenizer")
|
551 |
+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_fast=True)
|
552 |
+
|
553 |
+
for name in dataset_names:
|
554 |
+
print(f"{name} test: loading dataset.")
|
555 |
+
dataset = CalibrationDataset.get_dataset(name, tokenizer=tokenizer)
|
556 |
+
if not checksum_only:
|
557 |
+
print(f"{name} test: running tokenize_dataset.")
|
558 |
+
toks = dataset.tokenize_dataset()
|
559 |
+
print(f"{name} test: getting dataset_url.")
|
560 |
+
url = dataset.get_dataset_url()
|
561 |
+
print(f"{name} - randomized? {dataset.randomize}")
|
562 |
+
print(
|
563 |
+
f"{name} - result: cls.data: length: {len(dataset.data)}, "
|
564 |
+
f"first row length: {len(dataset.data[0])}, "
|
565 |
+
f"first row data: '{dataset.data[0]}'."
|
566 |
+
)
|
567 |
+
print(
|
568 |
+
f"{name} - result: cls.samples: length: {len(dataset.samples)}, "
|
569 |
+
f"first row length: {len(dataset.samples[0])}, "
|
570 |
+
f"first row sample: '{dataset.samples[0]}'."
|
571 |
+
)
|
572 |
+
print(
|
573 |
+
f"{name} - result: tokenize_dataset result: length: {len(toks)}, "
|
574 |
+
f"length first row input_ids: {len(toks[0]['input_ids'])}."
|
575 |
+
)
|
576 |
+
print(
|
577 |
+
f"{name} - result: dataset_url: {url}"
|
578 |
+
)
|
579 |
+
checksum = dataset.generate_checksum()
|
580 |
+
print(
|
581 |
+
f"{name} - result: sha256 checksum: {checksum}"
|
582 |
+
)
|
583 |
+
|
584 |
+
except KeyboardInterrupt:
|
585 |
+
print("Test aborted")
|
586 |
+
|
587 |
+
except Exception as e:
|
588 |
+
print(
|
589 |
+
f"Received an exception during test. Test failed. "
|
590 |
+
f"Exception: {e}"
|
591 |
+
)
|
592 |
+
raise
|
593 |
+
|
594 |
+
|
595 |
+
if __name__ == "__main__":
|
596 |
+
import argparse
|
597 |
+
|
598 |
+
parser = argparse.ArgumentParser(description="Test calibration datasets")
|
599 |
+
parser.add_argument("--datasets", "-d", "-n", nargs="*", type=str, help="Dataset(s) to check; default is all")
|
600 |
+
parser.add_argument("--checksum_only", "-co", action="store_true", help="Only ouput the checksums for the datasets")
|
601 |
+
args = parser.parse_args()
|
602 |
+
|
603 |
+
test_datasets(args.datasets, checksum_only=args.checksum_only)
|
src/medusa_training_script.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Hold the training script for the medusa model.
|
3 |
+
|
4 |
+
Adapted from the original code here: https://github.com/FasterDecoding/Medusa/blob/main/medusa/train/train.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
import pathlib
|
10 |
+
from typing import Dict, Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch.utils.data import Dataset
|
14 |
+
import transformers
|
15 |
+
from transformers import Trainer, BitsAndBytesConfig
|
16 |
+
from transformers.trainer_pt_utils import LabelSmoother
|
17 |
+
from torch.nn import CrossEntropyLoss
|
18 |
+
from medusa.model.medusa_model import MedusaModel, MedusaConfig
|
19 |
+
|
20 |
+
from calibration_datasets import CalibrationDataset
|
21 |
+
|
22 |
+
|
23 |
+
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
24 |
+
|
25 |
+
|
26 |
+
# Customized for training Medusa heads
|
27 |
+
class CustomizedTrainer(Trainer):
|
28 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
29 |
+
"""
|
30 |
+
Compute the training loss for the model.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
model (torch.nn.Module): The model for which to compute the loss.
|
34 |
+
inputs (dict): The input data, including input IDs, attention mask, and labels.
|
35 |
+
return_outputs (bool): Whether to return model outputs along with the loss.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
Union[float, Tuple[float, torch.Tensor]]: The computed loss, optionally with model outputs.
|
39 |
+
"""
|
40 |
+
# DDP will give us model.module
|
41 |
+
if hasattr(model, "module"):
|
42 |
+
medusa = model.module.medusa
|
43 |
+
else:
|
44 |
+
medusa = model.medusa
|
45 |
+
|
46 |
+
logits = model(
|
47 |
+
input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]
|
48 |
+
)
|
49 |
+
labels = inputs["labels"]
|
50 |
+
# Shift so that tokens < n predict n
|
51 |
+
loss = 0
|
52 |
+
loss_fct = CrossEntropyLoss()
|
53 |
+
log = {}
|
54 |
+
for i in range(medusa):
|
55 |
+
medusa_logits = logits[i, :, : -(2 + i)].contiguous()
|
56 |
+
medusa_labels = labels[..., 2 + i :].contiguous()
|
57 |
+
medusa_logits = medusa_logits.view(-1, logits.shape[-1])
|
58 |
+
medusa_labels = medusa_labels.view(-1)
|
59 |
+
medusa_labels = medusa_labels.to(medusa_logits.device)
|
60 |
+
loss_i = loss_fct(medusa_logits, medusa_labels)
|
61 |
+
loss += loss_i
|
62 |
+
not_ignore = medusa_labels.ne(IGNORE_TOKEN_ID)
|
63 |
+
medusa_labels = medusa_labels[not_ignore]
|
64 |
+
|
65 |
+
# Add top-k accuracy
|
66 |
+
for k in range(1, 6):
|
67 |
+
_, topk = medusa_logits.topk(k, dim=-1)
|
68 |
+
topk = topk[not_ignore]
|
69 |
+
correct = topk.eq(medusa_labels.unsqueeze(-1)).any(-1)
|
70 |
+
log[f"medusa{i}_top{k}"] = correct.float().mean().item()
|
71 |
+
|
72 |
+
log[f"medusa{i}_loss"] = loss_i.item()
|
73 |
+
self.log(log)
|
74 |
+
return (loss, logits) if return_outputs else loss
|
75 |
+
|
76 |
+
|
77 |
+
@dataclass
|
78 |
+
class ModelArguments:
|
79 |
+
model_name_or_path: Optional[str] = field()
|
80 |
+
load_in_4bit: bool = field(
|
81 |
+
default=False,
|
82 |
+
metadata={"help": "Load in 4 bit."},
|
83 |
+
)
|
84 |
+
load_in_8bit: bool = field(
|
85 |
+
default=False,
|
86 |
+
metadata={"help": "Load in 8 bit."},
|
87 |
+
)
|
88 |
+
|
89 |
+
|
90 |
+
@dataclass
|
91 |
+
class DataArguments:
|
92 |
+
dataset: str = field(
|
93 |
+
metadata={"help": "One of the datasets names in a CalibrationDataset subclass."},
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
@dataclass
|
98 |
+
class TrainingArguments(transformers.TrainingArguments):
|
99 |
+
cache_dir: Optional[str] = field(default=None)
|
100 |
+
optim: str = field(default="adamw_torch")
|
101 |
+
model_max_length: int = field(
|
102 |
+
default=2048,
|
103 |
+
metadata={
|
104 |
+
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
105 |
+
},
|
106 |
+
)
|
107 |
+
medusa_num_heads: int = field(
|
108 |
+
default=1,
|
109 |
+
metadata={"help": "Number of Medusa heads."},
|
110 |
+
)
|
111 |
+
medusa_num_layers: int = field(
|
112 |
+
default=1,
|
113 |
+
metadata={"help": "Number of layers for each Medusa head."},
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
local_rank = None
|
118 |
+
|
119 |
+
|
120 |
+
def rank0_print(*args):
|
121 |
+
if local_rank == 0:
|
122 |
+
print(*args)
|
123 |
+
|
124 |
+
|
125 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
126 |
+
"""
|
127 |
+
Save the model's state dictionary to a specified directory.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
trainer (transformers.Trainer): The Hugging Face Trainer object.
|
131 |
+
output_dir (str): The directory where the model state dictionary will be saved.
|
132 |
+
"""
|
133 |
+
state_dict = trainer.model.state_dict()
|
134 |
+
if trainer.args.should_save:
|
135 |
+
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
136 |
+
del state_dict
|
137 |
+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
138 |
+
|
139 |
+
|
140 |
+
class SupervisedDataset(Dataset):
|
141 |
+
"""Dataset for supervised fine-tuning.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
dataset (str): One of the datasets names in a CalibrationDataset subclass.
|
145 |
+
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
|
146 |
+
"""
|
147 |
+
|
148 |
+
def __init__(self, dataset, tokenizer: transformers.PreTrainedTokenizer):
|
149 |
+
super(SupervisedDataset, self).__init__()
|
150 |
+
|
151 |
+
rank0_print("Formatting inputs...")
|
152 |
+
dataset_classes = CalibrationDataset.__subclasses__()
|
153 |
+
for dataset_class in dataset_classes:
|
154 |
+
if dataset_class.dataset == dataset:
|
155 |
+
dataset = dataset_class(num_samples=int(1e6), seqlen=tokenizer.model_max_length, tokenizer=tokenizer)
|
156 |
+
break
|
157 |
+
tokenized = dataset.tokenize_dataset()
|
158 |
+
self.input_ids = torch.tensor([data["input_ids"] for data in tokenized], dtype=torch.long)
|
159 |
+
self.attention_mask = torch.tensor([data["attention_mask"] for data in tokenized], dtype=torch.long)
|
160 |
+
|
161 |
+
def __len__(self):
|
162 |
+
return self.input_ids.shape[0]
|
163 |
+
|
164 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
165 |
+
return dict(
|
166 |
+
input_ids=self.input_ids[i],
|
167 |
+
labels=self.input_ids[i],
|
168 |
+
attention_mask=self.attention_mask[i],
|
169 |
+
)
|
170 |
+
|
171 |
+
|
172 |
+
def train():
|
173 |
+
global local_rank
|
174 |
+
|
175 |
+
parser = transformers.HfArgumentParser(
|
176 |
+
(ModelArguments, DataArguments, TrainingArguments)
|
177 |
+
)
|
178 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
179 |
+
local_rank = training_args.local_rank
|
180 |
+
|
181 |
+
config = transformers.AutoConfig.from_pretrained(
|
182 |
+
model_args.model_name_or_path,
|
183 |
+
cache_dir=training_args.cache_dir,
|
184 |
+
)
|
185 |
+
config.use_cache = False
|
186 |
+
|
187 |
+
quantization_config = BitsAndBytesConfig(
|
188 |
+
load_in_4bit=True,
|
189 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
190 |
+
bnb_4bit_use_double_quant=True,
|
191 |
+
bnb_4bit_quant_type="nf4",
|
192 |
+
)
|
193 |
+
|
194 |
+
# Load model and tokenizer
|
195 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
196 |
+
model_args.model_name_or_path,
|
197 |
+
config=config,
|
198 |
+
cache_dir=training_args.cache_dir,
|
199 |
+
low_cpu_mem_usage=True,
|
200 |
+
torch_dtype=torch.bfloat16,
|
201 |
+
quantization_config=quantization_config if model_args.load_in_4bit else None,
|
202 |
+
load_in_4bit=model_args.load_in_4bit,
|
203 |
+
load_in_8bit=model_args.load_in_8bit,
|
204 |
+
)
|
205 |
+
|
206 |
+
# Freeze the base model
|
207 |
+
for param in model.base_model.parameters():
|
208 |
+
param.requires_grad = False
|
209 |
+
|
210 |
+
# Add Medusa heads
|
211 |
+
medusa_lm_head = MedusaModel(
|
212 |
+
model,
|
213 |
+
medusa_num_heads=training_args.medusa_num_heads,
|
214 |
+
medusa_num_layers=training_args.medusa_num_layers,
|
215 |
+
base_model_name_or_path=model_args.model_name_or_path,
|
216 |
+
)
|
217 |
+
|
218 |
+
# Format output dir
|
219 |
+
training_args.output_dir = f"{training_args.output_dir}_medusa_{model_args.model_name_or_path.split('/')[-1]}"
|
220 |
+
|
221 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
222 |
+
model_args.model_name_or_path,
|
223 |
+
cache_dir=training_args.cache_dir,
|
224 |
+
model_max_length=training_args.model_max_length,
|
225 |
+
padding_side="right",
|
226 |
+
use_fast=False,
|
227 |
+
)
|
228 |
+
tokenizer.pad_token = tokenizer.unk_token
|
229 |
+
|
230 |
+
# Load data
|
231 |
+
data_module = {"train_dataset": SupervisedDataset(data_args.dataset, tokenizer), "eval_dataset": None}
|
232 |
+
|
233 |
+
|
234 |
+
# Generate Medusa config for pushing to HF hub
|
235 |
+
medusa_config = MedusaConfig(
|
236 |
+
medusa_num_heads=training_args.medusa_num_heads,
|
237 |
+
medusa_num_layers=training_args.medusa_num_layers,
|
238 |
+
base_model_name_or_path=model_args.model_name_or_path,
|
239 |
+
)
|
240 |
+
|
241 |
+
# Save Medusa config
|
242 |
+
medusa_config.save_pretrained(training_args.output_dir)
|
243 |
+
|
244 |
+
# Start trainner
|
245 |
+
trainer = CustomizedTrainer(
|
246 |
+
model=medusa_lm_head, tokenizer=tokenizer, args=training_args, **data_module
|
247 |
+
)
|
248 |
+
|
249 |
+
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
250 |
+
trainer.train(resume_from_checkpoint=True)
|
251 |
+
else:
|
252 |
+
trainer.train()
|
253 |
+
model.config.use_cache = True
|
254 |
+
|
255 |
+
# Save MedusaHead seperately
|
256 |
+
if hasattr(medusa_lm_head, "module"):
|
257 |
+
lm_head = medusa_lm_head.module.medusa_head
|
258 |
+
else:
|
259 |
+
lm_head = medusa_lm_head.medusa_head
|
260 |
+
|
261 |
+
# Save Medusa heads
|
262 |
+
torch.save(
|
263 |
+
lm_head.state_dict(),
|
264 |
+
os.path.join(training_args.output_dir, "medusa_lm_head.pt"),
|
265 |
+
)
|
266 |
+
|
267 |
+
|
268 |
+
if __name__ == "__main__":
|
269 |
+
train()
|
medusa_training.py β src/train_workflow.py
RENAMED
@@ -1,4 +1,6 @@
|
|
1 |
-
|
|
|
|
|
2 |
import os
|
3 |
import multiprocessing as mp
|
4 |
|
@@ -9,26 +11,23 @@ import torch
|
|
9 |
import torch.distributed.run as distributed_run
|
10 |
|
11 |
OUTPUT_DIR = "medusa_heads"
|
12 |
-
MEDUSA_NUM_HEADS = 3
|
13 |
-
MEDUSA_NUM_LAYERS = 1
|
14 |
-
LR = 1e-3
|
15 |
|
16 |
DATASET = "vicuna"
|
17 |
|
18 |
# These can't be changed (e.g. they control the output path)
|
19 |
FIXED_TRAINING_ARGS = \
|
20 |
-
"""
|
21 |
--model_name_or_path {model_id}
|
22 |
--output_dir {output_dir}
|
23 |
--run_name {model_id}-medusa-{dataset}
|
24 |
-
--
|
25 |
-
--medusa_num_layers {medusa_num_layers}
|
26 |
-
--learning_rate {lr}
|
27 |
-
--data_path data/ShareGPT_V4.3_unfiltered_cleaned_split.json"""
|
28 |
|
29 |
# These can be freely changed
|
30 |
DEFAULT_TRAINING_ARGS = \
|
31 |
-
"""--
|
|
|
|
|
|
|
32 |
--num_train_epochs 1
|
33 |
--per_device_train_batch_size 64
|
34 |
--per_device_eval_batch_size 64
|
@@ -40,19 +39,13 @@ DEFAULT_TRAINING_ARGS = \
|
|
40 |
--lr_scheduler_type cosine
|
41 |
--logging_steps 10
|
42 |
--tf32 True
|
43 |
-
--
|
44 |
-
--
|
45 |
-
--auto_find_batch_size True"""
|
46 |
|
47 |
|
48 |
-
def train_medusa_heads(model_id: str, training_args: str):
|
49 |
all_training_args = FIXED_TRAINING_ARGS.format(
|
50 |
-
model_id=model_id,
|
51 |
-
output_dir=OUTPUT_DIR,
|
52 |
-
dataset=DATASET,
|
53 |
-
medusa_num_heads=MEDUSA_NUM_HEADS,
|
54 |
-
lr=LR,
|
55 |
-
medusa_num_layers=MEDUSA_NUM_LAYERS
|
56 |
) + "\n" + training_args
|
57 |
all_training_arg_list = []
|
58 |
for arg in all_training_args.split("\n"):
|
@@ -64,11 +57,11 @@ def train_medusa_heads(model_id: str, training_args: str):
|
|
64 |
distributed_run.run(args)
|
65 |
|
66 |
|
67 |
-
def run(model_id: str, training_args: str) -> str:
|
68 |
print(f"\n\n\nNEW RUN: {model_id}")
|
69 |
api = HfApi()
|
70 |
model_name = model_id.split("/")[-1]
|
71 |
-
repo_id = f"joaogante/{model_name}-medusa-{
|
72 |
|
73 |
# Input validation
|
74 |
if model_id == "":
|
@@ -101,7 +94,7 @@ def run(model_id: str, training_args: str) -> str:
|
|
101 |
|
102 |
# Run the medusa heads creation
|
103 |
try:
|
104 |
-
proc = mp.Process(target=train_medusa_heads, args=(model_id, training_args))
|
105 |
proc.start()
|
106 |
proc.join()
|
107 |
print("Medusa heads training process completed (it might have crashed!)")
|
@@ -117,7 +110,7 @@ def run(model_id: str, training_args: str) -> str:
|
|
117 |
try:
|
118 |
# Folder path from https://github.com/FasterDecoding/Medusa/blob/main/medusa/train/train.py#L399
|
119 |
folder_path = (
|
120 |
-
f"{OUTPUT_DIR}
|
121 |
)
|
122 |
if not any([x for x in os.listdir(folder_path) if len(x) >= 3 and x[-3:] == ".pt"]):
|
123 |
raise Exception(
|
|
|
1 |
+
"""
|
2 |
+
Holds the interface between the gradio app and the medusa training script
|
3 |
+
"""
|
4 |
import os
|
5 |
import multiprocessing as mp
|
6 |
|
|
|
11 |
import torch.distributed.run as distributed_run
|
12 |
|
13 |
OUTPUT_DIR = "medusa_heads"
|
|
|
|
|
|
|
14 |
|
15 |
DATASET = "vicuna"
|
16 |
|
17 |
# These can't be changed (e.g. they control the output path)
|
18 |
FIXED_TRAINING_ARGS = \
|
19 |
+
"""src/medusa_training_script.py
|
20 |
--model_name_or_path {model_id}
|
21 |
--output_dir {output_dir}
|
22 |
--run_name {model_id}-medusa-{dataset}
|
23 |
+
--dataset {dataset}"""
|
|
|
|
|
|
|
24 |
|
25 |
# These can be freely changed
|
26 |
DEFAULT_TRAINING_ARGS = \
|
27 |
+
"""--medusa_num_heads 3
|
28 |
+
--medusa_num_layers 1
|
29 |
+
--model_max_length 2048
|
30 |
+
--bf16 True
|
31 |
--num_train_epochs 1
|
32 |
--per_device_train_batch_size 64
|
33 |
--per_device_eval_batch_size 64
|
|
|
39 |
--lr_scheduler_type cosine
|
40 |
--logging_steps 10
|
41 |
--tf32 True
|
42 |
+
--auto_find_batch_size True
|
43 |
+
--learning_rate 1e-3"""
|
|
|
44 |
|
45 |
|
46 |
+
def train_medusa_heads(model_id: str, training_args: str, dataset: str):
|
47 |
all_training_args = FIXED_TRAINING_ARGS.format(
|
48 |
+
model_id=model_id, output_dir=OUTPUT_DIR, dataset=dataset,
|
|
|
|
|
|
|
|
|
|
|
49 |
) + "\n" + training_args
|
50 |
all_training_arg_list = []
|
51 |
for arg in all_training_args.split("\n"):
|
|
|
57 |
distributed_run.run(args)
|
58 |
|
59 |
|
60 |
+
def run(model_id: str, training_args: str, dataset: str) -> str:
|
61 |
print(f"\n\n\nNEW RUN: {model_id}")
|
62 |
api = HfApi()
|
63 |
model_name = model_id.split("/")[-1]
|
64 |
+
repo_id = f"joaogante/{model_name}-medusa-{dataset}"
|
65 |
|
66 |
# Input validation
|
67 |
if model_id == "":
|
|
|
94 |
|
95 |
# Run the medusa heads creation
|
96 |
try:
|
97 |
+
proc = mp.Process(target=train_medusa_heads, args=(model_id, training_args, dataset))
|
98 |
proc.start()
|
99 |
proc.join()
|
100 |
print("Medusa heads training process completed (it might have crashed!)")
|
|
|
110 |
try:
|
111 |
# Folder path from https://github.com/FasterDecoding/Medusa/blob/main/medusa/train/train.py#L399
|
112 |
folder_path = (
|
113 |
+
f"{OUTPUT_DIR}_medusa_{model_name}"
|
114 |
)
|
115 |
if not any([x for x in os.listdir(folder_path) if len(x) >= 3 and x[-3:] == ".pt"]):
|
116 |
raise Exception(
|