{ "cells": [ { "cell_type": "markdown", "id": "63b957e0-d83b-48a6-8ae0-a276b983e181", "metadata": {}, "source": [ "### Optional: install the necessary packages" ] }, { "cell_type": "code", "execution_count": 1, "id": "5569a330-ea6b-4402-9fd5-e7a0ce981bc9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: huggingface_hub in /workspace/.miniconda3/lib/python3.12/site-packages (0.26.2)\n", "Requirement already satisfied: filelock in /workspace/.miniconda3/lib/python3.12/site-packages (from huggingface_hub) (3.16.1)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from huggingface_hub) (2024.9.0)\n", "Requirement already satisfied: packaging>=20.9 in /workspace/.miniconda3/lib/python3.12/site-packages (from huggingface_hub) (24.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /workspace/.miniconda3/lib/python3.12/site-packages (from huggingface_hub) (6.0.2)\n", "Requirement already satisfied: requests in /workspace/.miniconda3/lib/python3.12/site-packages (from huggingface_hub) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.42.1 in /workspace/.miniconda3/lib/python3.12/site-packages (from huggingface_hub) (4.66.5)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /workspace/.miniconda3/lib/python3.12/site-packages (from huggingface_hub) (4.11.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /workspace/.miniconda3/lib/python3.12/site-packages (from requests->huggingface_hub) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /workspace/.miniconda3/lib/python3.12/site-packages (from requests->huggingface_hub) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /workspace/.miniconda3/lib/python3.12/site-packages (from requests->huggingface_hub) (2.2.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /workspace/.miniconda3/lib/python3.12/site-packages (from requests->huggingface_hub) (2024.8.30)\n", "Note: you may need to restart the kernel to use updated packages.\n", "Requirement already satisfied: datasets in /workspace/.miniconda3/lib/python3.12/site-packages (3.1.0)\n", "Requirement already satisfied: filelock in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (3.16.1)\n", "Requirement already satisfied: numpy>=1.17 in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (2.1.3)\n", "Requirement already satisfied: pyarrow>=15.0.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (18.0.0)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (0.3.8)\n", "Requirement already satisfied: pandas in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (2.2.3)\n", "Requirement already satisfied: requests>=2.32.2 in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.66.3 in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (4.66.5)\n", "Requirement already satisfied: xxhash in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (3.5.0)\n", "Requirement already satisfied: multiprocess<0.70.17 in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (0.70.16)\n", "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.9.0)\n", "Requirement already satisfied: aiohttp in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (3.11.2)\n", "Requirement already satisfied: huggingface-hub>=0.23.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (0.26.2)\n", "Requirement already satisfied: packaging in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (24.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /workspace/.miniconda3/lib/python3.12/site-packages (from datasets) (6.0.2)\n", "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (2.4.3)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /workspace/.miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (24.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /workspace/.miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (1.5.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /workspace/.miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (6.1.0)\n", "Requirement already satisfied: propcache>=0.2.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (0.2.0)\n", "Requirement already satisfied: yarl<2.0,>=1.17.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (1.17.1)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /workspace/.miniconda3/lib/python3.12/site-packages (from huggingface-hub>=0.23.0->datasets) (4.11.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /workspace/.miniconda3/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /workspace/.miniconda3/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /workspace/.miniconda3/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (2.2.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /workspace/.miniconda3/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (2024.8.30)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /workspace/.miniconda3/lib/python3.12/site-packages (from pandas->datasets) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in /workspace/.miniconda3/lib/python3.12/site-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: tzdata>=2022.7 in /workspace/.miniconda3/lib/python3.12/site-packages (from pandas->datasets) (2024.2)\n", "Requirement already satisfied: six>=1.5 in /workspace/.miniconda3/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", "Note: you may need to restart the kernel to use updated packages.\n", "Requirement already satisfied: bitsandbytes in /workspace/.miniconda3/lib/python3.12/site-packages (0.44.2.dev0)\n", "Requirement already satisfied: torch in /workspace/.miniconda3/lib/python3.12/site-packages (from bitsandbytes) (2.5.1)\n", "Requirement already satisfied: numpy in /workspace/.miniconda3/lib/python3.12/site-packages (from bitsandbytes) (2.1.3)\n", "Requirement already satisfied: filelock in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (3.16.1)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (4.11.0)\n", "Requirement already satisfied: networkx in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (3.4.2)\n", "Requirement already satisfied: jinja2 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (3.1.4)\n", "Requirement already satisfied: fsspec in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (2024.9.0)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (12.4.127)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (12.4.5.8)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (11.2.1.3)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (10.3.5.147)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (11.6.1.9)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (12.3.1.170)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (12.4.127)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (12.4.127)\n", "Requirement already satisfied: triton==3.1.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (3.1.0)\n", "Requirement already satisfied: setuptools in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (75.1.0)\n", "Requirement already satisfied: sympy==1.13.1 in /workspace/.miniconda3/lib/python3.12/site-packages (from torch->bitsandbytes) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from sympy==1.13.1->torch->bitsandbytes) (1.3.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /workspace/.miniconda3/lib/python3.12/site-packages (from jinja2->torch->bitsandbytes) (2.1.3)\n", "Note: you may need to restart the kernel to use updated packages.\n", "Note: you may need to restart the kernel to use updated packages.\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "!git config --global credential.helper store\n", "%pip install huggingface_hub\n", "%pip install -U datasets\n", "%pip install -U bitsandbytes\n", "%pip install -q git+https://github.com/huggingface/transformers.git\n", "%pip install -q accelerate datasets peft torchvision torchaudio" ] }, { "cell_type": "markdown", "id": "ca7f5161-0104-45c4-83c9-1dd0dad15e29", "metadata": {}, "source": [ "## Login on Hugging Face" ] }, { "cell_type": "code", "execution_count": 2, "id": "14c445aa-c8bc-43b9-8d18-46d79055e1f0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Hugging Face token found in environment variable\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.\n" ] } ], "source": [ "from huggingface_hub import login\n", "import os\n", "\n", "HF_TOKEN = \"hf_C…………\"\n", "\n", "if os.environ.get('HF_TOKEN') is not None:\n", " HF_TOKEN = os.environ.get('HF_TOKEN')\n", " print(f\"Hugging Face token found in environment variable\")\n", "try:\n", " import google.colab\n", " from google.colab import userdata\n", " if (userdata.get('HF_TOKEN') is not None) and (HF_TOKEN == \"\"):\n", " HF_TOKEN = userdata.get('HF_TOKEN')\n", " else:\n", " raise ValueError(\"Please set your Hugging Face token in the user data panel, or pass it as an environment variable\")\n", "except ModuleNotFoundError:\n", " if HF_TOKEN is None:\n", " raise ValueError(\"Please set your Hugging Face token in the user data panel, or pass it as an environment variable\")\n", "\n", "login(\n", " token=HF_TOKEN,\n", " add_to_git_credential=True\n", ")" ] }, { "cell_type": "markdown", "id": "c33a0325-f056-43c4-a500-6ad1b5632ee5", "metadata": {}, "source": [ "### Set the environment variables" ] }, { "cell_type": "code", "execution_count": 3, "id": "bfd9dbb4-59c4-4d97-9727-5c2d9e60724e", "metadata": {}, "outputs": [], "source": [ "#source_model_id = \"HuggingFaceM4/Idefics3-8B-Llama3\"\n", "source_model_id = \"meta-llama/Llama-3.2-11B-Vision-Instruct\"\n", "detination_model_id = \"eltorio/IDEFICS3_medical_instruct\"\n", "dataset_id = \"ruslanmv/ai-medical-dataset\"\n", "prompt= \"You are a medical doctor with 15 year of experience verifying the knowledge of a new diploma medical doctor\"\n", "output_dir = \"IDEFICS3_medical_instruct\"" ] }, { "cell_type": "markdown", "id": "da89fd1d-44dd-4d4e-803f-ebf58b23165f", "metadata": {}, "source": [ "### Optionally clone the model repository" ] }, { "cell_type": "code", "execution_count": 4, "id": "f6510468-c903-4bd2-80fd-2511b7fb2f72", "metadata": {}, "outputs": [], "source": [ "# clone Hugging Face model repository\n", "# !git clone https://huggingface.co/$destination_model_id $output_dir" ] }, { "cell_type": "markdown", "id": "2bfc04cd-3867-445f-9764-21bc72a07f60", "metadata": {}, "source": [ "### Load the dataset" ] }, { "cell_type": "code", "execution_count": 15, "id": "1efffd3f-aece-4858-b615-8fb1f2997068", "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b39d97f2fb434903b0521f5dea2fd37c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Resolving data files: 0%| | 0/18 [00:00 12 months after OLT randomized to receive either PRED and tacrolimus (TAC) or MMF and TAC were followed for 24 months. Withdrawal of steroids showed no difference regarding graft and patient survival. Also we demonstrated significantly lower glucose levels with lower HbA1c and a reduced need for insulin as well as a significantly lower serum cholesterol in the MMF group. Patients without steroids showed a lower incidence of osteopenia. Maintenance therapy in OLT patients with AIH may be performed safely using MMF instead of prednisone.'}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "eval_dataset\n", "eval_dataset[24]" ] }, { "cell_type": "markdown", "id": "01b0fe64-1f62-42a8-847d-9162f5015c4e", "metadata": { "id": "0JeaGZxHAMtG" }, "source": [ "### Create Data Collator for IDEFICS3 format." ] }, { "cell_type": "code", "execution_count": 53, "id": "29d96aea-445d-482d-b7dc-861635a5389c", "metadata": { "executionInfo": { "elapsed": 426, "status": "ok", "timestamp": 1730998596513, "user": { "displayName": "Ronan Le Meillat", "userId": "09161391957806824350" }, "user_tz": -60 }, "id": "X6TWyPHaAMtH" }, "outputs": [], "source": [ "class MyDataCollator:\n", " def __init__(self, processor):\n", " self.processor = processor\n", " self.image_token_id = 128256\n", "\n", " def __call__(self, samples):\n", " texts = []\n", " images = []\n", " for sample in samples:\n", " question = sample[\"question\"]\n", " answer = sample[\"context\"]\n", " messages = [\n", " {\n", " \"role\": \"system\",\n", " \"content\": [\n", " {\"type\": \"text\", \"text\": prompt}\n", " ]\n", "\n", " },\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\"type\": \"text\", \"text\": question },\n", " ]\n", " },\n", " {\n", " \"role\": \"assistant\",\n", " \"content\": [\n", " {\"type\": \"text\", \"text\": answer}\n", " ]\n", " }\n", " ]\n", " text = processor.apply_chat_template(messages, add_generation_prompt=False)\n", " texts.append(text.strip())\n", "\n", " batch = processor(text=texts, return_tensors=\"pt\", padding=True)\n", "\n", " labels = batch[\"input_ids\"].clone()\n", " #labels[labels == processor.tokenizer.pad_token_id] = self.image_token_id\n", " batch[\"labels\"] = labels\n", "\n", " return batch\n", "\n", "data_collator = MyDataCollator(processor)" ] }, { "cell_type": "markdown", "id": "a8f2e613-5695-4558-9a13-66158d82bed9", "metadata": { "id": "vsq4TtIJAMtH" }, "source": [ "### Setup training parameters" ] }, { "cell_type": "code", "execution_count": 54, "id": "f3cda658-05f6-4078-8d71-2d1c0352ecfa", "metadata": { "executionInfo": { "elapsed": 1008, "status": "ok", "timestamp": 1730998601172, "user": { "displayName": "Ronan Le Meillat", "userId": "09161391957806824350" }, "user_tz": -60 }, "id": "Q_WKQFfoAMtH" }, "outputs": [], "source": [ "from transformers import TrainingArguments, Trainer\n", "\n", "training_args = TrainingArguments(\n", " output_dir = output_dir,\n", " overwrite_output_dir = False,\n", " auto_find_batch_size = True,\n", " learning_rate = 2e-4,\n", " fp16 = True,\n", " per_device_train_batch_size = 2,\n", " per_device_eval_batch_size = 2,\n", " gradient_accumulation_steps = 8,\n", " dataloader_pin_memory = False,\n", " save_total_limit = 3,\n", " eval_strategy = \"steps\",\n", " save_strategy = \"steps\",\n", " eval_steps = 100,\n", " save_steps = 10, # checkpoint each 10 steps\n", " resume_from_checkpoint = True,\n", " logging_steps = 5,\n", " remove_unused_columns = False,\n", " push_to_hub = True,\n", " label_names = [\"labels\"],\n", " load_best_model_at_end = False,\n", " report_to = \"none\",\n", " optim = \"paged_adamw_8bit\",\n", " max_steps = 10, # remove this for training\n", ")" ] }, { "cell_type": "code", "execution_count": 55, "id": "e6569265-5941-4482-84e2-faf1b61b685c", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 426, "status": "ok", "timestamp": 1730998605441, "user": { "displayName": "Ronan Le Meillat", "userId": "09161391957806824350" }, "user_tz": -60 }, "id": "vSIo17mgAMtH", "outputId": "3bebd35a-ed7f-49ee-e1bc-91594e8dcd24" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "max_steps is given, it will override any value given in num_train_epochs\n" ] } ], "source": [ "trainer = Trainer(\n", " model = model,\n", " args = training_args,\n", " data_collator = data_collator,\n", " train_dataset = train_dataset,\n", " eval_dataset = eval_dataset,\n", ")" ] }, { "cell_type": "markdown", "id": "916ac153-206b-488a-b783-3ad0c4ba21b6", "metadata": { "id": "pmlwDsOpAMtI" }, "source": [ "### Start (or restart) Training" ] }, { "cell_type": "code", "execution_count": null, "id": "fb72a570-97e8-440e-b79f-640d8898e37c", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "WQA84KnTAMtI", "outputId": "ebb15160-f56e-4899-e608-b0d5fd0ba117" }, "outputs": [], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "id": "b109f2b9-f3cb-4732-8318-b74ed9e5aa25", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }