zetavg commited on
Commit
40a8f4e
1 Parent(s): d2eef14

extract configs from global

Browse files
LLaMA_LoRA.ipynb CHANGED
@@ -279,21 +279,23 @@
279
  {
280
  "cell_type": "code",
281
  "source": [
282
- "# @title Load the App (set config, prepare data dir, load base bodel)\n",
283
  "\n",
284
  "# @markdown For a LLaMA-7B model, it will take about ~5m to load for the first execution,\n",
285
  "# @markdown including download. Subsequent executions will take about 2m to load.\n",
286
  "\n",
287
  "# Set Configs\n",
288
- "from llama_lora.llama_lora.globals import Global\n",
289
- "Global.default_base_model_name = Global.base_model_name = base_model\n",
290
- "Global.base_model_choices = [base_model]\n",
 
291
  "data_dir_realpath = !realpath ./data\n",
292
- "Global.data_dir = data_dir_realpath[0]\n",
293
- "Global.load_8bit = True\n",
 
 
294
  "\n",
295
  "# Prepare Data Dir\n",
296
- "import os\n",
297
  "from llama_lora.llama_lora.utils.data import init_data_dir\n",
298
  "init_data_dir()\n",
299
  "\n",
 
279
  {
280
  "cell_type": "code",
281
  "source": [
282
+ "# @title Load the App (set config, prepare data dir, load base model)\n",
283
  "\n",
284
  "# @markdown For a LLaMA-7B model, it will take about ~5m to load for the first execution,\n",
285
  "# @markdown including download. Subsequent executions will take about 2m to load.\n",
286
  "\n",
287
  "# Set Configs\n",
288
+ "from llama_lora.llama_lora.config import Config, process_config\n",
289
+ "from llama_lora.llama_lora.globals import initialize_global\n",
290
+ "Config.default_base_model_name = base_model\n",
291
+ "Config.base_model_choices = [base_model]\n",
292
  "data_dir_realpath = !realpath ./data\n",
293
+ "Config.data_dir = data_dir_realpath[0]\n",
294
+ "Config.load_8bit = True\n",
295
+ "process_config()\n",
296
+ "initialize_global()\n",
297
  "\n",
298
  "# Prepare Data Dir\n",
 
299
  "from llama_lora.llama_lora.utils.data import init_data_dir\n",
300
  "init_data_dir()\n",
301
  "\n",
app.py CHANGED
@@ -1,30 +1,30 @@
1
- import os
2
- import sys
3
 
4
  import fire
5
  import gradio as gr
6
 
7
- from llama_lora.globals import Global
 
8
  from llama_lora.models import prepare_base_model
9
- from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
10
  from llama_lora.utils.data import init_data_dir
11
-
 
 
12
 
13
 
14
  def main(
15
- base_model: str = "",
16
- data_dir: str = "",
17
- base_model_choices: str = "",
18
- trust_remote_code: bool = False,
19
- # Allows to listen on all interfaces by providing '0.0.0.0'.
20
  server_name: str = "127.0.0.1",
21
  share: bool = False,
22
  skip_loading_base_model: bool = False,
23
- load_8bit: bool = False,
24
- ui_show_sys_info: bool = True,
25
- ui_dev_mode: bool = False,
26
- wandb_api_key: str = "",
27
- wandb_project: str = "",
28
  ):
29
  '''
30
  Start the LLaMA-LoRA Tuner UI.
@@ -41,51 +41,54 @@ def main(
41
  :param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
42
  '''
43
 
44
- base_model = base_model or os.environ.get("LLAMA_LORA_BASE_MODEL", "")
45
- data_dir = data_dir or os.environ.get("LLAMA_LORA_DATA_DIR", "")
46
- assert (
47
- base_model
48
- ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
49
 
50
- assert (
51
- data_dir
52
- ), "Please specify a --data_dir, e.g. --data_dir='./data'"
 
 
53
 
54
- Global.default_base_model_name = Global.base_model_name = base_model
 
55
 
56
- if base_model_choices:
57
- base_model_choices = base_model_choices.split(',')
58
- base_model_choices = [name.strip() for name in base_model_choices]
59
- Global.base_model_choices = base_model_choices
60
 
61
- if base_model not in Global.base_model_choices:
62
- Global.base_model_choices = [base_model] + Global.base_model_choices
63
 
64
- Global.trust_remote_code = trust_remote_code
 
65
 
66
- Global.data_dir = os.path.abspath(data_dir)
67
- Global.load_8bit = load_8bit
68
 
69
- if len(wandb_api_key) > 0:
70
- Global.enable_wandb = True
71
- Global.wandb_api_key = wandb_api_key
72
- if len(wandb_project) > 0:
73
- Global.enable_wandb = True
74
- Global.wandb_project = wandb_project
75
 
76
- Global.ui_dev_mode = ui_dev_mode
77
- Global.ui_show_sys_info = ui_show_sys_info
 
 
 
 
 
 
 
 
78
 
79
- os.makedirs(data_dir, exist_ok=True)
80
  init_data_dir()
81
 
82
- if (not skip_loading_base_model) and (not ui_dev_mode):
83
- prepare_base_model(base_model)
84
 
85
  with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
86
  main_page()
87
 
88
- demo.queue(concurrency_count=1).launch(server_name=server_name, share=share)
 
89
 
90
 
91
  if __name__ == "__main__":
 
1
+ from typing import Union
 
2
 
3
  import fire
4
  import gradio as gr
5
 
6
+ from llama_lora.config import Config, process_config
7
+ from llama_lora.globals import initialize_global
8
  from llama_lora.models import prepare_base_model
 
9
  from llama_lora.utils.data import init_data_dir
10
+ from llama_lora.ui.main_page import (
11
+ main_page, get_page_title, main_page_custom_css
12
+ )
13
 
14
 
15
  def main(
16
+ base_model: Union[str, None] = None,
17
+ data_dir: Union[str, None] = None,
18
+ base_model_choices: Union[str, None] = None,
19
+ trust_remote_code: Union[bool, None] = None,
 
20
  server_name: str = "127.0.0.1",
21
  share: bool = False,
22
  skip_loading_base_model: bool = False,
23
+ load_8bit: Union[bool, None] = None,
24
+ ui_show_sys_info: Union[bool, None] = None,
25
+ ui_dev_mode: Union[bool, None] = None,
26
+ wandb_api_key: Union[str, None] = None,
27
+ wandb_project: Union[str, None] = None,
28
  ):
29
  '''
30
  Start the LLaMA-LoRA Tuner UI.
 
41
  :param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
42
  '''
43
 
44
+ if base_model is not None:
45
+ Config.default_base_model_name = base_model
 
 
 
46
 
47
+ if base_model_choices is not None:
48
+ Config.base_model_choices = base_model_choices
49
+
50
+ if trust_remote_code is not None:
51
+ Config.trust_remote_code = trust_remote_code
52
 
53
+ if data_dir is not None:
54
+ Config.data_dir = data_dir
55
 
56
+ if load_8bit is not None:
57
+ Config.load_8bit = load_8bit
 
 
58
 
59
+ if wandb_api_key is not None:
60
+ Config.wandb_api_key = wandb_api_key
61
 
62
+ if wandb_project is not None:
63
+ Config.default_wandb_project = wandb_project
64
 
65
+ if ui_dev_mode is not None:
66
+ Config.ui_dev_mode = ui_dev_mode
67
 
68
+ if ui_show_sys_info is not None:
69
+ Config.ui_show_sys_info = ui_show_sys_info
 
 
 
 
70
 
71
+ process_config()
72
+ initialize_global()
73
+
74
+ assert (
75
+ Config.default_base_model_name
76
+ ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
77
+
78
+ assert (
79
+ Config.data_dir
80
+ ), "Please specify a --data_dir, e.g. --data_dir='./data'"
81
 
 
82
  init_data_dir()
83
 
84
+ if (not skip_loading_base_model) and (not Config.ui_dev_mode):
85
+ prepare_base_model(Config.default_base_model_name)
86
 
87
  with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
88
  main_page()
89
 
90
+ demo.queue(concurrency_count=1).launch(
91
+ server_name=server_name, share=share)
92
 
93
 
94
  if __name__ == "__main__":
llama_lora/config.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Union
3
+
4
+
5
+ class Config:
6
+ """
7
+ Stores the application configuration. This is a singleton class.
8
+ """
9
+
10
+ data_dir: str = ""
11
+ load_8bit: bool = False
12
+
13
+ default_base_model_name: str = ""
14
+ base_model_choices: Union[List[str], str] = []
15
+
16
+ trust_remote_code: bool = False
17
+
18
+ # WandB
19
+ enable_wandb: Union[bool, None] = False
20
+ wandb_api_key: Union[str, None] = None
21
+ default_wandb_project: str = "llama-lora-tuner"
22
+
23
+ # UI related
24
+ ui_title: str = "LLaMA-LoRA Tuner"
25
+ ui_emoji: str = "🦙🎛️"
26
+ ui_subtitle: str = "Toolkit for evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA)."
27
+ ui_show_sys_info: bool = True
28
+ ui_dev_mode: bool = False
29
+ ui_dev_mode_title_prefix: str = "[UI DEV MODE] "
30
+
31
+
32
+ def process_config():
33
+ Config.data_dir = os.path.abspath(Config.data_dir)
34
+
35
+ if isinstance(Config.base_model_choices, str):
36
+ base_model_choices = Config.base_model_choices.split(',')
37
+ base_model_choices = [name.strip() for name in base_model_choices]
38
+ Config.base_model_choices = base_model_choices
39
+
40
+ if Config.default_base_model_name not in Config.base_model_choices:
41
+ Config.base_model_choices = [Config.default_base_model_name] + Config.base_model_choices
42
+
43
+ if Config.enable_wandb is None:
44
+ if Config.wandb_api_key and len(Config.wandb_api_key) > 0:
45
+ Config.enable_wandb = True
46
+ if Config.default_wandb_project and len(Config.default_wandb_project) > 0:
47
+ Config.enable_wandb = True
llama_lora/globals.py CHANGED
@@ -8,23 +8,21 @@ from typing import Any, Dict, List, Optional, Tuple, Union
8
  from numba import cuda
9
  import nvidia_smi
10
 
 
11
  from .utils.lru_cache import LRUCache
12
  from .utils.model_lru_cache import ModelLRUCache
13
  from .lib.finetune import train
14
 
15
 
16
  class Global:
17
- version = None
 
 
18
 
19
- data_dir: str = ""
20
- load_8bit: bool = False
21
 
22
- default_base_model_name: str = ""
23
  base_model_name: str = ""
24
  tokenizer_name = None
25
- base_model_choices: List[str] = []
26
-
27
- trust_remote_code = False
28
 
29
  # Functions
30
  train_fn: Any = train
@@ -48,18 +46,15 @@ class Global:
48
  gpu_total_cores = None # GPU total cores
49
  gpu_total_memory = None
50
 
51
- # WandB
52
- enable_wandb = False
53
- wandb_api_key = None
54
- default_wandb_project = "llama-lora-tuner"
55
 
56
- # UI related
57
- ui_title: str = "LLaMA-LoRA Tuner"
58
- ui_emoji: str = "🦙🎛️"
59
- ui_subtitle: str = "Toolkit for evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA)."
60
- ui_show_sys_info: bool = True
61
- ui_dev_mode: bool = False
62
- ui_dev_mode_title_prefix: str = "[UI DEV MODE] "
 
63
 
64
 
65
  def get_package_dir():
@@ -85,12 +80,6 @@ def get_git_commit_hash():
85
  print(f"Cannot get git commit hash: {e}")
86
 
87
 
88
- commit_hash = get_git_commit_hash()
89
-
90
- if commit_hash:
91
- Global.version = commit_hash[:8]
92
-
93
-
94
  def load_gpu_info():
95
  print("")
96
  try:
@@ -154,5 +143,3 @@ def load_gpu_info():
154
  print(f"Notice: cannot get GPU info: {e}")
155
 
156
  print("")
157
-
158
- load_gpu_info()
 
8
  from numba import cuda
9
  import nvidia_smi
10
 
11
+ from .config import Config
12
  from .utils.lru_cache import LRUCache
13
  from .utils.model_lru_cache import ModelLRUCache
14
  from .lib.finetune import train
15
 
16
 
17
  class Global:
18
+ """
19
+ A singleton class holding global states.
20
+ """
21
 
22
+ version: Union[str, None] = None
 
23
 
 
24
  base_model_name: str = ""
25
  tokenizer_name = None
 
 
 
26
 
27
  # Functions
28
  train_fn: Any = train
 
46
  gpu_total_cores = None # GPU total cores
47
  gpu_total_memory = None
48
 
 
 
 
 
49
 
50
+ def initialize_global():
51
+ Global.base_model_name = Config.default_base_model_name
52
+ commit_hash = get_git_commit_hash()
53
+
54
+ if commit_hash:
55
+ Global.version = commit_hash[:8]
56
+
57
+ load_gpu_info()
58
 
59
 
60
  def get_package_dir():
 
80
  print(f"Cannot get git commit hash: {e}")
81
 
82
 
 
 
 
 
 
 
83
  def load_gpu_info():
84
  print("")
85
  try:
 
143
  print(f"Notice: cannot get GPU info: {e}")
144
 
145
  print("")
 
 
llama_lora/models.py CHANGED
@@ -11,12 +11,13 @@ from transformers import (
11
  )
12
  from peft import PeftModel
13
 
 
14
  from .globals import Global
15
  from .lib.get_device import get_device
16
 
17
 
18
  def get_new_base_model(base_model_name):
19
- if Global.ui_dev_mode:
20
  return
21
 
22
  if Global.new_base_model_that_is_ready_to_be_used:
@@ -79,14 +80,14 @@ def _get_model_from_pretrained(model_class, model_name, from_tf=False, force_dow
79
  if device == "cuda":
80
  return model_class.from_pretrained(
81
  model_name,
82
- load_in_8bit=Global.load_8bit,
83
  torch_dtype=torch.float16,
84
  # device_map="auto",
85
  # ? https://github.com/tloen/alpaca-lora/issues/21
86
  device_map={'': 0},
87
  from_tf=from_tf,
88
  force_download=force_download,
89
- trust_remote_code=Global.trust_remote_code
90
  )
91
  elif device == "mps":
92
  return model_class.from_pretrained(
@@ -95,7 +96,7 @@ def _get_model_from_pretrained(model_class, model_name, from_tf=False, force_dow
95
  torch_dtype=torch.float16,
96
  from_tf=from_tf,
97
  force_download=force_download,
98
- trust_remote_code=Global.trust_remote_code
99
  )
100
  else:
101
  return model_class.from_pretrained(
@@ -104,12 +105,12 @@ def _get_model_from_pretrained(model_class, model_name, from_tf=False, force_dow
104
  low_cpu_mem_usage=True,
105
  from_tf=from_tf,
106
  force_download=force_download,
107
- trust_remote_code=Global.trust_remote_code
108
  )
109
 
110
 
111
  def get_tokenizer(base_model_name):
112
- if Global.ui_dev_mode:
113
  return
114
 
115
  loaded_tokenizer = Global.loaded_tokenizers.get(base_model_name)
@@ -119,13 +120,13 @@ def get_tokenizer(base_model_name):
119
  try:
120
  tokenizer = AutoTokenizer.from_pretrained(
121
  base_model_name,
122
- trust_remote_code=Global.trust_remote_code
123
  )
124
  except Exception as e:
125
  if 'LLaMATokenizer' in str(e):
126
  tokenizer = LlamaTokenizer.from_pretrained(
127
  base_model_name,
128
- trust_remote_code=Global.trust_remote_code
129
  )
130
  else:
131
  raise e
@@ -138,7 +139,7 @@ def get_tokenizer(base_model_name):
138
  def get_model(
139
  base_model_name,
140
  peft_model_name=None):
141
- if Global.ui_dev_mode:
142
  return
143
 
144
  if peft_model_name == "None":
@@ -156,7 +157,7 @@ def get_model(
156
 
157
  if peft_model_name:
158
  lora_models_directory_path = os.path.join(
159
- Global.data_dir, "lora_models")
160
  possible_lora_model_path = os.path.join(
161
  lora_models_directory_path, peft_model_name)
162
  if os.path.isdir(possible_lora_model_path):
@@ -211,7 +212,7 @@ def get_model(
211
  model.config.bos_token_id = 1
212
  model.config.eos_token_id = 2
213
 
214
- if not Global.load_8bit:
215
  model.half() # seems to fix bugs for some users.
216
 
217
  model.eval()
@@ -224,7 +225,7 @@ def get_model(
224
  return model
225
 
226
 
227
- def prepare_base_model(base_model_name=Global.default_base_model_name):
228
  Global.new_base_model_that_is_ready_to_be_used = get_new_base_model(
229
  base_model_name)
230
  Global.name_of_new_base_model_that_is_ready_to_be_used = base_model_name
 
11
  )
12
  from peft import PeftModel
13
 
14
+ from .config import Config
15
  from .globals import Global
16
  from .lib.get_device import get_device
17
 
18
 
19
  def get_new_base_model(base_model_name):
20
+ if Config.ui_dev_mode:
21
  return
22
 
23
  if Global.new_base_model_that_is_ready_to_be_used:
 
80
  if device == "cuda":
81
  return model_class.from_pretrained(
82
  model_name,
83
+ load_in_8bit=Config.load_8bit,
84
  torch_dtype=torch.float16,
85
  # device_map="auto",
86
  # ? https://github.com/tloen/alpaca-lora/issues/21
87
  device_map={'': 0},
88
  from_tf=from_tf,
89
  force_download=force_download,
90
+ trust_remote_code=Config.trust_remote_code
91
  )
92
  elif device == "mps":
93
  return model_class.from_pretrained(
 
96
  torch_dtype=torch.float16,
97
  from_tf=from_tf,
98
  force_download=force_download,
99
+ trust_remote_code=Config.trust_remote_code
100
  )
101
  else:
102
  return model_class.from_pretrained(
 
105
  low_cpu_mem_usage=True,
106
  from_tf=from_tf,
107
  force_download=force_download,
108
+ trust_remote_code=Config.trust_remote_code
109
  )
110
 
111
 
112
  def get_tokenizer(base_model_name):
113
+ if Config.ui_dev_mode:
114
  return
115
 
116
  loaded_tokenizer = Global.loaded_tokenizers.get(base_model_name)
 
120
  try:
121
  tokenizer = AutoTokenizer.from_pretrained(
122
  base_model_name,
123
+ trust_remote_code=Config.trust_remote_code
124
  )
125
  except Exception as e:
126
  if 'LLaMATokenizer' in str(e):
127
  tokenizer = LlamaTokenizer.from_pretrained(
128
  base_model_name,
129
+ trust_remote_code=Config.trust_remote_code
130
  )
131
  else:
132
  raise e
 
139
  def get_model(
140
  base_model_name,
141
  peft_model_name=None):
142
+ if Config.ui_dev_mode:
143
  return
144
 
145
  if peft_model_name == "None":
 
157
 
158
  if peft_model_name:
159
  lora_models_directory_path = os.path.join(
160
+ Config.data_dir, "lora_models")
161
  possible_lora_model_path = os.path.join(
162
  lora_models_directory_path, peft_model_name)
163
  if os.path.isdir(possible_lora_model_path):
 
212
  model.config.bos_token_id = 1
213
  model.config.eos_token_id = 2
214
 
215
+ if not Config.load_8bit:
216
  model.half() # seems to fix bugs for some users.
217
 
218
  model.eval()
 
225
  return model
226
 
227
 
228
+ def prepare_base_model(base_model_name=Config.default_base_model_name):
229
  Global.new_base_model_that_is_ready_to_be_used = get_new_base_model(
230
  base_model_name)
231
  Global.name_of_new_base_model_that_is_ready_to_be_used = base_model_name
llama_lora/ui/finetune_ui.py CHANGED
@@ -11,6 +11,7 @@ from random_word import RandomWords
11
  from transformers import TrainerCallback
12
  from huggingface_hub import try_to_load_from_cache, snapshot_download
13
 
 
14
  from ..globals import Global
15
  from ..models import (
16
  get_new_base_model, get_tokenizer,
@@ -240,9 +241,9 @@ def refresh_dataset_items_count(
240
 
241
  trace = traceback.format_exc()
242
  traces = [s.strip() for s in re.split("\n * File ", trace)]
243
- templates_path = os.path.join(Global.data_dir, "templates")
244
  traces_to_show = [s for s in traces if os.path.join(
245
- Global.data_dir, "templates") in s]
246
  traces_to_show = [re.sub(" *\n *", ": ", s) for s in traces_to_show]
247
  if len(traces_to_show) > 0:
248
  update_message = gr.Markdown.update(
@@ -323,7 +324,7 @@ def do_train(
323
  continue_from_checkpoint = None
324
  if continue_from_model:
325
  resume_from_model_path = os.path.join(
326
- Global.data_dir, "lora_models", continue_from_model)
327
  resume_from_checkpoint_param = resume_from_model_path
328
  if continue_from_checkpoint:
329
  resume_from_checkpoint_param = os.path.join(
@@ -360,7 +361,7 @@ def do_train(
360
  raise ValueError(
361
  f"Unable to continue from model {continue_from_model}. Continuation is only possible from models stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
362
 
363
- output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
364
  if os.path.exists(output_dir):
365
  if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
366
  raise ValueError(
@@ -399,7 +400,7 @@ def do_train(
399
  progress_detail += f", Loss: {last_loss:.4f}"
400
  return f"Training... ({progress_detail})"
401
 
402
- if Global.ui_dev_mode:
403
  Global.should_stop_training = False
404
 
405
  message = f"""Currently in UI dev mode, not doing the actual training.
@@ -575,8 +576,8 @@ Train data (first 10):
575
  additional_training_arguments=additional_training_arguments,
576
  additional_lora_config=additional_lora_config,
577
  callbacks=training_callbacks,
578
- wandb_api_key=Global.wandb_api_key,
579
- wandb_project=Global.default_wandb_project if Global.enable_wandb else None,
580
  wandb_group=wandb_group,
581
  wandb_run_name=model_name,
582
  wandb_tags=wandb_tags
@@ -605,7 +606,7 @@ def do_abort_training():
605
  def handle_continue_from_model_change(model_name):
606
  try:
607
  lora_models_directory_path = os.path.join(
608
- Global.data_dir, "lora_models")
609
  lora_model_directory_path = os.path.join(
610
  lora_models_directory_path, model_name)
611
  all_files = os.listdir(lora_model_directory_path)
@@ -651,7 +652,7 @@ def handle_load_params_from_model(
651
  unknown_keys = []
652
  try:
653
  lora_models_directory_path = os.path.join(
654
- Global.data_dir, "lora_models")
655
  lora_model_directory_path = os.path.join(
656
  lora_models_directory_path, model_name)
657
 
 
11
  from transformers import TrainerCallback
12
  from huggingface_hub import try_to_load_from_cache, snapshot_download
13
 
14
+ from ..config import Config
15
  from ..globals import Global
16
  from ..models import (
17
  get_new_base_model, get_tokenizer,
 
241
 
242
  trace = traceback.format_exc()
243
  traces = [s.strip() for s in re.split("\n * File ", trace)]
244
+ templates_path = os.path.join(Config.data_dir, "templates")
245
  traces_to_show = [s for s in traces if os.path.join(
246
+ Config.data_dir, "templates") in s]
247
  traces_to_show = [re.sub(" *\n *", ": ", s) for s in traces_to_show]
248
  if len(traces_to_show) > 0:
249
  update_message = gr.Markdown.update(
 
324
  continue_from_checkpoint = None
325
  if continue_from_model:
326
  resume_from_model_path = os.path.join(
327
+ Config.data_dir, "lora_models", continue_from_model)
328
  resume_from_checkpoint_param = resume_from_model_path
329
  if continue_from_checkpoint:
330
  resume_from_checkpoint_param = os.path.join(
 
361
  raise ValueError(
362
  f"Unable to continue from model {continue_from_model}. Continuation is only possible from models stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
363
 
364
+ output_dir = os.path.join(Config.data_dir, "lora_models", model_name)
365
  if os.path.exists(output_dir):
366
  if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
367
  raise ValueError(
 
400
  progress_detail += f", Loss: {last_loss:.4f}"
401
  return f"Training... ({progress_detail})"
402
 
403
+ if Config.ui_dev_mode:
404
  Global.should_stop_training = False
405
 
406
  message = f"""Currently in UI dev mode, not doing the actual training.
 
576
  additional_training_arguments=additional_training_arguments,
577
  additional_lora_config=additional_lora_config,
578
  callbacks=training_callbacks,
579
+ wandb_api_key=Config.wandb_api_key,
580
+ wandb_project=Config.default_wandb_project if Config.enable_wandb else None,
581
  wandb_group=wandb_group,
582
  wandb_run_name=model_name,
583
  wandb_tags=wandb_tags
 
606
  def handle_continue_from_model_change(model_name):
607
  try:
608
  lora_models_directory_path = os.path.join(
609
+ Config.data_dir, "lora_models")
610
  lora_model_directory_path = os.path.join(
611
  lora_models_directory_path, model_name)
612
  all_files = os.listdir(lora_model_directory_path)
 
652
  unknown_keys = []
653
  try:
654
  lora_models_directory_path = os.path.join(
655
+ Config.data_dir, "lora_models")
656
  lora_model_directory_path = os.path.join(
657
  lora_models_directory_path, model_name)
658
 
llama_lora/ui/inference_ui.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
  import transformers
8
  from transformers import GenerationConfig
9
 
 
10
  from ..globals import Global
11
  from ..models import get_model, get_tokenizer, get_device
12
  from ..lib.inference import generate
@@ -101,7 +102,7 @@ def do_inference(
101
  'generation_config': generation_config.to_dict(),
102
  })
103
 
104
- if Global.ui_dev_mode:
105
  message = f"Hi, I’m currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {base_model_name}\nLoRA model: {lora_model_name}\n\nThe following is your prompt:\n\n{prompt}"
106
  print(message)
107
 
@@ -318,7 +319,7 @@ def update_prompt_preview(prompt_template,
318
 
319
 
320
  def inference_ui():
321
- flagging_dir = os.path.join(Global.data_dir, "flagging", "inference")
322
  if not os.path.exists(flagging_dir):
323
  os.makedirs(flagging_dir)
324
 
 
7
  import transformers
8
  from transformers import GenerationConfig
9
 
10
+ from ..config import Config
11
  from ..globals import Global
12
  from ..models import get_model, get_tokenizer, get_device
13
  from ..lib.inference import generate
 
102
  'generation_config': generation_config.to_dict(),
103
  })
104
 
105
+ if Config.ui_dev_mode:
106
  message = f"Hi, I’m currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {base_model_name}\nLoRA model: {lora_model_name}\n\nThe following is your prompt:\n\n{prompt}"
107
  print(message)
108
 
 
319
 
320
 
321
  def inference_ui():
322
+ flagging_dir = os.path.join(Config.data_dir, "flagging", "inference")
323
  if not os.path.exists(flagging_dir):
324
  os.makedirs(flagging_dir)
325
 
llama_lora/ui/main_page.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
 
 
3
  from ..globals import Global
4
 
5
  from .inference_ui import inference_ui
@@ -21,7 +22,7 @@ def main_page():
21
  gr.Markdown(
22
  f"""
23
  <h1 class="app_title_text">{title}</h1> <wbr />
24
- <h2 class="app_subtitle_text">{Global.ui_subtitle}</h2>
25
  """,
26
  elem_id="page_title",
27
  )
@@ -29,7 +30,7 @@ def main_page():
29
  global_base_model_select = gr.Dropdown(
30
  label="Base Model",
31
  elem_id="global_base_model_select",
32
- choices=Global.base_model_choices,
33
  value=lambda: Global.base_model_name,
34
  allow_custom_value=True,
35
  )
@@ -146,11 +147,11 @@ def main_page():
146
 
147
 
148
  def get_page_title():
149
- title = Global.ui_title
150
- if (Global.ui_dev_mode):
151
- title = Global.ui_dev_mode_title_prefix + title
152
- if (Global.ui_emoji):
153
- title = f"{Global.ui_emoji} {title}"
154
  return title
155
 
156
 
@@ -953,8 +954,8 @@ def get_foot_info():
953
  info.append(f"Base model: `{Global.base_model_name}`")
954
  if Global.tokenizer_name and Global.tokenizer_name != Global.base_model_name:
955
  info.append(f"Tokenizer: `{Global.tokenizer_name}`")
956
- if Global.ui_show_sys_info:
957
- info.append(f"Data dir: `{Global.data_dir}`")
958
  return f"""\
959
  <small>{"&nbsp;&nbsp;·&nbsp;&nbsp;".join(info)}</small>
960
  """
 
1
  import gradio as gr
2
 
3
+ from ..config import Config
4
  from ..globals import Global
5
 
6
  from .inference_ui import inference_ui
 
22
  gr.Markdown(
23
  f"""
24
  <h1 class="app_title_text">{title}</h1> <wbr />
25
+ <h2 class="app_subtitle_text">{Config.ui_subtitle}</h2>
26
  """,
27
  elem_id="page_title",
28
  )
 
30
  global_base_model_select = gr.Dropdown(
31
  label="Base Model",
32
  elem_id="global_base_model_select",
33
+ choices=Config.base_model_choices,
34
  value=lambda: Global.base_model_name,
35
  allow_custom_value=True,
36
  )
 
147
 
148
 
149
  def get_page_title():
150
+ title = Config.ui_title
151
+ if (Config.ui_dev_mode):
152
+ title = Config.ui_dev_mode_title_prefix + title
153
+ if (Config.ui_emoji):
154
+ title = f"{Config.ui_emoji} {title}"
155
  return title
156
 
157
 
 
954
  info.append(f"Base model: `{Global.base_model_name}`")
955
  if Global.tokenizer_name and Global.tokenizer_name != Global.base_model_name:
956
  info.append(f"Tokenizer: `{Global.tokenizer_name}`")
957
+ if Config.ui_show_sys_info:
958
+ info.append(f"Data dir: `{Config.data_dir}`")
959
  return f"""\
960
  <small>{"&nbsp;&nbsp;·&nbsp;&nbsp;".join(info)}</small>
961
  """
llama_lora/ui/tokenizer_ui.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import time
3
  import json
4
 
 
5
  from ..globals import Global
6
  from ..models import get_tokenizer
7
 
@@ -12,7 +13,7 @@ def handle_decode(encoded_tokens_json):
12
 
13
  try:
14
  encoded_tokens = json.loads(encoded_tokens_json)
15
- if Global.ui_dev_mode:
16
  return f"Not actually decoding tokens in UI dev mode.", gr.Markdown.update("", visible=False)
17
  tokenizer = get_tokenizer(tokenizer_name)
18
  decoded_tokens = tokenizer.decode(encoded_tokens)
@@ -26,7 +27,7 @@ def handle_encode(decoded_tokens):
26
  tokenizer_name = Global.tokenizer_name or Global.base_model_name
27
 
28
  try:
29
- if Global.ui_dev_mode:
30
  return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
31
  tokenizer = get_tokenizer(tokenizer_name)
32
  result = tokenizer(decoded_tokens)
 
2
  import time
3
  import json
4
 
5
+ from ..config import Config
6
  from ..globals import Global
7
  from ..models import get_tokenizer
8
 
 
13
 
14
  try:
15
  encoded_tokens = json.loads(encoded_tokens_json)
16
+ if Config.ui_dev_mode:
17
  return f"Not actually decoding tokens in UI dev mode.", gr.Markdown.update("", visible=False)
18
  tokenizer = get_tokenizer(tokenizer_name)
19
  decoded_tokens = tokenizer.decode(encoded_tokens)
 
27
  tokenizer_name = Global.tokenizer_name or Global.base_model_name
28
 
29
  try:
30
+ if Config.ui_dev_mode:
31
  return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
32
  tokenizer = get_tokenizer(tokenizer_name)
33
  result = tokenizer(decoded_tokens)
llama_lora/utils/data.py CHANGED
@@ -3,20 +3,22 @@ import shutil
3
  import fnmatch
4
  import json
5
 
 
6
  from ..globals import Global
7
 
8
 
9
  def init_data_dir():
 
10
  current_file_path = os.path.abspath(__file__)
11
  parent_directory_path = os.path.dirname(current_file_path)
12
  project_dir_path = os.path.abspath(
13
  os.path.join(parent_directory_path, "..", ".."))
14
  copy_sample_data_if_not_exists(os.path.join(project_dir_path, "templates"),
15
- os.path.join(Global.data_dir, "templates"))
16
  copy_sample_data_if_not_exists(os.path.join(project_dir_path, "datasets"),
17
- os.path.join(Global.data_dir, "datasets"))
18
  copy_sample_data_if_not_exists(os.path.join(project_dir_path, "lora_models"),
19
- os.path.join(Global.data_dir, "lora_models"))
20
 
21
 
22
  def copy_sample_data_if_not_exists(source, destination):
@@ -28,28 +30,28 @@ def copy_sample_data_if_not_exists(source, destination):
28
 
29
 
30
  def get_available_template_names():
31
- templates_directory_path = os.path.join(Global.data_dir, "templates")
32
  all_files = os.listdir(templates_directory_path)
33
  names = [filename.rstrip(".json") for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.py")]
34
  return sorted(names)
35
 
36
 
37
  def get_available_dataset_names():
38
- datasets_directory_path = os.path.join(Global.data_dir, "datasets")
39
  all_files = os.listdir(datasets_directory_path)
40
  names = [filename for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.jsonl")]
41
  return sorted(names)
42
 
43
 
44
  def get_available_lora_model_names():
45
- lora_models_directory_path = os.path.join(Global.data_dir, "lora_models")
46
  all_items = os.listdir(lora_models_directory_path)
47
  names = [item for item in all_items if os.path.isdir(os.path.join(lora_models_directory_path, item))]
48
  return sorted(names)
49
 
50
 
51
  def get_path_of_available_lora_model(name):
52
- datasets_directory_path = os.path.join(Global.data_dir, "lora_models")
53
  path = os.path.join(datasets_directory_path, name)
54
  if os.path.isdir(path):
55
  return path
@@ -73,7 +75,7 @@ def get_info_of_available_lora_model(name):
73
 
74
 
75
  def get_dataset_content(name):
76
- file_name = os.path.join(Global.data_dir, "datasets", name)
77
  if not os.path.exists(file_name):
78
  raise ValueError(
79
  f"Can't read {file_name} from datasets. File does not exist.")
 
3
  import fnmatch
4
  import json
5
 
6
+ from ..config import Config
7
  from ..globals import Global
8
 
9
 
10
  def init_data_dir():
11
+ os.makedirs(Config.data_dir, exist_ok=True)
12
  current_file_path = os.path.abspath(__file__)
13
  parent_directory_path = os.path.dirname(current_file_path)
14
  project_dir_path = os.path.abspath(
15
  os.path.join(parent_directory_path, "..", ".."))
16
  copy_sample_data_if_not_exists(os.path.join(project_dir_path, "templates"),
17
+ os.path.join(Config.data_dir, "templates"))
18
  copy_sample_data_if_not_exists(os.path.join(project_dir_path, "datasets"),
19
+ os.path.join(Config.data_dir, "datasets"))
20
  copy_sample_data_if_not_exists(os.path.join(project_dir_path, "lora_models"),
21
+ os.path.join(Config.data_dir, "lora_models"))
22
 
23
 
24
  def copy_sample_data_if_not_exists(source, destination):
 
30
 
31
 
32
  def get_available_template_names():
33
+ templates_directory_path = os.path.join(Config.data_dir, "templates")
34
  all_files = os.listdir(templates_directory_path)
35
  names = [filename.rstrip(".json") for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.py")]
36
  return sorted(names)
37
 
38
 
39
  def get_available_dataset_names():
40
+ datasets_directory_path = os.path.join(Config.data_dir, "datasets")
41
  all_files = os.listdir(datasets_directory_path)
42
  names = [filename for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.jsonl")]
43
  return sorted(names)
44
 
45
 
46
  def get_available_lora_model_names():
47
+ lora_models_directory_path = os.path.join(Config.data_dir, "lora_models")
48
  all_items = os.listdir(lora_models_directory_path)
49
  names = [item for item in all_items if os.path.isdir(os.path.join(lora_models_directory_path, item))]
50
  return sorted(names)
51
 
52
 
53
  def get_path_of_available_lora_model(name):
54
+ datasets_directory_path = os.path.join(Config.data_dir, "lora_models")
55
  path = os.path.join(datasets_directory_path, name)
56
  if os.path.isdir(path):
57
  return path
 
75
 
76
 
77
  def get_dataset_content(name):
78
+ file_name = os.path.join(Config.data_dir, "datasets", name)
79
  if not os.path.exists(file_name):
80
  raise ValueError(
81
  f"Can't read {file_name} from datasets. File does not exist.")
llama_lora/utils/prompter.py CHANGED
@@ -9,6 +9,7 @@ import importlib
9
  import itertools
10
  from typing import Union, List
11
 
 
12
  from ..globals import Global
13
 
14
 
@@ -31,7 +32,7 @@ class Prompter(object):
31
  else:
32
  filename = base_filename + ext
33
 
34
- file_path = osp.join(Global.data_dir, "templates", filename)
35
 
36
  if not osp.exists(file_path):
37
  raise ValueError(f"Can't read {file_path}")
 
9
  import itertools
10
  from typing import Union, List
11
 
12
+ from ..config import Config
13
  from ..globals import Global
14
 
15
 
 
32
  else:
33
  filename = base_filename + ext
34
 
35
+ file_path = osp.join(Config.data_dir, "templates", filename)
36
 
37
  if not osp.exists(file_path):
38
  raise ValueError(f"Can't read {file_path}")