Spaces:
Runtime error
Runtime error
Add model details and set training parameters
Browse files- train_llm.ipynb +302 -0
- train_llm.py +96 -31
train_llm.ipynb
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 32,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os\n",
|
10 |
+
"from uuid import uuid4\n",
|
11 |
+
"import pandas as pd\n",
|
12 |
+
"\n",
|
13 |
+
"from datasets import load_dataset\n",
|
14 |
+
"import subprocess\n",
|
15 |
+
"from transformers import AutoTokenizer"
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"cell_type": "code",
|
20 |
+
"execution_count": 33,
|
21 |
+
"metadata": {},
|
22 |
+
"outputs": [],
|
23 |
+
"source": [
|
24 |
+
"# from dotenv import load_dotenv,find_dotenv\n",
|
25 |
+
"# load_dotenv(find_dotenv(),override=True)\n",
|
26 |
+
"\n",
|
27 |
+
"def max_token_len(dataset):\n",
|
28 |
+
" max_seq_length = 0\n",
|
29 |
+
" for row in dataset:\n",
|
30 |
+
" tokens = len(tokenizer(row['text'])['input_ids'])\n",
|
31 |
+
" if tokens > max_seq_length:\n",
|
32 |
+
" max_seq_length = tokens\n",
|
33 |
+
" return max_seq_length"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"cell_type": "code",
|
38 |
+
"execution_count": 34,
|
39 |
+
"metadata": {},
|
40 |
+
"outputs": [
|
41 |
+
{
|
42 |
+
"name": "stdout",
|
43 |
+
"output_type": "stream",
|
44 |
+
"text": [
|
45 |
+
"Model Max Length: 1000000000000000019884624838656\n"
|
46 |
+
]
|
47 |
+
}
|
48 |
+
],
|
49 |
+
"source": [
|
50 |
+
"# model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1'\n",
|
51 |
+
"model_name = 'mistralai/Mistral-7B-v0.1'\n",
|
52 |
+
"# model_name = 'distilbert-base-uncased'\n",
|
53 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
54 |
+
"model_max_length = tokenizer.model_max_length\n",
|
55 |
+
"print(\"Model Max Length:\", model_max_length)"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": 37,
|
61 |
+
"metadata": {},
|
62 |
+
"outputs": [
|
63 |
+
{
|
64 |
+
"name": "stdout",
|
65 |
+
"output_type": "stream",
|
66 |
+
"text": [
|
67 |
+
"Max token length train: 1121\n",
|
68 |
+
"Max token length validation: 38\n",
|
69 |
+
"Block size: 2242\n"
|
70 |
+
]
|
71 |
+
}
|
72 |
+
],
|
73 |
+
"source": [
|
74 |
+
"# Load dataset\n",
|
75 |
+
"dataset_name = 'ai-aerospace/ams_data_train_generic_v0.1_100'\n",
|
76 |
+
"dataset=load_dataset(dataset_name)\n",
|
77 |
+
"\n",
|
78 |
+
"# Write dataset files into data directory\n",
|
79 |
+
"data_directory = './fine_tune_data/'\n",
|
80 |
+
"\n",
|
81 |
+
"# Create the data directory if it doesn't exist\n",
|
82 |
+
"os.makedirs(data_directory, exist_ok=True)\n",
|
83 |
+
"\n",
|
84 |
+
"# Write the train data to a CSV file\n",
|
85 |
+
"train_data='train_data'\n",
|
86 |
+
"train_filename = os.path.join(data_directory, train_data)\n",
|
87 |
+
"dataset['train'].to_pandas().to_csv(train_filename+'.csv', columns=['text'], index=False)\n",
|
88 |
+
"max_token_length_train=max_token_len(dataset['train'])\n",
|
89 |
+
"print('Max token length train: '+str(max_token_length_train))\n",
|
90 |
+
"\n",
|
91 |
+
"# Write the validation data to a CSV file\n",
|
92 |
+
"validation_data='validation_data'\n",
|
93 |
+
"validation_filename = os.path.join(data_directory, validation_data)\n",
|
94 |
+
"dataset['validation'].to_pandas().to_csv(validation_filename+'.csv', columns=['text'], index=False)\n",
|
95 |
+
"max_token_length_validation=max_token_len(dataset['validation'])\n",
|
96 |
+
"print('Max token length validation: '+str(max_token_length_validation))\n",
|
97 |
+
" \n",
|
98 |
+
"max_token_length=max(max_token_length_train,max_token_length_validation)\n",
|
99 |
+
"if max_token_length > model_max_length:\n",
|
100 |
+
" raise ValueError(\"Maximum token length exceeds model limits.\")\n",
|
101 |
+
"block_size=2*max_token_length\n",
|
102 |
+
"print('Block size: '+str(block_size))\n",
|
103 |
+
"\n",
|
104 |
+
"# Define project parameters\n",
|
105 |
+
"username='ai-aerospace'\n",
|
106 |
+
"project_name='./llms/'+'ams_data_train-100_'+str(uuid4())\n",
|
107 |
+
"repo_name='ams-data-train-100-'+str(uuid4())"
|
108 |
+
]
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"cell_type": "code",
|
112 |
+
"execution_count": 46,
|
113 |
+
"metadata": {},
|
114 |
+
"outputs": [
|
115 |
+
{
|
116 |
+
"name": "stdout",
|
117 |
+
"output_type": "stream",
|
118 |
+
"text": [
|
119 |
+
"{'project_name': './llms/ams_data_train-100_6abb23dc-cb9d-428e-9079-e47deee0edd9', 'model_name': 'mistralai/Mistral-7B-v0.1', 'repo_id': 'ai-aerospace/ams-data-train-100-4601c8c8-0903-4f18-a6e8-1d2a40a697ce', 'train_data': 'train_data', 'validation_data': 'validation_data', 'data_directory': './fine_tune_data/', 'block_size': 2242, 'model_max_length': 1121, 'logging_steps': -1, 'evaluation_strategy': 'epoch', 'save_total_limit': 1, 'save_strategy': 'epoch', 'mixed_precision': 'fp16', 'lr': 3e-05, 'epochs': 3, 'batch_size': 2, 'warmup_ratio': 0.1, 'gradient_accumulation': 1, 'optimizer': 'adamw_torch', 'scheduler': 'linear', 'weight_decay': 0, 'max_grad_norm': 1, 'seed': 42, 'quantization': 'int4', 'lora_r': 16, 'lora_alpha': 32, 'lora_dropout': 0.05}\n"
|
120 |
+
]
|
121 |
+
}
|
122 |
+
],
|
123 |
+
"source": [
|
124 |
+
"\"\"\"\n",
|
125 |
+
"This set of parameters runs on a low memory gpu on hugging face spaces:\n",
|
126 |
+
"{\n",
|
127 |
+
" \"block_size\": 1024,\n",
|
128 |
+
" \"model_max_length\": 2048,\n",
|
129 |
+
" x\"use_flash_attention_2\": false,\n",
|
130 |
+
" x\"disable_gradient_checkpointing\": false,\n",
|
131 |
+
" \"logging_steps\": -1,\n",
|
132 |
+
" \"evaluation_strategy\": \"epoch\",\n",
|
133 |
+
" \"save_total_limit\": 1,\n",
|
134 |
+
" \"save_strategy\": \"epoch\",\n",
|
135 |
+
" x\"auto_find_batch_size\": false,\n",
|
136 |
+
" \"mixed_precision\": \"fp16\",\n",
|
137 |
+
" \"lr\": 0.00003,\n",
|
138 |
+
" \"epochs\": 3,\n",
|
139 |
+
" \"batch_size\": 2,\n",
|
140 |
+
" \"warmup_ratio\": 0.1,\n",
|
141 |
+
" \"gradient_accumulation\": 1,\n",
|
142 |
+
" \"optimizer\": \"adamw_torch\",\n",
|
143 |
+
" \"scheduler\": \"linear\",\n",
|
144 |
+
" \"weight_decay\": 0,\n",
|
145 |
+
" \"max_grad_norm\": 1,\n",
|
146 |
+
" \"seed\": 42,\n",
|
147 |
+
" \"apply_chat_template\": false,\n",
|
148 |
+
" \"quantization\": \"int4\",\n",
|
149 |
+
" \"target_modules\": \"\",\n",
|
150 |
+
" x\"merge_adapter\": false,\n",
|
151 |
+
" \"peft\": true,\n",
|
152 |
+
" \"lora_r\": 16,\n",
|
153 |
+
" \"lora_alpha\": 32,\n",
|
154 |
+
" \"lora_dropout\": 0.05\n",
|
155 |
+
"}\n",
|
156 |
+
"\"\"\"\n",
|
157 |
+
"\n",
|
158 |
+
"model_params={\n",
|
159 |
+
" \"project_name\": project_name,\n",
|
160 |
+
" \"model_name\": model_name,\n",
|
161 |
+
" \"repo_id\": username+'/'+repo_name,\n",
|
162 |
+
" \"train_data\": train_data,\n",
|
163 |
+
" \"validation_data\": validation_data,\n",
|
164 |
+
" \"data_directory\": data_directory,\n",
|
165 |
+
" \"block_size\": block_size,\n",
|
166 |
+
" \"model_max_length\": max_token_length,\n",
|
167 |
+
" \"logging_steps\": -1,\n",
|
168 |
+
" \"evaluation_strategy\": \"epoch\",\n",
|
169 |
+
" \"save_total_limit\": 1,\n",
|
170 |
+
" \"save_strategy\": \"epoch\",\n",
|
171 |
+
" \"mixed_precision\": \"fp16\",\n",
|
172 |
+
" \"lr\": 0.00003,\n",
|
173 |
+
" \"epochs\": 3,\n",
|
174 |
+
" \"batch_size\": 2,\n",
|
175 |
+
" \"warmup_ratio\": 0.1,\n",
|
176 |
+
" \"gradient_accumulation\": 1,\n",
|
177 |
+
" \"optimizer\": \"adamw_torch\",\n",
|
178 |
+
" \"scheduler\": \"linear\",\n",
|
179 |
+
" \"weight_decay\": 0,\n",
|
180 |
+
" \"max_grad_norm\": 1,\n",
|
181 |
+
" \"seed\": 42,\n",
|
182 |
+
" \"quantization\": \"int4\",\n",
|
183 |
+
" \"lora_r\": 16,\n",
|
184 |
+
" \"lora_alpha\": 32,\n",
|
185 |
+
" \"lora_dropout\": 0.05\n",
|
186 |
+
"}\n",
|
187 |
+
"for key, value in model_params.items():\n",
|
188 |
+
" os.environ[key] = str(value)\n",
|
189 |
+
"\n",
|
190 |
+
"print(model_params)\n",
|
191 |
+
"\n",
|
192 |
+
"\n",
|
193 |
+
"# Save parameters to environment variables\n",
|
194 |
+
"# os.environ[\"project_name\"] = project_name\n",
|
195 |
+
"# os.environ[\"model_name\"] = model_name\n",
|
196 |
+
"# os.environ[\"repo_id\"] = username+'/'+repo_name\n",
|
197 |
+
"# os.environ[\"train_data\"] = train_data \n",
|
198 |
+
"# os.environ[\"validation_data\"] = validation_data\n",
|
199 |
+
"# os.environ[\"data_directory\"] = data_directory"
|
200 |
+
]
|
201 |
+
},
|
202 |
+
{
|
203 |
+
"cell_type": "code",
|
204 |
+
"execution_count": 49,
|
205 |
+
"metadata": {},
|
206 |
+
"outputs": [
|
207 |
+
{
|
208 |
+
"name": "stderr",
|
209 |
+
"output_type": "stream",
|
210 |
+
"text": [
|
211 |
+
"⚠️ WARNING | 2023-12-22 10:39:42 | autotrain.cli.run_dreambooth:<module>:14 - ❌ Some DreamBooth components are missing! Please run `autotrain setup` to install it. Ignore this warning if you are not using DreamBooth or running `autotrain setup` already.\n",
|
212 |
+
"usage: autotrain <command> [<args>]\n",
|
213 |
+
"AutoTrain advanced CLI: error: unrecognized arguments: --batch_size 2\n"
|
214 |
+
]
|
215 |
+
},
|
216 |
+
{
|
217 |
+
"ename": "CalledProcessError",
|
218 |
+
"evalue": "Command '\nautotrain llm --train --trainer sft --project_name ./llms/ams_data_train-100_6abb23dc-cb9d-428e-9079-e47deee0edd9 --model mistralai/Mistral-7B-v0.1 --data_path ./fine_tune_data/ --train_split train_data --valid_split validation_data --repo_id ai-aerospace/ams-data-train-100-4601c8c8-0903-4f18-a6e8-1d2a40a697ce --push_to_hub --token HUGGINGFACE_TOKEN --block_size 2242 --model_max_length 1121 --logging_steps -1 --evaluation_strategy epoch --save_total_limit 1 --save_strategy epoch --fp16 --lr 3e-05 --num_train_epochs 3 --batch_size 2 --warmup_ratio 0.1 --gradient_accumulation 1 --optimizer adamw_torch --scheduler linear --weight_decay 0 --max_grad_norm 1 --seed 42 --use_int4 --use-peft --lora_r 16 --lora_alpha 32 --lora_dropout 0.05\n' returned non-zero exit status 2.",
|
219 |
+
"output_type": "error",
|
220 |
+
"traceback": [
|
221 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
222 |
+
"\u001b[0;31mCalledProcessError\u001b[0m Traceback (most recent call last)",
|
223 |
+
"Cell \u001b[0;32mIn[49], line 40\u001b[0m\n\u001b[1;32m 4\u001b[0m command\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\"\"\u001b[39m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;124mautotrain llm --train \u001b[39m\u001b[38;5;130;01m\\\u001b[39;00m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;124m --trainer sft \u001b[39m\u001b[38;5;130;01m\\\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;124m --lora_dropout \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_params[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlora_dropout\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;124m\"\"\"\u001b[39m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;66;03m# Use subprocess.run() to execute the command\u001b[39;00m\n\u001b[0;32m---> 40\u001b[0m \u001b[43msubprocess\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcommand\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshell\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcheck\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
|
224 |
+
"File \u001b[0;32m/usr/lib/python3.11/subprocess.py:571\u001b[0m, in \u001b[0;36mrun\u001b[0;34m(input, capture_output, timeout, check, *popenargs, **kwargs)\u001b[0m\n\u001b[1;32m 569\u001b[0m retcode \u001b[38;5;241m=\u001b[39m process\u001b[38;5;241m.\u001b[39mpoll()\n\u001b[1;32m 570\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check \u001b[38;5;129;01mand\u001b[39;00m retcode:\n\u001b[0;32m--> 571\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m CalledProcessError(retcode, process\u001b[38;5;241m.\u001b[39margs,\n\u001b[1;32m 572\u001b[0m output\u001b[38;5;241m=\u001b[39mstdout, stderr\u001b[38;5;241m=\u001b[39mstderr)\n\u001b[1;32m 573\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m CompletedProcess(process\u001b[38;5;241m.\u001b[39margs, retcode, stdout, stderr)\n",
|
225 |
+
"\u001b[0;31mCalledProcessError\u001b[0m: Command '\nautotrain llm --train --trainer sft --project_name ./llms/ams_data_train-100_6abb23dc-cb9d-428e-9079-e47deee0edd9 --model mistralai/Mistral-7B-v0.1 --data_path ./fine_tune_data/ --train_split train_data --valid_split validation_data --repo_id ai-aerospace/ams-data-train-100-4601c8c8-0903-4f18-a6e8-1d2a40a697ce --push_to_hub --token HUGGINGFACE_TOKEN --block_size 2242 --model_max_length 1121 --logging_steps -1 --evaluation_strategy epoch --save_total_limit 1 --save_strategy epoch --fp16 --lr 3e-05 --num_train_epochs 3 --batch_size 2 --warmup_ratio 0.1 --gradient_accumulation 1 --optimizer adamw_torch --scheduler linear --weight_decay 0 --max_grad_norm 1 --seed 42 --use_int4 --use-peft --lora_r 16 --lora_alpha 32 --lora_dropout 0.05\n' returned non-zero exit status 2."
|
226 |
+
]
|
227 |
+
}
|
228 |
+
],
|
229 |
+
"source": [
|
230 |
+
"\n",
|
231 |
+
"# Set .venv and execute the autotrain script\n",
|
232 |
+
"# To see all parameters: autotrain llm --help\n",
|
233 |
+
"# !autotrain llm --train --project_name my-llm --model TinyLlama/TinyLlama-1.1B-Chat-v0.1 --data_path . --use-peft --use_int4 --learning_rate 2e-4 --train_batch_size 6 --num_train_epochs 3 --trainer sft\n",
|
234 |
+
"command=f\"\"\"\n",
|
235 |
+
"autotrain llm --train \\\n",
|
236 |
+
" --trainer sft \\\n",
|
237 |
+
" --project_name {model_params['project_name']} \\\n",
|
238 |
+
" --model {model_params['model_name']} \\\n",
|
239 |
+
" --data_path {model_params['data_directory']} \\\n",
|
240 |
+
" --train_split {model_params['train_data']} \\\n",
|
241 |
+
" --valid_split {model_params['validation_data']} \\\n",
|
242 |
+
" --repo_id {model_params['repo_id']} \\\n",
|
243 |
+
" --push_to_hub \\\n",
|
244 |
+
" --token HUGGINGFACE_TOKEN \\\n",
|
245 |
+
" --block_size {model_params['block_size']} \\\n",
|
246 |
+
" --model_max_length {model_params['model_max_length']} \\\n",
|
247 |
+
" --logging_steps {model_params['logging_steps']} \\\n",
|
248 |
+
" --evaluation_strategy {model_params['evaluation_strategy']} \\\n",
|
249 |
+
" --save_total_limit {model_params['save_total_limit']} \\\n",
|
250 |
+
" --save_strategy {model_params['save_strategy']} \\\n",
|
251 |
+
" --fp16 \\\n",
|
252 |
+
" --lr {model_params['lr']} \\\n",
|
253 |
+
" --num_train_epochs {model_params['epochs']} \\\n",
|
254 |
+
" --train_batch_size {model_params['batch_size']} \\\n",
|
255 |
+
" --warmup_ratio {model_params['warmup_ratio']} \\\n",
|
256 |
+
" --gradient_accumulation {model_params['gradient_accumulation']} \\\n",
|
257 |
+
" --optimizer {model_params['optimizer']} \\\n",
|
258 |
+
" --scheduler linear \\\n",
|
259 |
+
" --weight_decay {model_params['weight_decay']} \\\n",
|
260 |
+
" --max_grad_norm {model_params['max_grad_norm']} \\\n",
|
261 |
+
" --seed {model_params['seed']} \\\n",
|
262 |
+
" --use_int4 \\\n",
|
263 |
+
" --use-peft \\\n",
|
264 |
+
" --lora_r {model_params['lora_r']} \\\n",
|
265 |
+
" --lora_alpha {model_params['lora_alpha']} \\\n",
|
266 |
+
" --lora_dropout {model_params['lora_dropout']}\n",
|
267 |
+
"\"\"\"\n",
|
268 |
+
"\n",
|
269 |
+
"# Use subprocess.run() to execute the command\n",
|
270 |
+
"subprocess.run(command, shell=True, check=True)"
|
271 |
+
]
|
272 |
+
},
|
273 |
+
{
|
274 |
+
"cell_type": "code",
|
275 |
+
"execution_count": null,
|
276 |
+
"metadata": {},
|
277 |
+
"outputs": [],
|
278 |
+
"source": []
|
279 |
+
}
|
280 |
+
],
|
281 |
+
"metadata": {
|
282 |
+
"kernelspec": {
|
283 |
+
"display_name": ".venv",
|
284 |
+
"language": "python",
|
285 |
+
"name": "python3"
|
286 |
+
},
|
287 |
+
"language_info": {
|
288 |
+
"codemirror_mode": {
|
289 |
+
"name": "ipython",
|
290 |
+
"version": 3
|
291 |
+
},
|
292 |
+
"file_extension": ".py",
|
293 |
+
"mimetype": "text/x-python",
|
294 |
+
"name": "python",
|
295 |
+
"nbconvert_exporter": "python",
|
296 |
+
"pygments_lexer": "ipython3",
|
297 |
+
"version": "3.11.7"
|
298 |
+
}
|
299 |
+
},
|
300 |
+
"nbformat": 4,
|
301 |
+
"nbformat_minor": 2
|
302 |
+
}
|
train_llm.py
CHANGED
@@ -4,10 +4,30 @@ import pandas as pd
|
|
4 |
|
5 |
from datasets import load_dataset
|
6 |
import subprocess
|
|
|
7 |
|
|
|
8 |
# from dotenv import load_dotenv,find_dotenv
|
9 |
# load_dotenv(find_dotenv(),override=True)
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
# Load dataset
|
12 |
dataset_name = 'ai-aerospace/ams_data_train_generic_v0.1_100'
|
13 |
dataset=load_dataset(dataset_name)
|
@@ -21,54 +41,99 @@ os.makedirs(data_directory, exist_ok=True)
|
|
21 |
# Write the train data to a CSV file
|
22 |
train_data='train_data'
|
23 |
train_filename = os.path.join(data_directory, train_data)
|
24 |
-
dataset['train'].to_pandas().to_csv(train_filename, columns=['text'], index=False)
|
|
|
|
|
25 |
|
26 |
# Write the validation data to a CSV file
|
27 |
validation_data='validation_data'
|
28 |
validation_filename = os.path.join(data_directory, validation_data)
|
29 |
-
dataset['validation'].to_pandas().to_csv(validation_filename, columns=['text'], index=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
# Define project parameters
|
32 |
username='ai-aerospace'
|
33 |
project_name='./llms/'+'ams_data_train-100_'+str(uuid4())
|
34 |
repo_name='ams-data-train-100-'+str(uuid4())
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
|
|
54 |
# Set .venv and execute the autotrain script
|
55 |
# To see all parameters: autotrain llm --help
|
56 |
# !autotrain llm --train --project_name my-llm --model TinyLlama/TinyLlama-1.1B-Chat-v0.1 --data_path . --use-peft --use_int4 --learning_rate 2e-4 --train_batch_size 6 --num_train_epochs 3 --trainer sft
|
57 |
-
command="""
|
58 |
autotrain llm --train \
|
59 |
-
--project_name ${project_name} \
|
60 |
-
--model ${model_name} \
|
61 |
-
--data_path ${data_directory} \
|
62 |
-
--train_split ${train_data} \
|
63 |
-
--valid_split ${validation_data} \
|
64 |
-
--use-peft \
|
65 |
-
--learning_rate 2e-4 \
|
66 |
-
--train_batch_size 6 \
|
67 |
-
--num_train_epochs 3 \
|
68 |
--trainer sft \
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
--push_to_hub \
|
70 |
-
--
|
71 |
-
--
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
"""
|
73 |
|
74 |
# Use subprocess.run() to execute the command
|
|
|
4 |
|
5 |
from datasets import load_dataset
|
6 |
import subprocess
|
7 |
+
from transformers import AutoTokenizer
|
8 |
|
9 |
+
### Read environment variables
|
10 |
# from dotenv import load_dotenv,find_dotenv
|
11 |
# load_dotenv(find_dotenv(),override=True)
|
12 |
|
13 |
+
### Functions
|
14 |
+
def max_token_len(dataset):
|
15 |
+
max_seq_length = 0
|
16 |
+
for row in dataset:
|
17 |
+
tokens = len(tokenizer(row['text'])['input_ids'])
|
18 |
+
if tokens > max_seq_length:
|
19 |
+
max_seq_length = tokens
|
20 |
+
return max_seq_length
|
21 |
+
|
22 |
+
### Model details
|
23 |
+
# model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1'
|
24 |
+
model_name = 'mistralai/Mistral-7B-v0.1'
|
25 |
+
# model_name = 'distilbert-base-uncased'
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
27 |
+
model_max_length = tokenizer.model_max_length
|
28 |
+
print("Model Max Length:", model_max_length)
|
29 |
+
|
30 |
+
### Repo name, dataset initialization, and data directory
|
31 |
# Load dataset
|
32 |
dataset_name = 'ai-aerospace/ams_data_train_generic_v0.1_100'
|
33 |
dataset=load_dataset(dataset_name)
|
|
|
41 |
# Write the train data to a CSV file
|
42 |
train_data='train_data'
|
43 |
train_filename = os.path.join(data_directory, train_data)
|
44 |
+
dataset['train'].to_pandas().to_csv(train_filename+'.csv', columns=['text'], index=False)
|
45 |
+
max_token_length_train=max_token_len(dataset['train'])
|
46 |
+
print('Max token length train: '+str(max_token_length_train))
|
47 |
|
48 |
# Write the validation data to a CSV file
|
49 |
validation_data='validation_data'
|
50 |
validation_filename = os.path.join(data_directory, validation_data)
|
51 |
+
dataset['validation'].to_pandas().to_csv(validation_filename+'.csv', columns=['text'], index=False)
|
52 |
+
max_token_length_validation=max_token_len(dataset['validation'])
|
53 |
+
print('Max token length validation: '+str(max_token_length_validation))
|
54 |
+
|
55 |
+
max_token_length=max(max_token_length_train,max_token_length_validation)
|
56 |
+
if max_token_length > model_max_length:
|
57 |
+
raise ValueError("Maximum token length exceeds model limits.")
|
58 |
+
block_size=2*max_token_length
|
59 |
|
60 |
# Define project parameters
|
61 |
username='ai-aerospace'
|
62 |
project_name='./llms/'+'ams_data_train-100_'+str(uuid4())
|
63 |
repo_name='ams-data-train-100-'+str(uuid4())
|
64 |
|
65 |
+
### Set training params
|
66 |
+
model_params={
|
67 |
+
"project_name": project_name,
|
68 |
+
"model_name": model_name,
|
69 |
+
"repo_id": username+'/'+repo_name,
|
70 |
+
"train_data": train_data,
|
71 |
+
"validation_data": validation_data,
|
72 |
+
"data_directory": data_directory,
|
73 |
+
"block_size": block_size,
|
74 |
+
"model_max_length": max_token_length,
|
75 |
+
"logging_steps": -1,
|
76 |
+
"evaluation_strategy": "epoch",
|
77 |
+
"save_total_limit": 1,
|
78 |
+
"save_strategy": "epoch",
|
79 |
+
"mixed_precision": "fp16",
|
80 |
+
"lr": 0.00003,
|
81 |
+
"epochs": 3,
|
82 |
+
"batch_size": 2,
|
83 |
+
"warmup_ratio": 0.1,
|
84 |
+
"gradient_accumulation": 1,
|
85 |
+
"optimizer": "adamw_torch",
|
86 |
+
"scheduler": "linear",
|
87 |
+
"weight_decay": 0,
|
88 |
+
"max_grad_norm": 1,
|
89 |
+
"seed": 42,
|
90 |
+
"quantization": "int4",
|
91 |
+
"target_modules": "",
|
92 |
+
"lora_r": 16,
|
93 |
+
"lora_alpha": 32,
|
94 |
+
"lora_dropout": 0.05
|
95 |
+
}
|
96 |
+
for key, value in model_params.items():
|
97 |
+
os.environ[key] = str(value)
|
98 |
|
99 |
+
### Feed into and run autotrain command
|
100 |
# Set .venv and execute the autotrain script
|
101 |
# To see all parameters: autotrain llm --help
|
102 |
# !autotrain llm --train --project_name my-llm --model TinyLlama/TinyLlama-1.1B-Chat-v0.1 --data_path . --use-peft --use_int4 --learning_rate 2e-4 --train_batch_size 6 --num_train_epochs 3 --trainer sft
|
103 |
+
command=f"""
|
104 |
autotrain llm --train \
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
--trainer sft \
|
106 |
+
--project_name {model_params['project_name']} \
|
107 |
+
--model {model_params['model_name']} \
|
108 |
+
--data_path {model_params['data_directory']} \
|
109 |
+
--train_split {model_params['train_data']} \
|
110 |
+
--valid_split {model_params['validation_data']} \
|
111 |
+
--repo_id {model_params['repo_id']} \
|
112 |
--push_to_hub \
|
113 |
+
--token HUGGINGFACE_TOKEN
|
114 |
+
--block_size {model_params['block_size']} \
|
115 |
+
--model_max_length {model_params['model_max_length']} \
|
116 |
+
--logging_steps {model_params['logging_steps']} \
|
117 |
+
--evaluation_strategy {model_params['evaluation_strategy']} \
|
118 |
+
--save_total_limit {model_params['save_total_limit']} \
|
119 |
+
--save_strategy {model_params['save_strategy']} \
|
120 |
+
--fp16 \
|
121 |
+
--lr {model_params['lr']} \
|
122 |
+
--num_train_epochs {model_params['lr']} \
|
123 |
+
--batch_size {model_params['batch_size']} \
|
124 |
+
--warmup_ratio {model_params['warmup_ratio']} \
|
125 |
+
--gradient_accumulation {model_params['gradient_accumulation']} \
|
126 |
+
--optimizer {model_params['gradient_accumulation']} \
|
127 |
+
--scheduler linear \
|
128 |
+
--weight_decay {model_params['weight_decay']} \
|
129 |
+
--max_grad_norm {model_params['max_grad_norm']} \
|
130 |
+
--seed {model_params['seed']} \
|
131 |
+
--use_int4 \
|
132 |
+
--target_modules {model_params['target_modules']} \
|
133 |
+
--use-peft \
|
134 |
+
--lora_r {model_params['lora_r']} \
|
135 |
+
--lora_alpha {model_params['lora_alpha']} \
|
136 |
+
--lora_dropout {model_params['lora_dropout']}
|
137 |
"""
|
138 |
|
139 |
# Use subprocess.run() to execute the command
|