zetavg commited on
Commit
a5e11b9
1 Parent(s): 890373d

split finetune ui

Browse files
llama_lora/ui/finetune/data_processing.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from ...utils.data import get_dataset_content
3
+
4
+ from .values import (
5
+ default_dataset_plain_text_input_variables_separator,
6
+ default_dataset_plain_text_input_and_output_separator,
7
+ default_dataset_plain_text_data_separator,
8
+ )
9
+
10
+
11
+ def get_data_from_input(load_dataset_from, dataset_text, dataset_text_format,
12
+ dataset_plain_text_input_variables_separator,
13
+ dataset_plain_text_input_and_output_separator,
14
+ dataset_plain_text_data_separator,
15
+ dataset_from_data_dir, prompter):
16
+ if load_dataset_from == "Text Input":
17
+ if dataset_text_format == "JSON":
18
+ data = json.loads(dataset_text)
19
+
20
+ elif dataset_text_format == "JSON Lines":
21
+ lines = dataset_text.split('\n')
22
+ data = []
23
+ for i, line in enumerate(lines):
24
+ line_number = i + 1
25
+ try:
26
+ data.append(json.loads(line))
27
+ except Exception as e:
28
+ raise ValueError(
29
+ f"Error parsing JSON on line {line_number}: {e}")
30
+
31
+ else: # Plain Text
32
+ data = parse_plain_text_input(
33
+ dataset_text,
34
+ (
35
+ dataset_plain_text_input_variables_separator or
36
+ default_dataset_plain_text_input_variables_separator
37
+ ).replace("\\n", "\n"),
38
+ (
39
+ dataset_plain_text_input_and_output_separator or
40
+ default_dataset_plain_text_input_and_output_separator
41
+ ).replace("\\n", "\n"),
42
+ (
43
+ dataset_plain_text_data_separator or
44
+ default_dataset_plain_text_data_separator
45
+ ).replace("\\n", "\n"),
46
+ prompter.get_variable_names()
47
+ )
48
+
49
+ else: # Load dataset from data directory
50
+ data = get_dataset_content(dataset_from_data_dir)
51
+
52
+ return data
53
+
54
+
55
+ def parse_plain_text_input(
56
+ value,
57
+ variables_separator, input_output_separator, data_separator,
58
+ variable_names
59
+ ):
60
+ items = value.split(data_separator)
61
+ result = []
62
+ for item in items:
63
+ parts = item.split(input_output_separator)
64
+ variables = get_val_from_arr(parts, 0, "").split(variables_separator)
65
+ variables = [it.strip() for it in variables]
66
+ variables_dict = {name: var for name,
67
+ var in zip(variable_names, variables)}
68
+ output = get_val_from_arr(parts, 1, "").strip()
69
+ result.append({'variables': variables_dict, 'output': output})
70
+ return result
71
+
72
+
73
+ def get_val_from_arr(arr, index, default=None):
74
+ return arr[index] if -len(arr) <= index < len(arr) else default
llama_lora/ui/finetune/finetune_ui.py CHANGED
The diff for this file is too large to render. See raw diff
 
