Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
β’
a5d7977
1
Parent(s):
fdcd724
add wandb support
Browse files- .gitignore +1 -0
- README.md +2 -1
- app.py +27 -1
- llama_lora/globals.py +5 -0
- llama_lora/lib/finetune.py +27 -3
- llama_lora/ui/finetune_ui.py +4 -1
.gitignore
CHANGED
@@ -3,4 +3,5 @@ __pycache__/
|
|
3 |
/venv
|
4 |
.vscode
|
5 |
|
|
|
6 |
/data
|
|
|
3 |
/venv
|
4 |
.vscode
|
5 |
|
6 |
+
/wandb
|
7 |
/data
|
README.md
CHANGED
@@ -60,13 +60,14 @@ file_mounts:
|
|
60 |
setup: |
|
61 |
git clone https://github.com/zetavg/LLaMA-LoRA-Tuner.git llama_lora_tuner
|
62 |
cd llama_lora_tuner && pip install -r requirements.lock.txt
|
|
|
63 |
cd ..
|
64 |
echo 'Dependencies installed.'
|
65 |
|
66 |
# Start the app.
|
67 |
run: |
|
68 |
echo 'Starting...'
|
69 |
-
python llama_lora_tuner/app.py --data_dir='/data' --base_model='decapoda-research/llama-7b-hf' --share
|
70 |
```
|
71 |
|
72 |
Then launch a cluster to run the task:
|
|
|
60 |
setup: |
|
61 |
git clone https://github.com/zetavg/LLaMA-LoRA-Tuner.git llama_lora_tuner
|
62 |
cd llama_lora_tuner && pip install -r requirements.lock.txt
|
63 |
+
pip install wandb
|
64 |
cd ..
|
65 |
echo 'Dependencies installed.'
|
66 |
|
67 |
# Start the app.
|
68 |
run: |
|
69 |
echo 'Starting...'
|
70 |
+
python llama_lora_tuner/app.py --data_dir='/data' --wandb_api_key "$([ -f /data/secrets/wandb_api_key ] && cat /data/secrets/wandb_api_key | tr -d '\n')" --base_model='decapoda-research/llama-7b-hf' --share
|
71 |
```
|
72 |
|
73 |
Then launch a cluster to run the task:
|
app.py
CHANGED
@@ -5,21 +5,37 @@ import fire
|
|
5 |
import gradio as gr
|
6 |
|
7 |
from llama_lora.globals import Global
|
|
|
8 |
from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
|
9 |
from llama_lora.utils.data import init_data_dir
|
10 |
|
11 |
|
|
|
12 |
def main(
|
13 |
-
load_8bit: bool = False,
|
14 |
base_model: str = "",
|
15 |
data_dir: str = "",
|
16 |
# Allows to listen on all interfaces by providing '0.0.0.0'.
|
17 |
server_name: str = "127.0.0.1",
|
18 |
share: bool = False,
|
19 |
skip_loading_base_model: bool = False,
|
|
|
20 |
ui_show_sys_info: bool = True,
|
21 |
ui_dev_mode: bool = False,
|
|
|
|
|
22 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
base_model = base_model or os.environ.get("LLAMA_LORA_BASE_MODEL", "")
|
24 |
data_dir = data_dir or os.environ.get("LLAMA_LORA_DATA_DIR", "")
|
25 |
assert (
|
@@ -34,12 +50,22 @@ def main(
|
|
34 |
Global.data_dir = os.path.abspath(data_dir)
|
35 |
Global.load_8bit = load_8bit
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
Global.ui_dev_mode = ui_dev_mode
|
38 |
Global.ui_show_sys_info = ui_show_sys_info
|
39 |
|
40 |
os.makedirs(data_dir, exist_ok=True)
|
41 |
init_data_dir()
|
42 |
|
|
|
|
|
|
|
43 |
with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
|
44 |
main_page()
|
45 |
|
|
|
5 |
import gradio as gr
|
6 |
|
7 |
from llama_lora.globals import Global
|
8 |
+
from llama_lora.models import prepare_base_model
|
9 |
from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
|
10 |
from llama_lora.utils.data import init_data_dir
|
11 |
|
12 |
|
13 |
+
|
14 |
def main(
|
|
|
15 |
base_model: str = "",
|
16 |
data_dir: str = "",
|
17 |
# Allows to listen on all interfaces by providing '0.0.0.0'.
|
18 |
server_name: str = "127.0.0.1",
|
19 |
share: bool = False,
|
20 |
skip_loading_base_model: bool = False,
|
21 |
+
load_8bit: bool = False,
|
22 |
ui_show_sys_info: bool = True,
|
23 |
ui_dev_mode: bool = False,
|
24 |
+
wandb_api_key: str = "",
|
25 |
+
wandb_project: str = "",
|
26 |
):
|
27 |
+
'''
|
28 |
+
Start the LLaMA-LoRA Tuner UI.
|
29 |
+
|
30 |
+
:param base_model: (required) The name of the default base model to use.
|
31 |
+
:param data_dir: (required) The path to the directory to store data.
|
32 |
+
:param server_name: Allows to listen on all interfaces by providing '0.0.0.0'.
|
33 |
+
:param share: Create a public Gradio URL.
|
34 |
+
|
35 |
+
:param wandb_api_key: The API key for Weights & Biases. Setting either this or `wandb_project` will enable Weights & Biases.
|
36 |
+
:param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
|
37 |
+
'''
|
38 |
+
|
39 |
base_model = base_model or os.environ.get("LLAMA_LORA_BASE_MODEL", "")
|
40 |
data_dir = data_dir or os.environ.get("LLAMA_LORA_DATA_DIR", "")
|
41 |
assert (
|
|
|
50 |
Global.data_dir = os.path.abspath(data_dir)
|
51 |
Global.load_8bit = load_8bit
|
52 |
|
53 |
+
if len(wandb_api_key) > 0:
|
54 |
+
Global.enable_wandb = True
|
55 |
+
Global.wandb_api_key = wandb_api_key
|
56 |
+
if len(wandb_project) > 0:
|
57 |
+
Global.enable_wandb = True
|
58 |
+
Global.wandb_project = wandb_project
|
59 |
+
|
60 |
Global.ui_dev_mode = ui_dev_mode
|
61 |
Global.ui_show_sys_info = ui_show_sys_info
|
62 |
|
63 |
os.makedirs(data_dir, exist_ok=True)
|
64 |
init_data_dir()
|
65 |
|
66 |
+
if (not skip_loading_base_model) and (not ui_dev_mode):
|
67 |
+
prepare_base_model(base_model)
|
68 |
+
|
69 |
with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
|
70 |
main_page()
|
71 |
|
llama_lora/globals.py
CHANGED
@@ -40,6 +40,11 @@ class Global:
|
|
40 |
gpu_total_cores = None # GPU total cores
|
41 |
gpu_total_memory = None
|
42 |
|
|
|
|
|
|
|
|
|
|
|
43 |
# UI related
|
44 |
ui_title: str = "LLaMA-LoRA Tuner"
|
45 |
ui_emoji: str = "π¦ποΈ"
|
|
|
40 |
gpu_total_cores = None # GPU total cores
|
41 |
gpu_total_memory = None
|
42 |
|
43 |
+
# WandB
|
44 |
+
enable_wandb = False
|
45 |
+
wandb_api_key = None
|
46 |
+
default_wandb_project = "llama-lora-tuner"
|
47 |
+
|
48 |
# UI related
|
49 |
ui_title: str = "LLaMA-LoRA Tuner"
|
50 |
ui_emoji: str = "π¦ποΈ"
|
llama_lora/lib/finetune.py
CHANGED
@@ -50,8 +50,32 @@ def train(
|
|
50 |
save_total_limit: int = 3,
|
51 |
logging_steps: int = 10,
|
52 |
# logging
|
53 |
-
callbacks: List[Any] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
if os.path.exists(output_dir):
|
56 |
if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
|
57 |
raise ValueError(
|
@@ -204,8 +228,8 @@ def train(
|
|
204 |
load_best_model_at_end=True if val_set_size > 0 else False,
|
205 |
ddp_find_unused_parameters=False if ddp else None,
|
206 |
group_by_length=group_by_length,
|
207 |
-
|
208 |
-
|
209 |
),
|
210 |
data_collator=transformers.DataCollatorForSeq2Seq(
|
211 |
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
|
|
50 |
save_total_limit: int = 3,
|
51 |
logging_steps: int = 10,
|
52 |
# logging
|
53 |
+
callbacks: List[Any] = [],
|
54 |
+
# wandb params
|
55 |
+
wandb_api_key = None,
|
56 |
+
wandb_project: str = "",
|
57 |
+
wandb_run_name: str = "",
|
58 |
+
wandb_watch: str = "false", # options: false | gradients | all
|
59 |
+
wandb_log_model: str = "true", # options: false | true
|
60 |
):
|
61 |
+
if wandb_api_key:
|
62 |
+
os.environ["WANDB_API_KEY"] = wandb_api_key
|
63 |
+
if wandb_project:
|
64 |
+
os.environ["WANDB_PROJECT"] = wandb_project
|
65 |
+
if wandb_run_name:
|
66 |
+
os.environ["WANDB_RUN_NAME"] = wandb_run_name
|
67 |
+
if wandb_watch:
|
68 |
+
os.environ["WANDB_WATCH"] = wandb_watch
|
69 |
+
if wandb_log_model:
|
70 |
+
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
|
71 |
+
use_wandb = (wandb_project and len(wandb_project) > 0) or (
|
72 |
+
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
|
73 |
+
)
|
74 |
+
if use_wandb:
|
75 |
+
os.environ['WANDB_MODE'] = "online"
|
76 |
+
else:
|
77 |
+
os.environ['WANDB_MODE'] = "disabled"
|
78 |
+
|
79 |
if os.path.exists(output_dir):
|
80 |
if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
|
81 |
raise ValueError(
|
|
|
228 |
load_best_model_at_end=True if val_set_size > 0 else False,
|
229 |
ddp_find_unused_parameters=False if ddp else None,
|
230 |
group_by_length=group_by_length,
|
231 |
+
report_to="wandb" if use_wandb else None,
|
232 |
+
run_name=wandb_run_name if use_wandb else None,
|
233 |
),
|
234 |
data_collator=transformers.DataCollatorForSeq2Seq(
|
235 |
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
llama_lora/ui/finetune_ui.py
CHANGED
@@ -491,7 +491,10 @@ Train data (first 10):
|
|
491 |
save_steps, # save_steps
|
492 |
save_total_limit, # save_total_limit
|
493 |
logging_steps, # logging_steps
|
494 |
-
training_callbacks # callbacks
|
|
|
|
|
|
|
495 |
)
|
496 |
|
497 |
logs_str = "\n".join([json.dumps(log)
|
|
|
491 |
save_steps, # save_steps
|
492 |
save_total_limit, # save_total_limit
|
493 |
logging_steps, # logging_steps
|
494 |
+
training_callbacks, # callbacks
|
495 |
+
Global.wandb_api_key, # wandb_api_key
|
496 |
+
Global.default_wandb_project if Global.enable_wandb else None, # wandb_project
|
497 |
+
model_name # wandb_run_name
|
498 |
)
|
499 |
|
500 |
logs_str = "\n".join([json.dumps(log)
|