me@hg.co commited on
Commit
17643e1
1 Parent(s): b57b542
Files changed (1) hide show
  1. instruct.ipynb +633 -0
instruct.ipynb ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "63b957e0-d83b-48a6-8ae0-a276b983e181",
6
+ "metadata": {},
7
+ "source": [
8
+ "### Optional: install the necessary packages"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "5569a330-ea6b-4402-9fd5-e7a0ce981bc9",
15
+ "metadata": {},
16
+ "outputs": [
17
+ {
18
+ "name": "stdout",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "Requirement already satisfied: huggingface_hub in /workspace/.miniconda3/lib/python3.12/site-packages (0.26.2)\n",
22
+ "Requirement already satisfied: filelock in /workspace/.miniconda3/lib/python3.12/site-packages (from huggingface_hub) (3.16.1)\n",
23
+ "Requirement already satisfied: fsspec>=2023.5.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from huggingface_hub) (2024.9.0)\n",
24
+ "Requirement already satisfied: packaging>=20.9 in /workspace/.miniconda3/lib/python3.12/site-packages (from huggingface_hub) (24.1)\n",
25
+ "Requirement already satisfied: pyyaml>=5.1 in /workspace/.miniconda3/lib/python3.12/site-packages (from huggingface_hub) (6.0.2)\n",
26
+ "Requirement already satisfied: requests in /workspace/.miniconda3/lib/python3.12/site-packages (from huggingface_hub) (2.32.3)\n",
27
+ "Requirement already satisfied: tqdm>=4.42.1 in /workspace/.miniconda3/lib/python3.12/site-packages (from huggingface_hub) (4.66.5)\n",
28
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /workspace/.miniconda3/lib/python3.12/site-packages (from huggingface_hub) (4.11.0)\n",
29
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /workspace/.miniconda3/lib/python3.12/site-packages (from requests->huggingface_hub) (3.3.2)\n",
30
+ "Requirement already satisfied: idna<4,>=2.5 in /workspace/.miniconda3/lib/python3.12/site-packages (from requests->huggingface_hub) (3.7)\n",
31
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /workspace/.miniconda3/lib/python3.12/site-packages (from requests->huggingface_hub) (2.2.3)\n",
32
+ "Requirement already satisfied: certifi>=2017.4.17 in /workspace/.miniconda3/lib/python3.12/site-packages (from requests->huggingface_hub) (2024.8.30)\n",
33
+ "Note: you may need to restart the kernel to use updated packages.\n",
34
+ "Requirement already satisfied: datasets in /workspace/.miniconda3/lib/python3.12/site-packages (3.1.0)\n",
35
+ "Requirement already satisfied: filelock in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (3.16.1)\n",
36
+ "Requirement already satisfied: numpy>=1.17 in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (2.1.3)\n",
37
+ "Requirement already satisfied: pyarrow>=15.0.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (18.0.0)\n",
38
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (0.3.8)\n",
39
+ "Requirement already satisfied: pandas in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (2.2.3)\n",
40
+ "Requirement already satisfied: requests>=2.32.2 in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (2.32.3)\n",
41
+ "Requirement already satisfied: tqdm>=4.66.3 in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (4.66.5)\n",
42
+ "Requirement already satisfied: xxhash in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (3.5.0)\n",
43
+ "Requirement already satisfied: multiprocess<0.70.17 in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (0.70.16)\n",
44
+ "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.9.0)\n",
45
+ "Requirement already satisfied: aiohttp in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (3.11.2)\n",
46
+ "Requirement already satisfied: huggingface-hub>=0.23.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (0.26.2)\n",
47
+ "Requirement already satisfied: packaging in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (24.1)\n",
48
+ "Requirement already satisfied: pyyaml>=5.1 in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (6.0.2)\n",
49
+ "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (2.4.3)\n",
50
+ "Requirement already satisfied: aiosignal>=1.1.2 in /workspace/.miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (1.3.1)\n",
51
+ "Requirement already satisfied: attrs>=17.3.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (24.2.0)\n",
52
+ "Requirement already satisfied: frozenlist>=1.1.1 in /workspace/.miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (1.5.0)\n",
53
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /workspace/.miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (6.1.0)\n",
54
+ "Requirement already satisfied: propcache>=0.2.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (0.2.0)\n",
55
+ "Requirement already satisfied: yarl<2.0,>=1.17.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (1.17.1)\n",
56
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /workspace/.miniconda3/lib/python3.12/site-packages (from huggingface-hub>=0.23.0->datasets) (4.11.0)\n",
57
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /workspace/.miniconda3/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (3.3.2)\n",
58
+ "Requirement already satisfied: idna<4,>=2.5 in /workspace/.miniconda3/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (3.7)\n",
59
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /workspace/.miniconda3/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (2.2.3)\n",
60
+ "Requirement already satisfied: certifi>=2017.4.17 in /workspace/.miniconda3/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (2024.8.30)\n",
61
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /workspace/.miniconda3/lib/python3.12/site-packages (from pandas->datasets) (2.9.0.post0)\n",
62
+ "Requirement already satisfied: pytz>=2020.1 in /workspace/.miniconda3/lib/python3.12/site-packages (from pandas->datasets) (2024.1)\n",
63
+ "Requirement already satisfied: tzdata>=2022.7 in /workspace/.miniconda3/lib/python3.12/site-packages (from pandas->datasets) (2024.2)\n",
64
+ "Requirement already satisfied: six>=1.5 in /workspace/.miniconda3/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
65
+ "Note: you may need to restart the kernel to use updated packages.\n",
66
+ "Requirement already satisfied: bitsandbytes in /workspace/.miniconda3/lib/python3.12/site-packages (0.44.2.dev0)\n",
67
+ "Requirement already satisfied: torch in /workspace/.miniconda3/lib/python3.12/site-packages (from bitsandbytes) (2.5.1)\n",
68
+ "Requirement already satisfied: numpy in /workspace/.miniconda3/lib/python3.12/site-packages (from bitsandbytes) (2.1.3)\n",
69
+ "Requirement already satisfied: filelock in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (3.16.1)\n",
70
+ "Requirement already satisfied: typing-extensions>=4.8.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (4.11.0)\n",
71
+ "Requirement already satisfied: networkx in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (3.4.2)\n",
72
+ "Requirement already satisfied: jinja2 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (3.1.4)\n",
73
+ "Requirement already satisfied: fsspec in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (2024.9.0)\n",
74
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (12.4.127)\n",
75
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (12.4.127)\n",
76
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (12.4.127)\n",
77
+ "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (9.1.0.70)\n",
78
+ "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (12.4.5.8)\n",
79
+ "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (11.2.1.3)\n",
80
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (10.3.5.147)\n",
81
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (11.6.1.9)\n",
82
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (12.3.1.170)\n",
83
+ "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (2.21.5)\n",
84
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (12.4.127)\n",
85
+ "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (12.4.127)\n",
86
+ "Requirement already satisfied: triton==3.1.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (3.1.0)\n",
87
+ "Requirement already satisfied: setuptools in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (75.1.0)\n",
88
+ "Requirement already satisfied: sympy==1.13.1 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (1.13.1)\n",
89
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from sympy==1.13.1->torch->bitsandbytes) (1.3.0)\n",
90
+ "Requirement already satisfied: MarkupSafe>=2.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from jinja2->torch->bitsandbytes) (2.1.3)\n",
91
+ "Note: you may need to restart the kernel to use updated packages.\n",
92
+ "Note: you may need to restart the kernel to use updated packages.\n",
93
+ "Note: you may need to restart the kernel to use updated packages.\n"
94
+ ]
95
+ }
96
+ ],
97
+ "source": [
98
+ "!git config --global credential.helper store\n",
99
+ "%pip install huggingface_hub\n",
100
+ "%pip install -U datasets\n",
101
+ "%pip install -U bitsandbytes\n",
102
+ "%pip install -q git+https://github.com/huggingface/transformers.git\n",
103
+ "%pip install -q accelerate datasets peft torchvision torchaudio"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "markdown",
108
+ "id": "ca7f5161-0104-45c4-83c9-1dd0dad15e29",
109
+ "metadata": {},
110
+ "source": [
111
+ "## Login on Hugging Face"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": 2,
117
+ "id": "14c445aa-c8bc-43b9-8d18-46d79055e1f0",
118
+ "metadata": {},
119
+ "outputs": [
120
+ {
121
+ "name": "stdout",
122
+ "output_type": "stream",
123
+ "text": [
124
+ "Hugging Face token found in environment variable\n"
125
+ ]
126
+ },
127
+ {
128
+ "name": "stderr",
129
+ "output_type": "stream",
130
+ "text": [
131
+ "Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.\n"
132
+ ]
133
+ }
134
+ ],
135
+ "source": [
136
+ "from huggingface_hub import login\n",
137
+ "import os\n",
138
+ "\n",
139
+ "HF_TOKEN = \"hf_C…………\"\n",
140
+ "\n",
141
+ "if os.environ.get('HF_TOKEN') is not None:\n",
142
+ " HF_TOKEN = os.environ.get('HF_TOKEN')\n",
143
+ " print(f\"Hugging Face token found in environment variable\")\n",
144
+ "try:\n",
145
+ " import google.colab\n",
146
+ " from google.colab import userdata\n",
147
+ " if (userdata.get('HF_TOKEN') is not None) and (HF_TOKEN == \"\"):\n",
148
+ " HF_TOKEN = userdata.get('HF_TOKEN')\n",
149
+ " else:\n",
150
+ " raise ValueError(\"Please set your Hugging Face token in the user data panel, or pass it as an environment variable\")\n",
151
+ "except ModuleNotFoundError:\n",
152
+ " if HF_TOKEN is None:\n",
153
+ " raise ValueError(\"Please set your Hugging Face token in the user data panel, or pass it as an environment variable\")\n",
154
+ "\n",
155
+ "login(\n",
156
+ " token=HF_TOKEN,\n",
157
+ " add_to_git_credential=True\n",
158
+ ")"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "markdown",
163
+ "id": "c33a0325-f056-43c4-a500-6ad1b5632ee5",
164
+ "metadata": {},
165
+ "source": [
166
+ "### Set the environment variables"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": 3,
172
+ "id": "bfd9dbb4-59c4-4d97-9727-5c2d9e60724e",
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": [
176
+ "#source_model_id = \"HuggingFaceM4/Idefics3-8B-Llama3\"\n",
177
+ "source_model_id = \"meta-llama/Llama-3.2-11B-Vision-Instruct\"\n",
178
+ "detination_model_id = \"eltorio/IDEFICS3_medical_instruct\"\n",
179
+ "dataset_id = \"ruslanmv/ai-medical-dataset\"\n",
180
+ "prompt= \"You are a medical doctor with 15 year of experience verifying the knowledge of a new diploma medical doctor\"\n",
181
+ "output_dir = \"IDEFICS3_medical_instruct\""
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "markdown",
186
+ "id": "da89fd1d-44dd-4d4e-803f-ebf58b23165f",
187
+ "metadata": {},
188
+ "source": [
189
+ "### Optionally clone the model repository"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": 4,
195
+ "id": "f6510468-c903-4bd2-80fd-2511b7fb2f72",
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "# clone Hugging Face model repository\n",
200
+ "# !git clone https://huggingface.co/$destination_model_id $output_dir"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "markdown",
205
+ "id": "2bfc04cd-3867-445f-9764-21bc72a07f60",
206
+ "metadata": {},
207
+ "source": [
208
+ "### Load the dataset"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": 5,
214
+ "id": "1efffd3f-aece-4858-b615-8fb1f2997068",
215
+ "metadata": {
216
+ "scrolled": true
217
+ },
218
+ "outputs": [
219
+ {
220
+ "data": {
221
+ "application/vnd.jupyter.widget-view+json": {
222
+ "model_id": "f67282ac10514645b9ffcdb9f797cf22",
223
+ "version_major": 2,
224
+ "version_minor": 0
225
+ },
226
+ "text/plain": [
227
+ "Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
228
+ ]
229
+ },
230
+ "metadata": {},
231
+ "output_type": "display_data"
232
+ },
233
+ {
234
+ "data": {
235
+ "application/vnd.jupyter.widget-view+json": {
236
+ "model_id": "0ff3d2de672641c8b59a0518d2987e4f",
237
+ "version_major": 2,
238
+ "version_minor": 0
239
+ },
240
+ "text/plain": [
241
+ "Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
242
+ ]
243
+ },
244
+ "metadata": {},
245
+ "output_type": "display_data"
246
+ },
247
+ {
248
+ "data": {
249
+ "application/vnd.jupyter.widget-view+json": {
250
+ "model_id": "48453247d4564242b8a5c216fc3d7fe7",
251
+ "version_major": 2,
252
+ "version_minor": 0
253
+ },
254
+ "text/plain": [
255
+ "Loading dataset shards: 0%| | 0/18 [00:00<?, ?it/s]"
256
+ ]
257
+ },
258
+ "metadata": {},
259
+ "output_type": "display_data"
260
+ }
261
+ ],
262
+ "source": [
263
+ "from datasets import load_dataset\n",
264
+ "\n",
265
+ "base_dataset = load_dataset(\"ruslanmv/ai-medical-dataset\")\n",
266
+ "# define the train dataset as a random 80% of the data\n",
267
+ "train_dataset = base_dataset[\"train\"].train_test_split(test_size=0.2)[\"train\"]\n",
268
+ "# define the eval dataset as the remaining 20%\n",
269
+ "eval_dataset = base_dataset[\"train\"].train_test_split(test_size=0.2)[\"test\"]"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "markdown",
274
+ "id": "6be720c0-90a3-4ec0-bf1c-4d18e6b2bb31",
275
+ "metadata": {
276
+ "id": "Tu199rI3AMtG"
277
+ },
278
+ "source": [
279
+ "### Configure LoRA adapters"
280
+ ]
281
+ },
282
+ {
283
+ "cell_type": "code",
284
+ "execution_count": null,
285
+ "id": "86163110-f084-429e-bc70-5b281a679d1c",
286
+ "metadata": {
287
+ "colab": {
288
+ "base_uri": "https://localhost:8080/",
289
+ "height": 67,
290
+ "referenced_widgets": [
291
+ "9f270a1faba44c1194d683a431d43f6f",
292
+ "7403ea0ef7f04e2389c351b0e10a68ce",
293
+ "fbc17ebe22854738b98bf2488c6105b1",
294
+ "22e04cb812df4f5ca621b041c213d17e",
295
+ "96d4f59ed68749eaaf9ad0c6a4e94f13",
296
+ "7a8bc62ce92a4571ad42ac2beef851ef",
297
+ "ae0117f952d94e469793bd8db3ad9a19",
298
+ "0b8f8a93515d49488642b008e47b4bf3",
299
+ "c9691e94467b4badbb71116136350c93",
300
+ "fcd773b4e6154af78e7d476c59294e25",
301
+ "7bd9e3fc12e64bc7bb95891dcd4153c7"
302
+ ]
303
+ },
304
+ "executionInfo": {
305
+ "elapsed": 106706,
306
+ "status": "ok",
307
+ "timestamp": 1730998316179,
308
+ "user": {
309
+ "displayName": "Ronan Le Meillat",
310
+ "userId": "09161391957806824350"
311
+ },
312
+ "user_tz": -60
313
+ },
314
+ "id": "GTf6VxwwAMtG",
315
+ "outputId": "86a93163-6b30-4e8b-ca1c-7392922a4aa6"
316
+ },
317
+ "outputs": [
318
+ {
319
+ "name": "stderr",
320
+ "output_type": "stream",
321
+ "text": [
322
+ "`low_cpu_mem_usage` was None, now default to True since model is quantized.\n"
323
+ ]
324
+ },
325
+ {
326
+ "data": {
327
+ "application/vnd.jupyter.widget-view+json": {
328
+ "model_id": "932c25f63b9242f0b2bfb42004a8362b",
329
+ "version_major": 2,
330
+ "version_minor": 0
331
+ },
332
+ "text/plain": [
333
+ "Loading checkpoint shards: 0%| | 0/5 [00:00<?, ?it/s]"
334
+ ]
335
+ },
336
+ "metadata": {},
337
+ "output_type": "display_data"
338
+ }
339
+ ],
340
+ "source": [
341
+ "import torch\n",
342
+ "from peft import LoraConfig, get_peft_model\n",
343
+ "from transformers import AutoProcessor, MllamaForConditionalGeneration, BitsAndBytesConfig, Idefics3ForConditionalGeneration\n",
344
+ "\n",
345
+ "DEVICE = \"cuda:0\"\n",
346
+ "USE_LORA = False\n",
347
+ "USE_QLORA = True\n",
348
+ "\n",
349
+ "processor = AutoProcessor.from_pretrained(\n",
350
+ " source_model_id,\n",
351
+ " do_image_splitting=False\n",
352
+ ")\n",
353
+ "\n",
354
+ "if USE_QLORA or USE_LORA:\n",
355
+ " lora_config = LoraConfig(\n",
356
+ " r=8,\n",
357
+ " lora_alpha=8,\n",
358
+ " lora_dropout=0.1,\n",
359
+ " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
360
+ " \"gate_proj\", \"up_proj\", \"down_proj\",],\n",
361
+ " #target_modules='.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$',\n",
362
+ " task_type=\"CAUSAL_LM\",\n",
363
+ " use_dora=False if USE_QLORA else True,\n",
364
+ " init_lora_weights=\"gaussian\"\n",
365
+ " )\n",
366
+ " if USE_QLORA:\n",
367
+ " bnb_config = BitsAndBytesConfig(\n",
368
+ " load_in_4bit=True,\n",
369
+ " bnb_4bit_quant_type=\"nf4\",\n",
370
+ " bnb_4bit_compute_dtype=torch.float16\n",
371
+ " )\n",
372
+ " model = MllamaForConditionalGeneration.from_pretrained(\n",
373
+ " source_model_id,\n",
374
+ " torch_dtype=torch.float16,\n",
375
+ " quantization_config=bnb_config if USE_QLORA else None,\n",
376
+ " )\n",
377
+ " model = get_peft_model(model, lora_config)\n",
378
+ " #model.add_adapter(lora_config)\n",
379
+ " #model.enable_adapters()\n",
380
+ "else:\n",
381
+ " model = MllamaForConditionalGeneration.from_pretrained(\n",
382
+ " source_model_id,\n",
383
+ " torch_dtype=torch.float16,\n",
384
+ " _attn_implementation=\"flash_attention_2\", # This works for A100 or H100\n",
385
+ " ).to(DEVICE)"
386
+ ]
387
+ },
388
+ {
389
+ "cell_type": "code",
390
+ "execution_count": null,
391
+ "id": "ffbbd54b-e579-44ef-9652-cd8496b2fd4d",
392
+ "metadata": {},
393
+ "outputs": [],
394
+ "source": [
395
+ "model"
396
+ ]
397
+ },
398
+ {
399
+ "cell_type": "code",
400
+ "execution_count": null,
401
+ "id": "f2c2fd34-e1e4-427d-86d0-73bf74ff0005",
402
+ "metadata": {},
403
+ "outputs": [],
404
+ "source": [
405
+ "eval_dataset\n",
406
+ "eval_dataset[2310]"
407
+ ]
408
+ },
409
+ {
410
+ "cell_type": "markdown",
411
+ "id": "01b0fe64-1f62-42a8-847d-9162f5015c4e",
412
+ "metadata": {
413
+ "id": "0JeaGZxHAMtG"
414
+ },
415
+ "source": [
416
+ "### Create Data Collator for IDEFICS3 format."
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "execution_count": null,
422
+ "id": "29d96aea-445d-482d-b7dc-861635a5389c",
423
+ "metadata": {
424
+ "executionInfo": {
425
+ "elapsed": 426,
426
+ "status": "ok",
427
+ "timestamp": 1730998596513,
428
+ "user": {
429
+ "displayName": "Ronan Le Meillat",
430
+ "userId": "09161391957806824350"
431
+ },
432
+ "user_tz": -60
433
+ },
434
+ "id": "X6TWyPHaAMtH"
435
+ },
436
+ "outputs": [],
437
+ "source": [
438
+ "class MyDataCollator:\n",
439
+ " def __init__(self, processor):\n",
440
+ " self.processor = processor\n",
441
+ " self.image_token_id = processor.tokenizer.additional_special_tokens_ids[\n",
442
+ " processor.tokenizer.additional_special_tokens.index(\"<image>\")\n",
443
+ " ]\n",
444
+ "\n",
445
+ " def __call__(self, samples):\n",
446
+ " texts = []\n",
447
+ " images = []\n",
448
+ " for sample in samples:\n",
449
+ " question = sample[\"question\"]\n",
450
+ " answer = sample[\"context\"]\n",
451
+ " messages = [\n",
452
+ " {\n",
453
+ " \"role\": \"system\",\n",
454
+ " \"content\": [\n",
455
+ " {\"type\": \"text\", \"text\": prompt}\n",
456
+ " ]\n",
457
+ "\n",
458
+ " },\n",
459
+ " {\n",
460
+ " \"role\": \"user\",\n",
461
+ " \"content\": [\n",
462
+ " {\"type\": \"text\", \"text\": question },\n",
463
+ " ]\n",
464
+ " },\n",
465
+ " {\n",
466
+ " \"role\": \"assistant\",\n",
467
+ " \"content\": [\n",
468
+ " {\"type\": \"text\", \"text\": answer}\n",
469
+ " ]\n",
470
+ " }\n",
471
+ " ]\n",
472
+ " text = processor.apply_chat_template(messages, add_generation_prompt=False)\n",
473
+ " texts.append(text.strip())\n",
474
+ "\n",
475
+ " batch = processor(text=texts, return_tensors=\"pt\", padding=True)\n",
476
+ "\n",
477
+ " #labels = batch[\"input_ids\"].clone()\n",
478
+ " #labels[labels == processor.tokenizer.pad_token_id] = self.image_token_id\n",
479
+ " #batch[\"labels\"] = labels\n",
480
+ "\n",
481
+ " return batch\n",
482
+ "\n",
483
+ "data_collator = MyDataCollator(processor)"
484
+ ]
485
+ },
486
+ {
487
+ "cell_type": "markdown",
488
+ "id": "a8f2e613-5695-4558-9a13-66158d82bed9",
489
+ "metadata": {
490
+ "id": "vsq4TtIJAMtH"
491
+ },
492
+ "source": [
493
+ "### Setup training parameters"
494
+ ]
495
+ },
496
+ {
497
+ "cell_type": "code",
498
+ "execution_count": null,
499
+ "id": "f3cda658-05f6-4078-8d71-2d1c0352ecfa",
500
+ "metadata": {
501
+ "executionInfo": {
502
+ "elapsed": 1008,
503
+ "status": "ok",
504
+ "timestamp": 1730998601172,
505
+ "user": {
506
+ "displayName": "Ronan Le Meillat",
507
+ "userId": "09161391957806824350"
508
+ },
509
+ "user_tz": -60
510
+ },
511
+ "id": "Q_WKQFfoAMtH"
512
+ },
513
+ "outputs": [],
514
+ "source": [
515
+ "from transformers import TrainingArguments, Trainer\n",
516
+ "\n",
517
+ "training_args = TrainingArguments(\n",
518
+ " output_dir = output_dir,\n",
519
+ " overwrite_output_dir = False,\n",
520
+ " auto_find_batch_size = True,\n",
521
+ " learning_rate = 2e-4,\n",
522
+ " fp16 = True,\n",
523
+ " per_device_train_batch_size = 2,\n",
524
+ " per_device_eval_batch_size = 2,\n",
525
+ " gradient_accumulation_steps = 8,\n",
526
+ " dataloader_pin_memory = False,\n",
527
+ " save_total_limit = 3,\n",
528
+ " eval_strategy = \"steps\",\n",
529
+ " save_strategy = \"steps\",\n",
530
+ " eval_steps = 100,\n",
531
+ " save_steps = 10, # checkpoint each 10 steps\n",
532
+ " resume_from_checkpoint = True,\n",
533
+ " logging_steps = 5,\n",
534
+ " remove_unused_columns = False,\n",
535
+ " push_to_hub = True,\n",
536
+ " label_names = [\"labels\"],\n",
537
+ " load_best_model_at_end = False,\n",
538
+ " report_to = \"none\",\n",
539
+ " optim = \"paged_adamw_8bit\",\n",
540
+ " max_steps = 10, # remove this for training\n",
541
+ ")"
542
+ ]
543
+ },
544
+ {
545
+ "cell_type": "code",
546
+ "execution_count": null,
547
+ "id": "e6569265-5941-4482-84e2-faf1b61b685c",
548
+ "metadata": {
549
+ "colab": {
550
+ "base_uri": "https://localhost:8080/"
551
+ },
552
+ "executionInfo": {
553
+ "elapsed": 426,
554
+ "status": "ok",
555
+ "timestamp": 1730998605441,
556
+ "user": {
557
+ "displayName": "Ronan Le Meillat",
558
+ "userId": "09161391957806824350"
559
+ },
560
+ "user_tz": -60
561
+ },
562
+ "id": "vSIo17mgAMtH",
563
+ "outputId": "3bebd35a-ed7f-49ee-e1bc-91594e8dcd24"
564
+ },
565
+ "outputs": [],
566
+ "source": [
567
+ "trainer = Trainer(\n",
568
+ " model = model,\n",
569
+ " args = training_args,\n",
570
+ " data_collator = data_collator,\n",
571
+ " train_dataset = train_dataset,\n",
572
+ " eval_dataset = eval_dataset,\n",
573
+ ")"
574
+ ]
575
+ },
576
+ {
577
+ "cell_type": "markdown",
578
+ "id": "916ac153-206b-488a-b783-3ad0c4ba21b6",
579
+ "metadata": {
580
+ "id": "pmlwDsOpAMtI"
581
+ },
582
+ "source": [
583
+ "### Start (or restart) Training"
584
+ ]
585
+ },
586
+ {
587
+ "cell_type": "code",
588
+ "execution_count": null,
589
+ "id": "fb72a570-97e8-440e-b79f-640d8898e37c",
590
+ "metadata": {
591
+ "colab": {
592
+ "base_uri": "https://localhost:8080/",
593
+ "height": 1000
594
+ },
595
+ "id": "WQA84KnTAMtI",
596
+ "outputId": "ebb15160-f56e-4899-e608-b0d5fd0ba117"
597
+ },
598
+ "outputs": [],
599
+ "source": [
600
+ "trainer.train()"
601
+ ]
602
+ },
603
+ {
604
+ "cell_type": "code",
605
+ "execution_count": null,
606
+ "id": "b109f2b9-f3cb-4732-8318-b74ed9e5aa25",
607
+ "metadata": {},
608
+ "outputs": [],
609
+ "source": []
610
+ }
611
+ ],
612
+ "metadata": {
613
+ "kernelspec": {
614
+ "display_name": "Python 3 (ipykernel)",
615
+ "language": "python",
616
+ "name": "python3"
617
+ },
618
+ "language_info": {
619
+ "codemirror_mode": {
620
+ "name": "ipython",
621
+ "version": 3
622
+ },
623
+ "file_extension": ".py",
624
+ "mimetype": "text/x-python",
625
+ "name": "python",
626
+ "nbconvert_exporter": "python",
627
+ "pygments_lexer": "ipython3",
628
+ "version": "3.12.7"
629
+ }
630
+ },
631
+ "nbformat": 4,
632
+ "nbformat_minor": 5
633
+ }