Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
•
40a8f4e
1
Parent(s):
d2eef14
extract configs from global
Browse files- LLaMA_LoRA.ipynb +9 -7
- app.py +48 -45
- llama_lora/config.py +47 -0
- llama_lora/globals.py +13 -26
- llama_lora/models.py +13 -12
- llama_lora/ui/finetune_ui.py +10 -9
- llama_lora/ui/inference_ui.py +3 -2
- llama_lora/ui/main_page.py +10 -9
- llama_lora/ui/tokenizer_ui.py +3 -2
- llama_lora/utils/data.py +10 -8
- llama_lora/utils/prompter.py +2 -1
LLaMA_LoRA.ipynb
CHANGED
@@ -279,21 +279,23 @@
|
|
279 |
{
|
280 |
"cell_type": "code",
|
281 |
"source": [
|
282 |
-
"# @title Load the App (set config, prepare data dir, load base
|
283 |
"\n",
|
284 |
"# @markdown For a LLaMA-7B model, it will take about ~5m to load for the first execution,\n",
|
285 |
"# @markdown including download. Subsequent executions will take about 2m to load.\n",
|
286 |
"\n",
|
287 |
"# Set Configs\n",
|
288 |
-
"from llama_lora.llama_lora.
|
289 |
-
"
|
290 |
-
"
|
|
|
291 |
"data_dir_realpath = !realpath ./data\n",
|
292 |
-
"
|
293 |
-
"
|
|
|
|
|
294 |
"\n",
|
295 |
"# Prepare Data Dir\n",
|
296 |
-
"import os\n",
|
297 |
"from llama_lora.llama_lora.utils.data import init_data_dir\n",
|
298 |
"init_data_dir()\n",
|
299 |
"\n",
|
|
|
279 |
{
|
280 |
"cell_type": "code",
|
281 |
"source": [
|
282 |
+
"# @title Load the App (set config, prepare data dir, load base model)\n",
|
283 |
"\n",
|
284 |
"# @markdown For a LLaMA-7B model, it will take about ~5m to load for the first execution,\n",
|
285 |
"# @markdown including download. Subsequent executions will take about 2m to load.\n",
|
286 |
"\n",
|
287 |
"# Set Configs\n",
|
288 |
+
"from llama_lora.llama_lora.config import Config, process_config\n",
|
289 |
+
"from llama_lora.llama_lora.globals import initialize_global\n",
|
290 |
+
"Config.default_base_model_name = base_model\n",
|
291 |
+
"Config.base_model_choices = [base_model]\n",
|
292 |
"data_dir_realpath = !realpath ./data\n",
|
293 |
+
"Config.data_dir = data_dir_realpath[0]\n",
|
294 |
+
"Config.load_8bit = True\n",
|
295 |
+
"process_config()\n",
|
296 |
+
"initialize_global()\n",
|
297 |
"\n",
|
298 |
"# Prepare Data Dir\n",
|
|
|
299 |
"from llama_lora.llama_lora.utils.data import init_data_dir\n",
|
300 |
"init_data_dir()\n",
|
301 |
"\n",
|
app.py
CHANGED
@@ -1,30 +1,30 @@
|
|
1 |
-
import
|
2 |
-
import sys
|
3 |
|
4 |
import fire
|
5 |
import gradio as gr
|
6 |
|
7 |
-
from llama_lora.
|
|
|
8 |
from llama_lora.models import prepare_base_model
|
9 |
-
from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
|
10 |
from llama_lora.utils.data import init_data_dir
|
11 |
-
|
|
|
|
|
12 |
|
13 |
|
14 |
def main(
|
15 |
-
base_model: str =
|
16 |
-
data_dir: str =
|
17 |
-
base_model_choices: str =
|
18 |
-
trust_remote_code: bool =
|
19 |
-
# Allows to listen on all interfaces by providing '0.0.0.0'.
|
20 |
server_name: str = "127.0.0.1",
|
21 |
share: bool = False,
|
22 |
skip_loading_base_model: bool = False,
|
23 |
-
load_8bit: bool =
|
24 |
-
ui_show_sys_info: bool =
|
25 |
-
ui_dev_mode: bool =
|
26 |
-
wandb_api_key: str =
|
27 |
-
wandb_project: str =
|
28 |
):
|
29 |
'''
|
30 |
Start the LLaMA-LoRA Tuner UI.
|
@@ -41,51 +41,54 @@ def main(
|
|
41 |
:param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
|
42 |
'''
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
assert (
|
47 |
-
base_model
|
48 |
-
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
53 |
|
54 |
-
|
|
|
55 |
|
56 |
-
if
|
57 |
-
|
58 |
-
base_model_choices = [name.strip() for name in base_model_choices]
|
59 |
-
Global.base_model_choices = base_model_choices
|
60 |
|
61 |
-
if
|
62 |
-
|
63 |
|
64 |
-
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
|
69 |
-
if
|
70 |
-
|
71 |
-
Global.wandb_api_key = wandb_api_key
|
72 |
-
if len(wandb_project) > 0:
|
73 |
-
Global.enable_wandb = True
|
74 |
-
Global.wandb_project = wandb_project
|
75 |
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
-
os.makedirs(data_dir, exist_ok=True)
|
80 |
init_data_dir()
|
81 |
|
82 |
-
if (not skip_loading_base_model) and (not ui_dev_mode):
|
83 |
-
prepare_base_model(
|
84 |
|
85 |
with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
|
86 |
main_page()
|
87 |
|
88 |
-
demo.queue(concurrency_count=1).launch(
|
|
|
89 |
|
90 |
|
91 |
if __name__ == "__main__":
|
|
|
1 |
+
from typing import Union
|
|
|
2 |
|
3 |
import fire
|
4 |
import gradio as gr
|
5 |
|
6 |
+
from llama_lora.config import Config, process_config
|
7 |
+
from llama_lora.globals import initialize_global
|
8 |
from llama_lora.models import prepare_base_model
|
|
|
9 |
from llama_lora.utils.data import init_data_dir
|
10 |
+
from llama_lora.ui.main_page import (
|
11 |
+
main_page, get_page_title, main_page_custom_css
|
12 |
+
)
|
13 |
|
14 |
|
15 |
def main(
|
16 |
+
base_model: Union[str, None] = None,
|
17 |
+
data_dir: Union[str, None] = None,
|
18 |
+
base_model_choices: Union[str, None] = None,
|
19 |
+
trust_remote_code: Union[bool, None] = None,
|
|
|
20 |
server_name: str = "127.0.0.1",
|
21 |
share: bool = False,
|
22 |
skip_loading_base_model: bool = False,
|
23 |
+
load_8bit: Union[bool, None] = None,
|
24 |
+
ui_show_sys_info: Union[bool, None] = None,
|
25 |
+
ui_dev_mode: Union[bool, None] = None,
|
26 |
+
wandb_api_key: Union[str, None] = None,
|
27 |
+
wandb_project: Union[str, None] = None,
|
28 |
):
|
29 |
'''
|
30 |
Start the LLaMA-LoRA Tuner UI.
|
|
|
41 |
:param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
|
42 |
'''
|
43 |
|
44 |
+
if base_model is not None:
|
45 |
+
Config.default_base_model_name = base_model
|
|
|
|
|
|
|
46 |
|
47 |
+
if base_model_choices is not None:
|
48 |
+
Config.base_model_choices = base_model_choices
|
49 |
+
|
50 |
+
if trust_remote_code is not None:
|
51 |
+
Config.trust_remote_code = trust_remote_code
|
52 |
|
53 |
+
if data_dir is not None:
|
54 |
+
Config.data_dir = data_dir
|
55 |
|
56 |
+
if load_8bit is not None:
|
57 |
+
Config.load_8bit = load_8bit
|
|
|
|
|
58 |
|
59 |
+
if wandb_api_key is not None:
|
60 |
+
Config.wandb_api_key = wandb_api_key
|
61 |
|
62 |
+
if wandb_project is not None:
|
63 |
+
Config.default_wandb_project = wandb_project
|
64 |
|
65 |
+
if ui_dev_mode is not None:
|
66 |
+
Config.ui_dev_mode = ui_dev_mode
|
67 |
|
68 |
+
if ui_show_sys_info is not None:
|
69 |
+
Config.ui_show_sys_info = ui_show_sys_info
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
process_config()
|
72 |
+
initialize_global()
|
73 |
+
|
74 |
+
assert (
|
75 |
+
Config.default_base_model_name
|
76 |
+
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
77 |
+
|
78 |
+
assert (
|
79 |
+
Config.data_dir
|
80 |
+
), "Please specify a --data_dir, e.g. --data_dir='./data'"
|
81 |
|
|
|
82 |
init_data_dir()
|
83 |
|
84 |
+
if (not skip_loading_base_model) and (not Config.ui_dev_mode):
|
85 |
+
prepare_base_model(Config.default_base_model_name)
|
86 |
|
87 |
with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
|
88 |
main_page()
|
89 |
|
90 |
+
demo.queue(concurrency_count=1).launch(
|
91 |
+
server_name=server_name, share=share)
|
92 |
|
93 |
|
94 |
if __name__ == "__main__":
|
llama_lora/config.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Union
|
3 |
+
|
4 |
+
|
5 |
+
class Config:
|
6 |
+
"""
|
7 |
+
Stores the application configuration. This is a singleton class.
|
8 |
+
"""
|
9 |
+
|
10 |
+
data_dir: str = ""
|
11 |
+
load_8bit: bool = False
|
12 |
+
|
13 |
+
default_base_model_name: str = ""
|
14 |
+
base_model_choices: Union[List[str], str] = []
|
15 |
+
|
16 |
+
trust_remote_code: bool = False
|
17 |
+
|
18 |
+
# WandB
|
19 |
+
enable_wandb: Union[bool, None] = False
|
20 |
+
wandb_api_key: Union[str, None] = None
|
21 |
+
default_wandb_project: str = "llama-lora-tuner"
|
22 |
+
|
23 |
+
# UI related
|
24 |
+
ui_title: str = "LLaMA-LoRA Tuner"
|
25 |
+
ui_emoji: str = "🦙🎛️"
|
26 |
+
ui_subtitle: str = "Toolkit for evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA)."
|
27 |
+
ui_show_sys_info: bool = True
|
28 |
+
ui_dev_mode: bool = False
|
29 |
+
ui_dev_mode_title_prefix: str = "[UI DEV MODE] "
|
30 |
+
|
31 |
+
|
32 |
+
def process_config():
|
33 |
+
Config.data_dir = os.path.abspath(Config.data_dir)
|
34 |
+
|
35 |
+
if isinstance(Config.base_model_choices, str):
|
36 |
+
base_model_choices = Config.base_model_choices.split(',')
|
37 |
+
base_model_choices = [name.strip() for name in base_model_choices]
|
38 |
+
Config.base_model_choices = base_model_choices
|
39 |
+
|
40 |
+
if Config.default_base_model_name not in Config.base_model_choices:
|
41 |
+
Config.base_model_choices = [Config.default_base_model_name] + Config.base_model_choices
|
42 |
+
|
43 |
+
if Config.enable_wandb is None:
|
44 |
+
if Config.wandb_api_key and len(Config.wandb_api_key) > 0:
|
45 |
+
Config.enable_wandb = True
|
46 |
+
if Config.default_wandb_project and len(Config.default_wandb_project) > 0:
|
47 |
+
Config.enable_wandb = True
|
llama_lora/globals.py
CHANGED
@@ -8,23 +8,21 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
8 |
from numba import cuda
|
9 |
import nvidia_smi
|
10 |
|
|
|
11 |
from .utils.lru_cache import LRUCache
|
12 |
from .utils.model_lru_cache import ModelLRUCache
|
13 |
from .lib.finetune import train
|
14 |
|
15 |
|
16 |
class Global:
|
17 |
-
|
|
|
|
|
18 |
|
19 |
-
|
20 |
-
load_8bit: bool = False
|
21 |
|
22 |
-
default_base_model_name: str = ""
|
23 |
base_model_name: str = ""
|
24 |
tokenizer_name = None
|
25 |
-
base_model_choices: List[str] = []
|
26 |
-
|
27 |
-
trust_remote_code = False
|
28 |
|
29 |
# Functions
|
30 |
train_fn: Any = train
|
@@ -48,18 +46,15 @@ class Global:
|
|
48 |
gpu_total_cores = None # GPU total cores
|
49 |
gpu_total_memory = None
|
50 |
|
51 |
-
# WandB
|
52 |
-
enable_wandb = False
|
53 |
-
wandb_api_key = None
|
54 |
-
default_wandb_project = "llama-lora-tuner"
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
63 |
|
64 |
|
65 |
def get_package_dir():
|
@@ -85,12 +80,6 @@ def get_git_commit_hash():
|
|
85 |
print(f"Cannot get git commit hash: {e}")
|
86 |
|
87 |
|
88 |
-
commit_hash = get_git_commit_hash()
|
89 |
-
|
90 |
-
if commit_hash:
|
91 |
-
Global.version = commit_hash[:8]
|
92 |
-
|
93 |
-
|
94 |
def load_gpu_info():
|
95 |
print("")
|
96 |
try:
|
@@ -154,5 +143,3 @@ def load_gpu_info():
|
|
154 |
print(f"Notice: cannot get GPU info: {e}")
|
155 |
|
156 |
print("")
|
157 |
-
|
158 |
-
load_gpu_info()
|
|
|
8 |
from numba import cuda
|
9 |
import nvidia_smi
|
10 |
|
11 |
+
from .config import Config
|
12 |
from .utils.lru_cache import LRUCache
|
13 |
from .utils.model_lru_cache import ModelLRUCache
|
14 |
from .lib.finetune import train
|
15 |
|
16 |
|
17 |
class Global:
|
18 |
+
"""
|
19 |
+
A singleton class holding global states.
|
20 |
+
"""
|
21 |
|
22 |
+
version: Union[str, None] = None
|
|
|
23 |
|
|
|
24 |
base_model_name: str = ""
|
25 |
tokenizer_name = None
|
|
|
|
|
|
|
26 |
|
27 |
# Functions
|
28 |
train_fn: Any = train
|
|
|
46 |
gpu_total_cores = None # GPU total cores
|
47 |
gpu_total_memory = None
|
48 |
|
|
|
|
|
|
|
|
|
49 |
|
50 |
+
def initialize_global():
|
51 |
+
Global.base_model_name = Config.default_base_model_name
|
52 |
+
commit_hash = get_git_commit_hash()
|
53 |
+
|
54 |
+
if commit_hash:
|
55 |
+
Global.version = commit_hash[:8]
|
56 |
+
|
57 |
+
load_gpu_info()
|
58 |
|
59 |
|
60 |
def get_package_dir():
|
|
|
80 |
print(f"Cannot get git commit hash: {e}")
|
81 |
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
def load_gpu_info():
|
84 |
print("")
|
85 |
try:
|
|
|
143 |
print(f"Notice: cannot get GPU info: {e}")
|
144 |
|
145 |
print("")
|
|
|
|
llama_lora/models.py
CHANGED
@@ -11,12 +11,13 @@ from transformers import (
|
|
11 |
)
|
12 |
from peft import PeftModel
|
13 |
|
|
|
14 |
from .globals import Global
|
15 |
from .lib.get_device import get_device
|
16 |
|
17 |
|
18 |
def get_new_base_model(base_model_name):
|
19 |
-
if
|
20 |
return
|
21 |
|
22 |
if Global.new_base_model_that_is_ready_to_be_used:
|
@@ -79,14 +80,14 @@ def _get_model_from_pretrained(model_class, model_name, from_tf=False, force_dow
|
|
79 |
if device == "cuda":
|
80 |
return model_class.from_pretrained(
|
81 |
model_name,
|
82 |
-
load_in_8bit=
|
83 |
torch_dtype=torch.float16,
|
84 |
# device_map="auto",
|
85 |
# ? https://github.com/tloen/alpaca-lora/issues/21
|
86 |
device_map={'': 0},
|
87 |
from_tf=from_tf,
|
88 |
force_download=force_download,
|
89 |
-
trust_remote_code=
|
90 |
)
|
91 |
elif device == "mps":
|
92 |
return model_class.from_pretrained(
|
@@ -95,7 +96,7 @@ def _get_model_from_pretrained(model_class, model_name, from_tf=False, force_dow
|
|
95 |
torch_dtype=torch.float16,
|
96 |
from_tf=from_tf,
|
97 |
force_download=force_download,
|
98 |
-
trust_remote_code=
|
99 |
)
|
100 |
else:
|
101 |
return model_class.from_pretrained(
|
@@ -104,12 +105,12 @@ def _get_model_from_pretrained(model_class, model_name, from_tf=False, force_dow
|
|
104 |
low_cpu_mem_usage=True,
|
105 |
from_tf=from_tf,
|
106 |
force_download=force_download,
|
107 |
-
trust_remote_code=
|
108 |
)
|
109 |
|
110 |
|
111 |
def get_tokenizer(base_model_name):
|
112 |
-
if
|
113 |
return
|
114 |
|
115 |
loaded_tokenizer = Global.loaded_tokenizers.get(base_model_name)
|
@@ -119,13 +120,13 @@ def get_tokenizer(base_model_name):
|
|
119 |
try:
|
120 |
tokenizer = AutoTokenizer.from_pretrained(
|
121 |
base_model_name,
|
122 |
-
trust_remote_code=
|
123 |
)
|
124 |
except Exception as e:
|
125 |
if 'LLaMATokenizer' in str(e):
|
126 |
tokenizer = LlamaTokenizer.from_pretrained(
|
127 |
base_model_name,
|
128 |
-
trust_remote_code=
|
129 |
)
|
130 |
else:
|
131 |
raise e
|
@@ -138,7 +139,7 @@ def get_tokenizer(base_model_name):
|
|
138 |
def get_model(
|
139 |
base_model_name,
|
140 |
peft_model_name=None):
|
141 |
-
if
|
142 |
return
|
143 |
|
144 |
if peft_model_name == "None":
|
@@ -156,7 +157,7 @@ def get_model(
|
|
156 |
|
157 |
if peft_model_name:
|
158 |
lora_models_directory_path = os.path.join(
|
159 |
-
|
160 |
possible_lora_model_path = os.path.join(
|
161 |
lora_models_directory_path, peft_model_name)
|
162 |
if os.path.isdir(possible_lora_model_path):
|
@@ -211,7 +212,7 @@ def get_model(
|
|
211 |
model.config.bos_token_id = 1
|
212 |
model.config.eos_token_id = 2
|
213 |
|
214 |
-
if not
|
215 |
model.half() # seems to fix bugs for some users.
|
216 |
|
217 |
model.eval()
|
@@ -224,7 +225,7 @@ def get_model(
|
|
224 |
return model
|
225 |
|
226 |
|
227 |
-
def prepare_base_model(base_model_name=
|
228 |
Global.new_base_model_that_is_ready_to_be_used = get_new_base_model(
|
229 |
base_model_name)
|
230 |
Global.name_of_new_base_model_that_is_ready_to_be_used = base_model_name
|
|
|
11 |
)
|
12 |
from peft import PeftModel
|
13 |
|
14 |
+
from .config import Config
|
15 |
from .globals import Global
|
16 |
from .lib.get_device import get_device
|
17 |
|
18 |
|
19 |
def get_new_base_model(base_model_name):
|
20 |
+
if Config.ui_dev_mode:
|
21 |
return
|
22 |
|
23 |
if Global.new_base_model_that_is_ready_to_be_used:
|
|
|
80 |
if device == "cuda":
|
81 |
return model_class.from_pretrained(
|
82 |
model_name,
|
83 |
+
load_in_8bit=Config.load_8bit,
|
84 |
torch_dtype=torch.float16,
|
85 |
# device_map="auto",
|
86 |
# ? https://github.com/tloen/alpaca-lora/issues/21
|
87 |
device_map={'': 0},
|
88 |
from_tf=from_tf,
|
89 |
force_download=force_download,
|
90 |
+
trust_remote_code=Config.trust_remote_code
|
91 |
)
|
92 |
elif device == "mps":
|
93 |
return model_class.from_pretrained(
|
|
|
96 |
torch_dtype=torch.float16,
|
97 |
from_tf=from_tf,
|
98 |
force_download=force_download,
|
99 |
+
trust_remote_code=Config.trust_remote_code
|
100 |
)
|
101 |
else:
|
102 |
return model_class.from_pretrained(
|
|
|
105 |
low_cpu_mem_usage=True,
|
106 |
from_tf=from_tf,
|
107 |
force_download=force_download,
|
108 |
+
trust_remote_code=Config.trust_remote_code
|
109 |
)
|
110 |
|
111 |
|
112 |
def get_tokenizer(base_model_name):
|
113 |
+
if Config.ui_dev_mode:
|
114 |
return
|
115 |
|
116 |
loaded_tokenizer = Global.loaded_tokenizers.get(base_model_name)
|
|
|
120 |
try:
|
121 |
tokenizer = AutoTokenizer.from_pretrained(
|
122 |
base_model_name,
|
123 |
+
trust_remote_code=Config.trust_remote_code
|
124 |
)
|
125 |
except Exception as e:
|
126 |
if 'LLaMATokenizer' in str(e):
|
127 |
tokenizer = LlamaTokenizer.from_pretrained(
|
128 |
base_model_name,
|
129 |
+
trust_remote_code=Config.trust_remote_code
|
130 |
)
|
131 |
else:
|
132 |
raise e
|
|
|
139 |
def get_model(
|
140 |
base_model_name,
|
141 |
peft_model_name=None):
|
142 |
+
if Config.ui_dev_mode:
|
143 |
return
|
144 |
|
145 |
if peft_model_name == "None":
|
|
|
157 |
|
158 |
if peft_model_name:
|
159 |
lora_models_directory_path = os.path.join(
|
160 |
+
Config.data_dir, "lora_models")
|
161 |
possible_lora_model_path = os.path.join(
|
162 |
lora_models_directory_path, peft_model_name)
|
163 |
if os.path.isdir(possible_lora_model_path):
|
|
|
212 |
model.config.bos_token_id = 1
|
213 |
model.config.eos_token_id = 2
|
214 |
|
215 |
+
if not Config.load_8bit:
|
216 |
model.half() # seems to fix bugs for some users.
|
217 |
|
218 |
model.eval()
|
|
|
225 |
return model
|
226 |
|
227 |
|
228 |
+
def prepare_base_model(base_model_name=Config.default_base_model_name):
|
229 |
Global.new_base_model_that_is_ready_to_be_used = get_new_base_model(
|
230 |
base_model_name)
|
231 |
Global.name_of_new_base_model_that_is_ready_to_be_used = base_model_name
|
llama_lora/ui/finetune_ui.py
CHANGED
@@ -11,6 +11,7 @@ from random_word import RandomWords
|
|
11 |
from transformers import TrainerCallback
|
12 |
from huggingface_hub import try_to_load_from_cache, snapshot_download
|
13 |
|
|
|
14 |
from ..globals import Global
|
15 |
from ..models import (
|
16 |
get_new_base_model, get_tokenizer,
|
@@ -240,9 +241,9 @@ def refresh_dataset_items_count(
|
|
240 |
|
241 |
trace = traceback.format_exc()
|
242 |
traces = [s.strip() for s in re.split("\n * File ", trace)]
|
243 |
-
templates_path = os.path.join(
|
244 |
traces_to_show = [s for s in traces if os.path.join(
|
245 |
-
|
246 |
traces_to_show = [re.sub(" *\n *", ": ", s) for s in traces_to_show]
|
247 |
if len(traces_to_show) > 0:
|
248 |
update_message = gr.Markdown.update(
|
@@ -323,7 +324,7 @@ def do_train(
|
|
323 |
continue_from_checkpoint = None
|
324 |
if continue_from_model:
|
325 |
resume_from_model_path = os.path.join(
|
326 |
-
|
327 |
resume_from_checkpoint_param = resume_from_model_path
|
328 |
if continue_from_checkpoint:
|
329 |
resume_from_checkpoint_param = os.path.join(
|
@@ -360,7 +361,7 @@ def do_train(
|
|
360 |
raise ValueError(
|
361 |
f"Unable to continue from model {continue_from_model}. Continuation is only possible from models stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
|
362 |
|
363 |
-
output_dir = os.path.join(
|
364 |
if os.path.exists(output_dir):
|
365 |
if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
|
366 |
raise ValueError(
|
@@ -399,7 +400,7 @@ def do_train(
|
|
399 |
progress_detail += f", Loss: {last_loss:.4f}"
|
400 |
return f"Training... ({progress_detail})"
|
401 |
|
402 |
-
if
|
403 |
Global.should_stop_training = False
|
404 |
|
405 |
message = f"""Currently in UI dev mode, not doing the actual training.
|
@@ -575,8 +576,8 @@ Train data (first 10):
|
|
575 |
additional_training_arguments=additional_training_arguments,
|
576 |
additional_lora_config=additional_lora_config,
|
577 |
callbacks=training_callbacks,
|
578 |
-
wandb_api_key=
|
579 |
-
wandb_project=
|
580 |
wandb_group=wandb_group,
|
581 |
wandb_run_name=model_name,
|
582 |
wandb_tags=wandb_tags
|
@@ -605,7 +606,7 @@ def do_abort_training():
|
|
605 |
def handle_continue_from_model_change(model_name):
|
606 |
try:
|
607 |
lora_models_directory_path = os.path.join(
|
608 |
-
|
609 |
lora_model_directory_path = os.path.join(
|
610 |
lora_models_directory_path, model_name)
|
611 |
all_files = os.listdir(lora_model_directory_path)
|
@@ -651,7 +652,7 @@ def handle_load_params_from_model(
|
|
651 |
unknown_keys = []
|
652 |
try:
|
653 |
lora_models_directory_path = os.path.join(
|
654 |
-
|
655 |
lora_model_directory_path = os.path.join(
|
656 |
lora_models_directory_path, model_name)
|
657 |
|
|
|
11 |
from transformers import TrainerCallback
|
12 |
from huggingface_hub import try_to_load_from_cache, snapshot_download
|
13 |
|
14 |
+
from ..config import Config
|
15 |
from ..globals import Global
|
16 |
from ..models import (
|
17 |
get_new_base_model, get_tokenizer,
|
|
|
241 |
|
242 |
trace = traceback.format_exc()
|
243 |
traces = [s.strip() for s in re.split("\n * File ", trace)]
|
244 |
+
templates_path = os.path.join(Config.data_dir, "templates")
|
245 |
traces_to_show = [s for s in traces if os.path.join(
|
246 |
+
Config.data_dir, "templates") in s]
|
247 |
traces_to_show = [re.sub(" *\n *", ": ", s) for s in traces_to_show]
|
248 |
if len(traces_to_show) > 0:
|
249 |
update_message = gr.Markdown.update(
|
|
|
324 |
continue_from_checkpoint = None
|
325 |
if continue_from_model:
|
326 |
resume_from_model_path = os.path.join(
|
327 |
+
Config.data_dir, "lora_models", continue_from_model)
|
328 |
resume_from_checkpoint_param = resume_from_model_path
|
329 |
if continue_from_checkpoint:
|
330 |
resume_from_checkpoint_param = os.path.join(
|
|
|
361 |
raise ValueError(
|
362 |
f"Unable to continue from model {continue_from_model}. Continuation is only possible from models stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
|
363 |
|
364 |
+
output_dir = os.path.join(Config.data_dir, "lora_models", model_name)
|
365 |
if os.path.exists(output_dir):
|
366 |
if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
|
367 |
raise ValueError(
|
|
|
400 |
progress_detail += f", Loss: {last_loss:.4f}"
|
401 |
return f"Training... ({progress_detail})"
|
402 |
|
403 |
+
if Config.ui_dev_mode:
|
404 |
Global.should_stop_training = False
|
405 |
|
406 |
message = f"""Currently in UI dev mode, not doing the actual training.
|
|
|
576 |
additional_training_arguments=additional_training_arguments,
|
577 |
additional_lora_config=additional_lora_config,
|
578 |
callbacks=training_callbacks,
|
579 |
+
wandb_api_key=Config.wandb_api_key,
|
580 |
+
wandb_project=Config.default_wandb_project if Config.enable_wandb else None,
|
581 |
wandb_group=wandb_group,
|
582 |
wandb_run_name=model_name,
|
583 |
wandb_tags=wandb_tags
|
|
|
606 |
def handle_continue_from_model_change(model_name):
|
607 |
try:
|
608 |
lora_models_directory_path = os.path.join(
|
609 |
+
Config.data_dir, "lora_models")
|
610 |
lora_model_directory_path = os.path.join(
|
611 |
lora_models_directory_path, model_name)
|
612 |
all_files = os.listdir(lora_model_directory_path)
|
|
|
652 |
unknown_keys = []
|
653 |
try:
|
654 |
lora_models_directory_path = os.path.join(
|
655 |
+
Config.data_dir, "lora_models")
|
656 |
lora_model_directory_path = os.path.join(
|
657 |
lora_models_directory_path, model_name)
|
658 |
|
llama_lora/ui/inference_ui.py
CHANGED
@@ -7,6 +7,7 @@ import torch
|
|
7 |
import transformers
|
8 |
from transformers import GenerationConfig
|
9 |
|
|
|
10 |
from ..globals import Global
|
11 |
from ..models import get_model, get_tokenizer, get_device
|
12 |
from ..lib.inference import generate
|
@@ -101,7 +102,7 @@ def do_inference(
|
|
101 |
'generation_config': generation_config.to_dict(),
|
102 |
})
|
103 |
|
104 |
-
if
|
105 |
message = f"Hi, I’m currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {base_model_name}\nLoRA model: {lora_model_name}\n\nThe following is your prompt:\n\n{prompt}"
|
106 |
print(message)
|
107 |
|
@@ -318,7 +319,7 @@ def update_prompt_preview(prompt_template,
|
|
318 |
|
319 |
|
320 |
def inference_ui():
|
321 |
-
flagging_dir = os.path.join(
|
322 |
if not os.path.exists(flagging_dir):
|
323 |
os.makedirs(flagging_dir)
|
324 |
|
|
|
7 |
import transformers
|
8 |
from transformers import GenerationConfig
|
9 |
|
10 |
+
from ..config import Config
|
11 |
from ..globals import Global
|
12 |
from ..models import get_model, get_tokenizer, get_device
|
13 |
from ..lib.inference import generate
|
|
|
102 |
'generation_config': generation_config.to_dict(),
|
103 |
})
|
104 |
|
105 |
+
if Config.ui_dev_mode:
|
106 |
message = f"Hi, I’m currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {base_model_name}\nLoRA model: {lora_model_name}\n\nThe following is your prompt:\n\n{prompt}"
|
107 |
print(message)
|
108 |
|
|
|
319 |
|
320 |
|
321 |
def inference_ui():
|
322 |
+
flagging_dir = os.path.join(Config.data_dir, "flagging", "inference")
|
323 |
if not os.path.exists(flagging_dir):
|
324 |
os.makedirs(flagging_dir)
|
325 |
|
llama_lora/ui/main_page.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
|
|
|
3 |
from ..globals import Global
|
4 |
|
5 |
from .inference_ui import inference_ui
|
@@ -21,7 +22,7 @@ def main_page():
|
|
21 |
gr.Markdown(
|
22 |
f"""
|
23 |
<h1 class="app_title_text">{title}</h1> <wbr />
|
24 |
-
<h2 class="app_subtitle_text">{
|
25 |
""",
|
26 |
elem_id="page_title",
|
27 |
)
|
@@ -29,7 +30,7 @@ def main_page():
|
|
29 |
global_base_model_select = gr.Dropdown(
|
30 |
label="Base Model",
|
31 |
elem_id="global_base_model_select",
|
32 |
-
choices=
|
33 |
value=lambda: Global.base_model_name,
|
34 |
allow_custom_value=True,
|
35 |
)
|
@@ -146,11 +147,11 @@ def main_page():
|
|
146 |
|
147 |
|
148 |
def get_page_title():
|
149 |
-
title =
|
150 |
-
if (
|
151 |
-
title =
|
152 |
-
if (
|
153 |
-
title = f"{
|
154 |
return title
|
155 |
|
156 |
|
@@ -953,8 +954,8 @@ def get_foot_info():
|
|
953 |
info.append(f"Base model: `{Global.base_model_name}`")
|
954 |
if Global.tokenizer_name and Global.tokenizer_name != Global.base_model_name:
|
955 |
info.append(f"Tokenizer: `{Global.tokenizer_name}`")
|
956 |
-
if
|
957 |
-
info.append(f"Data dir: `{
|
958 |
return f"""\
|
959 |
<small>{" · ".join(info)}</small>
|
960 |
"""
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
from ..config import Config
|
4 |
from ..globals import Global
|
5 |
|
6 |
from .inference_ui import inference_ui
|
|
|
22 |
gr.Markdown(
|
23 |
f"""
|
24 |
<h1 class="app_title_text">{title}</h1> <wbr />
|
25 |
+
<h2 class="app_subtitle_text">{Config.ui_subtitle}</h2>
|
26 |
""",
|
27 |
elem_id="page_title",
|
28 |
)
|
|
|
30 |
global_base_model_select = gr.Dropdown(
|
31 |
label="Base Model",
|
32 |
elem_id="global_base_model_select",
|
33 |
+
choices=Config.base_model_choices,
|
34 |
value=lambda: Global.base_model_name,
|
35 |
allow_custom_value=True,
|
36 |
)
|
|
|
147 |
|
148 |
|
149 |
def get_page_title():
|
150 |
+
title = Config.ui_title
|
151 |
+
if (Config.ui_dev_mode):
|
152 |
+
title = Config.ui_dev_mode_title_prefix + title
|
153 |
+
if (Config.ui_emoji):
|
154 |
+
title = f"{Config.ui_emoji} {title}"
|
155 |
return title
|
156 |
|
157 |
|
|
|
954 |
info.append(f"Base model: `{Global.base_model_name}`")
|
955 |
if Global.tokenizer_name and Global.tokenizer_name != Global.base_model_name:
|
956 |
info.append(f"Tokenizer: `{Global.tokenizer_name}`")
|
957 |
+
if Config.ui_show_sys_info:
|
958 |
+
info.append(f"Data dir: `{Config.data_dir}`")
|
959 |
return f"""\
|
960 |
<small>{" · ".join(info)}</small>
|
961 |
"""
|
llama_lora/ui/tokenizer_ui.py
CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
|
|
2 |
import time
|
3 |
import json
|
4 |
|
|
|
5 |
from ..globals import Global
|
6 |
from ..models import get_tokenizer
|
7 |
|
@@ -12,7 +13,7 @@ def handle_decode(encoded_tokens_json):
|
|
12 |
|
13 |
try:
|
14 |
encoded_tokens = json.loads(encoded_tokens_json)
|
15 |
-
if
|
16 |
return f"Not actually decoding tokens in UI dev mode.", gr.Markdown.update("", visible=False)
|
17 |
tokenizer = get_tokenizer(tokenizer_name)
|
18 |
decoded_tokens = tokenizer.decode(encoded_tokens)
|
@@ -26,7 +27,7 @@ def handle_encode(decoded_tokens):
|
|
26 |
tokenizer_name = Global.tokenizer_name or Global.base_model_name
|
27 |
|
28 |
try:
|
29 |
-
if
|
30 |
return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
|
31 |
tokenizer = get_tokenizer(tokenizer_name)
|
32 |
result = tokenizer(decoded_tokens)
|
|
|
2 |
import time
|
3 |
import json
|
4 |
|
5 |
+
from ..config import Config
|
6 |
from ..globals import Global
|
7 |
from ..models import get_tokenizer
|
8 |
|
|
|
13 |
|
14 |
try:
|
15 |
encoded_tokens = json.loads(encoded_tokens_json)
|
16 |
+
if Config.ui_dev_mode:
|
17 |
return f"Not actually decoding tokens in UI dev mode.", gr.Markdown.update("", visible=False)
|
18 |
tokenizer = get_tokenizer(tokenizer_name)
|
19 |
decoded_tokens = tokenizer.decode(encoded_tokens)
|
|
|
27 |
tokenizer_name = Global.tokenizer_name or Global.base_model_name
|
28 |
|
29 |
try:
|
30 |
+
if Config.ui_dev_mode:
|
31 |
return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
|
32 |
tokenizer = get_tokenizer(tokenizer_name)
|
33 |
result = tokenizer(decoded_tokens)
|
llama_lora/utils/data.py
CHANGED
@@ -3,20 +3,22 @@ import shutil
|
|
3 |
import fnmatch
|
4 |
import json
|
5 |
|
|
|
6 |
from ..globals import Global
|
7 |
|
8 |
|
9 |
def init_data_dir():
|
|
|
10 |
current_file_path = os.path.abspath(__file__)
|
11 |
parent_directory_path = os.path.dirname(current_file_path)
|
12 |
project_dir_path = os.path.abspath(
|
13 |
os.path.join(parent_directory_path, "..", ".."))
|
14 |
copy_sample_data_if_not_exists(os.path.join(project_dir_path, "templates"),
|
15 |
-
os.path.join(
|
16 |
copy_sample_data_if_not_exists(os.path.join(project_dir_path, "datasets"),
|
17 |
-
os.path.join(
|
18 |
copy_sample_data_if_not_exists(os.path.join(project_dir_path, "lora_models"),
|
19 |
-
os.path.join(
|
20 |
|
21 |
|
22 |
def copy_sample_data_if_not_exists(source, destination):
|
@@ -28,28 +30,28 @@ def copy_sample_data_if_not_exists(source, destination):
|
|
28 |
|
29 |
|
30 |
def get_available_template_names():
|
31 |
-
templates_directory_path = os.path.join(
|
32 |
all_files = os.listdir(templates_directory_path)
|
33 |
names = [filename.rstrip(".json") for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.py")]
|
34 |
return sorted(names)
|
35 |
|
36 |
|
37 |
def get_available_dataset_names():
|
38 |
-
datasets_directory_path = os.path.join(
|
39 |
all_files = os.listdir(datasets_directory_path)
|
40 |
names = [filename for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.jsonl")]
|
41 |
return sorted(names)
|
42 |
|
43 |
|
44 |
def get_available_lora_model_names():
|
45 |
-
lora_models_directory_path = os.path.join(
|
46 |
all_items = os.listdir(lora_models_directory_path)
|
47 |
names = [item for item in all_items if os.path.isdir(os.path.join(lora_models_directory_path, item))]
|
48 |
return sorted(names)
|
49 |
|
50 |
|
51 |
def get_path_of_available_lora_model(name):
|
52 |
-
datasets_directory_path = os.path.join(
|
53 |
path = os.path.join(datasets_directory_path, name)
|
54 |
if os.path.isdir(path):
|
55 |
return path
|
@@ -73,7 +75,7 @@ def get_info_of_available_lora_model(name):
|
|
73 |
|
74 |
|
75 |
def get_dataset_content(name):
|
76 |
-
file_name = os.path.join(
|
77 |
if not os.path.exists(file_name):
|
78 |
raise ValueError(
|
79 |
f"Can't read {file_name} from datasets. File does not exist.")
|
|
|
3 |
import fnmatch
|
4 |
import json
|
5 |
|
6 |
+
from ..config import Config
|
7 |
from ..globals import Global
|
8 |
|
9 |
|
10 |
def init_data_dir():
|
11 |
+
os.makedirs(Config.data_dir, exist_ok=True)
|
12 |
current_file_path = os.path.abspath(__file__)
|
13 |
parent_directory_path = os.path.dirname(current_file_path)
|
14 |
project_dir_path = os.path.abspath(
|
15 |
os.path.join(parent_directory_path, "..", ".."))
|
16 |
copy_sample_data_if_not_exists(os.path.join(project_dir_path, "templates"),
|
17 |
+
os.path.join(Config.data_dir, "templates"))
|
18 |
copy_sample_data_if_not_exists(os.path.join(project_dir_path, "datasets"),
|
19 |
+
os.path.join(Config.data_dir, "datasets"))
|
20 |
copy_sample_data_if_not_exists(os.path.join(project_dir_path, "lora_models"),
|
21 |
+
os.path.join(Config.data_dir, "lora_models"))
|
22 |
|
23 |
|
24 |
def copy_sample_data_if_not_exists(source, destination):
|
|
|
30 |
|
31 |
|
32 |
def get_available_template_names():
|
33 |
+
templates_directory_path = os.path.join(Config.data_dir, "templates")
|
34 |
all_files = os.listdir(templates_directory_path)
|
35 |
names = [filename.rstrip(".json") for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.py")]
|
36 |
return sorted(names)
|
37 |
|
38 |
|
39 |
def get_available_dataset_names():
|
40 |
+
datasets_directory_path = os.path.join(Config.data_dir, "datasets")
|
41 |
all_files = os.listdir(datasets_directory_path)
|
42 |
names = [filename for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.jsonl")]
|
43 |
return sorted(names)
|
44 |
|
45 |
|
46 |
def get_available_lora_model_names():
|
47 |
+
lora_models_directory_path = os.path.join(Config.data_dir, "lora_models")
|
48 |
all_items = os.listdir(lora_models_directory_path)
|
49 |
names = [item for item in all_items if os.path.isdir(os.path.join(lora_models_directory_path, item))]
|
50 |
return sorted(names)
|
51 |
|
52 |
|
53 |
def get_path_of_available_lora_model(name):
|
54 |
+
datasets_directory_path = os.path.join(Config.data_dir, "lora_models")
|
55 |
path = os.path.join(datasets_directory_path, name)
|
56 |
if os.path.isdir(path):
|
57 |
return path
|
|
|
75 |
|
76 |
|
77 |
def get_dataset_content(name):
|
78 |
+
file_name = os.path.join(Config.data_dir, "datasets", name)
|
79 |
if not os.path.exists(file_name):
|
80 |
raise ValueError(
|
81 |
f"Can't read {file_name} from datasets. File does not exist.")
|
llama_lora/utils/prompter.py
CHANGED
@@ -9,6 +9,7 @@ import importlib
|
|
9 |
import itertools
|
10 |
from typing import Union, List
|
11 |
|
|
|
12 |
from ..globals import Global
|
13 |
|
14 |
|
@@ -31,7 +32,7 @@ class Prompter(object):
|
|
31 |
else:
|
32 |
filename = base_filename + ext
|
33 |
|
34 |
-
file_path = osp.join(
|
35 |
|
36 |
if not osp.exists(file_path):
|
37 |
raise ValueError(f"Can't read {file_path}")
|
|
|
9 |
import itertools
|
10 |
from typing import Union, List
|
11 |
|
12 |
+
from ..config import Config
|
13 |
from ..globals import Global
|
14 |
|
15 |
|
|
|
32 |
else:
|
33 |
filename = base_filename + ext
|
34 |
|
35 |
+
file_path = osp.join(Config.data_dir, "templates", filename)
|
36 |
|
37 |
if not osp.exists(file_path):
|
38 |
raise ValueError(f"Can't read {file_path}")
|