zetavg commited on
Commit
40d6251
β€’
2 Parent(s): 29067b4 889210b

Merge branch 'main' into hf-ui-demo

Browse files
LLaMA_LoRA.ipynb CHANGED
@@ -60,20 +60,15 @@
60
  "# @title A small workaround { display-mode: \"form\" }\n",
61
  "# @markdown Don't panic if you see an error here. Just click the `RESTART RUNTIME` button in the output below, then Run All again.\n",
62
  "# @markdown The error will disappear on the next run.\n",
63
- "!pip install Pillow==9.3.0 numpy==1.23.5\n",
64
  "\n",
 
65
  "import PIL\n",
66
- "major, minor = map(float, PIL.__version__.split(\".\")[:2])\n",
67
- "version_float = major + minor / 10**len(str(minor))\n",
68
- "print('PIL', version_float)\n",
69
- "if version_float < 9.003:\n",
70
- " raise Exception(\"Restart the runtime by clicking the 'RESTART RUNTIME' button above (or Runtime > Restart Runtime).\")\n",
71
- "\n",
72
  "import numpy\n",
73
- "major, minor = map(float, numpy.__version__.split(\".\")[:2])\n",
74
- "version_float = major + minor / 10**len(str(minor))\n",
75
- "print('numpy', version_float)\n",
76
- "if version_float < 1.0023:\n",
77
  " raise Exception(\"Restart the runtime by clicking the 'RESTART RUNTIME' button above (or Runtime > Restart Runtime).\")"
78
  ],
79
  "metadata": {
@@ -144,15 +139,17 @@
144
  "# colab_notebook_name = remove_ipynb_extension(colab_notebook_filename)\n",
145
  "\n",
146
  "from google.colab import drive\n",
147
- "drive.mount(google_drive_mount_path)\n",
 
148
  "\n",
149
- "# google_drive_data_directory_relative_path = f\"{google_drive_colab_data_folder}/{colab_notebook_name}\"\n",
150
- "google_drive_data_directory_relative_path = google_drive_folder\n",
151
- "google_drive_data_directory_path = f\"{google_drive_mount_path}/My Drive/{google_drive_data_directory_relative_path}\"\n",
152
- "!mkdir -p \"{google_drive_data_directory_path}\"\n",
153
- "!ln -nsf \"{google_drive_data_directory_path}\" ./data\n",
154
- "!touch \"data/This folder is used by the Colab notebook \\\"{colab_notebook_filename}\\\".txt\"\n",
155
- "!echo \"Data will be stored in Google Drive folder: \\\"{google_drive_data_directory_relative_path}\\\", which is mounted under \\\"{google_drive_data_directory_path}\\\"\"\n"
 
156
  ],
157
  "metadata": {
158
  "id": "iZmRtUY68U5f"
 
60
  "# @title A small workaround { display-mode: \"form\" }\n",
61
  "# @markdown Don't panic if you see an error here. Just click the `RESTART RUNTIME` button in the output below, then Run All again.\n",
62
  "# @markdown The error will disappear on the next run.\n",
63
+ "%pip install Pillow==9.3.0 numpy==1.23.5\n",
64
  "\n",
65
+ "import pkg_resources as r\n",
66
  "import PIL\n",
 
 
 
 
 
 
67
  "import numpy\n",
68
+ "for module, min_version in [(PIL, \"9.3\"), (numpy, \"1.23\")]:\n",
69
+ " lib_version = r.parse_version(module.__version__)\n",
70
+ " print(module.__name__, lib_version)\n",
71
+ " if lib_version < r.parse_version(min_version):\n",
72
  " raise Exception(\"Restart the runtime by clicking the 'RESTART RUNTIME' button above (or Runtime > Restart Runtime).\")"
73
  ],
74
  "metadata": {
 
139
  "# colab_notebook_name = remove_ipynb_extension(colab_notebook_filename)\n",
140
  "\n",
141
  "from google.colab import drive\n",
142
+ "try:\n",
143
+ " drive.mount(google_drive_mount_path)\n",
144
  "\n",
145
+ " google_drive_data_directory_relative_path = google_drive_folder\n",
146
+ " google_drive_data_directory_path = f\"{google_drive_mount_path}/My Drive/{google_drive_data_directory_relative_path}\"\n",
147
+ " !mkdir -p \"{google_drive_data_directory_path}\"\n",
148
+ " !ln -nsf \"{google_drive_data_directory_path}\" ./data\n",
149
+ " !touch \"data/This folder is used by the Colab notebook \\\"{colab_notebook_filename}\\\".txt\"\n",
150
+ " !echo \"Data will be stored in Google Drive folder: \\\"{google_drive_data_directory_relative_path}\\\", which is mounted under \\\"{google_drive_data_directory_path}\\\"\"\n",
151
+ "except Exception as e:\n",
152
+ " print(\"Drive won't be mounted!\")"
153
  ],
