Spaces:
Runtime error
Runtime error
Merge branch 'main' into hf-ui-demo
Browse files- LLaMA_LoRA.ipynb +16 -19
- README.md +4 -1
- app.py +3 -0
- llama_lora/globals.py +2 -0
- llama_lora/lib/inference.py +3 -3
- llama_lora/models.py +83 -23
- llama_lora/ui/inference_ui.py +175 -25
- llama_lora/ui/main_page.py +39 -0
LLaMA_LoRA.ipynb
CHANGED
@@ -60,20 +60,15 @@
|
|
60 |
"# @title A small workaround { display-mode: \"form\" }\n",
|
61 |
"# @markdown Don't panic if you see an error here. Just click the `RESTART RUNTIME` button in the output below, then Run All again.\n",
|
62 |
"# @markdown The error will disappear on the next run.\n",
|
63 |
-
"
|
64 |
"\n",
|
|
|
65 |
"import PIL\n",
|
66 |
-
"major, minor = map(float, PIL.__version__.split(\".\")[:2])\n",
|
67 |
-
"version_float = major + minor / 10**len(str(minor))\n",
|
68 |
-
"print('PIL', version_float)\n",
|
69 |
-
"if version_float < 9.003:\n",
|
70 |
-
" raise Exception(\"Restart the runtime by clicking the 'RESTART RUNTIME' button above (or Runtime > Restart Runtime).\")\n",
|
71 |
-
"\n",
|
72 |
"import numpy\n",
|
73 |
-
"
|
74 |
-
"
|
75 |
-
"print(
|
76 |
-
"if
|
77 |
" raise Exception(\"Restart the runtime by clicking the 'RESTART RUNTIME' button above (or Runtime > Restart Runtime).\")"
|
78 |
],
|
79 |
"metadata": {
|
@@ -144,15 +139,17 @@
|
|
144 |
"# colab_notebook_name = remove_ipynb_extension(colab_notebook_filename)\n",
|
145 |
"\n",
|
146 |
"from google.colab import drive\n",
|
147 |
-
"
|
|
|
148 |
"\n",
|
149 |
-
"
|
150 |
-
"
|
151 |
-
"
|
152 |
-
"!
|
153 |
-
"!
|
154 |
-
"!
|
155 |
-
"
|
|
|
156 |
],
|
157 |
"metadata": {
|
158 |
"id": "iZmRtUY68U5f"
|
|
|
60 |
"# @title A small workaround { display-mode: \"form\" }\n",
|
61 |
"# @markdown Don't panic if you see an error here. Just click the `RESTART RUNTIME` button in the output below, then Run All again.\n",
|
62 |
"# @markdown The error will disappear on the next run.\n",
|
63 |
+
"%pip install Pillow==9.3.0 numpy==1.23.5\n",
|
64 |
"\n",
|
65 |
+
"import pkg_resources as r\n",
|
66 |
"import PIL\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
"import numpy\n",
|
68 |
+
"for module, min_version in [(PIL, \"9.3\"), (numpy, \"1.23\")]:\n",
|
69 |
+
" lib_version = r.parse_version(module.__version__)\n",
|
70 |
+
" print(module.__name__, lib_version)\n",
|
71 |
+
" if lib_version < r.parse_version(min_version):\n",
|
72 |
" raise Exception(\"Restart the runtime by clicking the 'RESTART RUNTIME' button above (or Runtime > Restart Runtime).\")"
|
73 |
],
|
74 |
"metadata": {
|
|
|
139 |
"# colab_notebook_name = remove_ipynb_extension(colab_notebook_filename)\n",
|
140 |
"\n",
|
141 |
"from google.colab import drive\n",
|
142 |
+
"try:\n",
|
143 |
+
" drive.mount(google_drive_mount_path)\n",
|
144 |
"\n",
|
145 |
+
" google_drive_data_directory_relative_path = google_drive_folder\n",
|
146 |
+
" google_drive_data_directory_path = f\"{google_drive_mount_path}/My Drive/{google_drive_data_directory_relative_path}\"\n",
|
147 |
+
" !mkdir -p \"{google_drive_data_directory_path}\"\n",
|
148 |
+
" !ln -nsf \"{google_drive_data_directory_path}\" ./data\n",
|
149 |
+
" !touch \"data/This folder is used by the Colab notebook \\\"{colab_notebook_filename}\\\".txt\"\n",
|
150 |
+
" !echo \"Data will be stored in Google Drive folder: \\\"{google_drive_data_directory_relative_path}\\\", which is mounted under \\\"{google_drive_data_directory_path}\\\"\"\n",
|
151 |
+
"except Exception as e:\n",
|
152 |
+
" print(\"Drive won't be mounted!\")"
|
153 |
],
|
154 |
"metadata": {
|
155 |
"id": "iZmRtUY68U5f"
|
README.md
CHANGED
@@ -35,6 +35,7 @@ Making evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA) e
|
|
35 |
* **[1-click up and running in Google Colab](#run-on-google-colab)** with a standard GPU runtime.
|
36 |
* Loads and stores data in Google Drive.
|
37 |
* Evaluate various LLaMA LoRA models stored in your folder or from Hugging Face.<br /><a href="https://youtu.be/IoEMgouZ5xU"><img width="640px" src="https://user-images.githubusercontent.com/3784687/231023326-f28c84e2-df74-4179-b0ac-c25c4e8ca001.gif" /></a>
|
|
|
38 |
* Fine-tune LLaMA models with different prompt templates and training dataset format.<br /><a href="https://youtu.be/IoEMgouZ5xU?t=60"><img width="640px" src="https://user-images.githubusercontent.com/3784687/231026640-b5cf5c79-9fe9-430b-8d4e-7346eb9567ad.gif" /></a>
|
39 |
* Load JSON and JSONL datasets from your folder, or even paste plain text directly into the UI.
|
40 |
* Supports Stanford Alpaca [seed_tasks](https://github.com/tatsu-lab/stanford_alpaca/blob/main/seed_tasks.jsonl), [alpaca_data](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json) and [OpenAI "prompt"-"completion"](https://platform.openai.com/docs/guides/fine-tuning/data-formatting) format.
|
@@ -86,11 +87,13 @@ setup: |
|
|
86 |
pip install wandb
|
87 |
cd ..
|
88 |
echo 'Dependencies installed.'
|
|
|
|
|
89 |
|
90 |
# Start the app.
|
91 |
run: |
|
92 |
echo 'Starting...'
|
93 |
-
python llama_lora_tuner/app.py --data_dir='/data' --wandb_api_key
|
94 |
```
|
95 |
|
96 |
Then launch a cluster to run the task:
|
|
|
35 |
* **[1-click up and running in Google Colab](#run-on-google-colab)** with a standard GPU runtime.
|
36 |
* Loads and stores data in Google Drive.
|
37 |
* Evaluate various LLaMA LoRA models stored in your folder or from Hugging Face.<br /><a href="https://youtu.be/IoEMgouZ5xU"><img width="640px" src="https://user-images.githubusercontent.com/3784687/231023326-f28c84e2-df74-4179-b0ac-c25c4e8ca001.gif" /></a>
|
38 |
+
* Switch between base models such as `decapoda-research/llama-7b-hf`, `nomic-ai/gpt4all-j`, `databricks/dolly-v2-7b`, `EleutherAI/gpt-j-6b`, or `EleutherAI/pythia-6.9b`.
|
39 |
* Fine-tune LLaMA models with different prompt templates and training dataset format.<br /><a href="https://youtu.be/IoEMgouZ5xU?t=60"><img width="640px" src="https://user-images.githubusercontent.com/3784687/231026640-b5cf5c79-9fe9-430b-8d4e-7346eb9567ad.gif" /></a>
|
40 |
* Load JSON and JSONL datasets from your folder, or even paste plain text directly into the UI.
|
41 |
* Supports Stanford Alpaca [seed_tasks](https://github.com/tatsu-lab/stanford_alpaca/blob/main/seed_tasks.jsonl), [alpaca_data](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json) and [OpenAI "prompt"-"completion"](https://platform.openai.com/docs/guides/fine-tuning/data-formatting) format.
|
|
|
87 |
pip install wandb
|
88 |
cd ..
|
89 |
echo 'Dependencies installed.'
|
90 |
+
echo 'Pre-downloading base models so that you won't have to wait for long once the app is ready...'
|
91 |
+
python llama_lora_tuner/download_base_model.py --base_model_names='decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j,databricks/dolly-v2-7b'
|
92 |
|
93 |
# Start the app.
|
94 |
run: |
|
95 |
echo 'Starting...'
|
96 |
+
python llama_lora_tuner/app.py --data_dir='/data' --wandb_api_key="$([ -f /data/secrets/wandb_api_key ] && cat /data/secrets/wandb_api_key | tr -d '\n')" --base_model=decapoda-research/llama-7b-hf --base_model_choices='decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j,databricks/dolly-v2-7b --share
|
97 |
```
|
98 |
|
99 |
Then launch a cluster to run the task:
|
app.py
CHANGED
@@ -15,6 +15,7 @@ def main(
|
|
15 |
base_model: str = "",
|
16 |
data_dir: str = "",
|
17 |
base_model_choices: str = "",
|
|
|
18 |
# Allows to listen on all interfaces by providing '0.0.0.0'.
|
19 |
server_name: str = "127.0.0.1",
|
20 |
share: bool = False,
|
@@ -60,6 +61,8 @@ def main(
|
|
60 |
if base_model not in Global.base_model_choices:
|
61 |
Global.base_model_choices = [base_model] + Global.base_model_choices
|
62 |
|
|
|
|
|
63 |
Global.data_dir = os.path.abspath(data_dir)
|
64 |
Global.load_8bit = load_8bit
|
65 |
|
|
|
15 |
base_model: str = "",
|
16 |
data_dir: str = "",
|
17 |
base_model_choices: str = "",
|
18 |
+
trust_remote_code: bool = False,
|
19 |
# Allows to listen on all interfaces by providing '0.0.0.0'.
|
20 |
server_name: str = "127.0.0.1",
|
21 |
share: bool = False,
|
|
|
61 |
if base_model not in Global.base_model_choices:
|
62 |
Global.base_model_choices = [base_model] + Global.base_model_choices
|
63 |
|
64 |
+
Global.trust_remote_code = trust_remote_code
|
65 |
+
|
66 |
Global.data_dir = os.path.abspath(data_dir)
|
67 |
Global.load_8bit = load_8bit
|
68 |
|
llama_lora/globals.py
CHANGED
@@ -20,6 +20,8 @@ class Global:
|
|
20 |
base_model_name: str = ""
|
21 |
base_model_choices: List[str] = []
|
22 |
|
|
|
|
|
23 |
# Functions
|
24 |
train_fn: Any = train
|
25 |
|
|
|
20 |
base_model_name: str = ""
|
21 |
base_model_choices: List[str] = []
|
22 |
|
23 |
+
trust_remote_code = False
|
24 |
+
|
25 |
# Functions
|
26 |
train_fn: Any = train
|
27 |
|
llama_lora/lib/inference.py
CHANGED
@@ -66,14 +66,14 @@ def generate(
|
|
66 |
with generate_with_streaming(**generate_params) as generator:
|
67 |
for output in generator:
|
68 |
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
|
69 |
-
yield decoded_output, output
|
70 |
if output[-1] in [tokenizer.eos_token_id]:
|
71 |
break
|
72 |
|
73 |
if generation_output:
|
74 |
output = generation_output.sequences[0]
|
75 |
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
|
76 |
-
yield decoded_output, output
|
77 |
|
78 |
return # early return for stream_output
|
79 |
|
@@ -82,5 +82,5 @@ def generate(
|
|
82 |
generation_output = model.generate(**generate_params)
|
83 |
output = generation_output.sequences[0]
|
84 |
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
|
85 |
-
yield decoded_output, output
|
86 |
return
|
|
|
66 |
with generate_with_streaming(**generate_params) as generator:
|
67 |
for output in generator:
|
68 |
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
|
69 |
+
yield decoded_output, output, False
|
70 |
if output[-1] in [tokenizer.eos_token_id]:
|
71 |
break
|
72 |
|
73 |
if generation_output:
|
74 |
output = generation_output.sequences[0]
|
75 |
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
|
76 |
+
yield decoded_output, output, True
|
77 |
|
78 |
return # early return for stream_output
|
79 |
|
|
|
82 |
generation_output = model.generate(**generate_params)
|
83 |
output = generation_output.sequences[0]
|
84 |
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
|
85 |
+
yield decoded_output, output, True
|
86 |
return
|
llama_lora/models.py
CHANGED
@@ -5,7 +5,10 @@ import json
|
|
5 |
import re
|
6 |
|
7 |
import torch
|
8 |
-
from transformers import
|
|
|
|
|
|
|
9 |
from peft import PeftModel
|
10 |
|
11 |
from .globals import Global
|
@@ -27,37 +30,83 @@ def get_new_base_model(base_model_name):
|
|
27 |
Global.name_of_new_base_model_that_is_ready_to_be_used = None
|
28 |
clear_cache()
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
device = get_device()
|
31 |
|
32 |
if device == "cuda":
|
33 |
-
|
34 |
-
|
35 |
load_in_8bit=Global.load_8bit,
|
36 |
torch_dtype=torch.float16,
|
37 |
# device_map="auto",
|
38 |
# ? https://github.com/tloen/alpaca-lora/issues/21
|
39 |
device_map={'': 0},
|
|
|
|
|
|
|
40 |
)
|
41 |
elif device == "mps":
|
42 |
-
|
43 |
-
|
44 |
device_map={"": device},
|
45 |
torch_dtype=torch.float16,
|
|
|
|
|
|
|
46 |
)
|
47 |
else:
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
)
|
51 |
|
52 |
-
tokenizer = get_tokenizer(base_model_name)
|
53 |
-
|
54 |
-
if re.match("[^/]+/llama", base_model_name):
|
55 |
-
model.config.pad_token_id = tokenizer.pad_token_id = 0
|
56 |
-
model.config.bos_token_id = tokenizer.bos_token_id = 1
|
57 |
-
model.config.eos_token_id = tokenizer.eos_token_id = 2
|
58 |
-
|
59 |
-
return model
|
60 |
-
|
61 |
|
62 |
def get_tokenizer(base_model_name):
|
63 |
if Global.ui_dev_mode:
|
@@ -68,10 +117,16 @@ def get_tokenizer(base_model_name):
|
|
68 |
return loaded_tokenizer
|
69 |
|
70 |
try:
|
71 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
|
|
72 |
except Exception as e:
|
73 |
if 'LLaMATokenizer' in str(e):
|
74 |
-
tokenizer = LlamaTokenizer.from_pretrained(
|
|
|
|
|
|
|
75 |
else:
|
76 |
raise e
|
77 |
|
@@ -100,13 +155,15 @@ def get_model(
|
|
100 |
peft_model_name_or_path = peft_model_name
|
101 |
|
102 |
if peft_model_name:
|
103 |
-
lora_models_directory_path = os.path.join(
|
|
|
104 |
possible_lora_model_path = os.path.join(
|
105 |
lora_models_directory_path, peft_model_name)
|
106 |
if os.path.isdir(possible_lora_model_path):
|
107 |
peft_model_name_or_path = possible_lora_model_path
|
108 |
|
109 |
-
possible_model_info_json_path = os.path.join(
|
|
|
110 |
if os.path.isfile(possible_model_info_json_path):
|
111 |
try:
|
112 |
with open(possible_model_info_json_path, "r") as file:
|
@@ -115,7 +172,8 @@ def get_model(
|
|
115 |
if possible_hf_model_name and json_data.get("load_from_hf"):
|
116 |
peft_model_name_or_path = possible_hf_model_name
|
117 |
except Exception as e:
|
118 |
-
raise ValueError(
|
|
|
119 |
|
120 |
Global.loaded_models.prepare_to_set()
|
121 |
clear_cache()
|
@@ -148,7 +206,8 @@ def get_model(
|
|
148 |
)
|
149 |
|
150 |
if re.match("[^/]+/llama", base_model_name):
|
151 |
-
model.config.pad_token_id = get_tokenizer(
|
|
|
152 |
model.config.bos_token_id = 1
|
153 |
model.config.eos_token_id = 2
|
154 |
|
@@ -166,7 +225,8 @@ def get_model(
|
|
166 |
|
167 |
|
168 |
def prepare_base_model(base_model_name=Global.default_base_model_name):
|
169 |
-
Global.new_base_model_that_is_ready_to_be_used = get_new_base_model(
|
|
|
170 |
Global.name_of_new_base_model_that_is_ready_to_be_used = base_model_name
|
171 |
|
172 |
|
|
|
5 |
import re
|
6 |
|
7 |
import torch
|
8 |
+
from transformers import (
|
9 |
+
AutoModelForCausalLM, AutoModel,
|
10 |
+
AutoTokenizer, LlamaTokenizer
|
11 |
+
)
|
12 |
from peft import PeftModel
|
13 |
|
14 |
from .globals import Global
|
|
|
30 |
Global.name_of_new_base_model_that_is_ready_to_be_used = None
|
31 |
clear_cache()
|
32 |
|
33 |
+
model_class = AutoModelForCausalLM
|
34 |
+
from_tf = False
|
35 |
+
force_download = False
|
36 |
+
has_tried_force_download = False
|
37 |
+
while True:
|
38 |
+
try:
|
39 |
+
model = _get_model_from_pretrained(
|
40 |
+
model_class, base_model_name, from_tf=from_tf, force_download=force_download)
|
41 |
+
break
|
42 |
+
except Exception as e:
|
43 |
+
if 'from_tf' in str(e):
|
44 |
+
print(
|
45 |
+
f"Got error while loading model {base_model_name} with AutoModelForCausalLM: {e}.")
|
46 |
+
print("Retrying with from_tf=True...")
|
47 |
+
from_tf = True
|
48 |
+
force_download = False
|
49 |
+
elif model_class == AutoModelForCausalLM:
|
50 |
+
print(
|
51 |
+
f"Got error while loading model {base_model_name} with AutoModelForCausalLM: {e}.")
|
52 |
+
print("Retrying with AutoModel...")
|
53 |
+
model_class = AutoModel
|
54 |
+
force_download = False
|
55 |
+
else:
|
56 |
+
if has_tried_force_download:
|
57 |
+
raise e
|
58 |
+
print(
|
59 |
+
f"Got error while loading model {base_model_name}: {e}.")
|
60 |
+
print("Retrying with force_download=True...")
|
61 |
+
model_class = AutoModelForCausalLM
|
62 |
+
from_tf = False
|
63 |
+
force_download = True
|
64 |
+
has_tried_force_download = True
|
65 |
+
|
66 |
+
tokenizer = get_tokenizer(base_model_name)
|
67 |
+
|
68 |
+
if re.match("[^/]+/llama", base_model_name):
|
69 |
+
model.config.pad_token_id = tokenizer.pad_token_id = 0
|
70 |
+
model.config.bos_token_id = tokenizer.bos_token_id = 1
|
71 |
+
model.config.eos_token_id = tokenizer.eos_token_id = 2
|
72 |
+
|
73 |
+
return model
|
74 |
+
|
75 |
+
|
76 |
+
def _get_model_from_pretrained(model_class, model_name, from_tf=False, force_download=False):
|
77 |
device = get_device()
|
78 |
|
79 |
if device == "cuda":
|
80 |
+
return model_class.from_pretrained(
|
81 |
+
model_name,
|
82 |
load_in_8bit=Global.load_8bit,
|
83 |
torch_dtype=torch.float16,
|
84 |
# device_map="auto",
|
85 |
# ? https://github.com/tloen/alpaca-lora/issues/21
|
86 |
device_map={'': 0},
|
87 |
+
from_tf=from_tf,
|
88 |
+
force_download=force_download,
|
89 |
+
trust_remote_code=Global.trust_remote_code
|
90 |
)
|
91 |
elif device == "mps":
|
92 |
+
return model_class.from_pretrained(
|
93 |
+
model_name,
|
94 |
device_map={"": device},
|
95 |
torch_dtype=torch.float16,
|
96 |
+
from_tf=from_tf,
|
97 |
+
force_download=force_download,
|
98 |
+
trust_remote_code=Global.trust_remote_code
|
99 |
)
|
100 |
else:
|
101 |
+
return model_class.from_pretrained(
|
102 |
+
model_name,
|
103 |
+
device_map={"": device},
|
104 |
+
low_cpu_mem_usage=True,
|
105 |
+
from_tf=from_tf,
|
106 |
+
force_download=force_download,
|
107 |
+
trust_remote_code=Global.trust_remote_code
|
108 |
)
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
def get_tokenizer(base_model_name):
|
112 |
if Global.ui_dev_mode:
|
|
|
117 |
return loaded_tokenizer
|
118 |
|
119 |
try:
|
120 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
121 |
+
base_model_name,
|
122 |
+
trust_remote_code=Global.trust_remote_code
|
123 |
+
)
|
124 |
except Exception as e:
|
125 |
if 'LLaMATokenizer' in str(e):
|
126 |
+
tokenizer = LlamaTokenizer.from_pretrained(
|
127 |
+
base_model_name,
|
128 |
+
trust_remote_code=Global.trust_remote_code
|
129 |
+
)
|
130 |
else:
|
131 |
raise e
|
132 |
|
|
|
155 |
peft_model_name_or_path = peft_model_name
|
156 |
|
157 |
if peft_model_name:
|
158 |
+
lora_models_directory_path = os.path.join(
|
159 |
+
Global.data_dir, "lora_models")
|
160 |
possible_lora_model_path = os.path.join(
|
161 |
lora_models_directory_path, peft_model_name)
|
162 |
if os.path.isdir(possible_lora_model_path):
|
163 |
peft_model_name_or_path = possible_lora_model_path
|
164 |
|
165 |
+
possible_model_info_json_path = os.path.join(
|
166 |
+
possible_lora_model_path, "info.json")
|
167 |
if os.path.isfile(possible_model_info_json_path):
|
168 |
try:
|
169 |
with open(possible_model_info_json_path, "r") as file:
|
|
|
172 |
if possible_hf_model_name and json_data.get("load_from_hf"):
|
173 |
peft_model_name_or_path = possible_hf_model_name
|
174 |
except Exception as e:
|
175 |
+
raise ValueError(
|
176 |
+
"Error reading model info from {possible_model_info_json_path}: {e}")
|
177 |
|
178 |
Global.loaded_models.prepare_to_set()
|
179 |
clear_cache()
|
|
|
206 |
)
|
207 |
|
208 |
if re.match("[^/]+/llama", base_model_name):
|
209 |
+
model.config.pad_token_id = get_tokenizer(
|
210 |
+
base_model_name).pad_token_id = 0
|
211 |
model.config.bos_token_id = 1
|
212 |
model.config.eos_token_id = 2
|
213 |
|
|
|
225 |
|
226 |
|
227 |
def prepare_base_model(base_model_name=Global.default_base_model_name):
|
228 |
+
Global.new_base_model_that_is_ready_to_be_used = get_new_base_model(
|
229 |
+
base_model_name)
|
230 |
Global.name_of_new_base_model_that_is_ready_to_be_used = base_model_name
|
231 |
|
232 |
|
llama_lora/ui/inference_ui.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
import time
|
3 |
import json
|
4 |
|
@@ -21,13 +22,21 @@ default_show_raw = True
|
|
21 |
inference_output_lines = 12
|
22 |
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
|
25 |
base_model_name = Global.base_model_name
|
26 |
|
27 |
try:
|
28 |
get_tokenizer(base_model_name)
|
29 |
get_model(base_model_name, lora_model_name)
|
30 |
-
return ("", "")
|
31 |
|
32 |
except Exception as e:
|
33 |
raise gr.Error(e)
|
@@ -65,6 +74,31 @@ def do_inference(
|
|
65 |
prompter = Prompter(prompt_template)
|
66 |
prompt = prompter.generate_prompt(variables)
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
if Global.ui_dev_mode:
|
69 |
message = f"Hi, Iβm currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {base_model_name}\nLoRA model: {lora_model_name}\n\nThe following is your prompt:\n\n{prompt}"
|
70 |
print(message)
|
@@ -83,35 +117,50 @@ def do_inference(
|
|
83 |
out += "\n"
|
84 |
yield out
|
85 |
|
|
|
86 |
for partial_sentence in word_generator(message):
|
|
|
87 |
yield (
|
88 |
gr.Textbox.update(
|
89 |
-
value=
|
|
|
90 |
json.dumps(
|
91 |
-
list(range(len(
|
|
|
|
|
|
|
|
|
|
|
92 |
)
|
93 |
time.sleep(0.05)
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
return
|
96 |
time.sleep(1)
|
97 |
yield (
|
98 |
gr.Textbox.update(value=message, lines=inference_output_lines),
|
99 |
-
json.dumps(list(range(len(message.split()))), indent=2)
|
|
|
|
|
|
|
100 |
)
|
101 |
return
|
102 |
|
103 |
tokenizer = get_tokenizer(base_model_name)
|
104 |
model = get_model(base_model_name, lora_model_name)
|
105 |
|
106 |
-
generation_config = GenerationConfig(
|
107 |
-
temperature=float(temperature), # to avoid ValueError('`temperature` has to be a strictly positive float, but is 2')
|
108 |
-
top_p=top_p,
|
109 |
-
top_k=top_k,
|
110 |
-
repetition_penalty=repetition_penalty,
|
111 |
-
num_beams=num_beams,
|
112 |
-
do_sample=temperature > 0, # https://github.com/huggingface/transformers/issues/22405#issuecomment-1485527953
|
113 |
-
)
|
114 |
-
|
115 |
def ui_generation_stopping_criteria(input_ids, score, **kwargs):
|
116 |
if Global.should_stop_generating:
|
117 |
return True
|
@@ -129,10 +178,8 @@ def do_inference(
|
|
129 |
'stream_output': stream_output
|
130 |
}
|
131 |
|
132 |
-
for (decoded_output, output) in generate(**generation_args):
|
133 |
-
raw_output_str =
|
134 |
-
if show_raw:
|
135 |
-
raw_output_str = str(output)
|
136 |
response = prompter.get_response(decoded_output)
|
137 |
|
138 |
if Global.should_stop_generating:
|
@@ -141,7 +188,12 @@ def do_inference(
|
|
141 |
yield (
|
142 |
gr.Textbox.update(
|
143 |
value=response, lines=inference_output_lines),
|
144 |
-
raw_output_str
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
if Global.should_stop_generating:
|
147 |
# If the user stops the generation, and then clicks the
|
@@ -199,11 +251,13 @@ def get_warning_message_for_lora_model_and_prompt_template(lora_model, prompt_te
|
|
199 |
if lora_mode_info and isinstance(lora_mode_info, dict):
|
200 |
model_base_model = lora_mode_info.get("base_model")
|
201 |
if model_base_model and model_base_model != Global.base_model_name:
|
202 |
-
messages.append(
|
|
|
203 |
|
204 |
model_prompt_template = lora_mode_info.get("prompt_template")
|
205 |
if model_prompt_template and model_prompt_template != prompt_template:
|
206 |
-
messages.append(
|
|
|
207 |
|
208 |
return " ".join(messages)
|
209 |
|
@@ -221,7 +275,8 @@ def handle_prompt_template_change(prompt_template, lora_model):
|
|
221 |
|
222 |
model_prompt_template_message_update = gr.Markdown.update(
|
223 |
"", visible=False)
|
224 |
-
warning_message = get_warning_message_for_lora_model_and_prompt_template(
|
|
|
225 |
if warning_message:
|
226 |
model_prompt_template_message_update = gr.Markdown.update(
|
227 |
warning_message, visible=True)
|
@@ -241,7 +296,8 @@ def handle_lora_model_change(lora_model, prompt_template):
|
|
241 |
|
242 |
model_prompt_template_message_update = gr.Markdown.update(
|
243 |
"", visible=False)
|
244 |
-
warning_message = get_warning_message_for_lora_model_and_prompt_template(
|
|
|
245 |
if warning_message:
|
246 |
model_prompt_template_message_update = gr.Markdown.update(
|
247 |
warning_message, visible=True)
|
@@ -260,6 +316,56 @@ def update_prompt_preview(prompt_template,
|
|
260 |
|
261 |
|
262 |
def inference_ui():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
things_that_might_timeout = []
|
264 |
|
265 |
with gr.Blocks() as inference_ui_blocks:
|
@@ -387,6 +493,47 @@ def inference_ui():
|
|
387 |
inference_output = gr.Textbox(
|
388 |
lines=inference_output_lines, label="Output", elem_id="inference_output")
|
389 |
inference_output.style(show_copy_button=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
with gr.Accordion(
|
391 |
"Raw Output",
|
392 |
open=not default_show_raw,
|
@@ -400,7 +547,8 @@ def inference_ui():
|
|
400 |
interactive=False,
|
401 |
elem_id="inference_raw_output")
|
402 |
|
403 |
-
reload_selected_models_btn = gr.Button(
|
|
|
404 |
|
405 |
show_raw_change_event = show_raw.change(
|
406 |
fn=lambda show_raw: gr.Accordion.update(visible=show_raw),
|
@@ -440,7 +588,8 @@ def inference_ui():
|
|
440 |
generate_event = generate_btn.click(
|
441 |
fn=prepare_inference,
|
442 |
inputs=[lora_model],
|
443 |
-
outputs=[inference_output,
|
|
|
444 |
).then(
|
445 |
fn=do_inference,
|
446 |
inputs=[
|
@@ -457,7 +606,8 @@ def inference_ui():
|
|
457 |
stream_output,
|
458 |
show_raw,
|
459 |
],
|
460 |
-
outputs=[inference_output,
|
|
|
461 |
api_name="inference"
|
462 |
)
|
463 |
stop_btn.click(
|
|
|
1 |
import gradio as gr
|
2 |
+
import os
|
3 |
import time
|
4 |
import json
|
5 |
|
|
|
22 |
inference_output_lines = 12
|
23 |
|
24 |
|
25 |
+
class LoggingItem:
|
26 |
+
def __init__(self, label):
|
27 |
+
self.label = label
|
28 |
+
|
29 |
+
def deserialize(self, value, **kwargs):
|
30 |
+
return value
|
31 |
+
|
32 |
+
|
33 |
def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
|
34 |
base_model_name = Global.base_model_name
|
35 |
|
36 |
try:
|
37 |
get_tokenizer(base_model_name)
|
38 |
get_model(base_model_name, lora_model_name)
|
39 |
+
return ("", "", gr.Textbox.update(visible=False))
|
40 |
|
41 |
except Exception as e:
|
42 |
raise gr.Error(e)
|
|
|
74 |
prompter = Prompter(prompt_template)
|
75 |
prompt = prompter.generate_prompt(variables)
|
76 |
|
77 |
+
generation_config = GenerationConfig(
|
78 |
+
# to avoid ValueError('`temperature` has to be a strictly positive float, but is 2')
|
79 |
+
temperature=float(temperature),
|
80 |
+
top_p=top_p,
|
81 |
+
top_k=top_k,
|
82 |
+
repetition_penalty=repetition_penalty,
|
83 |
+
num_beams=num_beams,
|
84 |
+
# https://github.com/huggingface/transformers/issues/22405#issuecomment-1485527953
|
85 |
+
do_sample=temperature > 0,
|
86 |
+
)
|
87 |
+
|
88 |
+
def get_output_for_flagging(output, raw_output, completed=True):
|
89 |
+
return json.dumps({
|
90 |
+
'base_model': base_model_name,
|
91 |
+
'adaptor_model': lora_model_name,
|
92 |
+
'prompt': prompt,
|
93 |
+
'output': output,
|
94 |
+
'completed': completed,
|
95 |
+
'raw_output': raw_output,
|
96 |
+
'max_new_tokens': max_new_tokens,
|
97 |
+
'prompt_template': prompt_template,
|
98 |
+
'prompt_template_variables': variables,
|
99 |
+
'generation_config': generation_config.to_dict(),
|
100 |
+
})
|
101 |
+
|
102 |
if Global.ui_dev_mode:
|
103 |
message = f"Hi, Iβm currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {base_model_name}\nLoRA model: {lora_model_name}\n\nThe following is your prompt:\n\n{prompt}"
|
104 |
print(message)
|
|
|
117 |
out += "\n"
|
118 |
yield out
|
119 |
|
120 |
+
output = ""
|
121 |
for partial_sentence in word_generator(message):
|
122 |
+
output = partial_sentence
|
123 |
yield (
|
124 |
gr.Textbox.update(
|
125 |
+
value=output,
|
126 |
+
lines=inference_output_lines),
|
127 |
json.dumps(
|
128 |
+
list(range(len(output.split()))),
|
129 |
+
indent=2),
|
130 |
+
gr.Textbox.update(
|
131 |
+
value=get_output_for_flagging(
|
132 |
+
output, "", completed=False),
|
133 |
+
visible=True)
|
134 |
)
|
135 |
time.sleep(0.05)
|
136 |
|
137 |
+
yield (
|
138 |
+
gr.Textbox.update(
|
139 |
+
value=output,
|
140 |
+
lines=inference_output_lines),
|
141 |
+
json.dumps(
|
142 |
+
list(range(len(output.split()))),
|
143 |
+
indent=2),
|
144 |
+
gr.Textbox.update(
|
145 |
+
value=get_output_for_flagging(
|
146 |
+
output, "", completed=True),
|
147 |
+
visible=True)
|
148 |
+
)
|
149 |
+
|
150 |
return
|
151 |
time.sleep(1)
|
152 |
yield (
|
153 |
gr.Textbox.update(value=message, lines=inference_output_lines),
|
154 |
+
json.dumps(list(range(len(message.split()))), indent=2),
|
155 |
+
gr.Textbox.update(
|
156 |
+
value=get_output_for_flagging(message, ""),
|
157 |
+
visible=True)
|
158 |
)
|
159 |
return
|
160 |
|
161 |
tokenizer = get_tokenizer(base_model_name)
|
162 |
model = get_model(base_model_name, lora_model_name)
|
163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
def ui_generation_stopping_criteria(input_ids, score, **kwargs):
|
165 |
if Global.should_stop_generating:
|
166 |
return True
|
|
|
178 |
'stream_output': stream_output
|
179 |
}
|
180 |
|
181 |
+
for (decoded_output, output, completed) in generate(**generation_args):
|
182 |
+
raw_output_str = str(output)
|
|
|
|
|
183 |
response = prompter.get_response(decoded_output)
|
184 |
|
185 |
if Global.should_stop_generating:
|
|
|
188 |
yield (
|
189 |
gr.Textbox.update(
|
190 |
value=response, lines=inference_output_lines),
|
191 |
+
raw_output_str,
|
192 |
+
gr.Textbox.update(
|
193 |
+
value=get_output_for_flagging(
|
194 |
+
decoded_output, raw_output_str, completed=completed),
|
195 |
+
visible=True)
|
196 |
+
)
|
197 |
|
198 |
if Global.should_stop_generating:
|
199 |
# If the user stops the generation, and then clicks the
|
|
|
251 |
if lora_mode_info and isinstance(lora_mode_info, dict):
|
252 |
model_base_model = lora_mode_info.get("base_model")
|
253 |
if model_base_model and model_base_model != Global.base_model_name:
|
254 |
+
messages.append(
|
255 |
+
f"β οΈ This model was trained on top of base model `{model_base_model}`, it might not work properly with the selected base model `{Global.base_model_name}`.")
|
256 |
|
257 |
model_prompt_template = lora_mode_info.get("prompt_template")
|
258 |
if model_prompt_template and model_prompt_template != prompt_template:
|
259 |
+
messages.append(
|
260 |
+
f"This model was trained with prompt template `{model_prompt_template}`.")
|
261 |
|
262 |
return " ".join(messages)
|
263 |
|
|
|
275 |
|
276 |
model_prompt_template_message_update = gr.Markdown.update(
|
277 |
"", visible=False)
|
278 |
+
warning_message = get_warning_message_for_lora_model_and_prompt_template(
|
279 |
+
lora_model, prompt_template)
|
280 |
if warning_message:
|
281 |
model_prompt_template_message_update = gr.Markdown.update(
|
282 |
warning_message, visible=True)
|
|
|
296 |
|
297 |
model_prompt_template_message_update = gr.Markdown.update(
|
298 |
"", visible=False)
|
299 |
+
warning_message = get_warning_message_for_lora_model_and_prompt_template(
|
300 |
+
lora_model, prompt_template)
|
301 |
if warning_message:
|
302 |
model_prompt_template_message_update = gr.Markdown.update(
|
303 |
warning_message, visible=True)
|
|
|
316 |
|
317 |
|
318 |
def inference_ui():
|
319 |
+
flagging_dir = os.path.join(Global.data_dir, "flagging", "inference")
|
320 |
+
if not os.path.exists(flagging_dir):
|
321 |
+
os.makedirs(flagging_dir)
|
322 |
+
|
323 |
+
flag_callback = gr.CSVLogger()
|
324 |
+
flag_components = [
|
325 |
+
LoggingItem("Base Model"),
|
326 |
+
LoggingItem("Adaptor Model"),
|
327 |
+
LoggingItem("Type"),
|
328 |
+
LoggingItem("Prompt"),
|
329 |
+
LoggingItem("Output"),
|
330 |
+
LoggingItem("Completed"),
|
331 |
+
LoggingItem("Config"),
|
332 |
+
LoggingItem("Raw Output"),
|
333 |
+
LoggingItem("Max New Tokens"),
|
334 |
+
LoggingItem("Prompt Template"),
|
335 |
+
LoggingItem("Prompt Template Variables"),
|
336 |
+
LoggingItem("Generation Config"),
|
337 |
+
]
|
338 |
+
flag_callback.setup(flag_components, flagging_dir)
|
339 |
+
|
340 |
+
def get_flag_callback_args(output_for_flagging_str, flag_type):
|
341 |
+
output_for_flagging = json.loads(output_for_flagging_str)
|
342 |
+
generation_config = output_for_flagging.get("generation_config", {})
|
343 |
+
config = []
|
344 |
+
if generation_config.get('do_sample', False):
|
345 |
+
config.append(
|
346 |
+
f"Temperature: {generation_config.get('temperature')}")
|
347 |
+
config.append(f"Top P: {generation_config.get('top_p')}")
|
348 |
+
config.append(f"Top K: {generation_config.get('top_k')}")
|
349 |
+
num_beams = generation_config.get('num_beams', 1)
|
350 |
+
if num_beams > 1:
|
351 |
+
config.append(f"Beams: {generation_config.get('num_beams')}")
|
352 |
+
config.append(f"RP: {generation_config.get('repetition_penalty')}")
|
353 |
+
return [
|
354 |
+
output_for_flagging.get("base_model", ""),
|
355 |
+
output_for_flagging.get("adaptor_model", ""),
|
356 |
+
flag_type,
|
357 |
+
output_for_flagging.get("prompt", ""),
|
358 |
+
output_for_flagging.get("output", ""),
|
359 |
+
str(output_for_flagging.get("completed", "")),
|
360 |
+
", ".join(config),
|
361 |
+
output_for_flagging.get("raw_output", ""),
|
362 |
+
str(output_for_flagging.get("max_new_tokens", "")),
|
363 |
+
output_for_flagging.get("prompt_template", ""),
|
364 |
+
json.dumps(output_for_flagging.get(
|
365 |
+
"prompt_template_variables", "")),
|
366 |
+
json.dumps(output_for_flagging.get("generation_config", "")),
|
367 |
+
]
|
368 |
+
|
369 |
things_that_might_timeout = []
|
370 |
|
371 |
with gr.Blocks() as inference_ui_blocks:
|
|
|
493 |
inference_output = gr.Textbox(
|
494 |
lines=inference_output_lines, label="Output", elem_id="inference_output")
|
495 |
inference_output.style(show_copy_button=True)
|
496 |
+
|
497 |
+
with gr.Row(elem_id="inference_flagging_group"):
|
498 |
+
output_for_flagging = gr.Textbox(
|
499 |
+
interactive=False, visible=False,
|
500 |
+
elem_id="inference_output_for_flagging")
|
501 |
+
flag_btn = gr.Button(
|
502 |
+
"Flag", elem_id="inference_flag_btn")
|
503 |
+
flag_up_btn = gr.Button(
|
504 |
+
"π", elem_id="inference_flag_up_btn")
|
505 |
+
flag_down_btn = gr.Button(
|
506 |
+
"π", elem_id="inference_flag_down_btn")
|
507 |
+
flag_output = gr.Markdown(
|
508 |
+
"", elem_id="inference_flag_output")
|
509 |
+
flag_btn.click(
|
510 |
+
lambda d: (flag_callback.flag(
|
511 |
+
get_flag_callback_args(d, "Flag"),
|
512 |
+
flag_option="Flag",
|
513 |
+
username=None
|
514 |
+
), "")[1],
|
515 |
+
inputs=[output_for_flagging],
|
516 |
+
outputs=[flag_output],
|
517 |
+
preprocess=False)
|
518 |
+
flag_up_btn.click(
|
519 |
+
lambda d: (flag_callback.flag(
|
520 |
+
get_flag_callback_args(d, "π"),
|
521 |
+
flag_option="Up Vote",
|
522 |
+
username=None
|
523 |
+
), "")[1],
|
524 |
+
inputs=[output_for_flagging],
|
525 |
+
outputs=[flag_output],
|
526 |
+
preprocess=False)
|
527 |
+
flag_down_btn.click(
|
528 |
+
lambda d: (flag_callback.flag(
|
529 |
+
get_flag_callback_args(d, "π"),
|
530 |
+
flag_option="Down Vote",
|
531 |
+
username=None
|
532 |
+
), "")[1],
|
533 |
+
inputs=[output_for_flagging],
|
534 |
+
outputs=[flag_output],
|
535 |
+
preprocess=False)
|
536 |
+
|
537 |
with gr.Accordion(
|
538 |
"Raw Output",
|
539 |
open=not default_show_raw,
|
|
|
547 |
interactive=False,
|
548 |
elem_id="inference_raw_output")
|
549 |
|
550 |
+
reload_selected_models_btn = gr.Button(
|
551 |
+
"", elem_id="inference_reload_selected_models_btn")
|
552 |
|
553 |
show_raw_change_event = show_raw.change(
|
554 |
fn=lambda show_raw: gr.Accordion.update(visible=show_raw),
|
|
|
588 |
generate_event = generate_btn.click(
|
589 |
fn=prepare_inference,
|
590 |
inputs=[lora_model],
|
591 |
+
outputs=[inference_output,
|
592 |
+
inference_raw_output, output_for_flagging],
|
593 |
).then(
|
594 |
fn=do_inference,
|
595 |
inputs=[
|
|
|
606 |
stream_output,
|
607 |
show_raw,
|
608 |
],
|
609 |
+
outputs=[inference_output,
|
610 |
+
inference_raw_output, output_for_flagging],
|
611 |
api_name="inference"
|
612 |
)
|
613 |
stop_btn.click(
|
llama_lora/ui/main_page.py
CHANGED
@@ -398,6 +398,45 @@ def main_page_custom_css():
|
|
398 |
bottom: 16px;
|
399 |
}
|
400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
#dataset_plain_text_input_variables_separator textarea,
|
402 |
#dataset_plain_text_input_and_output_separator textarea,
|
403 |
#dataset_plain_text_data_separator textarea {
|
|
|
398 |
bottom: 16px;
|
399 |
}
|
400 |
|
401 |
+
#inference_flagging_group {
|
402 |
+
position: relative;
|
403 |
+
}
|
404 |
+
#inference_flag_output {
|
405 |
+
min-height: 1px !important;
|
406 |
+
position: absolute;
|
407 |
+
top: 0;
|
408 |
+
bottom: 0;
|
409 |
+
right: 0;
|
410 |
+
pointer-events: none;
|
411 |
+
opacity: 0.5;
|
412 |
+
}
|
413 |
+
#inference_flag_output .wrap {
|
414 |
+
top: 0;
|
415 |
+
bottom: 0;
|
416 |
+
right: 0;
|
417 |
+
justify-content: center;
|
418 |
+
align-items: flex-end;
|
419 |
+
padding: 4px !important;
|
420 |
+
}
|
421 |
+
#inference_flag_output .wrap svg {
|
422 |
+
display: none;
|
423 |
+
}
|
424 |
+
.form:has(> #inference_output_for_flagging),
|
425 |
+
#inference_output_for_flagging {
|
426 |
+
display: none;
|
427 |
+
}
|
428 |
+
#inference_flagging_group:has(#inference_output_for_flagging.hidden) {
|
429 |
+
opacity: 0.5;
|
430 |
+
pointer-events: none;
|
431 |
+
}
|
432 |
+
#inference_flag_up_btn, #inference_flag_down_btn {
|
433 |
+
min-width: 44px;
|
434 |
+
flex-grow: 1;
|
435 |
+
}
|
436 |
+
#inference_flag_btn {
|
437 |
+
flex-grow: 2;
|
438 |
+
}
|
439 |
+
|
440 |
#dataset_plain_text_input_variables_separator textarea,
|
441 |
#dataset_plain_text_input_and_output_separator textarea,
|
442 |
#dataset_plain_text_data_separator textarea {
|