llama_lora/ui/finetune/previewing.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import traceback
3
+ import re
4
+ import gradio as gr
5
+ import math
6
+
7
+ from ...config import Config
8
+ from ...utils.prompter import Prompter
9
+
10
+ from .data_processing import get_data_from_input
11
+
12
+
13
+ def refresh_preview(
14
+ template,
15
+ load_dataset_from,
16
+ dataset_from_data_dir,
17
+ dataset_text,
18
+ dataset_text_format,
19
+ dataset_plain_text_input_variables_separator,
20
+ dataset_plain_text_input_and_output_separator,
21
+ dataset_plain_text_data_separator,
22
+ max_preview_count,
23
+ ):
24
+ try:
25
+ prompter = Prompter(template)
26
+ variable_names = prompter.get_variable_names()
27
+
28
+ data = get_data_from_input(
29
+ load_dataset_from=load_dataset_from,
30
+ dataset_text=dataset_text,
31
+ dataset_text_format=dataset_text_format,
32
+ dataset_plain_text_input_variables_separator=dataset_plain_text_input_variables_separator,
33
+ dataset_plain_text_input_and_output_separator=dataset_plain_text_input_and_output_separator,
34
+ dataset_plain_text_data_separator=dataset_plain_text_data_separator,
35
+ dataset_from_data_dir=dataset_from_data_dir,
36
+ prompter=prompter
37
+ )
38
+
39
+ train_data = prompter.get_train_data_from_dataset(
40
+ data, max_preview_count)
41
+
42
+ train_data = train_data[:max_preview_count]
43
+
44
+ data_count = len(data)
45
+
46
+ headers = ['Prompt', 'Completion']
47
+ preview_data = [
48
+ [item.get("prompt", ""), item.get("completion", "")]
49
+ for item in train_data
50
+ ]
51
+
52
+ if not prompter.template_module:
53
+ variable_names = prompter.get_variable_names()
54
+ headers += [f"Variable: {variable_name}" for variable_name in variable_names]
55
+ variables = [
56
+ [item.get(f"_var_{name}", "") for name in variable_names]
57
+ for item in train_data
58
+ ]
59
+ preview_data = [d + v for d, v in zip(preview_data, variables)]
60
+
61
+ preview_info_message = f"The dataset has about {data_count} item(s)."
62
+ if data_count > max_preview_count:
63
+ preview_info_message += f" Previewing the first {max_preview_count}."
64
+
65
+ info_message = f"about {data_count} item(s)."
66
+ if load_dataset_from == "Data Dir":
67
+ info_message = "This dataset contains about " + info_message
68
+ update_message = gr.Markdown.update(info_message, visible=True)
69
+
70
+ return (
71
+ gr.Dataframe.update(
72
+ value={'data': preview_data, 'headers': headers}),
73
+ gr.Markdown.update(preview_info_message),
74
+ update_message,
75
+ update_message
76
+ )
77
+ except Exception as e:
78
+ update_message = gr.Markdown.update(
79
+ f"<span class=\"finetune_dataset_error_message\">Error: {e}.</span>",
80
+ visible=True)
81
+ return (
82
+ gr.Dataframe.update(value={'data': [], 'headers': []}),
83
+ gr.Markdown.update(
84
+ "Set the dataset in the \"Prepare\" tab, then preview it here."),
85
+ update_message,
86
+ update_message
87
+ )
88
+
89
+
90
+ def refresh_dataset_items_count(
91
+ template,
92
+ load_dataset_from,
93
+ dataset_from_data_dir,
94
+ dataset_text,
95
+ dataset_text_format,
96
+ dataset_plain_text_input_variables_separator,
97
+ dataset_plain_text_input_and_output_separator,
98
+ dataset_plain_text_data_separator,
99
+ max_preview_count,
100
+ ):
101
+ try:
102
+ prompter = Prompter(template)
103
+
104
+ data = get_data_from_input(
105
+ load_dataset_from=load_dataset_from,
106
+ dataset_text=dataset_text,
107
+ dataset_text_format=dataset_text_format,
108
+ dataset_plain_text_input_variables_separator=dataset_plain_text_input_variables_separator,
109
+ dataset_plain_text_input_and_output_separator=dataset_plain_text_input_and_output_separator,
110
+ dataset_plain_text_data_separator=dataset_plain_text_data_separator,
111
+ dataset_from_data_dir=dataset_from_data_dir,
112
+ prompter=prompter
113
+ )
114
+
115
+ train_data = prompter.get_train_data_from_dataset(
116
+ data)
117
+ data_count = len(train_data)
118
+
119
+ preview_info_message = f"The dataset contains {data_count} item(s)."
120
+ if data_count > max_preview_count:
121
+ preview_info_message += f" Previewing the first {max_preview_count}."
122
+
123
+ info_message = f"{data_count} item(s)."
124
+ if load_dataset_from == "Data Dir":
125
+ info_message = "This dataset contains " + info_message
126
+ update_message = gr.Markdown.update(info_message, visible=True)
127
+
128
+ return (
129
+ gr.Markdown.update(preview_info_message),
130
+ update_message,
131
+ update_message,
132
+ gr.Slider.update(maximum=math.floor(data_count / 2))
133
+ )
134
+ except Exception as e:
135
+ update_message = gr.Markdown.update(
136
+ f"<span class=\"finetune_dataset_error_message\">Error: {e}.</span>",
137
+ visible=True)
138
+
139
+ trace = traceback.format_exc()
140
+ traces = [s.strip() for s in re.split("\n * File ", trace)]
141
+ traces_to_show = [s for s in traces if os.path.join(
142
+ Config.data_dir, "templates") in s]
143
+ traces_to_show = [re.sub(" *\n *", ": ", s) for s in traces_to_show]
144
+ if len(traces_to_show) > 0:
145
+ update_message = gr.Markdown.update(
146
+ f"<span class=\"finetune_dataset_error_message\">Error: {e} ({','.join(traces_to_show)}).</span>",
147
+ visible=True)
148
+
149
+ return (
150
+ gr.Markdown.update(
151
+ "Set the dataset in the \"Prepare\" tab, then preview it here."),
152
+ update_message,
153
+ update_message,
154
+ gr.Slider.update(maximum=1)
155
+ )
llama_lora/ui/finetune/training.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import gradio as gr
5
+ import math
6
+
7
+ from transformers import TrainerCallback
8
+ from huggingface_hub import try_to_load_from_cache, snapshot_download
9
+
10
+ from ...config import Config
11
+ from ...globals import Global
12
+ from ...models import clear_cache, unload_models
13
+ from ...utils.prompter import Prompter
14
+
15
+ from .data_processing import get_data_from_input
16
+
17
+ should_training_progress_track_tqdm = True
18
+
19
+ if Global.gpu_total_cores is not None and Global.gpu_total_cores > 2560:
20
+ should_training_progress_track_tqdm = False
21
+
22
+
23
+ def do_train(
24
+ # Dataset
25
+ template,
26
+ load_dataset_from,
27
+ dataset_from_data_dir,
28
+ dataset_text,
29
+ dataset_text_format,
30
+ dataset_plain_text_input_variables_separator,
31
+ dataset_plain_text_input_and_output_separator,
32
+ dataset_plain_text_data_separator,
33
+ # Training Options
34
+ max_seq_length,
35
+ evaluate_data_count,
36
+ micro_batch_size,
37
+ gradient_accumulation_steps,
38
+ epochs,
39
+ learning_rate,
40
+ train_on_inputs,
41
+ lora_r,
42
+ lora_alpha,
43
+ lora_dropout,
44
+ lora_target_modules,
45
+ lora_modules_to_save,
46
+ load_in_8bit,
47
+ fp16,
48
+ bf16,
49
+ gradient_checkpointing,
50
+ save_steps,
51
+ save_total_limit,
52
+ logging_steps,
53
+ additional_training_arguments,
54
+ additional_lora_config,
55
+ model_name,
56
+ continue_from_model,
57
+ continue_from_checkpoint,
58
+ progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
59
+ ):
60
+ try:
61
+ base_model_name = Global.base_model_name
62
+ tokenizer_name = Global.tokenizer_name or Global.base_model_name
63
+
64
+ resume_from_checkpoint_param = None
65
+ if continue_from_model == "-" or continue_from_model == "None":
66
+ continue_from_model = None
67
+ if continue_from_checkpoint == "-" or continue_from_checkpoint == "None":
68
+ continue_from_checkpoint = None
69
+ if continue_from_model:
70
+ resume_from_model_path = os.path.join(
71
+ Config.data_dir, "lora_models", continue_from_model)
72
+ resume_from_checkpoint_param = resume_from_model_path
73
+ if continue_from_checkpoint:
74
+ resume_from_checkpoint_param = os.path.join(
75
+ resume_from_checkpoint_param, continue_from_checkpoint)
76
+ will_be_resume_from_checkpoint_file = os.path.join(
77
+ resume_from_checkpoint_param, "pytorch_model.bin")
78
+ if not os.path.exists(will_be_resume_from_checkpoint_file):
79
+ raise ValueError(
80
+ f"Unable to resume from checkpoint {continue_from_model}/{continue_from_checkpoint}. Resuming is only possible from checkpoints stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
81
+ else:
82
+ will_be_resume_from_checkpoint_file = os.path.join(
83
+ resume_from_checkpoint_param, "adapter_model.bin")
84
+ if not os.path.exists(will_be_resume_from_checkpoint_file):
85
+ # Try to get model in Hugging Face cache
86
+ resume_from_checkpoint_param = None
87
+ possible_hf_model_name = None
88
+ possible_model_info_file = os.path.join(
89
+ resume_from_model_path, "info.json")
90
+ if "/" in continue_from_model:
91
+ possible_hf_model_name = continue_from_model
92
+ elif os.path.exists(possible_model_info_file):
93
+ with open(possible_model_info_file, "r") as file:
94
+ model_info = json.load(file)
95
+ possible_hf_model_name = model_info.get(
96
+ "hf_model_name")
97
+ if possible_hf_model_name:
98
+ possible_hf_model_cached_path = try_to_load_from_cache(
99
+ possible_hf_model_name, 'adapter_model.bin')
100
+ if not possible_hf_model_cached_path:
101
+ snapshot_download(possible_hf_model_name)
102
+ possible_hf_model_cached_path = try_to_load_from_cache(
103
+ possible_hf_model_name, 'adapter_model.bin')
104
+ if possible_hf_model_cached_path:
105
+ resume_from_checkpoint_param = os.path.dirname(
106
+ possible_hf_model_cached_path)
107
+
108
+ if not resume_from_checkpoint_param:
109
+ raise ValueError(
110
+ f"Unable to continue from model {continue_from_model}. Continuation is only possible from models stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
111
+
112
+ output_dir = os.path.join(Config.data_dir, "lora_models", model_name)
113
+ if os.path.exists(output_dir):
114
+ if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
115
+ raise ValueError(
116
+ f"The output directory already exists and is not empty. ({output_dir})")
117
+
118
+ if not should_training_progress_track_tqdm:
119
+ progress(0, desc="Preparing train data...")
120
+
121
+ # Need RAM for training
122
+ unload_models()
123
+ Global.new_base_model_that_is_ready_to_be_used = None
124
+ Global.name_of_new_base_model_that_is_ready_to_be_used = None
125
+ clear_cache()
126
+
127
+ prompter = Prompter(template)
128
+ # variable_names = prompter.get_variable_names()
129
+
130
+ data = get_data_from_input(
131
+ load_dataset_from=load_dataset_from,
132
+ dataset_text=dataset_text,
133
+ dataset_text_format=dataset_text_format,
134
+ dataset_plain_text_input_variables_separator=dataset_plain_text_input_variables_separator,
135
+ dataset_plain_text_input_and_output_separator=dataset_plain_text_input_and_output_separator,
136
+ dataset_plain_text_data_separator=dataset_plain_text_data_separator,
137
+ dataset_from_data_dir=dataset_from_data_dir,
138
+ prompter=prompter
139
+ )
140
+
141
+ train_data = prompter.get_train_data_from_dataset(data)
142
+
143
+ def get_progress_text(epoch, epochs, last_loss):
144
+ progress_detail = f"Epoch {math.ceil(epoch)}/{epochs}"
145
+ if last_loss is not None:
146
+ progress_detail += f", Loss: {last_loss:.4f}"
147
+ return f"Training... ({progress_detail})"
148
+
149
+ if Config.ui_dev_mode:
150
+ Global.should_stop_training = False
151
+
152
+ message = f"""Currently in UI dev mode, not doing the actual training.
153
+
154
+ Train options: {json.dumps({
155
+ 'max_seq_length': max_seq_length,
156
+ 'val_set_size': evaluate_data_count,
157
+ 'micro_batch_size': micro_batch_size,
158
+ 'gradient_accumulation_steps': gradient_accumulation_steps,
159
+ 'epochs': epochs,
160
+ 'learning_rate': learning_rate,
161
+ 'train_on_inputs': train_on_inputs,
162
+ 'lora_r': lora_r,
163
+ 'lora_alpha': lora_alpha,
164
+ 'lora_dropout': lora_dropout,
165
+ 'lora_target_modules': lora_target_modules,
166
+ 'lora_modules_to_save': lora_modules_to_save,
167
+ 'load_in_8bit': load_in_8bit,
168
+ 'fp16': fp16,
169
+ 'bf16': bf16,
170
+ 'gradient_checkpointing': gradient_checkpointing,
171
+ 'model_name': model_name,
172
+ 'continue_from_model': continue_from_model,
173
+ 'continue_from_checkpoint': continue_from_checkpoint,
174
+ 'resume_from_checkpoint_param': resume_from_checkpoint_param,
175
+ }, indent=2)}
176
+
177
+ Train data (first 10):
178
+ {json.dumps(train_data[:10], indent=2)}
179
+ """
180
+ print(message)
181
+
182
+ for i in range(300):
183
+ if (Global.should_stop_training):
184
+ return
185
+ epochs = 3
186
+ epoch = i / 100
187
+ last_loss = None
188
+ if (i > 20):
189
+ last_loss = 3 + (i - 0) * (0.5 - 3) / (300 - 0)
190
+
191
+ progress(
192
+ (i, 300),
193
+ desc="(Simulate) " +
194
+ get_progress_text(epoch, epochs, last_loss)
195
+ )
196
+
197
+ time.sleep(0.1)
198
+
199
+ time.sleep(2)
200
+ return message
201
+
202
+ if not should_training_progress_track_tqdm:
203
+ progress(
204
+ 0, desc=f"Preparing model {base_model_name} for training...")
205
+
206
+ log_history = []
207
+
208
+ class UiTrainerCallback(TrainerCallback):
209
+ def _on_progress(self, args, state, control):
210
+ nonlocal log_history
211
+
212
+ if Global.should_stop_training:
213
+ control.should_training_stop = True
214
+ total_steps = (
215
+ state.max_steps if state.max_steps is not None else state.num_train_epochs * state.steps_per_epoch)
216
+ log_history = state.log_history
217
+ last_history = None
218
+ last_loss = None
219
+ if len(log_history) > 0:
220
+ last_history = log_history[-1]
221
+ last_loss = last_history.get('loss', None)
222
+
223
+ progress_detail = f"Epoch {math.ceil(state.epoch)}/{epochs}"
224
+ if last_loss is not None:
225
+ progress_detail += f", Loss: {last_loss:.4f}"
226
+
227
+ progress(
228
+ (state.global_step, total_steps),
229
+ desc=f"Training... ({progress_detail})"
230
+ )
231
+
232
+ def on_epoch_begin(self, args, state, control, **kwargs):
233
+ self._on_progress(args, state, control)
234
+
235
+ def on_step_end(self, args, state, control, **kwargs):
236
+ self._on_progress(args, state, control)
237
+
238
+ training_callbacks = [UiTrainerCallback]
239
+
240
+ Global.should_stop_training = False
241
+
242
+ # Do not let other tqdm iterations interfere the progress reporting after training starts.
243
+ # progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
244
+
245
+ if not os.path.exists(output_dir):
246
+ os.makedirs(output_dir)
247
+
248
+ with open(os.path.join(output_dir, "info.json"), 'w') as info_json_file:
249
+ dataset_name = "N/A (from text input)"
250
+ if load_dataset_from == "Data Dir":
251
+ dataset_name = dataset_from_data_dir
252
+
253
+ info = {
254
+ 'base_model': base_model_name,
255
+ 'prompt_template': template,
256
+ 'dataset_name': dataset_name,
257
+ 'dataset_rows': len(train_data),
258
+ 'timestamp': time.time(),
259
+
260
+ # These will be saved in another JSON file by the train function
261
+ # 'max_seq_length': max_seq_length,
262
+ # 'train_on_inputs': train_on_inputs,
263
+
264
+ # 'micro_batch_size': micro_batch_size,
265
+ # 'gradient_accumulation_steps': gradient_accumulation_steps,
266
+ # 'epochs': epochs,
267
+ # 'learning_rate': learning_rate,
268
+
269
+ # 'evaluate_data_count': evaluate_data_count,
270
+
271
+ # 'lora_r': lora_r,
272
+ # 'lora_alpha': lora_alpha,
273
+ # 'lora_dropout': lora_dropout,
274
+ # 'lora_target_modules': lora_target_modules,
275
+ }
276
+ if continue_from_model:
277
+ info['continued_from_model'] = continue_from_model
278
+ if continue_from_checkpoint:
279
+ info['continued_from_checkpoint'] = continue_from_checkpoint
280
+
281
+ if Global.version:
282
+ info['tuner_version'] = Global.version
283
+
284
+ json.dump(info, info_json_file, indent=2)
285
+
286
+ if not should_training_progress_track_tqdm:
287
+ progress(0, desc="Train starting...")
288
+
289
+ wandb_group = template
290
+ wandb_tags = [f"template:{template}"]
291
+ if load_dataset_from == "Data Dir" and dataset_from_data_dir:
292
+ wandb_group += f"/{dataset_from_data_dir}"
293
+ wandb_tags.append(f"dataset:{dataset_from_data_dir}")
294
+
295
+ train_output = Global.finetune_train_fn(
296
+ base_model=base_model_name,
297
+ tokenizer=tokenizer_name,
298
+ output_dir=output_dir,
299
+ train_data=train_data,
300
+ # 128, # batch_size (is not used, use gradient_accumulation_steps instead)
301
+ micro_batch_size=micro_batch_size,
302
+ gradient_accumulation_steps=gradient_accumulation_steps,
303
+ num_train_epochs=epochs,
304
+ learning_rate=learning_rate,
305
+ cutoff_len=max_seq_length,
306
+ val_set_size=evaluate_data_count,
307
+ lora_r=lora_r,
308
+ lora_alpha=lora_alpha,
309
+ lora_dropout=lora_dropout,
310
+ lora_target_modules=lora_target_modules,
311
+ lora_modules_to_save=lora_modules_to_save,
312
+ train_on_inputs=train_on_inputs,
313
+ load_in_8bit=load_in_8bit,
314
+ fp16=fp16,
315
+ bf16=bf16,
316
+ gradient_checkpointing=gradient_checkpointing,
317
+ group_by_length=False,
318
+ resume_from_checkpoint=resume_from_checkpoint_param,
319
+ save_steps=save_steps,
320
+ save_total_limit=save_total_limit,
321
+ logging_steps=logging_steps,
322
+ additional_training_arguments=additional_training_arguments,
323
+ additional_lora_config=additional_lora_config,
324
+ callbacks=training_callbacks,
325
+ wandb_api_key=Config.wandb_api_key,
326
+ wandb_project=Config.default_wandb_project if Config.enable_wandb else None,
327
+ wandb_group=wandb_group,
328
+ wandb_run_name=model_name,
329
+ wandb_tags=wandb_tags
330
+ )
331
+
332
+ logs_str = "\n".join([json.dumps(log)
333
+ for log in log_history]) or "None"
334
+
335
+ result_message = f"Training ended:\n{str(train_output)}"
336
+ print(result_message)
337
+ # result_message += f"\n\nLogs:\n{logs_str}"
338
+
339
+ clear_cache()
340
+
341
+ return result_message
342
+
343
+ except Exception as e:
344
+ raise gr.Error(
345
+ f"{e} (To dismiss this error, click the 'Abort' button)")
llama_lora/ui/finetune/values.py ADDED
The diff for this file is too large to render. See raw diff