{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "c3hpiPPEqmf6" }, "source": [ "##### Copyright 2024 Google LLC." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "bVm-2hW9z9HR" }, "outputs": [], "source": [ "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "u71STQRgnQ3a" }, "source": [ "# Fine-tune PaliGemma for Image Description with Custom Dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "wR53lePHuiP-" }, "source": [ "This notebook guides you through the process of fine-tuning [PaliGemma](https://ai.google.dev/gemma/docs/paligemma), a powerful vision-language model, for bird description using [JAX](https://jax.readthedocs.io/en/latest/installation.html). We will leverage a curated subset of a bird species dataset and enrich it with descriptive text for each bird. The resulting dataset, comprising 3,692 image-description pairs, will be used to fine-tune PaliGemma, enabling it to generate accurate and detailed descriptions of bird images.\n", "\n", "\n", " \n", "
\n", " Run in Google Colab\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "qRi1rF4MWlQi" }, "source": [ "### Get access to PaliGemma\n", "\n", "Before using PaliGemma for the first time, you must request access to the model through Kaggle by setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup), or completing the following steps:\n", "\n", "1. Log in to [Kaggle](https://www.kaggle.com), or create a new Kaggle account if you don't already have one.\n", "1. Go to the [Gemma model card](https://www.kaggle.com/models/google/paligemma/), as PaliGemma is a Gemma variant and click **Request Access**.\n", "1. Complete the consent form and accept the terms and conditions.\n", "\n", "To generate a Kaggle API key, open your [**Settings** page in Kaggle](https://www.kaggle.com/settings) and click **Create New Token**. This triggers the download of a `kaggle.json` file containing your API credentials.\n", "\n", "Then, in Colab, select **Secrets** (πŸ”‘) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`." ] }, { "cell_type": "markdown", "metadata": { "id": "KHskrDmKpNGS" }, "source": [ "### Select the runtime\n", "\n", "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the PaliGemma model. In this case, you can use a T4 GPU:\n", "\n", "1. In the upper-right of the Colab window, click the **β–Ύ (Additional connection options)** dropdown menu.\n", "1. Select **Change runtime type**.\n", "1. Under **Hardware accelerator**, select **T4 GPU**." ] }, { "cell_type": "markdown", "metadata": { "id": "Kp6XQ2hQB8lv" }, "source": [ "### Set environment variables for Kaggle API credentials" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4VpqN2dKjqjl" }, "outputs": [], "source": [ "import os\n", "from google.colab import userdata\n", "\n", "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n", "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')" ] }, { "cell_type": "markdown", "metadata": { "id": "nCE3e7NFpjxZ" }, "source": [ "### Fetch the `big_vision` repository and install related dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DfxKb3F839Ks" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.9/77.9 kB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Building wheel for ml_collections (setup.py) ... \u001b[?25l\u001b[?25hdone\n" ] } ], "source": [ "import sys\n", "\n", "# TPUs with\n", "if \"COLAB_TPU_ADDR\" in os.environ:\n", " raise \"It seems you are using Colab with remote TPUs which is not supported.\"\n", "\n", "# Fetch big_vision repository if python doesn't know about it and install\n", "# dependencies needed for this notebook.\n", "if not os.path.exists(\"big_vision_repo\"):\n", " !git clone --quiet --branch=main --depth=1 \\\n", " https://github.com/google-research/big_vision big_vision_repo\n", "\n", "# Append big_vision code to python import path\n", "if \"big_vision_repo\" not in sys.path:\n", " sys.path.append(\"big_vision_repo\")\n", "\n", "# Install missing dependencies. Assume jax~=0.4.25 with GPU available.\n", "!pip3 install -q \"overrides\" \"ml_collections\" \"einops~=0.7\" \"sentencepiece\"" ] }, { "cell_type": "markdown", "metadata": { "id": "zDoq0O77GF30" }, "source": [ "### Import JAX and other dependencies\n", "\n", "Import JAX and other dependencies required for PaliGemma, like TensorFlow and NumPy." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dTfe2k8J4Bw0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "JAX version: 0.4.26\n", "JAX platform: gpu\n", "JAX devices: 1\n" ] } ], "source": [ "# Import necessary libraries\n", "import base64\n", "import functools\n", "import html\n", "import io\n", "import glob\n", "\n", "import warnings\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "import ml_collections\n", "\n", "import tensorflow as tf\n", "import sentencepiece\n", "\n", "import pandas as pd\n", "import random\n", "import json\n", "\n", "from IPython.core.display import display, HTML\n", "from PIL import Image\n", "import matplotlib.pyplot as plt\n", "\n", "# Import model definition from big_vision\n", "from big_vision.models.proj.paligemma import paligemma\n", "from big_vision.trainers.proj.paligemma import predict_fns\n", "\n", "# Import big vision utilities\n", "import big_vision.datasets.jsonl\n", "import big_vision.utils\n", "import big_vision.sharding\n", "\n", "# Don't let TF use the GPU or TPUs\n", "tf.config.set_visible_devices([], \"GPU\")\n", "tf.config.set_visible_devices([], \"TPU\")\n", "\n", "backend = jax.lib.xla_bridge.get_backend()\n", "print(f\"JAX version: {jax.__version__}\")\n", "print(f\"JAX platform: {backend.platform}\")\n", "print(f\"JAX devices: {jax.device_count()}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "b9kSadtIhjlX" }, "source": [ "## Download and configure the model\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gQNOTfF24AV4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading the checkpoint from Kaggle, this could take a few minutes....\n", "Downloading from https://www.kaggle.com/api/v1/models/google/paligemma/jax/paligemma-3b-pt-224/1/download/paligemma-3b-pt-224.f16.npz...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5.45G/5.45G [01:03<00:00, 91.8MB/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model path: /root/.cache/kagglehub/models/google/paligemma/jax/paligemma-3b-pt-224/1/paligemma-3b-pt-224.f16.npz\n", "Downloading the model tokenizer...\n", "Copying gs://big_vision/paligemma_tokenizer.model...\n", "- [1 files][ 4.1 MiB/ 4.1 MiB] \n", "Operation completed over 1 objects/4.1 MiB. \n", "Tokenizer path: ./paligemma_tokenizer.model\n" ] } ], "source": [ "import kagglehub\n", "\n", "MODEL_PATH = \"./pt_224_128.params.f16.npz\"\n", "if not os.path.exists(MODEL_PATH):\n", " print(\"Downloading the checkpoint from Kaggle, this could take a few minutes....\")\n", " # Note: kaggle archive contains the same checkpoint in multiple formats.\n", " # Download only the float16 model.\n", " MODEL_PATH = kagglehub.model_download('google/paligemma/jax/paligemma-3b-pt-224', 'paligemma-3b-pt-224.f16.npz')\n", " print(f\"Model path: {MODEL_PATH}\")\n", "\n", "TOKENIZER_PATH = \"./paligemma_tokenizer.model\"\n", "if not os.path.exists(TOKENIZER_PATH):\n", " print(\"Downloading the model tokenizer...\")\n", " !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}\n", " print(f\"Tokenizer path: {TOKENIZER_PATH}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "R9R0nS8qjqjo" }, "source": [ "# Prepare Dataset for Fine-tunning\n", "Here, we process the bird image dataset and descriptions for use with PaliGemma.\n", "\n", "1. Curating the Dataset:\n", "\n", "* The **525 Bird Species dataset** [(`gpiosenka/100-bird-species`)](https://www.kaggle.com/datasets/gpiosenka/100-bird-species)from Kaggle contains a comprehensive collection of images representing various bird species. Each image is labeled with its corresponding bird species, providing diverse visual data for training and validation.\n", "\n", "* **Bird Species Descriptions Dataset**: The Bird Species Description DataFrame [(`selamw/birds-discription-df`)](https://www.kaggle.com/datasets/selamw/birds-discription-df) complements the image dataset by providing textual descriptions for the first 23 out of the 525 bird species. This enriches our training data with descriptive text, facilitating a vision-language learning approach with PaliGemma.\n", "\n", "2. Downloading the Datasets from Kaggle:\n", "\n", "* To obtain the datasets containing bird species images and their descriptions, download them directly from Kaggle using the following commands:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QXGmAq4wK56p" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset URL: https://www.kaggle.com/datasets/gpiosenka/100-bird-species\n", "License(s): CC0-1.0\n", "Downloading 100-bird-species.zip to /content\n", " 99% 1.94G/1.96G [00:15<00:00, 242MB/s]\n", "100% 1.96G/1.96G [00:15<00:00, 138MB/s]\n", "Dataset URL: https://www.kaggle.com/datasets/selamw/birds-discription-df\n", "License(s): Apache 2.0\n", "Downloading birds-discription-df.zip to /content\n", " 0% 0.00/24.2k [00:00\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
class idfilepathslabelsdata setscientific namebird_description
00.0ABBOTTS BABBLER/001.jpgABBOTTS BABBLERtrainMALACOCINCLA ABBOTTIAbbott's Babbler: Look for this small insectiv...
10.0ABBOTTS BABBLER/007.jpgABBOTTS BABBLERtrainMALACOCINCLA ABBOTTIAbbott's Babbler: Look for this small insectiv...
20.0ABBOTTS BABBLER/008.jpgABBOTTS BABBLERtrainMALACOCINCLA ABBOTTIAbbott's Babbler: Look for this small insectiv...
30.0ABBOTTS BABBLER/009.jpgABBOTTS BABBLERtrainMALACOCINCLA ABBOTTIAbbott's Babbler: Look for this small insectiv...
40.0ABBOTTS BABBLER/002.jpgABBOTTS BABBLERtrainMALACOCINCLA ABBOTTIAbbott's Babbler: Look for this small insectiv...
.....................
3917524.0BLACK BREASTED PUFFBIRD/3.jpgBLACK BREASTED PUFFBIRDvalidNOTHARCHUS PECTORALISBlack-breasted Puffbird: Observe the medium-si...
3918524.0BLACK BREASTED PUFFBIRD/4.jpgBLACK BREASTED PUFFBIRDvalidNOTHARCHUS PECTORALISBlack-breasted Puffbird: Observe the medium-si...
3919524.0BLACK BREASTED PUFFBIRD/1.jpgBLACK BREASTED PUFFBIRDvalidNOTHARCHUS PECTORALISBlack-breasted Puffbird: Observe the medium-si...
3920524.0BLACK BREASTED PUFFBIRD/2.jpgBLACK BREASTED PUFFBIRDvalidNOTHARCHUS PECTORALISBlack-breasted Puffbird: Observe the medium-si...
3921524.0BLACK BREASTED PUFFBIRD/5.jpgBLACK BREASTED PUFFBIRDvalidNOTHARCHUS PECTORALISBlack-breasted Puffbird: Observe the medium-si...
\n", "

3922 rows Γ— 6 columns

\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", " \n", " \n", " \n", "
\n", "\n", "
\n", " \n" ], "text/plain": [ " class id filepaths labels \\\n", "0 0.0 ABBOTTS BABBLER/001.jpg ABBOTTS BABBLER \n", "1 0.0 ABBOTTS BABBLER/007.jpg ABBOTTS BABBLER \n", "2 0.0 ABBOTTS BABBLER/008.jpg ABBOTTS BABBLER \n", "3 0.0 ABBOTTS BABBLER/009.jpg ABBOTTS BABBLER \n", "4 0.0 ABBOTTS BABBLER/002.jpg ABBOTTS BABBLER \n", "... ... ... ... \n", "3917 524.0 BLACK BREASTED PUFFBIRD/3.jpg BLACK BREASTED PUFFBIRD \n", "3918 524.0 BLACK BREASTED PUFFBIRD/4.jpg BLACK BREASTED PUFFBIRD \n", "3919 524.0 BLACK BREASTED PUFFBIRD/1.jpg BLACK BREASTED PUFFBIRD \n", "3920 524.0 BLACK BREASTED PUFFBIRD/2.jpg BLACK BREASTED PUFFBIRD \n", "3921 524.0 BLACK BREASTED PUFFBIRD/5.jpg BLACK BREASTED PUFFBIRD \n", "\n", " data set scientific name \\\n", "0 train MALACOCINCLA ABBOTTI \n", "1 train MALACOCINCLA ABBOTTI \n", "2 train MALACOCINCLA ABBOTTI \n", "3 train MALACOCINCLA ABBOTTI \n", "4 train MALACOCINCLA ABBOTTI \n", "... ... ... \n", "3917 valid NOTHARCHUS PECTORALIS \n", "3918 valid NOTHARCHUS PECTORALIS \n", "3919 valid NOTHARCHUS PECTORALIS \n", "3920 valid NOTHARCHUS PECTORALIS \n", "3921 valid NOTHARCHUS PECTORALIS \n", "\n", " bird_description \n", "0 Abbott's Babbler: Look for this small insectiv... \n", "1 Abbott's Babbler: Look for this small insectiv... \n", "2 Abbott's Babbler: Look for this small insectiv... \n", "3 Abbott's Babbler: Look for this small insectiv... \n", "4 Abbott's Babbler: Look for this small insectiv... \n", "... ... \n", "3917 Black-breasted Puffbird: Observe the medium-si... \n", "3918 Black-breasted Puffbird: Observe the medium-si... \n", "3919 Black-breasted Puffbird: Observe the medium-si... \n", "3920 Black-breasted Puffbird: Observe the medium-si... \n", "3921 Black-breasted Puffbird: Observe the medium-si... \n", "\n", "[3922 rows x 6 columns]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load the descriptions DataFrame\n", "birds_discription_df = pd.read_csv(\"birds_description.csv\")\n", "birds_discription_df" ] }, { "cell_type": "markdown", "metadata": { "id": "CZg6Rj0Njqjp" }, "source": [ "### Display number of unique species, number of training and validation samples" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uA-l53mTjqjp" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of unique bird species: 23\n", "Number of training samples: 3692\n", "Number of validation samples: 115\n" ] } ], "source": [ "print(\"Number of unique bird species:\", len(birds_discription_df['labels'].unique()))\n", "print(\"Number of training samples:\", (birds_discription_df['data set'] == \"train\").sum())\n", "print(\"Number of validation samples:\", (birds_discription_df['data set'] == \"valid\").sum())" ] }, { "cell_type": "markdown", "metadata": { "id": "1Qkqjs2Yjqjp" }, "source": [ "## Convert DataFrame to JSON Lines for Finetuning\n", "\n", "Since PaliGemma expects data in JSON Lines format for finetuning, we convert the DataFrame containing bird information into separate JSON Lines files for training and validation data.\n", "\n", "For fine-tuning, we only need two columns from the DataFrame:\n", "\n", "* `\"filepaths\"`: This column contains the paths to the bird images.\n", "* `\"bird_description\"`: This column contains descriptions for each bird species." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jzFnL_wujqjp" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DataFrame converted to JSON Lines format and saved to train.jsonl and valid.jsonl\n" ] } ], "source": [ "def df_to_jsonl(df):\n", " \"\"\"Converts a pandas DataFrame to separate JSON Lines files for train and validation data.\n", "\n", " Args:\n", " df: The pandas DataFrame to convert.\n", "\n", " Returns:\n", " None. Writes the JSON Lines to separate 'train.jsonl' and 'valid.jsonl' files.\n", " \"\"\"\n", "\n", " train_file = open('train.jsonl', 'w')\n", " valid_file = open('valid.jsonl', 'w')\n", "\n", " try:\n", " for index, row in df.iterrows():\n", " if row['data set'] == 'train':\n", " data = {\n", " 'prefix': '',\n", " 'suffix': row['bird_description'],\n", " 'image': row['filepaths']\n", " }\n", " json.dump(data, train_file)\n", " train_file.write('\\n')\n", " elif row['data set'] == 'valid':\n", " data = {\n", " 'prefix': '',\n", " 'suffix': row['bird_description'],\n", " 'image': row['filepaths']\n", " }\n", " json.dump(data, valid_file)\n", " valid_file.write('\\n')\n", "\n", " finally:\n", " train_file.close()\n", " valid_file.close()\n", "\n", " print(\"DataFrame converted to JSON Lines format and saved to train.jsonl and valid.jsonl\")\n", "\n", "# Convert the birds description DataFrame to JSON Lines format\n", "df_to_jsonl(birds_discription_df)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "aMyiB9bUjqjp" }, "source": [ "## Display first record from JSON lines file" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xnud6HhMjqjp" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "First Record:\n", "{\n", " \"prefix\": \"\",\n", " \"suffix\": \"Abbott's Babbler: Look for this small insectivorous bird with distinctive streaked brown plumage and pale buff underparts, making its home in the diverse landscapes of South Asia.\",\n", " \"image\": \"ABBOTTS BABBLER/001.jpg\"\n", "}\n" ] } ], "source": [ "def display_first_record(filename):\n", " \"\"\"Opens a JSON Lines file and displays only the first record.\n", "\n", " Args:\n", " filename: The path to the JSON Lines file.\n", " \"\"\"\n", " try:\n", " with open(filename, 'r') as f:\n", " first_line = f.readline().strip()\n", " if first_line: # Check if there's data in the file\n", " data = json.loads(first_line)\n", " print(f\"First Record:\\n{json.dumps(data, indent=2)}\")\n", " else:\n", " print(\"File is empty or corrupt.\")\n", " except FileNotFoundError:\n", " print(f\"Error: File '{filename}' not found.\")\n", "\n", "# Display the first record in 'train.jsonl'\n", "display_first_record('train.jsonl')" ] }, { "cell_type": "markdown", "metadata": { "id": "rv7w-cGuLj5o" }, "source": [ "# Download and configure PaliGemma model\n", "This section retrieves the pre-trained model and tokenizer weights." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "v5e1RGCDjqjq" }, "outputs": [], "source": [ "# Define model\n", "model_config = ml_collections.FrozenConfigDict({\n", " \"llm\": {\"vocab_size\": 257_152},\n", " \"img\": {\"variant\": \"So400m/14\", \"pool_type\": \"none\", \"scan\": True, \"dtype_mm\": \"float16\"}\n", "})\n", "model = paligemma.Model(**model_config)\n", "tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)\n", "\n", "# Load params - this can take up to 1 minute in a notebook.\n", "params = paligemma.load(None, MODEL_PATH, model_config)\n", "\n", "# Define `decode` function to sample outputs from the model.\n", "decode_fn = predict_fns.get_all(model)['decode']\n", "decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())" ] }, { "cell_type": "markdown", "metadata": { "id": "uidBwmb8LwZ5" }, "source": [ "### Move model parameters into GPU/TPU memory" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RWOdf_fw2SAO" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " == Model params == \n", "img/Transformer/encoder_norm/bias (1152,) float16\n", "img/Transformer/encoder_norm/scale (1152,) float16\n", "img/Transformer/encoderblock/LayerNorm_0/bias (27, 1152) float16\n", "img/Transformer/encoderblock/LayerNorm_0/scale (27, 1152) float16\n", "img/Transformer/encoderblock/LayerNorm_1/bias (27, 1152) float16\n", "img/Transformer/encoderblock/LayerNorm_1/scale (27, 1152) float16\n", "img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias (27, 4304) float16\n", "img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel (27, 1152, 4304) float16\n", "img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias (27, 1152) float16\n", "img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel (27, 4304, 1152) float16\n", "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias (27, 16, 72) float16\n", "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel (27, 1152, 16, 72) float16\n", "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias (27, 1152) float16\n", "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel (27, 16, 72, 1152) float16\n", "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias (27, 16, 72) float16\n", "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel (27, 1152, 16, 72) float16\n", "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias (27, 16, 72) float16\n", "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel (27, 1152, 16, 72) float16\n", "img/embedding/bias (1152,) float16\n", "img/embedding/kernel (14, 14, 3, 1152) float16\n", "img/head/bias (2048,) float16\n", "img/head/kernel (1152, 2048) float16\n", "img/pos_embedding (1, 256, 1152) float16\n", "llm/embedder/input_embedding (257152, 2048) float16\n", "llm/final_norm/scale (2048,) float16\n", "llm/layers/attn/attn_vec_einsum/w (18, 8, 256, 2048) float32\n", "llm/layers/attn/kv_einsum/w (18, 2, 1, 2048, 256) float32\n", "llm/layers/attn/q_einsum/w (18, 8, 2048, 256) float32\n", "llm/layers/mlp/gating_einsum (18, 2, 2048, 16384) float16\n", "llm/layers/mlp/linear (18, 16384, 2048) float16\n", "llm/layers/pre_attention_norm/scale (18, 2048) float16\n", "llm/layers/pre_ffw_norm/scale (18, 2048) float16\n" ] } ], "source": [ "# Create a pytree mask of the trainable params.\n", "def is_trainable_param(name, param): # pylint: disable=unused-argument\n", " if name.startswith(\"llm/layers/attn/\"): return True\n", " if name.startswith(\"llm/\"): return False\n", " if name.startswith(\"img/\"): return False\n", " raise ValueError(f\"Unexpected param name {name}\")\n", "trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)\n", "\n", "# If more than one device is available (e.g. multiple GPUs) the parameters can\n", "# be sharded across them to reduce HBM usage per device.\n", "mesh = jax.sharding.Mesh(jax.devices(), (\"data\"))\n", "\n", "data_sharding = jax.sharding.NamedSharding(\n", " mesh, jax.sharding.PartitionSpec(\"data\"))\n", "\n", "params_sharding = big_vision.sharding.infer_sharding(\n", " params, strategy=[('.*', 'fsdp(axis=\"data\")')], mesh=mesh)\n", "\n", "# Yes: Some donated buffers are not usable.\n", "warnings.filterwarnings(\n", " \"ignore\", message=\"Some donated buffers were not usable\")\n", "\n", "@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))\n", "def maybe_cast_to_f32(params, trainable):\n", " return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p,\n", " params, trainable)\n", "\n", "# Loading all params in simultaneous - albeit much faster and more succinct -\n", "# requires more RAM than the notebook runtimes have by default.\n", "# Instead we do it param by param.\n", "params, treedef = jax.tree.flatten(params)\n", "sharding_leaves = jax.tree.leaves(params_sharding)\n", "trainable_leaves = jax.tree.leaves(trainable_mask)\n", "for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):\n", " params[idx] = big_vision.utils.reshard(params[idx], sharding)\n", " params[idx] = maybe_cast_to_f32(params[idx], trainable)\n", " params[idx].block_until_ready()\n", "params = jax.tree.unflatten(treedef, params)\n", "\n", "# Print params to show what the model is made of.\n", "def parameter_overview(params):\n", " for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:\n", " print(f\"{path:80s} {str(arr.shape):22s} {arr.dtype}\")\n", "\n", "print(\" == Model params == \")\n", "parameter_overview(params)" ] }, { "cell_type": "markdown", "metadata": { "id": "iD_9XXQkn1Mv" }, "source": [ "# Prepare to tune the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8SRW0NuU4UcW" }, "outputs": [], "source": [ "def preprocess_image(image, size=224):\n", " # Model has been trained to handle images of different aspects ratios\n", " # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize\n", " # options are helpful to improve quality in some tasks.\n", " image = np.asarray(image)\n", " if image.ndim == 2: # Convert image without last channel into greyscale.\n", " image = np.stack((image,)*3, axis=-1)\n", " image = image[..., :3] # Remove alpha layer.\n", " assert image.shape[-1] == 3\n", "\n", " image = tf.constant(image)\n", " image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)\n", " return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1]\n", "\n", "def preprocess_tokens(prefix, suffix=None, seqlen=None):\n", " # Model has been trained to handle tokenized text composed of a prefix with\n", " # full attention and a suffix with causal attention.\n", " separator = \"\\n\"\n", " tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)\n", " mask_ar = [0] * len(tokens) # 0 to use full attention for prefix.\n", " mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss.\n", "\n", " if suffix:\n", " suffix = tokenizer.encode(suffix, add_eos=True)\n", " tokens += suffix\n", " mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix.\n", " mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss.\n", "\n", " mask_input = [1] * len(tokens) # 1 if it's a token, 0 if padding.\n", " if seqlen:\n", " padding = [0] * max(0, seqlen - len(tokens))\n", " tokens = tokens[:seqlen] + padding\n", " mask_ar = mask_ar[:seqlen] + padding\n", " mask_loss = mask_loss[:seqlen] + padding\n", " mask_input = mask_input[:seqlen] + padding\n", "\n", " return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))\n", "\n", "def postprocess_tokens(tokens):\n", " tokens = tokens.tolist() # np.array to list[int]\n", " try: # Remove tokens at and after EOS if any.\n", " eos_pos = tokens.index(tokenizer.eos_id())\n", " tokens = tokens[:eos_pos]\n", " except ValueError:\n", " pass\n", " return tokenizer.decode(tokens)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ovgWBgdHJZq3" }, "source": [ "### Create the training and validation iterators" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "whzWOojGOtzi" }, "outputs": [], "source": [ "SEQLEN = 128\n", "\n", "TRAIN_DATA_DIR = 'train/'\n", "VALID_DATA_DIR = 'valid/'\n", "\n", "# Load training data\n", "train_dataset = big_vision.datasets.jsonl.DataSource(\n", " os.path.join(\"train.jsonl\"),\n", " fopen_keys={\"image\": TRAIN_DATA_DIR})\n", "\n", "# Load validation data\n", "val_dataset = big_vision.datasets.jsonl.DataSource(\n", " os.path.join(\"valid.jsonl\"),\n", " fopen_keys={\"image\": VALID_DATA_DIR})\n", "\n", "def train_data_iterator():\n", " \"\"\"Never ending iterator over training examples.\"\"\"\n", " # Shuffle examples and repeat so one can train for many epochs.\n", " dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()\n", " for example in dataset.as_numpy_iterator():\n", " image = Image.open(io.BytesIO(example[\"image\"]))\n", " image = preprocess_image(image)\n", "\n", " # Define prefix for tokenization\n", " prefix = \"describe en\"\n", " suffix = example[\"suffix\"].decode().lower()\n", " tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)\n", "\n", " yield {\n", " \"image\": np.asarray(image),\n", " \"text\": np.asarray(tokens),\n", " \"mask_ar\": np.asarray(mask_ar),\n", " \"mask_loss\": np.asarray(mask_loss),\n", " }\n", "\n", "\n", "def validation_data_iterator():\n", " \"\"\"Single iterator over validation examples.\"\"\"\n", " for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():\n", " image = Image.open(io.BytesIO(example[\"image\"]))\n", " image = preprocess_image(image)\n", "\n", " # Define prefix for tokenization\n", " prefix = \"describe en\" # Could also be a different prefix per example 'describe en'\n", " tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)\n", "\n", " yield {\n", " \"image\": np.asarray(image),\n", " \"text\": np.asarray(tokens),\n", " \"mask_ar\": np.asarray(mask_ar),\n", " \"mask_input\": np.asarray(mask_input),\n", " }\n" ] }, { "cell_type": "markdown", "metadata": { "id": "84olaM5dCiAl" }, "source": [ "### View sample training examples" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BzJfb5t0nsLq" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training Examples\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", "

yellow-billed chough: observe this medium-sized crow easily recognizable by its sleek black feathers and contrasting long yellow bill, a common sight in the mountainous regions of europe and asia, noted for its acrobatic flight and yellow bill.

\n", "
\n", " \n", "
\n", " \n", "

american bittern: notice this medium-sized heron, a master of camouflage, blending seamlessly into the marshes of north and central america, distinguished by its cryptic plumage and elongated neck.

\n", "
\n", " \n", "
\n", " \n", "

american pipit: look for this small, ground-dwelling songbird with streaked brown plumage and long tail, filling the air with melodic tunes across the grasslands of north america, europe, and asia.

\n", "
\n", " \n", "
\n", " \n", "

alexandrine parakeet: notice this medium-sized parrot bursting with a rainbow of colors, a vibrant resident of the forests and woodlands of south and southeast asia, distinguished by its large size and distinctive red beak.

\n", "
\n", " \n", "
\n", " \n", "

crowned crane: observe this epitome of elegance with a golden crown, calling the wetlands and grasslands of sub-saharan africa home, distinguished by its tall stature, long legs, and regal posture.

\n", "
\n", " \n", "
\n", " \n", "

african pygmy goose: look for this small freshwater goose with dark brown feathers and contrasting white markings, gracing the lakes and rivers of sub-saharan africa, noted for its petite size and distinctive facial markings.

\n", "
\n", " \n", "
\n", " \n", "

american bittern: notice this medium-sized heron, a master of camouflage, blending seamlessly into the marshes of north and central america, distinguished by its cryptic plumage and elongated neck.

\n", "
\n", " \n", "
\n", " \n", "

african pygmy goose: look for this small freshwater goose with dark brown feathers and contrasting white markings, gracing the lakes and rivers of sub-saharan africa, noted for its petite size and distinctive facial markings.

\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def render_inline(image, resize=(128, 128)):\n", " \"\"\"Converts an image array into inline HTML.\n", "\n", " Args:\n", " image (numpy.ndarray): The image array to convert.\n", " resize (tuple): Optional. Size to resize the image to before conversion. Default is (128, 128).\n", "\n", " Returns:\n", " str: HTML representation of the image encoded in base64.\n", " \"\"\"\n", " image = Image.fromarray(image)\n", " image.resize(resize)\n", " with io.BytesIO() as buffer:\n", " image.save(buffer, format='jpeg')\n", " image_b64 = str(base64.b64encode(buffer.getvalue()), \"utf-8\")\n", " return f\"data:image/jpeg;base64,{image_b64}\"\n", "\n", "def render_example(image, description):\n", " \"\"\"Generates HTML for displaying an image with its description.\n", "\n", " Args:\n", " image (numpy.ndarray): The image array to display.\n", " description (str): The description of the image.\n", "\n", " Returns:\n", " str: HTML representation of the image with description.\n", " \"\"\"\n", " image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -> [0, 255]\n", " return f\"\"\"\n", "
\n", " \n", "

{html.escape(description)}

\n", "
\n", " \"\"\"\n", "\n", "html_out = \"\"\n", "for idx, example in zip(range(8), train_data_iterator()):\n", " description = postprocess_tokens(example[\"text\"]) # detokenize model input.\n", " description = description[len(\"describe en\\n\"):]\n", " html_out += render_example(example[\"image\"], description)\n", "\n", "print(\"Training Examples\")\n", "display(HTML(html_out))" ] }, { "cell_type": "markdown", "metadata": { "id": "N2BwpXkfI8OT" }, "source": [ "### Define the training and evaluation loops" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dwUV_imW3WQJ" }, "outputs": [], "source": [ "# The main update_fn using a simple stochastic gradient descent (SGD).\n", "@functools.partial(jax.jit, donate_argnums=(0,))\n", "def update_fn(params, batch, learning_rate):\n", " \"\"\"Performs one update step using stochastic gradient descent (SGD).\n", "\n", " Args:\n", " params (dict): Current model parameters.\n", " batch (dict): Batch of data containing images, texts, and masks.\n", " learning_rate (float): Learning rate for the update.\n", "\n", " Returns:\n", " tuple: Updated parameters and the calculated loss.\n", " \"\"\"\n", "\n", " imgs, txts, mask_ar = batch[\"image\"], batch[\"text\"], batch[\"mask_ar\"]\n", "\n", " def loss_fn(params):\n", " text_logits, _ = model.apply({\"params\": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)\n", " logp = jax.nn.log_softmax(text_logits, axis=-1)\n", "\n", " # The model takes as input txts[:, :-1] but the loss is defined as predicting\n", " # next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens\n", " # are part of the loss (e.g. prefix and padded tokens are not included).\n", " mask_loss = batch[\"mask_loss\"][:, 1:]\n", " targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])\n", "\n", " # Compute the loss per example. i.e. the mean of per token pplx.\n", " # Since each example has a different number of tokens we normalize it.\n", " token_pplx = jnp.sum(logp * targets, axis=-1) # sum across vocab_size.\n", " example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1) # sum across seq_len.\n", " example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1) # weight by num of tokens.\n", "\n", " # batch_loss: mean of per example loss.\n", " return jnp.mean(example_loss)\n", "\n", " loss, grads = jax.value_and_grad(loss_fn)(params)\n", "\n", " # Apply gradients to trainable params using SGD.\n", " def apply_grad(param, gradient, trainable):\n", " if not trainable: return param\n", " return param - learning_rate * gradient\n", "\n", " params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)\n", "\n", " return params, loss\n", "\n", "# Evaluation/inference loop.\n", "def make_predictions(data_iterator, *, num_examples=None,\n", " batch_size=4, seqlen=SEQLEN, sampler=\"greedy\"):\n", " \"\"\"Generates model predictions for given data iterator.\n", "\n", " Args:\n", " data_iterator (iterator): Iterator yielding batches of data.\n", " num_examples (int, optional): Maximum number of examples to generate predictions for.\n", " batch_size (int, optional): Batch size for inference. Default is 4.\n", " seqlen (int, optional): Maximum sequence length for decoding. Default is SEQLEN.\n", " sampler (str, optional): Sampling method for generating predictions. Default is \"greedy\".\n", "\n", " Returns:\n", " list: List of tuples containing image and corresponding model response.\n", " \"\"\"\n", " outputs = []\n", " while True:\n", " # Construct a list of examples in the batch.\n", " examples = []\n", " try:\n", " for _ in range(batch_size):\n", " examples.append(next(data_iterator))\n", " examples[-1][\"_mask\"] = np.array(True) # Indicates true example.\n", " except StopIteration:\n", " if len(examples) == 0:\n", " return outputs\n", "\n", " # Not enough examples to complete a batch. Pad by repeating last example.\n", " while len(examples) % batch_size:\n", " examples.append(dict(examples[-1]))\n", " examples[-1][\"_mask\"] = np.array(False) # Indicates padding example.\n", "\n", " # Convert list of examples into a dict of np.arrays and load onto devices.\n", " batch = jax.tree.map(lambda *x: np.stack(x), *examples)\n", " batch = big_vision.utils.reshard(batch, data_sharding)\n", "\n", " # Make model predictions\n", " tokens = decode({\"params\": params}, batch=batch,\n", " max_decode_len=seqlen, sampler=sampler)\n", "\n", " # Fetch model predictions to device and detokenize.\n", " tokens, mask = jax.device_get((tokens, batch[\"_mask\"]))\n", " tokens = tokens[mask] # remove padding examples.\n", " responses = [postprocess_tokens(t) for t in tokens]\n", "\n", "\n", " # Append to html output.\n", " for example, response in zip(examples, responses):\n", " outputs.append((example[\"image\"], response))\n", " if num_examples and len(outputs) >= num_examples:\n", " return outputs" ] }, { "cell_type": "markdown", "metadata": { "id": "n9r9V1jwJvu9" }, "source": [ "# Finetune the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kXIT0lB9jqjs" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "step: 1/128 lr: 0.00005 loss: 3.5541\n", "step: 2/128 lr: 0.00009 loss: 3.2215\n", "step: 3/128 lr: 0.00014 loss: 3.7010\n", "step: 4/128 lr: 0.00019 loss: 3.2698\n", "step: 5/128 lr: 0.00023 loss: 3.5187\n", "step: 6/128 lr: 0.00028 loss: 3.3232\n", "step: 7/128 lr: 0.00033 loss: 3.2168\n", "step: 8/128 lr: 0.00038 loss: 3.2819\n", "step: 9/128 lr: 0.00042 loss: 3.2132\n", "step: 10/128 lr: 0.00047 loss: 3.2118\n", "step: 11/128 lr: 0.00052 loss: 3.1371\n", "step: 12/128 lr: 0.00056 loss: 3.1616\n", "step: 13/128 lr: 0.00061 loss: 3.5072\n", "step: 14/128 lr: 0.00066 loss: 2.9888\n", "step: 15/128 lr: 0.00070 loss: 3.2852\n", "step: 16/128 lr: 0.00075 loss: 3.1900\n", "step: 17/128 lr: 0.00080 loss: 3.1098\n", "step: 18/128 lr: 0.00084 loss: 3.0129\n", "step: 19/128 lr: 0.00089 loss: 3.0052\n", "step: 20/128 lr: 0.00094 loss: 2.9630\n", "step: 21/128 lr: 0.00098 loss: 2.8648\n", "step: 22/128 lr: 0.00103 loss: 2.9091\n", "step: 23/128 lr: 0.00108 loss: 2.7960\n", "step: 24/128 lr: 0.00112 loss: 2.9222\n", "step: 25/128 lr: 0.00117 loss: 2.7515\n", "step: 26/128 lr: 0.00122 loss: 2.7756\n", "step: 27/128 lr: 0.00127 loss: 2.6340\n", "step: 28/128 lr: 0.00131 loss: 2.9309\n", "step: 29/128 lr: 0.00136 loss: 2.7362\n", "step: 30/128 lr: 0.00141 loss: 2.4330\n", "step: 31/128 lr: 0.00145 loss: 2.4895\n", "step: 32/128 lr: 0.00150 loss: 2.7376\n", "step: 33/128 lr: 0.00155 loss: 2.3519\n", "step: 34/128 lr: 0.00159 loss: 2.3741\n", "step: 35/128 lr: 0.00164 loss: 2.1744\n", "step: 36/128 lr: 0.00169 loss: 2.2894\n", "step: 37/128 lr: 0.00173 loss: 2.5071\n", "step: 38/128 lr: 0.00178 loss: 2.2978\n", "step: 39/128 lr: 0.00183 loss: 2.4412\n", "step: 40/128 lr: 0.00188 loss: 2.1825\n", "step: 41/128 lr: 0.00192 loss: 2.1376\n", "step: 42/128 lr: 0.00197 loss: 2.1129\n", "step: 43/128 lr: 0.00202 loss: 2.3329\n", "step: 44/128 lr: 0.00206 loss: 1.9307\n", "step: 45/128 lr: 0.00211 loss: 2.0200\n", "step: 46/128 lr: 0.00216 loss: 1.9894\n", "step: 47/128 lr: 0.00220 loss: 1.8897\n", "step: 48/128 lr: 0.00225 loss: 1.9321\n", "step: 49/128 lr: 0.00230 loss: 1.8077\n", "step: 50/128 lr: 0.00234 loss: 1.7965\n", "step: 51/128 lr: 0.00239 loss: 1.8647\n", "step: 52/128 lr: 0.00244 loss: 1.6083\n", "step: 53/128 lr: 0.00248 loss: 1.7444\n", "step: 54/128 lr: 0.00253 loss: 1.6344\n", "step: 55/128 lr: 0.00258 loss: 1.6643\n", "step: 56/128 lr: 0.00262 loss: 1.7359\n", "step: 57/128 lr: 0.00267 loss: 1.5264\n", "step: 58/128 lr: 0.00272 loss: 1.6980\n", "step: 59/128 lr: 0.00277 loss: 1.7417\n", "step: 60/128 lr: 0.00281 loss: 1.5004\n", "step: 61/128 lr: 0.00286 loss: 1.6187\n", "step: 62/128 lr: 0.00291 loss: 1.4576\n", "step: 63/128 lr: 0.00295 loss: 1.4337\n", "step: 64/128 lr: 0.00300 loss: 1.4088\n", "step: 65/128 lr: 0.00300 loss: 1.6668\n", "step: 66/128 lr: 0.00299 loss: 1.4565\n", "step: 67/128 lr: 0.00298 loss: 1.3991\n", "step: 68/128 lr: 0.00297 loss: 1.5546\n", "step: 69/128 lr: 0.00296 loss: 1.4990\n", "step: 70/128 lr: 0.00294 loss: 0.9921\n", "step: 71/128 lr: 0.00291 loss: 1.4244\n", "step: 72/128 lr: 0.00289 loss: 1.2937\n", "step: 73/128 lr: 0.00286 loss: 1.1714\n", "step: 74/128 lr: 0.00283 loss: 0.9958\n", "step: 75/128 lr: 0.00279 loss: 1.3715\n", "step: 76/128 lr: 0.00275 loss: 1.1864\n", "step: 77/128 lr: 0.00271 loss: 1.2027\n", "step: 78/128 lr: 0.00267 loss: 1.0820\n", "step: 79/128 lr: 0.00262 loss: 1.1740\n", "step: 80/128 lr: 0.00257 loss: 1.4427\n", "step: 81/128 lr: 0.00252 loss: 1.0802\n", "step: 82/128 lr: 0.00247 loss: 0.8756\n", "step: 83/128 lr: 0.00241 loss: 1.3305\n", "step: 84/128 lr: 0.00235 loss: 0.8690\n", "step: 85/128 lr: 0.00229 loss: 1.0759\n", "step: 86/128 lr: 0.00223 loss: 1.0876\n", "step: 87/128 lr: 0.00216 loss: 0.8533\n", "step: 88/128 lr: 0.00210 loss: 0.7881\n", "step: 89/128 lr: 0.00203 loss: 0.7177\n", "step: 90/128 lr: 0.00196 loss: 0.8630\n", "step: 91/128 lr: 0.00189 loss: 0.8725\n", "step: 92/128 lr: 0.00182 loss: 0.8491\n", "step: 93/128 lr: 0.00175 loss: 0.9923\n", "step: 94/128 lr: 0.00168 loss: 0.6511\n", "step: 95/128 lr: 0.00161 loss: 0.4120\n", "step: 96/128 lr: 0.00154 loss: 0.5675\n", "step: 97/128 lr: 0.00146 loss: 0.9469\n", "step: 98/128 lr: 0.00139 loss: 1.1420\n", "step: 99/128 lr: 0.00132 loss: 0.5743\n", "step: 100/128 lr: 0.00125 loss: 0.7502\n", "step: 101/128 lr: 0.00118 loss: 0.8339\n", "step: 102/128 lr: 0.00111 loss: 0.7139\n", "step: 103/128 lr: 0.00104 loss: 0.4566\n", "step: 104/128 lr: 0.00097 loss: 0.8023\n", "step: 105/128 lr: 0.00090 loss: 0.5136\n", "step: 106/128 lr: 0.00084 loss: 0.7242\n", "step: 107/128 lr: 0.00077 loss: 0.8242\n", "step: 108/128 lr: 0.00071 loss: 0.5768\n", "step: 109/128 lr: 0.00065 loss: 0.6741\n", "step: 110/128 lr: 0.00059 loss: 0.4422\n", "step: 111/128 lr: 0.00053 loss: 0.3808\n", "step: 112/128 lr: 0.00048 loss: 0.7980\n", "step: 113/128 lr: 0.00043 loss: 0.6143\n", "step: 114/128 lr: 0.00038 loss: 0.7009\n", "step: 115/128 lr: 0.00033 loss: 0.6710\n", "step: 116/128 lr: 0.00029 loss: 0.3956\n", "step: 117/128 lr: 0.00025 loss: 0.7030\n", "step: 118/128 lr: 0.00021 loss: 0.7158\n", "step: 119/128 lr: 0.00017 loss: 0.4146\n", "step: 120/128 lr: 0.00014 loss: 0.6061\n", "step: 121/128 lr: 0.00011 loss: 0.4536\n", "step: 122/128 lr: 0.00009 loss: 0.5641\n", "step: 123/128 lr: 0.00006 loss: 0.6973\n", "step: 124/128 lr: 0.00004 loss: 0.6476\n", "step: 125/128 lr: 0.00003 loss: 0.7660\n", "step: 126/128 lr: 0.00002 loss: 0.4655\n", "step: 127/128 lr: 0.00001 loss: 0.6795\n", "step: 128/128 lr: 0.00000 loss: 0.5314\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Fine-tune the model parameters.\n", "BATCH_SIZE = 4\n", "TRAIN_EXAMPLES = 512\n", "LEARNING_RATE = 0.003\n", "\n", "TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE\n", "EVAL_STEPS = TRAIN_STEPS // 4 # Number of evaluation steps\n", "\n", "# Lists to store training losses\n", "losses = []\n", "train_data_it = train_data_iterator()\n", "\n", "# Learning rate schedule using cosine decay with warmup\n", "sched_fn = big_vision.utils.create_learning_rate_schedule(\n", " total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,\n", " decay_type=\"cosine\", warmup_percent=0.50)\n", "\n", "# Perform training steps\n", "for step in range(1, TRAIN_STEPS+1):\n", " # Make list of N training examples.\n", " examples = [next(train_data_it) for _ in range(BATCH_SIZE)]\n", "\n", " # Convert list of examples into a dict of np.arrays and load onto devices.\n", " batch = jax.tree.map(lambda *x: np.stack(x), *examples)\n", " batch = big_vision.utils.reshard(batch, data_sharding)\n", "\n", " # Training step and report training loss\n", " learning_rate = sched_fn(step)\n", " params, loss = update_fn(params, batch, learning_rate)\n", "\n", " loss = jax.device_get(loss)\n", " losses.append(loss)\n", " print(f\"step: {step:2d}/{TRAIN_STEPS:2d} lr: {learning_rate:.5f} loss: {loss:.4f}\")\n", "\n", "# Plotting the loss graph\n", "plt.plot(losses, label='Training Loss')\n", "plt.title('Training Loss Over Steps')\n", "plt.xlabel('Steps')\n", "plt.ylabel('Loss')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "glScsFLVJ52c" }, "source": [ "# Predictions and Validation Comparison" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R7Xjc9vyjqjt" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predictions and Validation Examples: 10 Randomly Selected Images \n", "\n" ] }, { "data": { "text/html": [ "\n", "
\n", "
Prediction: juvenile in nest
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
Validation: abbott's booby: keep an eye out for this majestic seabird soaring over the tropical indian and pacific oceans, especially around christmas island, noted for its large size and long wingspan while in flight.
\n", "
\n", " \n", "
\n", "
Prediction: albatross: observe this large seabird known for its brooding behavior, providing a visual spectacle with its wingspan and long flight feathers, and for nesting on remote islands in the southern ocean.
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
Validation: albatross: observe this large seafaring bird known for its impressive wingspan and effortless gliding flight, a majestic sight in the southern hemisphere, noted for its long wings and buoyant flight.
\n", "
\n", " \n", "
\n", "
Prediction: african black-crowned night-jar: observe this large nocturnal bird of prey in the savannas and woodlands of sub-saharan africa, noted for its distinctive black crown and yellow bill.
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
Validation: yellow-billed chough: observe this medium-sized crow easily recognizable by its sleek black feathers and contrasting long yellow bill, a common sight in the mountainous regions of europe and asia, noted for its acrobatic flight and yellow bill.
\n", "
\n", " \n", "
\n", "
Prediction: american redstart: observe this small insectivorous bird with red markings on its wings and tail, found in the deciduous forests of north america.
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
Validation: american redstart: observe this small insectivorous warbler with contrasting black feathers and vibrant orange patches on its wings, a resident of north america, noted for its active foraging behavior and constant tail flicking.
\n", "
\n", " \n", "
\n", "
Prediction: american redstart: observe this small insectivorous bird with red underparts and black feathers, found in the woodlands of north and central america, noted for its slender body and long tail.
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
Validation: american redstart: observe this small insectivorous warbler with contrasting black feathers and vibrant orange patches on its wings, a resident of north america, noted for its active foraging behavior and constant tail flicking.
\n", "
\n", " \n", "
\n", "
Prediction: american flamingo: observe this large wading bird, a standout in any tropical landscape, known for its vibrant pink plumage and long legs, a true icon of the caribbean.
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
Validation: american flamingo: look for this large wading bird recognizable by its vibrant pink to red plumage, long legs, and distinctively curved bill, a striking inhabitant of the caribbean, galΓ‘pagos, and coastal regions of south america, noted for its long neck and legs.
\n", "
\n", " \n", "
\n", "
Prediction: cuban toreador: a close-knit family of medium-sized birds with elongated snouts and robust wings, inhabiting the forests of central and south america, noted for their impressive casque of feathers.
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
Validation: african emerald cuckoo: look for this small parasitic bird with shimmering emerald feathers, a unique inhabitant of the forests of sub-saharan africa, characterized by its slender body and long tail.
\n", "
\n", " \n", "
\n", "
Prediction: african black hornbill: observe this medium-sized african bird with an impressive red bill and long beak, a distinctive sight endemic to the forests of southern and east africa, known for its impressive size and robust build.
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
Validation: abyssinian ground hornbill: notice this large terrestrial bird commanding attention with its distinctive helmet-like structure on its head, a resident of the savannas and woodlands of sub-saharan africa, characterized by its powerful bill and long eyelashes.
\n", "
\n", " \n", "
\n", "
Prediction: american goldfinch: observe this small, colorful finch familiarizing himself with backyard feeders, noted for its distinctive yellow plumage and black cap.
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
Validation: american goldfinch: notice this small, colorful finch with bright yellow feathers and contrasting black wings, bringing joy to the landscapes of north america, noted for its conical bill and black cap.
\n", "
\n", " \n", "
\n", "
Prediction: long-billed dowitch: look for this large, long-billed shorebird in the grasslands and savannas of southern africa, characterized by its distinctive black and white patterning and long, slender bill.
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
Validation: alexandrine parakeet: notice this medium-sized parrot bursting with a rainbow of colors, a vibrant resident of the forests and woodlands of south and southeast asia, distinguished by its large size and distinctive red beak.
\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def data_iterator(data_type=None):\n", " \"\"\"Iterates over examples for validation or prediction.\n", "\n", " Args:\n", " data_type (str): Type of data to iterate ('prediction' or 'validation').\n", "\n", " Yields:\n", " dict: Dictionary containing image and text data based on data_type.\n", " For 'prediction': {'image': np.array, 'text': np.array, 'mask_ar': np.array, 'mask_input': np.array}\n", " For 'validation': {'image': np.array, 'text': np.array, 'mask_ar': np.array, 'mask_loss': np.array}\n", " \"\"\"\n", " for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():\n", " image = Image.open(io.BytesIO(example[\"image\"]))\n", " image = preprocess_image(image)\n", "\n", " prefix = \"describe en\"\n", " if data_type == \"prediction\":\n", " tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)\n", "\n", " yield {\n", " \"image\": np.asarray(image),\n", " \"text\": np.asarray(tokens),\n", " \"mask_ar\": np.asarray(mask_ar),\n", " \"mask_input\": np.asarray(mask_input),\n", " }\n", " elif data_type == \"validation\":\n", " suffix = example[\"suffix\"].decode().lower()\n", " tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)\n", "\n", " yield {\n", " \"image\": np.asarray(image),\n", " \"text\": np.asarray(tokens),\n", " \"mask_ar\": np.asarray(mask_ar),\n", " \"mask_loss\": np.asarray(mask_loss),\n", " }\n", "\n", "def render_example(image, description):\n", " \"\"\"Renders an image with description in HTML format.\n", "\n", " Args:\n", " image (np.array): Image data as numpy array.\n", " description (str): Description text to display alongside the image.\n", "\n", " Returns:\n", " str: HTML formatted string for displaying the image and description.\n", " \"\"\"\n", " image = ((image + 1) / 2 * 255).astype(np.uint8) # [-1,1] -> [0, 255]\n", " return f\"\"\"\n", "
\n", " \n", "
\n", " \"\"\"\n", "\n", "def display_comparisons(predictions, validations):\n", " \"\"\"Displays side-by-side comparisons of predictions and validations.\n", "\n", " Args:\n", " predictions (list): List of tuples (image, description) for predictions.\n", " validations (list): List of tuples (image, description) for validations.\n", "\n", " Prints:\n", " Displays HTML output showing 10 randomly selected images with predictions and validations.\n", " \"\"\"\n", " html_out = \"\"\n", "\n", " # Select 10 random indices\n", " num_comparisons = min(10, len(predictions), len(validations))\n", " random_indices = random.sample(range(min(len(predictions), len(validations))), num_comparisons)\n", "\n", " for random_index in random_indices:\n", " pred_image, pred_description = predictions[random_index]\n", " val_image, val_description = validations[random_index]\n", "\n", " # Call render_example to get image content with description\n", " pred_content = render_example(pred_image, f\"Prediction: {pred_description}\")\n", " val_content = render_example(val_image, f\"Validation: {val_description}\")\n", "\n", " # Structure container with three columns and set widths\n", " html_out += f\"\"\"\n", "
\n", "
Prediction: {pred_description}
\n", "
{pred_content}
\n", "
Validation: {val_description}
\n", "
\n", " \"\"\"\n", " display(HTML(html_out))\n", "\n", "\n", "# Generate predictions and validations\n", "predictions = []\n", "for image, description_pred in make_predictions(data_iterator(\"prediction\"), batch_size=4):\n", " # Only append if both image and description_pred are not None\n", " if image is not None and description_pred is not None:\n", " predictions.append((image, description_pred))\n", "\n", "validations = []\n", "for example in data_iterator(\"validation\"):\n", " description = postprocess_tokens(example[\"text\"])\n", " description = description[len(\"describe en\\n\"):] # Strip prefix describe the image\n", " # Only append if image and description are not None\n", " if example[\"image\"] is not None and description is not None:\n", " validations.append((example[\"image\"], description))\n", "\n", "# Display predictions and validations side-by-side\n", "print(\"Predictions and Validation Examples: 10 Randomly Selected Images \\n\")\n", "display_comparisons(predictions, validations)" ] }, { "cell_type": "markdown", "metadata": { "id": "k-8GGxGMjqjt" }, "source": [ "# Conclusion\n", "\n", "This notebook fine-tuned PaliGemma on 3,692 image-description pairs from a diverse bird species dataset. Using 23 curated species with text descriptions, the model generally produced accurate descriptions similar to validation data. However, improvements are needed for better species identification accuracy, suggesting potential benefits from expanding training to include more of the 525 available species." ] }, { "cell_type": "markdown", "metadata": { "id": "e5HltsZsjqjt" }, "source": [ "# Reference\n", "- [Fine-tune PaliGemma with JAX](https://www.kaggle.com/code/nilaychauhan/fine-tune-paligemma-with-jax)\n", "- [PaliGemma model README](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md)\n", "- [BIRDS 525 SPECIES- IMAGE CLASSIFICATION](https://www.kaggle.com/datasets/gpiosenka/100-bird-species)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "Finetune_PaliGemma_for_image_description.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }