Spaces:
Runtime error
Runtime error
zetavg
commited on
split finetune ui
Browse files
llama_lora/ui/finetune/data_processing.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from ...utils.data import get_dataset_content
|
3 |
+
|
4 |
+
from .values import (
|
5 |
+
default_dataset_plain_text_input_variables_separator,
|
6 |
+
default_dataset_plain_text_input_and_output_separator,
|
7 |
+
default_dataset_plain_text_data_separator,
|
8 |
+
)
|
9 |
+
|
10 |
+
|
11 |
+
def get_data_from_input(load_dataset_from, dataset_text, dataset_text_format,
|
12 |
+
dataset_plain_text_input_variables_separator,
|
13 |
+
dataset_plain_text_input_and_output_separator,
|
14 |
+
dataset_plain_text_data_separator,
|
15 |
+
dataset_from_data_dir, prompter):
|
16 |
+
if load_dataset_from == "Text Input":
|
17 |
+
if dataset_text_format == "JSON":
|
18 |
+
data = json.loads(dataset_text)
|
19 |
+
|
20 |
+
elif dataset_text_format == "JSON Lines":
|
21 |
+
lines = dataset_text.split('\n')
|
22 |
+
data = []
|
23 |
+
for i, line in enumerate(lines):
|
24 |
+
line_number = i + 1
|
25 |
+
try:
|
26 |
+
data.append(json.loads(line))
|
27 |
+
except Exception as e:
|
28 |
+
raise ValueError(
|
29 |
+
f"Error parsing JSON on line {line_number}: {e}")
|
30 |
+
|
31 |
+
else: # Plain Text
|
32 |
+
data = parse_plain_text_input(
|
33 |
+
dataset_text,
|
34 |
+
(
|
35 |
+
dataset_plain_text_input_variables_separator or
|
36 |
+
default_dataset_plain_text_input_variables_separator
|
37 |
+
).replace("\\n", "\n"),
|
38 |
+
(
|
39 |
+
dataset_plain_text_input_and_output_separator or
|
40 |
+
default_dataset_plain_text_input_and_output_separator
|
41 |
+
).replace("\\n", "\n"),
|
42 |
+
(
|
43 |
+
dataset_plain_text_data_separator or
|
44 |
+
default_dataset_plain_text_data_separator
|
45 |
+
).replace("\\n", "\n"),
|
46 |
+
prompter.get_variable_names()
|
47 |
+
)
|
48 |
+
|
49 |
+
else: # Load dataset from data directory
|
50 |
+
data = get_dataset_content(dataset_from_data_dir)
|
51 |
+
|
52 |
+
return data
|
53 |
+
|
54 |
+
|
55 |
+
def parse_plain_text_input(
|
56 |
+
value,
|
57 |
+
variables_separator, input_output_separator, data_separator,
|
58 |
+
variable_names
|
59 |
+
):
|
60 |
+
items = value.split(data_separator)
|
61 |
+
result = []
|
62 |
+
for item in items:
|
63 |
+
parts = item.split(input_output_separator)
|
64 |
+
variables = get_val_from_arr(parts, 0, "").split(variables_separator)
|
65 |
+
variables = [it.strip() for it in variables]
|
66 |
+
variables_dict = {name: var for name,
|
67 |
+
var in zip(variable_names, variables)}
|
68 |
+
output = get_val_from_arr(parts, 1, "").strip()
|
69 |
+
result.append({'variables': variables_dict, 'output': output})
|
70 |
+
return result
|
71 |
+
|
72 |
+
|
73 |
+
def get_val_from_arr(arr, index, default=None):
|
74 |
+
return arr[index] if -len(arr) <= index < len(arr) else default
|
llama_lora/ui/finetune/finetune_ui.py
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
llama_lora/ui/finetune/previewing.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import traceback
|
3 |
+
import re
|
4 |
+
import gradio as gr
|
5 |
+
import math
|
6 |
+
|
7 |
+
from ...config import Config
|
8 |
+
from ...utils.prompter import Prompter
|
9 |
+
|
10 |
+
from .data_processing import get_data_from_input
|
11 |
+
|
12 |
+
|
13 |
+
def refresh_preview(
|
14 |
+
template,
|
15 |
+
load_dataset_from,
|
16 |
+
dataset_from_data_dir,
|
17 |
+
dataset_text,
|
18 |
+
dataset_text_format,
|
19 |
+
dataset_plain_text_input_variables_separator,
|
20 |
+
dataset_plain_text_input_and_output_separator,
|
21 |
+
dataset_plain_text_data_separator,
|
22 |
+
max_preview_count,
|
23 |
+
):
|
24 |
+
try:
|
25 |
+
prompter = Prompter(template)
|
26 |
+
variable_names = prompter.get_variable_names()
|
27 |
+
|
28 |
+
data = get_data_from_input(
|
29 |
+
load_dataset_from=load_dataset_from,
|
30 |
+
dataset_text=dataset_text,
|
31 |
+
dataset_text_format=dataset_text_format,
|
32 |
+
dataset_plain_text_input_variables_separator=dataset_plain_text_input_variables_separator,
|
33 |
+
dataset_plain_text_input_and_output_separator=dataset_plain_text_input_and_output_separator,
|
34 |
+
dataset_plain_text_data_separator=dataset_plain_text_data_separator,
|
35 |
+
dataset_from_data_dir=dataset_from_data_dir,
|
36 |
+
prompter=prompter
|
37 |
+
)
|
38 |
+
|
39 |
+
train_data = prompter.get_train_data_from_dataset(
|
40 |
+
data, max_preview_count)
|
41 |
+
|
42 |
+
train_data = train_data[:max_preview_count]
|
43 |
+
|
44 |
+
data_count = len(data)
|
45 |
+
|
46 |
+
headers = ['Prompt', 'Completion']
|
47 |
+
preview_data = [
|
48 |
+
[item.get("prompt", ""), item.get("completion", "")]
|
49 |
+
for item in train_data
|
50 |
+
]
|
51 |
+
|
52 |
+
if not prompter.template_module:
|
53 |
+
variable_names = prompter.get_variable_names()
|
54 |
+
headers += [f"Variable: {variable_name}" for variable_name in variable_names]
|
55 |
+
variables = [
|
56 |
+
[item.get(f"_var_{name}", "") for name in variable_names]
|
57 |
+
for item in train_data
|
58 |
+
]
|
59 |
+
preview_data = [d + v for d, v in zip(preview_data, variables)]
|
60 |
+
|
61 |
+
preview_info_message = f"The dataset has about {data_count} item(s)."
|
62 |
+
if data_count > max_preview_count:
|
63 |
+
preview_info_message += f" Previewing the first {max_preview_count}."
|
64 |
+
|
65 |
+
info_message = f"about {data_count} item(s)."
|
66 |
+
if load_dataset_from == "Data Dir":
|
67 |
+
info_message = "This dataset contains about " + info_message
|
68 |
+
update_message = gr.Markdown.update(info_message, visible=True)
|
69 |
+
|
70 |
+
return (
|
71 |
+
gr.Dataframe.update(
|
72 |
+
value={'data': preview_data, 'headers': headers}),
|
73 |
+
gr.Markdown.update(preview_info_message),
|
74 |
+
update_message,
|
75 |
+
update_message
|
76 |
+
)
|
77 |
+
except Exception as e:
|
78 |
+
update_message = gr.Markdown.update(
|
79 |
+
f"<span class=\"finetune_dataset_error_message\">Error: {e}.</span>",
|
80 |
+
visible=True)
|
81 |
+
return (
|
82 |
+
gr.Dataframe.update(value={'data': [], 'headers': []}),
|
83 |
+
gr.Markdown.update(
|
84 |
+
"Set the dataset in the \"Prepare\" tab, then preview it here."),
|
85 |
+
update_message,
|
86 |
+
update_message
|
87 |
+
)
|
88 |
+
|
89 |
+
|
90 |
+
def refresh_dataset_items_count(
|
91 |
+
template,
|
92 |
+
load_dataset_from,
|
93 |
+
dataset_from_data_dir,
|
94 |
+
dataset_text,
|
95 |
+
dataset_text_format,
|
96 |
+
dataset_plain_text_input_variables_separator,
|
97 |
+
dataset_plain_text_input_and_output_separator,
|
98 |
+
dataset_plain_text_data_separator,
|
99 |
+
max_preview_count,
|
100 |
+
):
|
101 |
+
try:
|
102 |
+
prompter = Prompter(template)
|
103 |
+
|
104 |
+
data = get_data_from_input(
|
105 |
+
load_dataset_from=load_dataset_from,
|
106 |
+
dataset_text=dataset_text,
|
107 |
+
dataset_text_format=dataset_text_format,
|
108 |
+
dataset_plain_text_input_variables_separator=dataset_plain_text_input_variables_separator,
|
109 |
+
dataset_plain_text_input_and_output_separator=dataset_plain_text_input_and_output_separator,
|
110 |
+
dataset_plain_text_data_separator=dataset_plain_text_data_separator,
|
111 |
+
dataset_from_data_dir=dataset_from_data_dir,
|
112 |
+
prompter=prompter
|
113 |
+
)
|
114 |
+
|
115 |
+
train_data = prompter.get_train_data_from_dataset(
|
116 |
+
data)
|
117 |
+
data_count = len(train_data)
|
118 |
+
|
119 |
+
preview_info_message = f"The dataset contains {data_count} item(s)."
|
120 |
+
if data_count > max_preview_count:
|
121 |
+
preview_info_message += f" Previewing the first {max_preview_count}."
|
122 |
+
|
123 |
+
info_message = f"{data_count} item(s)."
|
124 |
+
if load_dataset_from == "Data Dir":
|
125 |
+
info_message = "This dataset contains " + info_message
|
126 |
+
update_message = gr.Markdown.update(info_message, visible=True)
|
127 |
+
|
128 |
+
return (
|
129 |
+
gr.Markdown.update(preview_info_message),
|
130 |
+
update_message,
|
131 |
+
update_message,
|
132 |
+
gr.Slider.update(maximum=math.floor(data_count / 2))
|
133 |
+
)
|
134 |
+
except Exception as e:
|
135 |
+
update_message = gr.Markdown.update(
|
136 |
+
f"<span class=\"finetune_dataset_error_message\">Error: {e}.</span>",
|
137 |
+
visible=True)
|
138 |
+
|
139 |
+
trace = traceback.format_exc()
|
140 |
+
traces = [s.strip() for s in re.split("\n * File ", trace)]
|
141 |
+
traces_to_show = [s for s in traces if os.path.join(
|
142 |
+
Config.data_dir, "templates") in s]
|
143 |
+
traces_to_show = [re.sub(" *\n *", ": ", s) for s in traces_to_show]
|
144 |
+
if len(traces_to_show) > 0:
|
145 |
+
update_message = gr.Markdown.update(
|
146 |
+
f"<span class=\"finetune_dataset_error_message\">Error: {e} ({','.join(traces_to_show)}).</span>",
|
147 |
+
visible=True)
|
148 |
+
|
149 |
+
return (
|
150 |
+
gr.Markdown.update(
|
151 |
+
"Set the dataset in the \"Prepare\" tab, then preview it here."),
|
152 |
+
update_message,
|
153 |
+
update_message,
|
154 |
+
gr.Slider.update(maximum=1)
|
155 |
+
)
|
llama_lora/ui/finetune/training.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
import gradio as gr
|
5 |
+
import math
|
6 |
+
|
7 |
+
from transformers import TrainerCallback
|
8 |
+
from huggingface_hub import try_to_load_from_cache, snapshot_download
|
9 |
+
|
10 |
+
from ...config import Config
|
11 |
+
from ...globals import Global
|
12 |
+
from ...models import clear_cache, unload_models
|
13 |
+
from ...utils.prompter import Prompter
|
14 |
+
|
15 |
+
from .data_processing import get_data_from_input
|
16 |
+
|
17 |
+
should_training_progress_track_tqdm = True
|
18 |
+
|
19 |
+
if Global.gpu_total_cores is not None and Global.gpu_total_cores > 2560:
|
20 |
+
should_training_progress_track_tqdm = False
|
21 |
+
|
22 |
+
|
23 |
+
def do_train(
|
24 |
+
# Dataset
|
25 |
+
template,
|
26 |
+
load_dataset_from,
|
27 |
+
dataset_from_data_dir,
|
28 |
+
dataset_text,
|
29 |
+
dataset_text_format,
|
30 |
+
dataset_plain_text_input_variables_separator,
|
31 |
+
dataset_plain_text_input_and_output_separator,
|
32 |
+
dataset_plain_text_data_separator,
|
33 |
+
# Training Options
|
34 |
+
max_seq_length,
|
35 |
+
evaluate_data_count,
|
36 |
+
micro_batch_size,
|
37 |
+
gradient_accumulation_steps,
|
38 |
+
epochs,
|
39 |
+
learning_rate,
|
40 |
+
train_on_inputs,
|
41 |
+
lora_r,
|
42 |
+
lora_alpha,
|
43 |
+
lora_dropout,
|
44 |
+
lora_target_modules,
|
45 |
+
lora_modules_to_save,
|
46 |
+
load_in_8bit,
|
47 |
+
fp16,
|
48 |
+
bf16,
|
49 |
+
gradient_checkpointing,
|
50 |
+
save_steps,
|
51 |
+
save_total_limit,
|
52 |
+
logging_steps,
|
53 |
+
additional_training_arguments,
|
54 |
+
additional_lora_config,
|
55 |
+
model_name,
|
56 |
+
continue_from_model,
|
57 |
+
continue_from_checkpoint,
|
58 |
+
progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
|
59 |
+
):
|
60 |
+
try:
|
61 |
+
base_model_name = Global.base_model_name
|
62 |
+
tokenizer_name = Global.tokenizer_name or Global.base_model_name
|
63 |
+
|
64 |
+
resume_from_checkpoint_param = None
|
65 |
+
if continue_from_model == "-" or continue_from_model == "None":
|
66 |
+
continue_from_model = None
|
67 |
+
if continue_from_checkpoint == "-" or continue_from_checkpoint == "None":
|
68 |
+
continue_from_checkpoint = None
|
69 |
+
if continue_from_model:
|
70 |
+
resume_from_model_path = os.path.join(
|
71 |
+
Config.data_dir, "lora_models", continue_from_model)
|
72 |
+
resume_from_checkpoint_param = resume_from_model_path
|
73 |
+
if continue_from_checkpoint:
|
74 |
+
resume_from_checkpoint_param = os.path.join(
|
75 |
+
resume_from_checkpoint_param, continue_from_checkpoint)
|
76 |
+
will_be_resume_from_checkpoint_file = os.path.join(
|
77 |
+
resume_from_checkpoint_param, "pytorch_model.bin")
|
78 |
+
if not os.path.exists(will_be_resume_from_checkpoint_file):
|
79 |
+
raise ValueError(
|
80 |
+
f"Unable to resume from checkpoint {continue_from_model}/{continue_from_checkpoint}. Resuming is only possible from checkpoints stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
|
81 |
+
else:
|
82 |
+
will_be_resume_from_checkpoint_file = os.path.join(
|
83 |
+
resume_from_checkpoint_param, "adapter_model.bin")
|
84 |
+
if not os.path.exists(will_be_resume_from_checkpoint_file):
|
85 |
+
# Try to get model in Hugging Face cache
|
86 |
+
resume_from_checkpoint_param = None
|
87 |
+
possible_hf_model_name = None
|
88 |
+
possible_model_info_file = os.path.join(
|
89 |
+
resume_from_model_path, "info.json")
|
90 |
+
if "/" in continue_from_model:
|
91 |
+
possible_hf_model_name = continue_from_model
|
92 |
+
elif os.path.exists(possible_model_info_file):
|
93 |
+
with open(possible_model_info_file, "r") as file:
|
94 |
+
model_info = json.load(file)
|
95 |
+
possible_hf_model_name = model_info.get(
|
96 |
+
"hf_model_name")
|
97 |
+
if possible_hf_model_name:
|
98 |
+
possible_hf_model_cached_path = try_to_load_from_cache(
|
99 |
+
possible_hf_model_name, 'adapter_model.bin')
|
100 |
+
if not possible_hf_model_cached_path:
|
101 |
+
snapshot_download(possible_hf_model_name)
|
102 |
+
possible_hf_model_cached_path = try_to_load_from_cache(
|
103 |
+
possible_hf_model_name, 'adapter_model.bin')
|
104 |
+
if possible_hf_model_cached_path:
|
105 |
+
resume_from_checkpoint_param = os.path.dirname(
|
106 |
+
possible_hf_model_cached_path)
|
107 |
+
|
108 |
+
if not resume_from_checkpoint_param:
|
109 |
+
raise ValueError(
|
110 |
+
f"Unable to continue from model {continue_from_model}. Continuation is only possible from models stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
|
111 |
+
|
112 |
+
output_dir = os.path.join(Config.data_dir, "lora_models", model_name)
|
113 |
+
if os.path.exists(output_dir):
|
114 |
+
if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
|
115 |
+
raise ValueError(
|
116 |
+
f"The output directory already exists and is not empty. ({output_dir})")
|
117 |
+
|
118 |
+
if not should_training_progress_track_tqdm:
|
119 |
+
progress(0, desc="Preparing train data...")
|
120 |
+
|
121 |
+
# Need RAM for training
|
122 |
+
unload_models()
|
123 |
+
Global.new_base_model_that_is_ready_to_be_used = None
|
124 |
+
Global.name_of_new_base_model_that_is_ready_to_be_used = None
|
125 |
+
clear_cache()
|
126 |
+
|
127 |
+
prompter = Prompter(template)
|
128 |
+
# variable_names = prompter.get_variable_names()
|
129 |
+
|
130 |
+
data = get_data_from_input(
|
131 |
+
load_dataset_from=load_dataset_from,
|
132 |
+
dataset_text=dataset_text,
|
133 |
+
dataset_text_format=dataset_text_format,
|
134 |
+
dataset_plain_text_input_variables_separator=dataset_plain_text_input_variables_separator,
|
135 |
+
dataset_plain_text_input_and_output_separator=dataset_plain_text_input_and_output_separator,
|
136 |
+
dataset_plain_text_data_separator=dataset_plain_text_data_separator,
|
137 |
+
dataset_from_data_dir=dataset_from_data_dir,
|
138 |
+
prompter=prompter
|
139 |
+
)
|
140 |
+
|
141 |
+
train_data = prompter.get_train_data_from_dataset(data)
|
142 |
+
|
143 |
+
def get_progress_text(epoch, epochs, last_loss):
|
144 |
+
progress_detail = f"Epoch {math.ceil(epoch)}/{epochs}"
|
145 |
+
if last_loss is not None:
|
146 |
+
progress_detail += f", Loss: {last_loss:.4f}"
|
147 |
+
return f"Training... ({progress_detail})"
|
148 |
+
|
149 |
+
if Config.ui_dev_mode:
|
150 |
+
Global.should_stop_training = False
|
151 |
+
|
152 |
+
message = f"""Currently in UI dev mode, not doing the actual training.
|
153 |
+
|
154 |
+
Train options: {json.dumps({
|
155 |
+
'max_seq_length': max_seq_length,
|
156 |
+
'val_set_size': evaluate_data_count,
|
157 |
+
'micro_batch_size': micro_batch_size,
|
158 |
+
'gradient_accumulation_steps': gradient_accumulation_steps,
|
159 |
+
'epochs': epochs,
|
160 |
+
'learning_rate': learning_rate,
|
161 |
+
'train_on_inputs': train_on_inputs,
|
162 |
+
'lora_r': lora_r,
|
163 |
+
'lora_alpha': lora_alpha,
|
164 |
+
'lora_dropout': lora_dropout,
|
165 |
+
'lora_target_modules': lora_target_modules,
|
166 |
+
'lora_modules_to_save': lora_modules_to_save,
|
167 |
+
'load_in_8bit': load_in_8bit,
|
168 |
+
'fp16': fp16,
|
169 |
+
'bf16': bf16,
|
170 |
+
'gradient_checkpointing': gradient_checkpointing,
|
171 |
+
'model_name': model_name,
|
172 |
+
'continue_from_model': continue_from_model,
|
173 |
+
'continue_from_checkpoint': continue_from_checkpoint,
|
174 |
+
'resume_from_checkpoint_param': resume_from_checkpoint_param,
|
175 |
+
}, indent=2)}
|
176 |
+
|
177 |
+
Train data (first 10):
|
178 |
+
{json.dumps(train_data[:10], indent=2)}
|
179 |
+
"""
|
180 |
+
print(message)
|
181 |
+
|
182 |
+
for i in range(300):
|
183 |
+
if (Global.should_stop_training):
|
184 |
+
return
|
185 |
+
epochs = 3
|
186 |
+
epoch = i / 100
|
187 |
+
last_loss = None
|
188 |
+
if (i > 20):
|
189 |
+
last_loss = 3 + (i - 0) * (0.5 - 3) / (300 - 0)
|
190 |
+
|
191 |
+
progress(
|
192 |
+
(i, 300),
|
193 |
+
desc="(Simulate) " +
|
194 |
+
get_progress_text(epoch, epochs, last_loss)
|
195 |
+
)
|
196 |
+
|
197 |
+
time.sleep(0.1)
|
198 |
+
|
199 |
+
time.sleep(2)
|
200 |
+
return message
|
201 |
+
|
202 |
+
if not should_training_progress_track_tqdm:
|
203 |
+
progress(
|
204 |
+
0, desc=f"Preparing model {base_model_name} for training...")
|
205 |
+
|
206 |
+
log_history = []
|
207 |
+
|
208 |
+
class UiTrainerCallback(TrainerCallback):
|
209 |
+
def _on_progress(self, args, state, control):
|
210 |
+
nonlocal log_history
|
211 |
+
|
212 |
+
if Global.should_stop_training:
|
213 |
+
control.should_training_stop = True
|
214 |
+
total_steps = (
|
215 |
+
state.max_steps if state.max_steps is not None else state.num_train_epochs * state.steps_per_epoch)
|
216 |
+
log_history = state.log_history
|
217 |
+
last_history = None
|
218 |
+
last_loss = None
|
219 |
+
if len(log_history) > 0:
|
220 |
+
last_history = log_history[-1]
|
221 |
+
last_loss = last_history.get('loss', None)
|
222 |
+
|
223 |
+
progress_detail = f"Epoch {math.ceil(state.epoch)}/{epochs}"
|
224 |
+
if last_loss is not None:
|
225 |
+
progress_detail += f", Loss: {last_loss:.4f}"
|
226 |
+
|
227 |
+
progress(
|
228 |
+
(state.global_step, total_steps),
|
229 |
+
desc=f"Training... ({progress_detail})"
|
230 |
+
)
|
231 |
+
|
232 |
+
def on_epoch_begin(self, args, state, control, **kwargs):
|
233 |
+
self._on_progress(args, state, control)
|
234 |
+
|
235 |
+
def on_step_end(self, args, state, control, **kwargs):
|
236 |
+
self._on_progress(args, state, control)
|
237 |
+
|
238 |
+
training_callbacks = [UiTrainerCallback]
|
239 |
+
|
240 |
+
Global.should_stop_training = False
|
241 |
+
|
242 |
+
# Do not let other tqdm iterations interfere the progress reporting after training starts.
|
243 |
+
# progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
|
244 |
+
|
245 |
+
if not os.path.exists(output_dir):
|
246 |
+
os.makedirs(output_dir)
|
247 |
+
|
248 |
+
with open(os.path.join(output_dir, "info.json"), 'w') as info_json_file:
|
249 |
+
dataset_name = "N/A (from text input)"
|
250 |
+
if load_dataset_from == "Data Dir":
|
251 |
+
dataset_name = dataset_from_data_dir
|
252 |
+
|
253 |
+
info = {
|
254 |
+
'base_model': base_model_name,
|
255 |
+
'prompt_template': template,
|
256 |
+
'dataset_name': dataset_name,
|
257 |
+
'dataset_rows': len(train_data),
|
258 |
+
'timestamp': time.time(),
|
259 |
+
|
260 |
+
# These will be saved in another JSON file by the train function
|
261 |
+
# 'max_seq_length': max_seq_length,
|
262 |
+
# 'train_on_inputs': train_on_inputs,
|
263 |
+
|
264 |
+
# 'micro_batch_size': micro_batch_size,
|
265 |
+
# 'gradient_accumulation_steps': gradient_accumulation_steps,
|
266 |
+
# 'epochs': epochs,
|
267 |
+
# 'learning_rate': learning_rate,
|
268 |
+
|
269 |
+
# 'evaluate_data_count': evaluate_data_count,
|
270 |
+
|
271 |
+
# 'lora_r': lora_r,
|
272 |
+
# 'lora_alpha': lora_alpha,
|
273 |
+
# 'lora_dropout': lora_dropout,
|
274 |
+
# 'lora_target_modules': lora_target_modules,
|
275 |
+
}
|
276 |
+
if continue_from_model:
|
277 |
+
info['continued_from_model'] = continue_from_model
|
278 |
+
if continue_from_checkpoint:
|
279 |
+
info['continued_from_checkpoint'] = continue_from_checkpoint
|
280 |
+
|
281 |
+
if Global.version:
|
282 |
+
info['tuner_version'] = Global.version
|
283 |
+
|
284 |
+
json.dump(info, info_json_file, indent=2)
|
285 |
+
|
286 |
+
if not should_training_progress_track_tqdm:
|
287 |
+
progress(0, desc="Train starting...")
|
288 |
+
|
289 |
+
wandb_group = template
|
290 |
+
wandb_tags = [f"template:{template}"]
|
291 |
+
if load_dataset_from == "Data Dir" and dataset_from_data_dir:
|
292 |
+
wandb_group += f"/{dataset_from_data_dir}"
|
293 |
+
wandb_tags.append(f"dataset:{dataset_from_data_dir}")
|
294 |
+
|
295 |
+
train_output = Global.finetune_train_fn(
|
296 |
+
base_model=base_model_name,
|
297 |
+
tokenizer=tokenizer_name,
|
298 |
+
output_dir=output_dir,
|
299 |
+
train_data=train_data,
|
300 |
+
# 128, # batch_size (is not used, use gradient_accumulation_steps instead)
|
301 |
+
micro_batch_size=micro_batch_size,
|
302 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
303 |
+
num_train_epochs=epochs,
|
304 |
+
learning_rate=learning_rate,
|
305 |
+
cutoff_len=max_seq_length,
|
306 |
+
val_set_size=evaluate_data_count,
|
307 |
+
lora_r=lora_r,
|
308 |
+
lora_alpha=lora_alpha,
|
309 |
+
lora_dropout=lora_dropout,
|
310 |
+
lora_target_modules=lora_target_modules,
|
311 |
+
lora_modules_to_save=lora_modules_to_save,
|
312 |
+
train_on_inputs=train_on_inputs,
|
313 |
+
load_in_8bit=load_in_8bit,
|
314 |
+
fp16=fp16,
|
315 |
+
bf16=bf16,
|
316 |
+
gradient_checkpointing=gradient_checkpointing,
|
317 |
+
group_by_length=False,
|
318 |
+
resume_from_checkpoint=resume_from_checkpoint_param,
|
319 |
+
save_steps=save_steps,
|
320 |
+
save_total_limit=save_total_limit,
|
321 |
+
logging_steps=logging_steps,
|
322 |
+
additional_training_arguments=additional_training_arguments,
|
323 |
+
additional_lora_config=additional_lora_config,
|
324 |
+
callbacks=training_callbacks,
|
325 |
+
wandb_api_key=Config.wandb_api_key,
|
326 |
+
wandb_project=Config.default_wandb_project if Config.enable_wandb else None,
|
327 |
+
wandb_group=wandb_group,
|
328 |
+
wandb_run_name=model_name,
|
329 |
+
wandb_tags=wandb_tags
|
330 |
+
)
|
331 |
+
|
332 |
+
logs_str = "\n".join([json.dumps(log)
|
333 |
+
for log in log_history]) or "None"
|
334 |
+
|
335 |
+
result_message = f"Training ended:\n{str(train_output)}"
|
336 |
+
print(result_message)
|
337 |
+
# result_message += f"\n\nLogs:\n{logs_str}"
|
338 |
+
|
339 |
+
clear_cache()
|
340 |
+
|
341 |
+
return result_message
|
342 |
+
|
343 |
+
except Exception as e:
|
344 |
+
raise gr.Error(
|
345 |
+
f"{e} (To dismiss this error, click the 'Abort' button)")
|
llama_lora/ui/finetune/values.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|