154
  "metadata": {
155
  "id": "iZmRtUY68U5f"
README.md CHANGED
@@ -35,6 +35,7 @@ Making evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA) e
35
  * **[1-click up and running in Google Colab](#run-on-google-colab)** with a standard GPU runtime.
36
  * Loads and stores data in Google Drive.
37
  * Evaluate various LLaMA LoRA models stored in your folder or from Hugging Face.<br /><a href="https://youtu.be/IoEMgouZ5xU"><img width="640px" src="https://user-images.githubusercontent.com/3784687/231023326-f28c84e2-df74-4179-b0ac-c25c4e8ca001.gif" /></a>
 
38
  * Fine-tune LLaMA models with different prompt templates and training dataset format.<br /><a href="https://youtu.be/IoEMgouZ5xU?t=60"><img width="640px" src="https://user-images.githubusercontent.com/3784687/231026640-b5cf5c79-9fe9-430b-8d4e-7346eb9567ad.gif" /></a>
39
  * Load JSON and JSONL datasets from your folder, or even paste plain text directly into the UI.
40
  * Supports Stanford Alpaca [seed_tasks](https://github.com/tatsu-lab/stanford_alpaca/blob/main/seed_tasks.jsonl), [alpaca_data](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json) and [OpenAI "prompt"-"completion"](https://platform.openai.com/docs/guides/fine-tuning/data-formatting) format.
@@ -86,11 +87,13 @@ setup: |
86
  pip install wandb
87
  cd ..
88
  echo 'Dependencies installed.'
 
 
89
 
90
  # Start the app.
91
  run: |
92
  echo 'Starting...'
93
- python llama_lora_tuner/app.py --data_dir='/data' --wandb_api_key "$([ -f /data/secrets/wandb_api_key ] && cat /data/secrets/wandb_api_key | tr -d '\n')" --base_model='decapoda-research/llama-7b-hf' --share
94
  ```
95
 
96
  Then launch a cluster to run the task:
 
35
  * **[1-click up and running in Google Colab](#run-on-google-colab)** with a standard GPU runtime.
36
  * Loads and stores data in Google Drive.
37
  * Evaluate various LLaMA LoRA models stored in your folder or from Hugging Face.<br /><a href="https://youtu.be/IoEMgouZ5xU"><img width="640px" src="https://user-images.githubusercontent.com/3784687/231023326-f28c84e2-df74-4179-b0ac-c25c4e8ca001.gif" /></a>
38
+ * Switch between base models such as `decapoda-research/llama-7b-hf`, `nomic-ai/gpt4all-j`, `databricks/dolly-v2-7b`, `EleutherAI/gpt-j-6b`, or `EleutherAI/pythia-6.9b`.
39
  * Fine-tune LLaMA models with different prompt templates and training dataset format.<br /><a href="https://youtu.be/IoEMgouZ5xU?t=60"><img width="640px" src="https://user-images.githubusercontent.com/3784687/231026640-b5cf5c79-9fe9-430b-8d4e-7346eb9567ad.gif" /></a>
40
  * Load JSON and JSONL datasets from your folder, or even paste plain text directly into the UI.
41
  * Supports Stanford Alpaca [seed_tasks](https://github.com/tatsu-lab/stanford_alpaca/blob/main/seed_tasks.jsonl), [alpaca_data](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json) and [OpenAI "prompt"-"completion"](https://platform.openai.com/docs/guides/fine-tuning/data-formatting) format.
 
87
  pip install wandb
88
  cd ..
89
  echo 'Dependencies installed.'
90
+ echo 'Pre-downloading base models so that you won't have to wait for long once the app is ready...'
91
+ python llama_lora_tuner/download_base_model.py --base_model_names='decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j,databricks/dolly-v2-7b'
92
 
93
  # Start the app.
94
  run: |
95
  echo 'Starting...'
96
+ python llama_lora_tuner/app.py --data_dir='/data' --wandb_api_key="$([ -f /data/secrets/wandb_api_key ] && cat /data/secrets/wandb_api_key | tr -d '\n')" --base_model=decapoda-research/llama-7b-hf --base_model_choices='decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j,databricks/dolly-v2-7b --share
97
  ```
98
 
99
  Then launch a cluster to run the task:
app.py CHANGED
@@ -15,6 +15,7 @@ def main(
15
  base_model: str = "",
16
  data_dir: str = "",
17
  base_model_choices: str = "",
 
18
  # Allows to listen on all interfaces by providing '0.0.0.0'.
19
  server_name: str = "127.0.0.1",
20
  share: bool = False,
@@ -60,6 +61,8 @@ def main(
60
  if base_model not in Global.base_model_choices:
61
  Global.base_model_choices = [base_model] + Global.base_model_choices
62
 
 
 
63
  Global.data_dir = os.path.abspath(data_dir)
64
  Global.load_8bit = load_8bit
65
 
 
15
  base_model: str = "",
16
  data_dir: str = "",
17
  base_model_choices: str = "",
18
+ trust_remote_code: bool = False,
19
  # Allows to listen on all interfaces by providing '0.0.0.0'.
20
  server_name: str = "127.0.0.1",
21
  share: bool = False,
 
61
  if base_model not in Global.base_model_choices:
62
  Global.base_model_choices = [base_model] + Global.base_model_choices
63
 
64
+ Global.trust_remote_code = trust_remote_code
65
+
66
  Global.data_dir = os.path.abspath(data_dir)
67
  Global.load_8bit = load_8bit
68
 
llama_lora/globals.py CHANGED
@@ -20,6 +20,8 @@ class Global:
20
  base_model_name: str = ""
21
  base_model_choices: List[str] = []
22
 
 
 
23
  # Functions
24
  train_fn: Any = train
25
 
 
20
  base_model_name: str = ""
21
  base_model_choices: List[str] = []
22
 
23
+ trust_remote_code = False
24
+
25
  # Functions
26
  train_fn: Any = train
27
 
llama_lora/lib/inference.py CHANGED
@@ -66,14 +66,14 @@ def generate(
66
  with generate_with_streaming(**generate_params) as generator:
67
  for output in generator:
68
  decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
69
- yield decoded_output, output
70
  if output[-1] in [tokenizer.eos_token_id]:
71
  break
72
 
73
  if generation_output:
74
  output = generation_output.sequences[0]
75
  decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
76
- yield decoded_output, output
77
 
78
  return # early return for stream_output
79
 
@@ -82,5 +82,5 @@ def generate(
82
  generation_output = model.generate(**generate_params)
83
  output = generation_output.sequences[0]
84
  decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
85
- yield decoded_output, output
86
  return
 
66
  with generate_with_streaming(**generate_params) as generator:
67
  for output in generator:
68
  decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
69
+ yield decoded_output, output, False
70
  if output[-1] in [tokenizer.eos_token_id]:
71
  break
72
 
73
  if generation_output:
74
  output = generation_output.sequences[0]
75
  decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
76
+ yield decoded_output, output, True
77
 
78
  return # early return for stream_output
79
 
 
82
  generation_output = model.generate(**generate_params)
83
  output = generation_output.sequences[0]
84
  decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
85
+ yield decoded_output, output, True
86
  return
llama_lora/models.py CHANGED
@@ -5,7 +5,10 @@ import json
5
  import re
6
 
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
 
 
 
9
  from peft import PeftModel
10
 
11
  from .globals import Global
@@ -27,37 +30,83 @@ def get_new_base_model(base_model_name):
27
  Global.name_of_new_base_model_that_is_ready_to_be_used = None
28
  clear_cache()
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  device = get_device()
31
 
32
  if device == "cuda":
33
- model = AutoModelForCausalLM.from_pretrained(
34
- base_model_name,
35
  load_in_8bit=Global.load_8bit,
36
  torch_dtype=torch.float16,
37
  # device_map="auto",
38
  # ? https://github.com/tloen/alpaca-lora/issues/21
39
  device_map={'': 0},
 
 
 
40
  )
41
  elif device == "mps":
42
- model = AutoModelForCausalLM.from_pretrained(
43
- base_model_name,
44
  device_map={"": device},
45
  torch_dtype=torch.float16,
 
 
 
46
  )
47
  else:
48
- model = AutoModelForCausalLM.from_pretrained(
49
- base_model_name, device_map={"": device}, low_cpu_mem_usage=True
 
 
 
 
 
50
  )
51
 
52
- tokenizer = get_tokenizer(base_model_name)
53
-
54
- if re.match("[^/]+/llama", base_model_name):
55
- model.config.pad_token_id = tokenizer.pad_token_id = 0
56
- model.config.bos_token_id = tokenizer.bos_token_id = 1
57
- model.config.eos_token_id = tokenizer.eos_token_id = 2
58
-
59
- return model
60
-
61
 
62
  def get_tokenizer(base_model_name):
63
  if Global.ui_dev_mode:
@@ -68,10 +117,16 @@ def get_tokenizer(base_model_name):
68
  return loaded_tokenizer
69
 
70
  try:
71
- tokenizer = AutoTokenizer.from_pretrained(base_model_name)
 
 
 
72
  except Exception as e:
73
  if 'LLaMATokenizer' in str(e):
74
- tokenizer = LlamaTokenizer.from_pretrained(base_model_name)
 
 
 
75
  else:
76
  raise e
77
 
@@ -100,13 +155,15 @@ def get_model(
100
  peft_model_name_or_path = peft_model_name
101
 
102
  if peft_model_name:
103
- lora_models_directory_path = os.path.join(Global.data_dir, "lora_models")
 
104
  possible_lora_model_path = os.path.join(
105
  lora_models_directory_path, peft_model_name)
106
  if os.path.isdir(possible_lora_model_path):
107
  peft_model_name_or_path = possible_lora_model_path
108
 
109
- possible_model_info_json_path = os.path.join(possible_lora_model_path, "info.json")
 
110
  if os.path.isfile(possible_model_info_json_path):
111
  try:
112
  with open(possible_model_info_json_path, "r") as file:
@@ -115,7 +172,8 @@ def get_model(
115
  if possible_hf_model_name and json_data.get("load_from_hf"):
116
  peft_model_name_or_path = possible_hf_model_name
117
  except Exception as e:
118
- raise ValueError("Error reading model info from {possible_model_info_json_path}: {e}")
 
119
 
120
  Global.loaded_models.prepare_to_set()
121
  clear_cache()
@@ -148,7 +206,8 @@ def get_model(
148
  )
149
 
150
  if re.match("[^/]+/llama", base_model_name):
151
- model.config.pad_token_id = get_tokenizer(base_model_name).pad_token_id = 0
 
152
  model.config.bos_token_id = 1
153
  model.config.eos_token_id = 2
154
 
@@ -166,7 +225,8 @@ def get_model(
166
 
167
 
168
  def prepare_base_model(base_model_name=Global.default_base_model_name):
169
- Global.new_base_model_that_is_ready_to_be_used = get_new_base_model(base_model_name)
 
170
  Global.name_of_new_base_model_that_is_ready_to_be_used = base_model_name
171
 
172
 
 
5
  import re
6
 
7
  import torch
8
+ from transformers import (
9
+ AutoModelForCausalLM, AutoModel,
10
+ AutoTokenizer, LlamaTokenizer
11
+ )
12
  from peft import PeftModel
13
 
14
  from .globals import Global
 
30
  Global.name_of_new_base_model_that_is_ready_to_be_used = None
31
  clear_cache()
32
 
33
+ model_class = AutoModelForCausalLM
34
+ from_tf = False
35
+ force_download = False
36
+ has_tried_force_download = False
37
+ while True:
38
+ try:
39
+ model = _get_model_from_pretrained(
40
+ model_class, base_model_name, from_tf=from_tf, force_download=force_download)
41
+ break
42
+ except Exception as e:
43
+ if 'from_tf' in str(e):
44
+ print(
45
+ f"Got error while loading model {base_model_name} with AutoModelForCausalLM: {e}.")
46
+ print("Retrying with from_tf=True...")
47
+ from_tf = True
48
+ force_download = False
49
+ elif model_class == AutoModelForCausalLM:
50
+ print(
51
+ f"Got error while loading model {base_model_name} with AutoModelForCausalLM: {e}.")
52
+ print("Retrying with AutoModel...")
53
+ model_class = AutoModel
54
+ force_download = False
55
+ else:
56
+ if has_tried_force_download:
57
+ raise e
58
+ print(
59
+ f"Got error while loading model {base_model_name}: {e}.")
60
+ print("Retrying with force_download=True...")
61
+ model_class = AutoModelForCausalLM
62
+ from_tf = False
63
+ force_download = True
64
+ has_tried_force_download = True
65
+
66
+ tokenizer = get_tokenizer(base_model_name)
67
+
68
+ if re.match("[^/]+/llama", base_model_name):
69
+ model.config.pad_token_id = tokenizer.pad_token_id = 0
70
+ model.config.bos_token_id = tokenizer.bos_token_id = 1
71
+ model.config.eos_token_id = tokenizer.eos_token_id = 2
72
+
73
+ return model
74
+
75
+
76
+ def _get_model_from_pretrained(model_class, model_name, from_tf=False, force_download=False):
77
  device = get_device()
78
 
79
  if device == "cuda":
80
+ return model_class.from_pretrained(
81
+ model_name,
82
  load_in_8bit=Global.load_8bit,
83
  torch_dtype=torch.float16,
84
  # device_map="auto",
85
  # ? https://github.com/tloen/alpaca-lora/issues/21
86
  device_map={'': 0},
87
+ from_tf=from_tf,
88
+ force_download=force_download,
89
+ trust_remote_code=Global.trust_remote_code
90
  )
91
  elif device == "mps":
92
+ return model_class.from_pretrained(
93
+ model_name,
94
  device_map={"": device},
95
  torch_dtype=torch.float16,
96
+ from_tf=from_tf,
97
+ force_download=force_download,
98
+ trust_remote_code=Global.trust_remote_code
99
  )
100
  else:
101
+ return model_class.from_pretrained(
102
+ model_name,
103
+ device_map={"": device},
104
+ low_cpu_mem_usage=True,
105
+ from_tf=from_tf,
106
+ force_download=force_download,
107
+ trust_remote_code=Global.trust_remote_code
108
  )
109
 
 
 
 
 
 
 
 
 
 
110
 
111
  def get_tokenizer(base_model_name):
112
  if Global.ui_dev_mode:
 
117
  return loaded_tokenizer
118
 
119
  try:
120
+ tokenizer = AutoTokenizer.from_pretrained(
121
+ base_model_name,
122
+ trust_remote_code=Global.trust_remote_code
123
+ )
124
  except Exception as e:
125
  if 'LLaMATokenizer' in str(e):
126
+ tokenizer = LlamaTokenizer.from_pretrained(
127
+ base_model_name,
128
+ trust_remote_code=Global.trust_remote_code
129
+ )
130
  else:
131
  raise e
132
 
 
155
  peft_model_name_or_path = peft_model_name
156
 
157
  if peft_model_name:
158
+ lora_models_directory_path = os.path.join(
159
+ Global.data_dir, "lora_models")
160
  possible_lora_model_path = os.path.join(
161
  lora_models_directory_path, peft_model_name)
162
  if os.path.isdir(possible_lora_model_path):
163
  peft_model_name_or_path = possible_lora_model_path
164
 
165
+ possible_model_info_json_path = os.path.join(
166
+ possible_lora_model_path, "info.json")
167
  if os.path.isfile(possible_model_info_json_path):
168
  try:
169
  with open(possible_model_info_json_path, "r") as file:
 
172
  if possible_hf_model_name and json_data.get("load_from_hf"):
173
  peft_model_name_or_path = possible_hf_model_name
174
  except Exception as e:
175
+ raise ValueError(
176
+ "Error reading model info from {possible_model_info_json_path}: {e}")
177
 
178
  Global.loaded_models.prepare_to_set()
179
  clear_cache()
 
206
  )
207
 
208
  if re.match("[^/]+/llama", base_model_name):
209
+ model.config.pad_token_id = get_tokenizer(
210
+ base_model_name).pad_token_id = 0
211
  model.config.bos_token_id = 1
212
  model.config.eos_token_id = 2
213
 
 
225
 
226
 
227
  def prepare_base_model(base_model_name=Global.default_base_model_name):
228
+ Global.new_base_model_that_is_ready_to_be_used = get_new_base_model(
229
+ base_model_name)
230
  Global.name_of_new_base_model_that_is_ready_to_be_used = base_model_name
231
 
232
 
llama_lora/ui/inference_ui.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  import time
3
  import json
4
 
@@ -21,13 +22,21 @@ default_show_raw = True
21
  inference_output_lines = 12
22
 
23
 
 
 
 
 
 
 
 
 
24
  def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
25
  base_model_name = Global.base_model_name
26
 
27
  try:
28
  get_tokenizer(base_model_name)
29
  get_model(base_model_name, lora_model_name)
30
- return ("", "")
31
 
32
  except Exception as e:
33
  raise gr.Error(e)
@@ -65,6 +74,31 @@ def do_inference(
65
  prompter = Prompter(prompt_template)
66
  prompt = prompter.generate_prompt(variables)
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  if Global.ui_dev_mode:
69
  message = f"Hi, I’m currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {base_model_name}\nLoRA model: {lora_model_name}\n\nThe following is your prompt:\n\n{prompt}"
70
  print(message)
@@ -83,35 +117,50 @@ def do_inference(
83
  out += "\n"
84
  yield out
85
 
 
86
  for partial_sentence in word_generator(message):
 
87
  yield (
88
  gr.Textbox.update(
89
- value=partial_sentence, lines=inference_output_lines),
 
90
  json.dumps(
91
- list(range(len(partial_sentence.split()))), indent=2)
 
 
 
 
 
92
  )
93
  time.sleep(0.05)
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  return
96
  time.sleep(1)
97
  yield (
98
  gr.Textbox.update(value=message, lines=inference_output_lines),
99
- json.dumps(list(range(len(message.split()))), indent=2)
 
 
 
100
  )
101
  return
102
 
103
  tokenizer = get_tokenizer(base_model_name)
104
  model = get_model(base_model_name, lora_model_name)
105
 
106
- generation_config = GenerationConfig(
107
- temperature=float(temperature), # to avoid ValueError('`temperature` has to be a strictly positive float, but is 2')
108
- top_p=top_p,
109
- top_k=top_k,
110
- repetition_penalty=repetition_penalty,
111
- num_beams=num_beams,
112
- do_sample=temperature > 0, # https://github.com/huggingface/transformers/issues/22405#issuecomment-1485527953
113
- )
114
-
115
  def ui_generation_stopping_criteria(input_ids, score, **kwargs):
116
  if Global.should_stop_generating:
117
  return True
@@ -129,10 +178,8 @@ def do_inference(
129
  'stream_output': stream_output
130
  }
131
 
132
- for (decoded_output, output) in generate(**generation_args):
133
- raw_output_str = None
134
- if show_raw:
135
- raw_output_str = str(output)
136
  response = prompter.get_response(decoded_output)
137
 
138
  if Global.should_stop_generating:
@@ -141,7 +188,12 @@ def do_inference(
141
  yield (
142
  gr.Textbox.update(
143
  value=response, lines=inference_output_lines),
144
- raw_output_str)
 
 
 
 
 
145
 
146
  if Global.should_stop_generating:
147
  # If the user stops the generation, and then clicks the
@@ -199,11 +251,13 @@ def get_warning_message_for_lora_model_and_prompt_template(lora_model, prompt_te
199
  if lora_mode_info and isinstance(lora_mode_info, dict):
200
  model_base_model = lora_mode_info.get("base_model")
201
  if model_base_model and model_base_model != Global.base_model_name:
202
- messages.append(f"⚠️ This model was trained on top of base model `{model_base_model}`, it might not work properly with the selected base model `{Global.base_model_name}`.")
 
203
 
204
  model_prompt_template = lora_mode_info.get("prompt_template")
205
  if model_prompt_template and model_prompt_template != prompt_template:
206
- messages.append(f"This model was trained with prompt template `{model_prompt_template}`.")
 
207
 
208
  return " ".join(messages)
209
 
@@ -221,7 +275,8 @@ def handle_prompt_template_change(prompt_template, lora_model):
221
 
222
  model_prompt_template_message_update = gr.Markdown.update(
223
  "", visible=False)
224
- warning_message = get_warning_message_for_lora_model_and_prompt_template(lora_model, prompt_template)
 
225
  if warning_message:
226
  model_prompt_template_message_update = gr.Markdown.update(
227
  warning_message, visible=True)
@@ -241,7 +296,8 @@ def handle_lora_model_change(lora_model, prompt_template):
241
 
242
  model_prompt_template_message_update = gr.Markdown.update(
243
  "", visible=False)
244
- warning_message = get_warning_message_for_lora_model_and_prompt_template(lora_model, prompt_template)
 
245
  if warning_message:
246
  model_prompt_template_message_update = gr.Markdown.update(
247
  warning_message, visible=True)
@@ -260,6 +316,56 @@ def update_prompt_preview(prompt_template,
260
 
261
 
262
  def inference_ui():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  things_that_might_timeout = []
264
 
265
  with gr.Blocks() as inference_ui_blocks:
@@ -387,6 +493,47 @@ def inference_ui():
387
  inference_output = gr.Textbox(
388
  lines=inference_output_lines, label="Output", elem_id="inference_output")
389
  inference_output.style(show_copy_button=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  with gr.Accordion(
391
  "Raw Output",
392
  open=not default_show_raw,
@@ -400,7 +547,8 @@ def inference_ui():
400
  interactive=False,
401
  elem_id="inference_raw_output")
402
 
403
- reload_selected_models_btn = gr.Button("", elem_id="inference_reload_selected_models_btn")
 
404
 
405
  show_raw_change_event = show_raw.change(
406
  fn=lambda show_raw: gr.Accordion.update(visible=show_raw),
@@ -440,7 +588,8 @@ def inference_ui():
440
  generate_event = generate_btn.click(
441
  fn=prepare_inference,
442
  inputs=[lora_model],
443
- outputs=[inference_output, inference_raw_output],
 
444
  ).then(
445
  fn=do_inference,
446
  inputs=[
@@ -457,7 +606,8 @@ def inference_ui():
457
  stream_output,
458
  show_raw,
459
  ],
460
- outputs=[inference_output, inference_raw_output],
 
461
  api_name="inference"
462
  )
463
  stop_btn.click(
 
1
  import gradio as gr
2
+ import os
3
  import time
4
  import json
5
 
 
22
  inference_output_lines = 12
23
 
24
 
25
+ class LoggingItem:
26
+ def __init__(self, label):
27
+ self.label = label
28
+
29
+ def deserialize(self, value, **kwargs):
30
+ return value
31
+
32
+
33
  def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
34
  base_model_name = Global.base_model_name
35
 
36
  try:
37
  get_tokenizer(base_model_name)
38
  get_model(base_model_name, lora_model_name)
39
+ return ("", "", gr.Textbox.update(visible=False))
40
 
41
  except Exception as e:
42
  raise gr.Error(e)
 
74
  prompter = Prompter(prompt_template)
75
  prompt = prompter.generate_prompt(variables)
76
 
77
+ generation_config = GenerationConfig(
78
+ # to avoid ValueError('`temperature` has to be a strictly positive float, but is 2')
79
+ temperature=float(temperature),
80
+ top_p=top_p,
81
+ top_k=top_k,
82
+ repetition_penalty=repetition_penalty,
83
+ num_beams=num_beams,
84
+ # https://github.com/huggingface/transformers/issues/22405#issuecomment-1485527953
85
+ do_sample=temperature > 0,
86
+ )
87
+
88
+ def get_output_for_flagging(output, raw_output, completed=True):
89
+ return json.dumps({
90
+ 'base_model': base_model_name,
91
+ 'adaptor_model': lora_model_name,
92
+ 'prompt': prompt,
93
+ 'output': output,
94
+ 'completed': completed,
95
+ 'raw_output': raw_output,
96
+ 'max_new_tokens': max_new_tokens,
97
+ 'prompt_template': prompt_template,
98
+ 'prompt_template_variables': variables,
99
+ 'generation_config': generation_config.to_dict(),
100
+ })
101
+
102
  if Global.ui_dev_mode:
103
  message = f"Hi, I’m currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {base_model_name}\nLoRA model: {lora_model_name}\n\nThe following is your prompt:\n\n{prompt}"
104
  print(message)
 
117
  out += "\n"
118
  yield out
119
 
120
+ output = ""
121
  for partial_sentence in word_generator(message):
122
+ output = partial_sentence
123
  yield (
124
  gr.Textbox.update(
125
+ value=output,
126
+ lines=inference_output_lines),
127
  json.dumps(
128
+ list(range(len(output.split()))),
129
+ indent=2),
130
+ gr.Textbox.update(
131
+ value=get_output_for_flagging(
132
+ output, "", completed=False),
133
+ visible=True)
134
  )
135
  time.sleep(0.05)
136
 
137
+ yield (
138
+ gr.Textbox.update(
139
+ value=output,
140
+ lines=inference_output_lines),
141
+ json.dumps(
142
+ list(range(len(output.split()))),
143
+ indent=2),
144
+ gr.Textbox.update(
145
+ value=get_output_for_flagging(
146
+ output, "", completed=True),
147
+ visible=True)
148
+ )
149
+
150
  return
151
  time.sleep(1)
152
  yield (
153
  gr.Textbox.update(value=message, lines=inference_output_lines),
154
+ json.dumps(list(range(len(message.split()))), indent=2),
155
+ gr.Textbox.update(
156
+ value=get_output_for_flagging(message, ""),
157
+ visible=True)
158
  )
159
  return
160
 
161
  tokenizer = get_tokenizer(base_model_name)
162
  model = get_model(base_model_name, lora_model_name)
163
 
 
 
 
 
 
 
 
 
 
164
  def ui_generation_stopping_criteria(input_ids, score, **kwargs):
165
  if Global.should_stop_generating:
166
  return True
 
178
  'stream_output': stream_output
179
  }
180
 
181
+ for (decoded_output, output, completed) in generate(**generation_args):
182
+ raw_output_str = str(output)
 
 
183
  response = prompter.get_response(decoded_output)
184
 
185
  if Global.should_stop_generating:
 
188
  yield (
189
  gr.Textbox.update(
190
  value=response, lines=inference_output_lines),
191
+ raw_output_str,
192
+ gr.Textbox.update(
193
+ value=get_output_for_flagging(
194
+ decoded_output, raw_output_str, completed=completed),
195
+ visible=True)
196
+ )
197
 
198
  if Global.should_stop_generating:
199
  # If the user stops the generation, and then clicks the
 
251
  if lora_mode_info and isinstance(lora_mode_info, dict):
252
  model_base_model = lora_mode_info.get("base_model")
253
  if model_base_model and model_base_model != Global.base_model_name:
254
+ messages.append(
255
+ f"⚠️ This model was trained on top of base model `{model_base_model}`, it might not work properly with the selected base model `{Global.base_model_name}`.")
256
 
257
  model_prompt_template = lora_mode_info.get("prompt_template")
258
  if model_prompt_template and model_prompt_template != prompt_template:
259
+ messages.append(
260
+ f"This model was trained with prompt template `{model_prompt_template}`.")
261
 
262
  return " ".join(messages)
263
 
 
275
 
276
  model_prompt_template_message_update = gr.Markdown.update(
277
  "", visible=False)
278
+ warning_message = get_warning_message_for_lora_model_and_prompt_template(
279
+ lora_model, prompt_template)
280
  if warning_message:
281
  model_prompt_template_message_update = gr.Markdown.update(
282
  warning_message, visible=True)
 
296
 
297
  model_prompt_template_message_update = gr.Markdown.update(
298
  "", visible=False)
299
+ warning_message = get_warning_message_for_lora_model_and_prompt_template(
300
+ lora_model, prompt_template)
301
  if warning_message:
302
  model_prompt_template_message_update = gr.Markdown.update(
303
  warning_message, visible=True)
 
316
 
317
 
318
  def inference_ui():
319
+ flagging_dir = os.path.join(Global.data_dir, "flagging", "inference")
320
+ if not os.path.exists(flagging_dir):
321
+ os.makedirs(flagging_dir)
322
+
323
+ flag_callback = gr.CSVLogger()
324
+ flag_components = [
325
+ LoggingItem("Base Model"),
326
+ LoggingItem("Adaptor Model"),
327
+ LoggingItem("Type"),
328
+ LoggingItem("Prompt"),
329
+ LoggingItem("Output"),
330
+ LoggingItem("Completed"),
331
+ LoggingItem("Config"),
332
+ LoggingItem("Raw Output"),
333
+ LoggingItem("Max New Tokens"),
334
+ LoggingItem("Prompt Template"),
335
+ LoggingItem("Prompt Template Variables"),
336
+ LoggingItem("Generation Config"),
337
+ ]
338
+ flag_callback.setup(flag_components, flagging_dir)
339
+
340
+ def get_flag_callback_args(output_for_flagging_str, flag_type):
341
+ output_for_flagging = json.loads(output_for_flagging_str)
342
+ generation_config = output_for_flagging.get("generation_config", {})
343
+ config = []
344
+ if generation_config.get('do_sample', False):
345
+ config.append(
346
+ f"Temperature: {generation_config.get('temperature')}")
347
+ config.append(f"Top P: {generation_config.get('top_p')}")
348
+ config.append(f"Top K: {generation_config.get('top_k')}")
349
+ num_beams = generation_config.get('num_beams', 1)
350
+ if num_beams > 1:
351
+ config.append(f"Beams: {generation_config.get('num_beams')}")
352
+ config.append(f"RP: {generation_config.get('repetition_penalty')}")
353
+ return [
354
+ output_for_flagging.get("base_model", ""),
355
+ output_for_flagging.get("adaptor_model", ""),
356
+ flag_type,
357
+ output_for_flagging.get("prompt", ""),
358
+ output_for_flagging.get("output", ""),
359
+ str(output_for_flagging.get("completed", "")),
360
+ ", ".join(config),
361
+ output_for_flagging.get("raw_output", ""),
362
+ str(output_for_flagging.get("max_new_tokens", "")),
363
+ output_for_flagging.get("prompt_template", ""),
364
+ json.dumps(output_for_flagging.get(
365
+ "prompt_template_variables", "")),
366
+ json.dumps(output_for_flagging.get("generation_config", "")),
367
+ ]
368
+
369
  things_that_might_timeout = []
370
 
371
  with gr.Blocks() as inference_ui_blocks:
 
493
  inference_output = gr.Textbox(
494
  lines=inference_output_lines, label="Output", elem_id="inference_output")
495
  inference_output.style(show_copy_button=True)
496
+
497
+ with gr.Row(elem_id="inference_flagging_group"):
498
+ output_for_flagging = gr.Textbox(
499
+ interactive=False, visible=False,
500
+ elem_id="inference_output_for_flagging")
501
+ flag_btn = gr.Button(
502
+ "Flag", elem_id="inference_flag_btn")
503
+ flag_up_btn = gr.Button(
504
+ "πŸ‘", elem_id="inference_flag_up_btn")
505
+ flag_down_btn = gr.Button(
506
+ "πŸ‘Ž", elem_id="inference_flag_down_btn")
507
+ flag_output = gr.Markdown(
508
+ "", elem_id="inference_flag_output")
509
+ flag_btn.click(
510
+ lambda d: (flag_callback.flag(
511
+ get_flag_callback_args(d, "Flag"),
512
+ flag_option="Flag",
513
+ username=None
514
+ ), "")[1],
515
+ inputs=[output_for_flagging],
516
+ outputs=[flag_output],
517
+ preprocess=False)
518
+ flag_up_btn.click(
519
+ lambda d: (flag_callback.flag(
520
+ get_flag_callback_args(d, "πŸ‘"),
521
+ flag_option="Up Vote",
522
+ username=None
523
+ ), "")[1],
524
+ inputs=[output_for_flagging],
525
+ outputs=[flag_output],
526
+ preprocess=False)
527
+ flag_down_btn.click(
528
+ lambda d: (flag_callback.flag(
529
+ get_flag_callback_args(d, "πŸ‘Ž"),
530
+ flag_option="Down Vote",
531
+ username=None
532
+ ), "")[1],
533
+ inputs=[output_for_flagging],
534
+ outputs=[flag_output],
535
+ preprocess=False)
536
+
537
  with gr.Accordion(
538
  "Raw Output",
539
  open=not default_show_raw,
 
547
  interactive=False,
548
  elem_id="inference_raw_output")
549
 
550
+ reload_selected_models_btn = gr.Button(
551
+ "", elem_id="inference_reload_selected_models_btn")
552
 
553
  show_raw_change_event = show_raw.change(
554
  fn=lambda show_raw: gr.Accordion.update(visible=show_raw),
 
588
  generate_event = generate_btn.click(
589
  fn=prepare_inference,
590
  inputs=[lora_model],
591
+ outputs=[inference_output,
592
+ inference_raw_output, output_for_flagging],
593
  ).then(
594
  fn=do_inference,
595
  inputs=[
 
606
  stream_output,
607
  show_raw,
608
  ],
609
+ outputs=[inference_output,
610
+ inference_raw_output, output_for_flagging],
611
  api_name="inference"
612
  )
613
  stop_btn.click(
llama_lora/ui/main_page.py CHANGED
@@ -398,6 +398,45 @@ def main_page_custom_css():
398
  bottom: 16px;
399
  }
400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  #dataset_plain_text_input_variables_separator textarea,
402
  #dataset_plain_text_input_and_output_separator textarea,
403
  #dataset_plain_text_data_separator textarea {
 
398
  bottom: 16px;
399
  }
400
 
401
+ #inference_flagging_group {
402
+ position: relative;
403
+ }
404
+ #inference_flag_output {
405
+ min-height: 1px !important;
406
+ position: absolute;
407
+ top: 0;
408
+ bottom: 0;
409
+ right: 0;
410
+ pointer-events: none;
411
+ opacity: 0.5;
412
+ }
413
+ #inference_flag_output .wrap {
414
+ top: 0;
415
+ bottom: 0;
416
+ right: 0;
417
+ justify-content: center;
418
+ align-items: flex-end;
419
+ padding: 4px !important;
420
+ }
421
+ #inference_flag_output .wrap svg {
422
+ display: none;
423
+ }
424
+ .form:has(> #inference_output_for_flagging),
425
+ #inference_output_for_flagging {
426
+ display: none;
427
+ }
428
+ #inference_flagging_group:has(#inference_output_for_flagging.hidden) {
429
+ opacity: 0.5;
430
+ pointer-events: none;
431
+ }
432
+ #inference_flag_up_btn, #inference_flag_down_btn {
433
+ min-width: 44px;
434
+ flex-grow: 1;
435
+ }
436
+ #inference_flag_btn {
437
+ flex-grow: 2;
438
+ }
439
+
440
  #dataset_plain_text_input_variables_separator textarea,
441
  #dataset_plain_text_input_and_output_separator textarea,
442
  #dataset_plain_text_data_separator textarea {