asigalov61 commited on
Commit
35dfe93
1 Parent(s): 3c83c87

Upload MIDIstral_pixtral_fine_tune_code.ipynb

Browse files
code/MIDIstral_pixtral_fine_tune_code.ipynb ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "***\n",
8
+ "\n",
9
+ "# MIDIstral Pixtral 12B Fine-Tuning Code\n",
10
+ "\n",
11
+ "***\n",
12
+ "\n",
13
+ "## Based upon fine-tuning code by Tomasz Stankiewicz\n",
14
+ "\n",
15
+ "## https://github.com/tomstaan/Clarivex-Pixtral-12B\n",
16
+ "\n",
17
+ "***\n",
18
+ "\n",
19
+ "### Project Los Angeles\n",
20
+ "### Tegridy Code 2024\n",
21
+ "\n",
22
+ "***"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "markdown",
27
+ "metadata": {},
28
+ "source": [
29
+ "# Setup"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "!python3 -m pip install --upgrade pip -q\n",
39
+ "!pip3 install -U transformers\n",
40
+ "!pip3 install -q accelerate datasets peft bitsandbytes hf_transfer flash_attn tensorboard\n",
41
+ "!pip3 install ipywidgets\n",
42
+ "!pip3 install --upgrade jinja2\n",
43
+ "!pip3 install --upgrade peft\n",
44
+ "!pip3 install -U pillow\n",
45
+ "!pip3 install pip install tf-keras\n",
46
+ "\n",
47
+ "# Can be a good idea to re-start the kernel after this"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": null,
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "!sudo pip3 install tf-keras"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "!sudo pip install -U numpy==1.26.1"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": null,
71
+ "metadata": {},
72
+ "outputs": [],
73
+ "source": [
74
+ "# Enable fast weights download and upload\n",
75
+ "import os\n",
76
+ "os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\""
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "markdown",
81
+ "metadata": {},
82
+ "source": [
83
+ "# Download model"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "import torch\n",
93
+ "from PIL import Image\n",
94
+ "from transformers import AutoProcessor, LlavaForConditionalGeneration\n",
95
+ "from transformers import BitsAndBytesConfig\n",
96
+ "\n",
97
+ "model_id = \"mistral-community/pixtral-12b\"\n",
98
+ "\n",
99
+ "model = LlavaForConditionalGeneration.from_pretrained(\n",
100
+ " model_id,\n",
101
+ " torch_dtype=torch.bfloat16,\n",
102
+ " device_map='auto',\n",
103
+ " #attn_implementation=\"sdpa\",\n",
104
+ ")\n",
105
+ "\n",
106
+ "processor = AutoProcessor.from_pretrained(model_id)\n",
107
+ "\n",
108
+ "# Extract the tokenizer from the processor\n",
109
+ "tokenizer = processor.tokenizer\n",
110
+ "\n",
111
+ "# Set the padding side to 'left' for Flash Attention compatibility\n",
112
+ "tokenizer.padding_side = \"left\""
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "markdown",
117
+ "metadata": {},
118
+ "source": [
119
+ "# Chat Template"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": null,
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "CHAT_TEMPLATE = \"\"\"\n",
129
+ "{%- for message in messages %} \n",
130
+ " {%- if message.role == \"user\" %} \n",
131
+ " <s>[INST] \n",
132
+ " {%- for item in message.content %} \n",
133
+ " {%- if item.type == \"text\" %} \n",
134
+ " {{ item.text }} \n",
135
+ " {%- elif item.type == \"image\" %} \n",
136
+ " \\n[IMG] \n",
137
+ " {%- endif %} \n",
138
+ " {%- endfor %} \n",
139
+ " [/INST] \n",
140
+ " {%- elif message.role == \"assistant\" %} \n",
141
+ " {%- for item in message.content %} \n",
142
+ " {%- if item.type == \"text\" %} \n",
143
+ " {{ item.text }} \n",
144
+ " {%- endif %} \n",
145
+ " {%- endfor %} \n",
146
+ " </s>\n",
147
+ " {%- endif %} \n",
148
+ "{%- endfor %} \n",
149
+ "\"\"\"\n",
150
+ "\n",
151
+ "# Set the chat template for the tokenizer\n",
152
+ "processor.chat_template = CHAT_TEMPLATE.replace(' ', '')\n",
153
+ "\n",
154
+ "processor.tokenizer.pad_token = processor.tokenizer.eos_token"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": null,
160
+ "metadata": {},
161
+ "outputs": [],
162
+ "source": [
163
+ "# Example conversation input with user and assistant roles\n",
164
+ "messages = [\n",
165
+ " {\n",
166
+ " \"role\": \"user\",\n",
167
+ " \"content\": [\n",
168
+ " {\"type\": \"text\", \"text\": \"Please describe the song music in detail. Thank you.\"},\n",
169
+ " {\"type\": \"image\"}\n",
170
+ " ]\n",
171
+ " },\n",
172
+ " {\n",
173
+ " \"role\": \"assistant\",\n",
174
+ " \"content\": [\n",
175
+ " {\"type\": \"text\", \"text\": \"The song 'Man In Black' by Johnny Cash in key A# has fast tempo and average pace with Acoustic Guitar(steel) lead, accompanying Acoustic Grand and predominant Acoustic Snare drums\"}\n",
176
+ " ]\n",
177
+ " }\n",
178
+ "]\n",
179
+ "\n",
180
+ "# Apply the chat template to format the messages\n",
181
+ "formatted_text = processor.apply_chat_template(messages, add_generation_prompt=False)\n",
182
+ "\n",
183
+ "# Output the formatted text\n",
184
+ "print(\"Formatted text:\\n\", formatted_text)"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "markdown",
189
+ "metadata": {},
190
+ "source": [
191
+ "# Download dataset"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": null,
197
+ "metadata": {},
198
+ "outputs": [],
199
+ "source": [
200
+ "from PIL import Image\n",
201
+ "import io\n",
202
+ "from datasets import load_dataset\n",
203
+ "\n",
204
+ "def deserialize_image(byte_data):\n",
205
+ " img_byte_arr = io.BytesIO(byte_data)\n",
206
+ " img = Image.open(img_byte_arr)\n",
207
+ " return img\n",
208
+ "\n",
209
+ "dataset = load_dataset(\"asigalov61/MIDIstral\", split='train').train_test_split(test_size=0.001)\n",
210
+ "\n",
211
+ "# Access the training and test sets\n",
212
+ "train_dataset = dataset[\"train\"]\n",
213
+ "eval_dataset = dataset[\"test\"]"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": null,
219
+ "metadata": {},
220
+ "outputs": [],
221
+ "source": [
222
+ "len(train_dataset)"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "code",
227
+ "execution_count": null,
228
+ "metadata": {},
229
+ "outputs": [],
230
+ "source": [
231
+ "eval_dataset[0]"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "markdown",
236
+ "metadata": {},
237
+ "source": [
238
+ "# Evaluation before fine-tuning"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": null,
244
+ "metadata": {},
245
+ "outputs": [],
246
+ "source": [
247
+ "import torch\n",
248
+ "from PIL import Image\n",
249
+ "from torchvision.transforms.functional import to_pil_image, resize\n",
250
+ "\n",
251
+ "def run_model_evaluation(model, dataset, num_samples=None, device='cuda', constant_query=None):\n",
252
+ " model.eval()\n",
253
+ " results = []\n",
254
+ "\n",
255
+ " # Limit the dataset if a specific number of samples is provided\n",
256
+ " if num_samples is not None:\n",
257
+ " dataset = torch.utils.data.Subset(dataset, range(num_samples))\n",
258
+ "\n",
259
+ " for example in dataset:\n",
260
+ " image = deserialize_image(example[\"image\"])\n",
261
+ " if constant_query is None:\n",
262
+ " query = example[\"query\"][\"en\"]\n",
263
+ " else:\n",
264
+ " query = constant_query # Use the constant query if provided\n",
265
+ " \n",
266
+ " # Display a reduced size version of the image\n",
267
+ " pil_image = image\n",
268
+ " aspect_ratio = pil_image.width / pil_image.height\n",
269
+ " new_width = 300\n",
270
+ " new_height = int(new_width / aspect_ratio)\n",
271
+ " display_image = resize(pil_image, (new_height, new_width))\n",
272
+ " display_image.show() # This will open the image in the default image viewer\n",
273
+ "\n",
274
+ " # Construct the message template\n",
275
+ " messages = [\n",
276
+ " {\n",
277
+ " \"role\": \"user\",\n",
278
+ " \"content\": [\n",
279
+ " # {\"type\": \"text\", \"text\": \"Answer briefly.\"},\n",
280
+ " {\"type\": \"text\", \"text\": query},\n",
281
+ " {\"type\": \"image\"}, # YOU CAN COMMENT THIS OUT IF THERE ARE NO IMAGES\n",
282
+ " # {\"type\": \"image\"}, # ADD A SECOND IMAGE!!! Note that the text is also possible here.\n",
283
+ " ]\n",
284
+ " }\n",
285
+ " ]\n",
286
+ "\n",
287
+ " # Apply the chat template to preprocess input\n",
288
+ " formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n",
289
+ " print(f\"Formatted prompt: {formatted_prompt}\")\n",
290
+ " text = processor.apply_chat_template(messages, add_generation_prompt=True)\n",
291
+ " inputs = processor(text=[text.strip()], images=[image], return_tensors=\"pt\", padding=True).to(device)\n",
292
+ " # inputs = processor(text=[text.strip()], images=[image, image2], return_tensors=\"pt\" padding=True).to(device)\n",
293
+ "\n",
294
+ " # Generate output from the model\n",
295
+ " generated_ids = model.generate(**inputs, max_new_tokens=64)\n",
296
+ " generated_texts = processor.batch_decode(generated_ids[:, inputs[\"input_ids\"].shape[-1]:])\n",
297
+ "\n",
298
+ " print(f\"Prediction: {generated_texts[0]}\\n\")\n",
299
+ "\n",
300
+ " results.append(generated_texts[0]) # Store the result\n",
301
+ "\n",
302
+ " return results\n",
303
+ "\n",
304
+ "\n"
305
+ ]
306
+ },
307
+ {
308
+ "cell_type": "code",
309
+ "execution_count": null,
310
+ "metadata": {},
311
+ "outputs": [],
312
+ "source": [
313
+ "# Usage\n",
314
+ "eval_results_before_fine_tuning = run_model_evaluation(model, \n",
315
+ " eval_dataset, \n",
316
+ " num_samples=2, \n",
317
+ " device='cuda', \n",
318
+ " constant_query='Please describe the song music in detail. Thank you.')\n",
319
+ "\n",
320
+ "print('eval_results_before_fine_tuning:', eval_results_before_fine_tuning)"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "markdown",
325
+ "metadata": {},
326
+ "source": [
327
+ "# Fine-tuning"
328
+ ]
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "execution_count": null,
333
+ "metadata": {},
334
+ "outputs": [],
335
+ "source": [
336
+ "import torch\n",
337
+ "\n",
338
+ "class MyDataCollator:\n",
339
+ " def __init__(self, processor):\n",
340
+ " self.processor = processor\n",
341
+ "\n",
342
+ " def __call__(self, examples):\n",
343
+ " texts = []\n",
344
+ " images = []\n",
345
+ " assistant_responses = [] # To track assistant responses for proper masking\n",
346
+ " for example in examples:\n",
347
+ " image = deserialize_image(example[\"image\"])\n",
348
+ " question = example[\"question\"] # for chess dataset\n",
349
+ " answer = example[\"answer\"] # for chess dataset\n",
350
+ "\n",
351
+ " messages = [\n",
352
+ " {\n",
353
+ " \"role\": \"user\",\n",
354
+ " \"content\": [\n",
355
+ " {\"type\": \"text\", \"text\": question},\n",
356
+ " {\"type\": \"image\"}, # Images after the text.\n",
357
+ " ]\n",
358
+ " },\n",
359
+ " {\n",
360
+ " \"role\": \"assistant\",\n",
361
+ " \"content\": [\n",
362
+ " {\"type\": \"text\", \"text\": answer}\n",
363
+ " ]\n",
364
+ " }\n",
365
+ " ]\n",
366
+ "\n",
367
+ " # Convert messages to the desired text format using processor's template\n",
368
+ " text = self.processor.apply_chat_template(messages, add_generation_prompt=False)\n",
369
+ "\n",
370
+ " texts.append(text.strip())\n",
371
+ " images.append([image])\n",
372
+ " assistant_responses.append(answer) # Track assistant's response for later use\n",
373
+ "\n",
374
+ " # Tokenize and process batch\n",
375
+ " batch = self.processor(text=texts, images=images, return_tensors=\"pt\", padding=True)\n",
376
+ "\n",
377
+ " # Prepare labels; we will mask non-assistant tokens for generation\n",
378
+ " labels = batch[\"input_ids\"].clone() \n",
379
+ "\n",
380
+ " # For each example, find assistant tokens and mask everything else\n",
381
+ " for i, (input_ids, assistant_response) in enumerate(zip(batch[\"input_ids\"], assistant_responses)):\n",
382
+ " # Tokenize just the assistant response\n",
383
+ " assistant_tokens = self.processor.tokenizer(assistant_response, return_tensors=\"pt\")[\"input_ids\"][0]\n",
384
+ "\n",
385
+ " # Find where the assistant tokens start in the input sequence\n",
386
+ " start_idx = self.find_subsequence(input_ids, assistant_tokens)\n",
387
+ "\n",
388
+ " if start_idx is not None:\n",
389
+ " # Mask everything except the assistant tokens\n",
390
+ " labels[i, :start_idx] = -100 # Ignore everything before the assistant's response\n",
391
+ " labels[i, start_idx + len(assistant_tokens):] = -100 # Ignore everything after\n",
392
+ "\n",
393
+ " # Assign masked labels back to the batch\n",
394
+ " batch[\"labels\"] = labels\n",
395
+ "\n",
396
+ " return batch\n",
397
+ " \n",
398
+ " def find_subsequence(self, sequence, subsequence):\n",
399
+ " \"\"\"\n",
400
+ " Find the start index of a subsequence (assistant tokens) in a sequence (input tokens).\n",
401
+ " \"\"\"\n",
402
+ " seq_len = len(sequence)\n",
403
+ " sub_len = len(subsequence)\n",
404
+ "\n",
405
+ " for i in range(seq_len - sub_len + 1):\n",
406
+ " if torch.equal(sequence[i:i + sub_len], subsequence):\n",
407
+ " return i\n",
408
+ " return None\n",
409
+ " \n",
410
+ "data_collator = MyDataCollator(processor)"
411
+ ]
412
+ },
413
+ {
414
+ "cell_type": "code",
415
+ "execution_count": null,
416
+ "metadata": {},
417
+ "outputs": [],
418
+ "source": [
419
+ "import torch\n",
420
+ "\n",
421
+ "# Select a small batch of examples (e.g., 2 examples for quick testing)\n",
422
+ "sample_batch = [train_dataset[i] for i in range(2)]\n",
423
+ "\n",
424
+ "# Call the data collator with the sample batch to process it\n",
425
+ "processed_batch = data_collator(sample_batch)\n",
426
+ "\n",
427
+ "# Print the processed batch keys to check what's inside\n",
428
+ "print(\"Processed batch keys:\", processed_batch.keys())\n",
429
+ "\n",
430
+ "# Print out the texts after applying the chat template\n",
431
+ "print(\"\\nTokenized input IDs (before padding):\")\n",
432
+ "print(processed_batch[\"input_ids\"])"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "code",
437
+ "execution_count": null,
438
+ "metadata": {},
439
+ "outputs": [],
440
+ "source": [
441
+ "processed_batch[\"input_ids\"].shape"
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "code",
446
+ "execution_count": null,
447
+ "metadata": {},
448
+ "outputs": [],
449
+ "source": [
450
+ "print(model)"
451
+ ]
452
+ },
453
+ {
454
+ "cell_type": "code",
455
+ "execution_count": null,
456
+ "metadata": {},
457
+ "outputs": [],
458
+ "source": [
459
+ "from peft import LoraConfig\n",
460
+ "\n",
461
+ "lora_config = LoraConfig(\n",
462
+ " r=32, # Rank (usually 8, 16, or 32 depending on model size and needs)\n",
463
+ " lora_alpha=32, # Scaling factor for the low-rank updates\n",
464
+ " use_rslora=True, # Use RS LoRA for regularization\n",
465
+ " target_modules=\"all-linear\", # Target specific modules (e.g., linear layers)\n",
466
+ " # modules_to_save=['lm_head','embed_tokens'],\n",
467
+ " lora_dropout=0.1, # Dropout for low-rank adapter layers\n",
468
+ " bias=\"none\", # Bias in adapter layers: \"none\", \"all\", \"lora_only\"\n",
469
+ " task_type=\"CAUSAL_LM\" # Task type: \"CAUSAL_LM\", \"SEQ_2_SEQ_LM\", or \"TOKEN_CLS\"\n",
470
+ ")"
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "code",
475
+ "execution_count": null,
476
+ "metadata": {},
477
+ "outputs": [],
478
+ "source": [
479
+ "from peft import get_peft_model\n",
480
+ "\n",
481
+ "model=get_peft_model(model, lora_config)"
482
+ ]
483
+ },
484
+ {
485
+ "cell_type": "code",
486
+ "execution_count": null,
487
+ "metadata": {},
488
+ "outputs": [],
489
+ "source": [
490
+ "model.print_trainable_parameters()"
491
+ ]
492
+ },
493
+ {
494
+ "cell_type": "code",
495
+ "execution_count": null,
496
+ "metadata": {},
497
+ "outputs": [],
498
+ "source": [
499
+ "from transformers import TrainingArguments, Trainer\n",
500
+ "\n",
501
+ "# for main fine-tuning\n",
502
+ "epochs = 1\n",
503
+ "lr = 3e-5\n",
504
+ "schedule = \"constant\"\n",
505
+ "\n",
506
+ "# Optional, for annealing\n",
507
+ "# epochs = 0.4\n",
508
+ "# lr = 3e-5\n",
509
+ "# schedule = \"linear\"\n",
510
+ "\n",
511
+ "run_name = f\"MIDIstral-{lr}_lr-{epochs}_epochs-{schedule}_schedule\"\n",
512
+ "\n",
513
+ "training_args = TrainingArguments(\n",
514
+ " # max_steps=1, # Optional: run only for one step, useful for debugging\n",
515
+ " num_train_epochs=epochs, # Number of training epochs\n",
516
+ " per_device_train_batch_size=8, # Batch size per device for training\n",
517
+ " per_device_eval_batch_size=8, # Batch size per device for evaluation\n",
518
+ " gradient_accumulation_steps=1, # Number of steps to accumulate gradients before updating\n",
519
+ " # warmup_steps=10, # Optional: number of warmup steps (uncomment if needed)\n",
520
+ " learning_rate=lr, # Learning rate for the optimizer\n",
521
+ " weight_decay=0.01, # Weight decay to apply (for regularization)\n",
522
+ " logging_steps=0.001, # Log training progress every 0.1 steps\n",
523
+ " output_dir=\"MIDIstral_pixtral\", # Directory where the fine-tuned model will be saved. Make sure it has pixtral in a name\n",
524
+ " eval_strategy=\"steps\", # Strategy for evaluation: perform evaluation every few steps\n",
525
+ " eval_steps=0.02, # Perform evaluation every 0.2 steps (relative to total steps)\n",
526
+ " lr_scheduler_type=schedule, # Set learning rate scheduler type\n",
527
+ " # save_strategy=\"steps\", # Optional: save model every few steps (commented out)\n",
528
+ " # save_steps=250, # Optional: how many steps between saves (commented out)\n",
529
+ " # save_total_limit=1, # Optional: total number of checkpoints to keep (commented out)\n",
530
+ " bf16=True, # Use bf16 precision for training\n",
531
+ " remove_unused_columns=False, # Do not remove unused columns from the dataset\n",
532
+ " report_to=\"tensorboard\", # Report results to TensorBoard for visualization\n",
533
+ " run_name=run_name, # Set the run name for tracking experiments\n",
534
+ " logging_dir=f\"./logs/{run_name}\", # Directory for logging\n",
535
+ " gradient_checkpointing=True, # Enable gradient checkpointing to save VRAM\n",
536
+ " gradient_checkpointing_kwargs={'use_reentrant': True} # Additional settings for gradient checkpointing\n",
537
+ ")\n",
538
+ "\n",
539
+ "\n",
540
+ "trainer = Trainer(\n",
541
+ " model=model, # The model to be trained\n",
542
+ " args=training_args, # Training arguments defined earlier\n",
543
+ " data_collator=data_collator, # Data collator to handle batches\n",
544
+ " train_dataset=train_dataset, # Training dataset\n",
545
+ " eval_dataset=eval_dataset, # Evaluation dataset for computing loss or metrics\n",
546
+ ")"
547
+ ]
548
+ },
549
+ {
550
+ "cell_type": "code",
551
+ "execution_count": null,
552
+ "metadata": {},
553
+ "outputs": [],
554
+ "source": [
555
+ "trainer.train()"
556
+ ]
557
+ },
558
+ {
559
+ "cell_type": "code",
560
+ "execution_count": null,
561
+ "metadata": {},
562
+ "outputs": [],
563
+ "source": [
564
+ "trainer.save_model('./MIDIstral/')"
565
+ ]
566
+ },
567
+ {
568
+ "cell_type": "code",
569
+ "execution_count": null,
570
+ "metadata": {},
571
+ "outputs": [],
572
+ "source": [
573
+ "trainer.push_to_hub(token='your-auth-token-here')"
574
+ ]
575
+ },
576
+ {
577
+ "cell_type": "code",
578
+ "execution_count": null,
579
+ "metadata": {},
580
+ "outputs": [],
581
+ "source": [
582
+ "processor.push_to_hub(\"asigalov61/MIDIstral_pixtral\", token='your-auth-token-here')"
583
+ ]
584
+ },
585
+ {
586
+ "cell_type": "markdown",
587
+ "metadata": {},
588
+ "source": [
589
+ "# Inference"
590
+ ]
591
+ },
592
+ {
593
+ "cell_type": "code",
594
+ "execution_count": null,
595
+ "metadata": {},
596
+ "outputs": [],
597
+ "source": [
598
+ "from transformers import LlavaForConditionalGeneration, AutoProcessor\n",
599
+ "import torch\n",
600
+ "\n",
601
+ "model = LlavaForConditionalGeneration.from_pretrained(\n",
602
+ " 'asigalov61/MIDIstral_pixtral',\n",
603
+ " torch_dtype=torch.bfloat16, # Adjust dtype if needed\n",
604
+ " device_map='auto'\n",
605
+ ")\n",
606
+ "processor = AutoProcessor.from_pretrained('asigalov61/MIDIstral_pixtral')\n",
607
+ "tokenizer = processor.tokenizer\n",
608
+ "tokenizer.padding_side = \"left\" # For Flash Attention compatibility\n",
609
+ "\n",
610
+ "print(\"Model and processor loaded successfully from checkpoint-30.\")"
611
+ ]
612
+ },
613
+ {
614
+ "cell_type": "markdown",
615
+ "metadata": {},
616
+ "source": [
617
+ "Evaluation"
618
+ ]
619
+ },
620
+ {
621
+ "cell_type": "code",
622
+ "execution_count": null,
623
+ "metadata": {},
624
+ "outputs": [],
625
+ "source": [
626
+ "eval_results_after_fine_tuning = run_model_evaluation(model, eval_dataset, num_samples=5, device='cuda', constant_query='Please write the most appropriate lyrics for the song. Thank you.')\n",
627
+ "\n",
628
+ "print('eval_results_before_fine_tuning:', eval_results_before_fine_tuning)\n",
629
+ "print('eval_results_after_fine_tuning:', eval_results_after_fine_tuning)"
630
+ ]
631
+ },
632
+ {
633
+ "cell_type": "code",
634
+ "execution_count": null,
635
+ "metadata": {},
636
+ "outputs": [],
637
+ "source": [
638
+ "eval_dataset[0]"
639
+ ]
640
+ },
641
+ {
642
+ "cell_type": "code",
643
+ "execution_count": null,
644
+ "metadata": {},
645
+ "outputs": [],
646
+ "source": [
647
+ "with open('eval_results.txt', 'w') as f:\n",
648
+ " f.write('eval_results_before_fine_tuning: ' + str(eval_results_before_fine_tuning) + '\\n')\n",
649
+ " f.write('eval_results_after_fine_tuning: ' + str(eval_results_after_fine_tuning) + '\\n')"
650
+ ]
651
+ }
652
+ ],
653
+ "metadata": {
654
+ "kernelspec": {
655
+ "display_name": "Python 3 (ipykernel)",
656
+ "language": "python",
657
+ "name": "python3"
658
+ },
659
+ "language_info": {
660
+ "codemirror_mode": {
661
+ "name": "ipython",
662
+ "version": 3
663
+ },
664
+ "file_extension": ".py",
665
+ "mimetype": "text/x-python",
666
+ "name": "python",
667
+ "nbconvert_exporter": "python",
668
+ "pygments_lexer": "ipython3",
669
+ "version": "3.12.7"
670
+ }
671
+ },
672
+ "nbformat": 4,
673
+ "nbformat_minor": 4
674
+ }