{ "cells": [ { "cell_type": "markdown", "source": [ "# Fine-Tuning GPT-2 with RLHF on Drugs.com Reviews for High-Quality Drug Reviews on Depression\n", "\n", "\n", "**Author**: Zakia Salod\n", "\n", "**Affiliation**: University of KwaZulu-Natal (UKZN), Durban, South Africa\n", "\n", "**Contact**: zakia.salod@gmail.com\n", "\n", "**Machine Used**: Google Colab T4 GPU\n", "\n", "**Last Updated**: 10 December 2023\n", "\n", "**Description**:\n", "This notebook demonstrates fine-tuning the GPT-2 model (specifically, Zakia/gpt2-drugscom_depression_reviews) using Reinforcement Learning with Human Feedback (RLHF), leveraging the TRL (transformer reinforcement learning) library. The base model (GPT-2) and reward model (DistilBERT, specifically, Zakia/distilbert-drugscom_depression_reviews) are both fine-tuned on the same Drugs.com reviews dataset, focusing on depression. The goal is to further refine the GPT-2 model's ability to generate high-quality patient reviews on depression drugs, using RLHF for targeted improvement. This approach aims to harness the strengths of both GPT-2 and DistilBERT in generating insightful and accurate text content.\n", "\n", "\n", "**License**:\n", "This work is licensed under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0). Free for educational and research use.\n", "\n" ], "metadata": { "id": "McyCLRHCRADe" } }, { "cell_type": "markdown", "source": [ "
\n", " \n", "

\n", " Figure 1: This diagram represents the RLHF process applied to the GPT-2 model (link) using the DrugsCom DepressionReviews dataset. The fine-tuned GPT-2 model (link) shown in purple, DistilBERT model (link) depicted in orange, and the dataset (link, filtered for 'Depression' condition in the 'train' set) mentioned in the turquoise box, are highlighted to show their integration in the fine-tuning process.

" ], "metadata": { "id": "flWxHE44V_gh" } }, { "cell_type": "markdown", "source": [ "## STEP 1: SETTING UP THE ENVIRONMENT" ], "metadata": { "id": "2LWsWZo5K0c1" } }, { "cell_type": "markdown", "source": [ "### Load Necessary Libraries" ], "metadata": { "id": "ta54fa0GLBzP" } }, { "cell_type": "code", "source": [ "# Enable automatic module reloading to reflect changes in external .py files\n", "%load_ext autoreload\n", "# Reload all modules before executing code, keeping modules up-to-date\n", "%autoreload 2" ], "metadata": { "id": "OQoTv4uQhxGa" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Install Required Packages" ], "metadata": { "id": "ZWPOCVIsLLpt" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LNu3PqsV4NNo", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "46070413-cf44-4583-896c-3ad8985430c0" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting Accelerator\n", " Downloading accelerator-2023.11.3.dev1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting peft\n", " Downloading peft-0.7.0-py3-none-any.whl (168 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.3/168.3 kB\u001b[0m \u001b[31m24.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting trl\n", " Downloading trl-0.7.4-py3-none-any.whl (133 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m133.9/133.9 kB\u001b[0m \u001b[31m19.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting wandb\n", " Downloading wandb-0.16.1-py3-none-any.whl (2.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m32.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting setproctitle>=1.1.8 (from Accelerator)\n", " Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)\n", "Collecting bottle<0.13,>=0.12.7 (from Accelerator)\n", " Downloading bottle-0.12.25-py3-none-any.whl (90 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m90.2/90.2 kB\u001b[0m \u001b[31m13.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting waitress>=1.0 (from Accelerator)\n", " Downloading waitress-2.1.2-py3-none-any.whl (57 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.7/57.7 kB\u001b[0m \u001b[31m9.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from peft) (1.23.5)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from peft) (23.2)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n", "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from peft) (6.0.1)\n", "Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.1.0+cu118)\n", "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from peft) (4.35.2)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from peft) (4.66.1)\n", "Collecting accelerate>=0.21.0 (from peft)\n", " Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m265.7/265.7 kB\u001b[0m \u001b[31m34.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from peft) (0.4.1)\n", "Requirement already satisfied: huggingface-hub>=0.17.0 in /usr/local/lib/python3.10/dist-packages (from peft) (0.19.4)\n", "Collecting datasets (from trl)\n", " Downloading datasets-2.15.0-py3-none-any.whl (521 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m521.2/521.2 kB\u001b[0m \u001b[31m53.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting tyro>=0.5.11 (from trl)\n", " Downloading tyro-0.6.0-py3-none-any.whl (100 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.9/100.9 kB\u001b[0m \u001b[31m15.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb) (8.1.7)\n", "Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)\n", " Downloading GitPython-3.1.40-py3-none-any.whl (190 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.6/190.6 kB\u001b[0m \u001b[31m25.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (2.31.0)\n", "Collecting sentry-sdk>=1.0.0 (from wandb)\n", " Downloading sentry_sdk-1.38.0-py2.py3-none-any.whl (252 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m252.8/252.8 kB\u001b[0m \u001b[31m32.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting docker-pycreds>=0.4.0 (from wandb)\n", " Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n", "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb) (67.7.2)\n", "Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb) (1.4.4)\n", "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.20.3)\n", "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n", "Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->wandb)\n", " Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.7/62.7 kB\u001b[0m \u001b[31m9.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (3.13.1)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (2023.6.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (4.5.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2023.11.17)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (1.12)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.2.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.2)\n", "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.1.0)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2023.6.3)\n", "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (0.15.0)\n", "Collecting docstring-parser>=0.14.1 (from tyro>=0.5.11->trl)\n", " Downloading docstring_parser-0.15-py3-none-any.whl (36 kB)\n", "Requirement already satisfied: rich>=11.1.0 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl) (13.7.0)\n", "Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl)\n", " Downloading shtab-1.6.5-py3-none-any.whl (13 kB)\n", "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets->trl) (9.0.0)\n", "Collecting pyarrow-hotfix (from datasets->trl)\n", " Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)\n", "Collecting dill<0.3.8,>=0.3.0 (from datasets->trl)\n", " Downloading dill-0.3.7-py3-none-any.whl (115 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m17.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets->trl) (1.5.3)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets->trl) (3.4.1)\n", "Collecting multiprocess (from datasets->trl)\n", " Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m16.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->trl) (3.9.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl) (23.1.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl) (6.0.4)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl) (1.9.3)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl) (1.4.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl) (1.3.1)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl) (4.0.3)\n", "Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb)\n", " Downloading smmap-5.0.1-py3-none-any.whl (24 kB)\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (3.0.0)\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (2.16.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (2.1.3)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->trl) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->trl) (2023.3.post1)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.13.0->peft) (1.3.0)\n", "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl) (0.1.2)\n", "Installing collected packages: bottle, waitress, smmap, shtab, setproctitle, sentry-sdk, pyarrow-hotfix, docstring-parser, docker-pycreds, dill, multiprocess, gitdb, Accelerator, tyro, GitPython, accelerate, wandb, datasets, trl, peft\n", "Successfully installed Accelerator-2023.11.3.dev1 GitPython-3.1.40 accelerate-0.25.0 bottle-0.12.25 datasets-2.15.0 dill-0.3.7 docker-pycreds-0.4.0 docstring-parser-0.15 gitdb-4.0.11 multiprocess-0.70.15 peft-0.7.0 pyarrow-hotfix-0.6 sentry-sdk-1.38.0 setproctitle-1.3.3 shtab-1.6.5 smmap-5.0.1 trl-0.7.4 tyro-0.6.0 waitress-2.1.2 wandb-0.16.1\n" ] } ], "source": [ "!pip install Accelerator peft trl wandb" ] }, { "cell_type": "markdown", "source": [ "### Import Necessary Libraries" ], "metadata": { "id": "PSvI3SMCLSgK" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HdOCuPm91iKd", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "4d8433b9-ef28-4719-b293-209fa870be8d" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n", " warnings.warn(\n" ] } ], "source": [ "from dataclasses import dataclass, field\n", "from typing import Optional\n", "import pandas as pd\n", "import re\n", "import html\n", "import numpy as np\n", "import random\n", "\n", "import torch\n", "from accelerate import Accelerator\n", "from datasets import load_dataset\n", "from peft import LoraConfig\n", "from tqdm import tqdm\n", "from transformers import AutoTokenizer, pipeline\n", "from datasets import concatenate_datasets\n", "\n", "from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, set_seed\n", "from trl.core import LengthSampler\n", "from trl.import_utils import is_xpu_available" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZE2hblIQ11lI" }, "outputs": [], "source": [ "tqdm.pandas()" ] }, { "cell_type": "markdown", "source": [ "### Set Random Seeds for Reproducibility" ], "metadata": { "id": "iWSnSLePLf2e" } }, { "cell_type": "code", "source": [ "seed_value = 42\n", "\n", "random.seed(seed_value)\n", "torch.manual_seed(seed_value)" ], "metadata": { "id": "WLhPyBaWiBll", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "b3788113-1664-425e-bf01-9aa14c35a885" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 5 } ] }, { "cell_type": "markdown", "source": [ "### Initialize Weights & Biases for Tracking" ], "metadata": { "id": "_TpL4t1yLrWW" } }, { "cell_type": "code", "source": [ "import wandb\n", "\n", "wandb.init(project=\"gpt2-drugscom_depression_reviews-hq-v1\")" ], "metadata": { "id": "83T68a9XiHHs", "colab": { "base_uri": "https://localhost:8080/", "height": 211 }, "outputId": "a84d1912-1487-4303-d75c-165e11ddeb35" }, "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "application/javascript": [ "\n", " window._wandbApiKey = new Promise((resolve, reject) => {\n", " function loadScript(url) {\n", " return new Promise(function(resolve, reject) {\n", " let newScript = document.createElement(\"script\");\n", " newScript.onerror = reject;\n", " newScript.onload = resolve;\n", " document.body.appendChild(newScript);\n", " newScript.src = url;\n", " });\n", " }\n", " loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n", " const iframe = document.createElement('iframe')\n", " iframe.style.cssText = \"width:0;height:0;border:none\"\n", " document.body.appendChild(iframe)\n", " const handshake = new Postmate({\n", " container: iframe,\n", " url: 'https://wandb.ai/authorize'\n", " });\n", " const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n", " handshake.then(function(child) {\n", " child.on('authorize', data => {\n", " clearTimeout(timeout)\n", " resolve(data)\n", " });\n", " });\n", " })\n", " });\n", " " ] }, "metadata": {} }, { "output_type": "stream", "name": "stderr", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)\n", "\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://wandb.ai/authorize\n", "wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ··········\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "Tracking run with wandb version 0.16.1" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "Run data is saved locally in /content/wandb/run-20231210_103320-e2k30lm5" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "Syncing run devoted-pyramid-6 to Weights & Biases (docs)
" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ " View project at https://wandb.ai/team-zakia/gpt2-drugscom_depression_reviews-hq-v1" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ " View run at https://wandb.ai/team-zakia/gpt2-drugscom_depression_reviews-hq-v1/runs/e2k30lm5" ] }, "metadata": {} }, { "output_type": "execute_result", "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "execution_count": 6 } ] }, { "cell_type": "markdown", "source": [ "## STEP 2: CONFIGURATION" ], "metadata": { "id": "JGTvUvvrL_mD" } }, { "cell_type": "markdown", "source": [ "### Define script arguments for training configuration" ], "metadata": { "id": "W329hvPKQZEy" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lT6DOtQE13gO" }, "outputs": [], "source": [ "@dataclass\n", "class ScriptArguments:\n", " ppo_config: PPOConfig = field(\n", " default_factory=lambda: PPOConfig(\n", " model_name=\"Zakia/gpt2-drugscom_depression_reviews\",\n", " query_dataset=\"Zakia/drugscom_reviews\",\n", " reward_model=\"sentiment-analysis:Zakia/distilbert-drugscom_depression_reviews\",\n", " learning_rate=1.41e-5,\n", " log_with=\"wandb\",\n", " mini_batch_size=128,\n", " batch_size=128,\n", " gradient_accumulation_steps=1,\n", " early_stopping=False,\n", " target_kl=6.0,\n", " kl_penalty=\"kl\",\n", " seed=0,\n", " use_score_scaling=False,\n", " use_score_norm=False,\n", " score_clip=None,\n", " )\n", " )\n", " use_seq2seq: bool = False\n", " \"\"\"whether to use seq2seq models\"\"\"\n", " use_peft: bool = False\n", " \"\"\"whether to use peft\"\"\"\n", " peft_config: Optional[LoraConfig] = field(\n", " default_factory=lambda: LoraConfig(\n", " r=16,\n", " lora_alpha=16,\n", " bias=\"none\",\n", " task_type=\"CAUSAL_LM\",\n", " ),\n", " )\n", " trust_remote_code: bool = field(default=False, metadata={\"help\": \"Enable `trust_remote_code`\"})" ] }, { "cell_type": "markdown", "source": [ "### Initialize script arguments" ], "metadata": { "id": "r2VBoJb3QdSv" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7NVo3ViFCaXq" }, "outputs": [], "source": [ "args = ScriptArguments(\n", " ppo_config=PPOConfig(\n", " model_name=\"Zakia/gpt2-drugscom_depression_reviews\",\n", " query_dataset=\"Zakia/drugscom_reviews\",\n", " reward_model=\"sentiment-analysis:Zakia/distilbert-drugscom_depression_reviews\",\n", " learning_rate=1.41e-5,\n", " log_with=\"wandb\",\n", " mini_batch_size=128,\n", " batch_size=128,\n", " gradient_accumulation_steps=1,\n", " early_stopping=False,\n", " target_kl=6.0,\n", " kl_penalty=\"kl\",\n", " seed=0,\n", " use_score_scaling=False,\n", " use_score_norm=False,\n", " score_clip=None,\n", " ),\n", " use_seq2seq=False,\n", " use_peft=False,\n", " peft_config=LoraConfig(\n", " r=16,\n", " lora_alpha=16,\n", " bias=\"none\",\n", " task_type=\"CAUSAL_LM\",\n", " ),\n", " trust_remote_code=False\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xXRNMIgu2G8u" }, "outputs": [], "source": [ "# We then define the arguments to pass to the sentiment analysis pipeline.\n", "# We set `return_all_scores` to True to get the sentiment score for each token.\n", "sent_kwargs = {\"return_all_scores\": True, \"function_to_apply\": \"none\", \"batch_size\": 16}\n", "\n", "# Select appropriate model class based on arguments\n", "trl_model_class = AutoModelForCausalLMWithValueHead if not args.use_seq2seq else AutoModelForSeq2SeqLMWithValueHead" ] }, { "cell_type": "markdown", "source": [ "## STEP 3: DATASET PREPARATION" ], "metadata": { "id": "OW8FrsVUMbtu" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XRyup-PP2dp-" }, "outputs": [], "source": [ "# Function to clean review text\n", "def clean_review(text):\n", " # Check if the text is a string\n", " if not isinstance(text, str):\n", " return \"\" # Return an empty string if the input is not a string\n", " text = html.unescape(text) # Decode HTML entities\n", " text = re.sub(r'\"', '', text) # Remove quotes\n", " text = re.sub(r'<.*?>', '', text) # Remove HTML tags\n", " return text" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mbAzTGjW2eJZ" }, "outputs": [], "source": [ "# Clean the reviews of the dataset\n", "# Apply the clean_review function in a batched manner\n", "def clean_reviews(batch):\n", " # Apply clean_review to each review in the batch and return the modified batch\n", " return {\"review\": [clean_review(review) for review in batch[\"review\"]]}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dZx9ibFT2J0F" }, "outputs": [], "source": [ "# Function to build and preprocess the dataset\n", "def build_dataset(config, query_dataset, input_min_text_length=2, input_max_text_length=8):\n", " \"\"\"\n", " Build dataset for training. This builds the dataset from `load_dataset`\n", "\n", " Args:\n", " query_dataset (`str`):\n", " The name of the dataset to be loaded.\n", "\n", " Returns:\n", " dataloader (`torch.utils.data.DataLoader`):\n", " The dataloader for the dataset.\n", " \"\"\"\n", " tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n", " tokenizer.pad_token = tokenizer.eos_token\n", "\n", " # Load the dataset\n", " ds = load_dataset(query_dataset, split=\"train\")\n", "\n", " # Filter the dataset for the condition 'Depression'\n", " ds = ds.filter(lambda x: x[\"condition\"] == \"Depression\")\n", "\n", " # Filter out (remove) rows with missing drugName, or review\n", " ds = ds.filter(lambda x: all([x.get(\"drugName\"), x.get(\"review\")]))\n", "\n", " # Clean the reviews\n", " ds = ds.map(clean_reviews, batched=True)\n", "\n", " # Get the number of records\n", " num_records = ds.num_rows\n", " print(f\"Number of records with Depression condition: {num_records}\")\n", "\n", " input_size = LengthSampler(input_min_text_length, input_max_text_length)\n", "\n", " # Tokenization\n", " def tokenize(sample):\n", " sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n", " sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n", " return sample\n", "\n", " ds = ds.map(tokenize, batched=False)\n", "\n", " ds.set_format(type=\"torch\")\n", " return ds" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xj_TD6Ct2NSe", "colab": { "base_uri": "https://localhost:8080/", "height": 514, "referenced_widgets": [ "cacd4ec7666a495f95ee42b049f545c2", "7d75ab5cb0b649f79d8ddf80e3f62807", "c1859def222a4df2b272072d85fa7676", "a390233698064a18a3e6b9937c621611", "c268a07ba37043e7a825a2724bbb6a0f", "fa781c1b8b864ebcbcf5e4a7c8cd98e7", "03bc941dea4f498e838e37006f8985e5", "d2ad6562faf342ed8766fd5f9f9dba40", "553fc4dfcbe5465fa21db472f86a8bea", "117235d9b98c43e7bf356daf77132272", "8bac8cd3bdb14708996df8d33a834336", "4e9adafe6671414fa178fb67146e53e0", "642513898ccb4cfd98c554f2d8665756", "4019ffabb9904efcb54a49c26c2247f4", "e12eafe5438d4052be51c60bc65b3b23", "14ada4f3f977438ba35d1566c9b5e2a9", "36fb6a78f9fd42baa9526e7c467b32a7", "51dc5aff47ee42698b7211362715e48e", "1a647ffbe9204b24a1434d3d2aba145a", "48258c33345346379ccd9872320cdc9d", "ec8c1ea1ef284c18992cbdfa9079d538", "23a115d6c2be4e4698c2e415591707f3", "0bb15371613b4da7b0135ce82bc5709a", "6136b1b5b74b43bdbf117bbe2f89d7c6", "8deb499353a3427d82121c8360af8272", "713e49faa756459eafff87beedd0053c", "cf69c35d47f744d59fb5fff54eccfb23", "84958a3ee2d842bf9c8389dee41ee667", "00f52244f15e46708b1e01942aa575ea", "8711a4b7109e4efba7fe533266af44b6", "c6703fde5def47bd952f9210c548ec5b", "6e67b256683c493aa44afac1881e592d", "574a3fe648fe498cbefc60218fa8fdc7", "39d3486b29094fae85cb112025af8342", "5d603bdb7c02421897ddce63917c035b", "c88049b17b744f64b3088ddae6824b68", "3b7bdc11cd9a42d5ad1b48223f930eaa", "4ee8a79a0954481cacc3f65a3cda7f12", "da2ca6f9c54c461d97e99b325582eaf1", "3f0788dc42024a3c9ea7c2e1f2687cf7", "0d808afa735c4d3895ba4f7dcafc3d98", "0fca0bc3490648e9b310fc5be2f005a6", "f290d4a2b8684d8581daf46dcfc1670e", "9359f059c38d460a8af54d4cd3599111", "9308e1edb78443edbbcf7ca8d20f8466", "0e3cc3db69c74a409c9224b27ab416ef", "59ae5f4383664c69b407b3652facd47a", "cb22224ffb2a423d83499f36e3b64985", "9e558ee8d00b43f09fddd93c4332026f", "e4d151f2f1c044348e39a0236fd62b6b", "0ab1a3050f204290ad901df76570a0b6", "9929d97f1b52427289095c978bd4415d", "d78e3edcd5b84a9ea8c6a06152a1d701", "86cc71c17f7843dab2c5ed9b1a65a512", "3ccdcbb154be45b2989ecf7638a81952", "0fe6d99031d448f0a5a3051be4409c30", "de8fadeb89384b9ebf256528585f7580", "a54a338b324e45aa8e11aa147c191cfc", "6c3fa0a121a84bdb826006fc660a7bd8", "8acbcf1272c84327bf1931dd57e3f8ce", "e6df030e8761480b815cc179100081fb", "52970f15fe5a47eda52b1cb55e9602ef", "e029124232b24bcbae055b2702511fe4", "35386540399c412a81d3b3464b961c28", "9faa4121a8cd4c36958f107f65c05e73", "470988e6b2824cf19ea6e388068b68c7", "9845c6765c07438e93860c270ae2451f", "17a9a8d22f084866b0e6a5c11a291c7c", "8b5c4df5338f4551923c067bfb45d79a", "092e05ec073d449ba98d77aa1bba4514", "582e9e8c0f224808b61e641d9f6857ce", "ecae45cb4be74d40ae16f32061231ee8", "28124aef1dc84c8fa15ba48b8e8cfbf1", "41d7ace4772e4cc0b523a2a5477bacf7", "ca0c21bdfcbb4c528877594a0a971581", "25a7f4bcbd1c4eaab834eeefcb5993d7", "b3445e6da72f441bbfe1f0dd4b1e4932", "cbfda3d9879e4748a78525d5dda70577", "faa302e9d1f545049bac36fb5d52818f", "c4b58f67858f46719703f20d8b099c67", "edd370fe6a8140d888af7c372590ce7d", "64d9c92cb26145a4984492605fc10d5b", "110adee53e4a4bf783093df05da9ce09", "73e84da761404fe0a709cfad72b8b7df", "b0c3adbf4fbe4e80bac83b9d2c94ce5b", "0515fd3cba86455099d89c261c2f80d7", "e920d4c12a024911b28e8bb405c14e03", "beb65b676e344002834650cc7ec61432", "cd0b4756d94c4ce1add21a03181ff26d", "cf39e0337a4341688a632742ff608807", "a5023c84fadd443183d2d7625631dfaf", "bb731a5bda2347768cd825e9c61e15a0", "9ade7ed95933444db5c73620addc8788", "9ee531eab41b477389e857696d09e2f8", "333c43311a4844239b0bf83e2420ff49", "0ecdcb5f225641f09395a257e80110dd", "6ab2f360a3474d8db6dc41a876ffa201", "e9247f7c246f4884b654205784c1ae09", "e4e020ccb3c84add8faf496e8962075c", "b7d4d69778694809a93365fd2624cd81", "9a3645005a3a47b3adc7fa2ee256857e", "c421dcf25c7b4534be0a9a062d60b149", "af0de2848fba4b63811f01bc7170a804", "9f5c50b1834447edaf061bea891abab1", "b7cd2971836a4a55add06ae8e8d50c6e", "43692b418e384051a1418a3e57f84e32", "d3718aed59184f47b661e83702157aec", "72ccb093cc9545ce8e914b0d742e3f63", "3f85f633656b4bbabae188c5c4e14f3d", "ba2397bcd3d9498fb13a1bea270dc9b3", "2c401223ab314184a9f9fa060dee7c16", "324ef4675ae9438cafb857d1e3d01663", "082ae5f30421430d9cd4393b60b1bc16", "83664e84c6b0453cba357733181daecc", "afd94bbf15ae48d195e274ea3b7586f0", "f0a2f11543fd4e27a9d4c74577301f72", "de5974832860470287e7ba5c8b4cb7bf", "460d3dd5c55847b1939f7d977acb949c", "c02b2a87303b4933a37cea947ac6b882", "e12ff38dc40048958292d13e479095f7", "3b383f70c3ba4da68a2be5547c745fb7", "d56542c79e4d403582d581863172be45", "7b6909e6715d49e0a210300806ec5708", "fa13db2eb72b4e5880fad82e9e4bd09f", "96e807c7152b41cb93e1d074a72aa335", "d8a79792f16846b1aa9f4081f29ec3fe", "bdc1a36a76224b2ebee5a65949b38f44", "26dbe6cf82c846c8834b6a16596bbce9", "1bbec64ae8744c35b1392da846781109", "cc5db5dc2cb746a889332f90a8457ccf", "6ad4b1a0219b4d5ea2d391c496bc6185", "866ae5fa78944680b122e2f22eeadc76", "b991a41dfb34468aa7ada5bc97a5f55d", "448fc530bc6d43329f09449c8b715c2f", "08796b0a45dd40ec9a787a14a0ebc43a", "f912bd92b07d44bcb57414ee8a736739", "2ec2932072cc45f69967b0698124b90c", "80617d1e7fad47298432f6c2d1bb37c6", "20b8f8471ce848f9945d8a584e709f26", "1566da9437d842c78714e524000a6651", "b032accd178c4c94ae7028868166abc1", "fd62fedb98a045e0bb1d95557036bcc3", "133a2fd9c2234c0680705fe2227ecc75", "00ae9bdf78304320b35d8f075ab73b63", "b26d0326d13f40ccac3af420813ec886", "c9d6362992934ed19677e8b4941be19b", "4bca47a41c2041e1a0df01333c1ae066", "4e5e0a03b56f4dd6928baf6f4b11fa89", "91648b294ebe4c21b34f153ada728980", "a1d59508c4494ba9afdef0b1a3e9ba89", "6ef2df2906cd4414a8875085a241e0ec", "202251b464a74089a38387a5bcbe7da9", "70a9be6f4beb44ec907e00efea964319", "45df2ef15ea9415fbea235c829c5674a", "4d45625c2df243a09c6b6de642526af6", "e883f0adf748406487cddcdc5e1fdbec", "56803e4a735a49a8aeee90e37ee28000", "6a1a5226b6d74c11b7c99349f799a833", "94315059a05a4b0caf65cf71c2975762", "aca5bdcb2ea243c5aadcc51b21cec9db", "7f01b61a6853496b92604c39996e0945", "5cf68a35ecf44b67a8f658a4de8c2a18", "7338589c0ce44338a06acdecc13e0fa6", "3d698d7bff144c66b6502750677ae259", "96d8315ff43a4d31980d1fd8d0f96a67" ] }, "outputId": "219d8213-1441-4eb4-a94b-dd1ef1783403" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "tokenizer_config.json: 0%| | 0.00/525 [00:00 model is loaded from 'Zakia/gpt2-drugscom_depression_reviews', and no v_head weight is found. This IS expected if you are not resuming PPO training.\n" ] } ], "source": [ "# Now let's build the model, the reference model, and the tokenizer.\n", "if not args.use_peft:\n", " ref_model = trl_model_class.from_pretrained(args.ppo_config.model_name, trust_remote_code=args.trust_remote_code)\n", " device_map = None\n", " peft_config = None\n", "else:\n", " peft_config = args.peft_config\n", " ref_model = None\n", " # Copy the model to each device\n", " device_map = {\"\": Accelerator().local_process_index}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B9rCL7qq2TGY", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "e8b6cec5-7154-4343-9854-bfdac871b8d5" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "WARNING:root:A model is loaded from 'Zakia/gpt2-drugscom_depression_reviews', and no v_head weight is found. This IS expected if you are not resuming PPO training.\n" ] } ], "source": [ "model = trl_model_class.from_pretrained(\n", " args.ppo_config.model_name,\n", " trust_remote_code=args.trust_remote_code,\n", " device_map=device_map,\n", " peft_config=peft_config,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lHWI6nnr2XaQ" }, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(args.ppo_config.model_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xx9_1KgQ2ZI5" }, "outputs": [], "source": [ "# Some tokenizers like GPT-2's don't have a padding token by default, so we set one here.\n", "tokenizer.pad_token_id = tokenizer.eos_token_id" ] }, { "cell_type": "markdown", "source": [ "## STEP 5: INITIALIZE PPO TRAINER" ], "metadata": { "id": "bMeivsaCNp9Q" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1GjvgKIh2at9", "colab": { "base_uri": "https://localhost:8080/", "height": 191, "referenced_widgets": [ "93ca61ad1458490eb9ff7dcaa65ed78e", "88905cb2d9b04ce3b9943ade60021877", "8496b0087b5b411592b6a86cac6114f8", "f8ebe36d0a18494aaac534843b50b98a", "ecb0f8feabb74e24957d608ed30b9e79", "83b2b5923753462bb883155f25e2ed97", "1273cc0a098b48c988f245b5a1695971", "437fa49dd83640f8ace33962f6e02d39" ] }, "outputId": "eaadc134-f9af-41b8-edae-fc2c4c2f0ce7" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "Finishing last run (ID:e2k30lm5) before initializing another..." ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "93ca61ad1458490eb9ff7dcaa65ed78e" } }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ " View run devoted-pyramid-6 at: https://wandb.ai/team-zakia/gpt2-drugscom_depression_reviews-hq-v1/runs/e2k30lm5
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "Find logs at: ./wandb/run-20231210_103320-e2k30lm5/logs" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "Successfully finished last run (ID:e2k30lm5). Initializing new run:
" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "Tracking run with wandb version 0.16.1" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "Run data is saved locally in /content/wandb/run-20231210_103402-ku689w1y" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "Syncing run feasible-meadow-29 to Weights & Biases (docs)
" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ " View project at https://wandb.ai/team-zakia/trl" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ " View run at https://wandb.ai/team-zakia/trl/runs/ku689w1y" ] }, "metadata": {} } ], "source": [ "# We then build the PPOTrainer, passing the model, the reference model, the tokenizer\n", "ppo_trainer = PPOTrainer(args.ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uDij2You2cso", "colab": { "base_uri": "https://localhost:8080/", "height": 177, "referenced_widgets": [ "4e7ebc1b23034242906da99e5e8c555f", "91e5df2345d044ffab3b29cd9f9a4a01", "25163b69edb64fa4abdbdbc51a1e4b8b", "0f19699df5504e92903bac4e3f34b728", "50210578809040f69a411838e88642b0", "5a0b1274bd65407187c1c3f1bb1b17e7", "651d0502265b4f31a21e32bf5f6eb043", "94cd62e0c6394064b36dc514e7b0e711", "2d83e027a50a4e6d9d5f5a619f2efcb5", "958a20e7b5e74f5ea785673978fc8aef", "c7518aa06b214475905cae03703c0573", "595ca5de46674303a2bbf9a4ef21d102", "89d819b308b2471fa34afec09d5629ac", "13ee8507149f43a699605094d960db04", "c120e61807e746a6a125ed08945f3912", "91e84dae0ad643e88ce0ad3516eb4f9f", "619bb66b520a4f7594473fb979427a2b", "c00d612864434bbdb2622c637539ceec", "4f2598daf7e64afd9bb3da624dcc7348", "2bb689313c24450081e1cd7ebc2e7918", "4495c8d4edbd47c389abf1bdf8b29e76", "cba28340fc2a441899455d8a873e7e84", "ae95cbd521714941993a632ab542fd9a", "3efae665789e41d4af5ed34c31c74d90", "3622487cfbf246cca077dd251883e53a", "7ddef805691d43a6a4a2b35b1f5d883d", "e2217559e65a473ab9b71b03d7cbb297", "9fb8236da6f54afd9253011b4717cb72", "bae9451263a1407baf7e70c05f531de5", "b2d70d5230a64539bb60d0c2e7e01c2b", "1140492da05e4e63b094f8f29c0540dc", "9bf2ef98e1fc440e866513e35bd87125", "7180eb905ed440ec86c1e4b975add76c", "6376474b1c6948359b695138eaf11eaa", "2b4324bfbaae41719ca9c5677777efce", "4f947347f0514eaea9a5bcb5ec392a38", "a7903a646288497a916d33014a2a76c5", "d64795eec91e484f897582bcdd3cc0f2", "a24ac1b77217463b88b36fc5f2f9fdad", "9ff1606224064eacae17d86b54410601", "4297a80d97ec472a944489c72873c54d", "159fb4426ea541d5a3bd5482524d973b", "7d2e185527a44e9994190082f8c64490", "0e2fa3f7bb3f4f4f82c326dd5ecbf79c", "5b8926d0f99f43628eb948987c1db07e", "c4b810a689cd43f382110a0eec61b20a", "e54a320d0fc24f399c1c2c5cef22e89c", "6608ffa3f5854051b71713bb30577e00", "b02392ce4ae54f59953dfdb05b084850", "879977d9712b48b3b5aa33d62c034b2a", "566f643dfbde469eaa2085be1c350e42", "0f122f75e3e546c1b9902561f8f86235", "7fa037c9f14843cbb07875e5e1cd0b26", "2a979135ea6a4fceac69ad30f17f8676", "c1cb2f93dd184640ad4c108656f1ca2a" ] }, "outputId": "8c266d22-5c22-4212-ae2e-ff4c0d516aea" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "config.json: 0%| | 0.00/781 [00:00\n", "\n", "

Figure 2: env/reward_mean plot showing the average reward per training step.

\n", "" ], "metadata": { "id": "OXSYzDTEfQ_c" } }, { "cell_type": "markdown", "source": [ "This plot represents the average reward that the model received at each step of the training process. A reward in this context is a numerical value that indicates how well the generated text aligns with high-quality standards as determined by the reward model. The upward trend in the plot suggests that the model consistently starts to generate more positive and aligned outputs as training progresses, indicating that the RLHF technique is effectively improving the model's performance." ], "metadata": { "id": "FL3MN1NNftq0" } }, { "cell_type": "markdown", "source": [ "### Reward Distribution Heatmap Plot" ], "metadata": { "id": "bjAICf9XerQS" } }, { "cell_type": "markdown", "source": [ "
\n", "\n", "

Figure 3: env/reward_dist heatmap plot showing the distribution of rewards over training steps.

\n", "
" ], "metadata": { "id": "SqAlkuGegBR_" } }, { "cell_type": "markdown", "source": [ "This heatmap illustrates the distribution of rewards over training steps. Each vertical slice of the plot can be thought of as a snapshot of the reward landscape at a given step, with the color intensity representing the frequency of rewards at different levels. As the training continues, we expect to see the color bands shift upwards, which would indicate that the model is more frequently generating higher-quality responses." ], "metadata": { "id": "2RqTT-dRgZro" } }, { "cell_type": "markdown", "source": [ "## STEP 7: EVALUATE THE MODEL" ], "metadata": { "id": "1jJoW0FZPERB" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VLjobu0AoWTN", "colab": { "base_uri": "https://localhost:8080/", "height": 766 }, "outputId": "35c5b8c3-3669-4e26-9c55-4338516bce63" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/transformers/pipelines/text_classification.py:105: UserWarning: `return_all_scores` is now deprecated, if want a similar functionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/transformers/pipelines/base.py:1101: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset\n", " warnings.warn(\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ " query \\\n", "0 Very Very good. Helps \n", "1 It worked for about \n", "2 Started on 20 \n", "3 I am a 43 \n", "4 I got on Pro \n", "5 This drug has changed me \n", "6 Good Med! Little skeptical \n", "7 I used to take cl \n", "8 Been on Prozac \n", "9 I've been on \n", "10 I've been taking this pill \n", "11 Was put on 30mg. \n", "12 I was addicted to \n", "13 I've been on approximately \n", "14 Super good for depression. \n", "15 Hyped me up like crazy \n", "16 I am a 32 \n", "17 My only complaint is that I \n", "18 I've been on \n", "19 Lexapro has been \n", "\n", " response (before) \\\n", "0 Very Very good. Helps to deal with some of the... \n", "1 a month. The nausea is gone and I no longer f... \n", "2 mg. It seems to help somewhat with some of t... \n", "3 year old woman with severe anxiety and depres... \n", "4 zac. I would sleep all night, feel sick one d... \n", "5 drastically. A year after taking an XL I bec... \n", "6 at first, because some things that seem so go... \n", "7 onazepam for 4 years and that caused much worse \n", "8 for 11 years now and still getting worse as t... \n", "9 Pristiq for 1 week now. The first 1 day was h... \n", "10 for 6 months and I feel 100% better after just \n", "11 Still depressed the first 3 months. Now I'm o... \n", "12 opiates for years when I was younger and I co... \n", "13 10 mg for six months and it's taking away som... \n", "14 Super good for depression. Diarrhea is better... \n", "15 , as did my MGs who said they thought this was... \n", "16 Y male and have been on up to 300 mg of zoloft... \n", "17 complaint is that I experience extreme foggy ... \n", "18 this 300mg for just a little over a year and ... \n", "19 better than anything I've tried to battle my ... \n", "\n", " response (after) rewards (before) \\\n", "0 me with extreme depression and anxiety. Can h... -1.911440 \n", "1 6 months and I feel so much better that I've ... 1.584692 \n", "2 mg for 4 days now...I feel great. I've been si... -3.230841 \n", "3 year old mother of two and a mother of two da... 0.613143 \n", "4 I got on Prozac for about two years. Prozac re... -3.422938 \n", "5 from an excited procrastinator! I feel like I... -3.109789 \n", "6 Good Med! Little skeptical & SEXy!<|endoftext|> -3.691891 \n", "7 onazepam and 5mg prozac at the same -1.095920 \n", "8 for 13 years. I have never felt this good in ... -1.745274 \n", "9 I've been on this for over three months, I fee... -0.377317 \n", "10 for almost 14 years and I took I F 3 times 1.914182 \n", "11 It made me feel much better! I could finally ... -3.148672 \n", "12 SSRI's, hallucinogens, and drug- just therapy... -3.175338 \n", "13 two dozen different SSRI's over the past thir... -3.352825 \n", "14 Super good for depression. I am 5'1 and I have... 1.200642 \n", "15 before and I'm a great mom now! I work as an ... -4.075667 \n", "16 yr old male and live in a very nice suburb (wi... 0.437383 \n", "17 feel fatigued during the day. And also depres... -4.083458 \n", "18 I've been on this anti-depressant for over 18 ... -0.072659 \n", "19 Lexapro has been AMAZING, so far. I've been on... -1.699961 \n", "\n", " rewards (after) \n", "0 2.428029 \n", "1 2.306569 \n", "2 2.124902 \n", "3 2.237940 \n", "4 2.157211 \n", "5 0.991710 \n", "6 -1.087502 \n", "7 1.416539 \n", "8 2.083820 \n", "9 2.041740 \n", "10 2.104877 \n", "11 2.363128 \n", "12 0.608868 \n", "13 2.426083 \n", "14 1.900504 \n", "15 1.796120 \n", "16 2.238822 \n", "17 -3.843148 \n", "18 0.724089 \n", "19 1.292072 " ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
queryresponse (before)response (after)rewards (before)rewards (after)
0Very Very good. HelpsVery Very good. Helps to deal with some of the...me with extreme depression and anxiety. Can h...-1.9114402.428029
1It worked for abouta month. The nausea is gone and I no longer f...6 months and I feel so much better that I've ...1.5846922.306569
2Started on 20mg. It seems to help somewhat with some of t...mg for 4 days now...I feel great. I've been si...-3.2308412.124902
3I am a 43year old woman with severe anxiety and depres...year old mother of two and a mother of two da...0.6131432.237940
4I got on Prozac. I would sleep all night, feel sick one d...I got on Prozac for about two years. Prozac re...-3.4229382.157211
5This drug has changed medrastically. A year after taking an XL I bec...from an excited procrastinator! I feel like I...-3.1097890.991710
6Good Med! Little skepticalat first, because some things that seem so go...Good Med! Little skeptical & SEXy!<|endoftext|>-3.691891-1.087502
7I used to take clonazepam for 4 years and that caused much worseonazepam and 5mg prozac at the same-1.0959201.416539
8Been on Prozacfor 11 years now and still getting worse as t...for 13 years. I have never felt this good in ...-1.7452742.083820
9I've been onPristiq for 1 week now. The first 1 day was h...I've been on this for over three months, I fee...-0.3773172.041740
10I've been taking this pillfor 6 months and I feel 100% better after justfor almost 14 years and I took I F 3 times1.9141822.104877
11Was put on 30mg.Still depressed the first 3 months. Now I'm o...It made me feel much better! I could finally ...-3.1486722.363128
12I was addicted toopiates for years when I was younger and I co...SSRI's, hallucinogens, and drug- just therapy...-3.1753380.608868
13I've been on approximately10 mg for six months and it's taking away som...two dozen different SSRI's over the past thir...-3.3528252.426083
14Super good for depression.Super good for depression. Diarrhea is better...Super good for depression. I am 5'1 and I have...1.2006421.900504
15Hyped me up like crazy, as did my MGs who said they thought this was...before and I'm a great mom now! I work as an ...-4.0756671.796120
16I am a 32Y male and have been on up to 300 mg of zoloft...yr old male and live in a very nice suburb (wi...0.4373832.238822
17My only complaint is that Icomplaint is that I experience extreme foggy ...feel fatigued during the day. And also depres...-4.083458-3.843148
18I've been onthis 300mg for just a little over a year and ...I've been on this anti-depressant for over 18 ...-0.0726590.724089
19Lexapro has beenbetter than anything I've tried to battle my ...Lexapro has been AMAZING, so far. I've been on...-1.6999611.292072
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", " \n", " \n", " \n", "
\n", "\n", "
\n", "
\n" ] }, "metadata": {}, "execution_count": 27 } ], "source": [ "output_min_length = 10\n", "output_max_length = 50\n", "output_length_sampler = LengthSampler(output_min_length, output_max_length)\n", "\n", "#### Get a batch from the dataset\n", "bs = 20\n", "game_data = dict()\n", "dataset.set_format(\"pandas\")\n", "df_batch = dataset[:].sample(bs)\n", "game_data[\"query\"] = df_batch[\"query\"].tolist()\n", "query_tensors = df_batch[\"input_ids\"].tolist()\n", "\n", "response_tensors_ref, response_tensors = [], []\n", "\n", "#### Get response from GPT2 and GPT2_REF\n", "for i in range(bs):\n", " gen_len = output_length_sampler()\n", "\n", " # Update generation_kwargs with the dynamic max_new_tokens value\n", " dynamic_generation_kwargs = generation_kwargs.copy()\n", " dynamic_generation_kwargs['max_new_tokens'] = gen_len\n", "\n", " output = ref_model.generate(\n", " torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), **dynamic_generation_kwargs\n", " ).squeeze()[-gen_len:]\n", " response_tensors_ref.append(output)\n", "\n", " output = model.generate(\n", " torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), **dynamic_generation_kwargs\n", " ).squeeze()[-gen_len:]\n", " response_tensors.append(output)\n", "\n", "#### Decode responses\n", "game_data[\"response (before)\"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]\n", "game_data[\"response (after)\"] = [tokenizer.decode(response_tensors[i]) for i in range(bs)]\n", "\n", "#### Sentiment analysis of query/response pairs before/after\n", "texts = [q + r for q, r in zip(game_data[\"query\"], game_data[\"response (before)\"])]\n", "game_data[\"rewards (before)\"] = [output[1][\"score\"] for output in sentiment_pipe(texts, **sent_kwargs)]\n", "\n", "texts = [q + r for q, r in zip(game_data[\"query\"], game_data[\"response (after)\"])]\n", "game_data[\"rewards (after)\"] = [output[1][\"score\"] for output in sentiment_pipe(texts, **sent_kwargs)]\n", "\n", "# Store results in a dataframe\n", "df_results = pd.DataFrame(game_data)\n", "df_results" ] }, { "cell_type": "markdown", "source": [ "Looking at the reward mean/median of the generated sequences we observe a significant difference." ], "metadata": { "id": "UCZqr9bhPgQ8" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0xiQcEO_tcT8", "colab": { "base_uri": "https://localhost:8080/", "height": 173 }, "outputId": "da58cc66-31a6-46e7-bddb-5a2ac4b0ca7a" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "mean:\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "rewards (before) -1.622197\n", "rewards (after) 1.415619\n", "dtype: float64" ] }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "\n", "median:\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "rewards (before) -1.828357\n", "rewards (after) 2.062780\n", "dtype: float64" ] }, "metadata": {} } ], "source": [ "print(\"mean:\")\n", "display(df_results[[\"rewards (before)\", \"rewards (after)\"]].mean())\n", "print()\n", "print(\"median:\")\n", "display(df_results[[\"rewards (before)\", \"rewards (after)\"]].median())" ] }, { "cell_type": "markdown", "source": [ "## STEP 8: SAVE THE FINE-TUNED GPT-2 MODEL: gpt2-drugscom_depression_reviews-hq-v1" ], "metadata": { "id": "77DlUMK_PlWv" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1aKcZ0sWtlQq" }, "outputs": [], "source": [ "from huggingface_hub import notebook_login # To log to our Hugging Face account to be able to upload models to the Hub." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sEQVHh8ztnlC", "colab": { "base_uri": "https://localhost:8080/", "height": 484, "referenced_widgets": [ "06614bb33a674d7c9529dba7237e29f2", "effe58d07862430c9f14fb6a8f0d9339", "a71cb398750c4c5691b30243f7a80c80", "cbaa0c4d6c6849f98e4b631a9611227b", "bff1b054f9b64676816ebb84827d9d77", "755ee2ee027848c79dfbb836b87e4968", "26bb5803b0ce442d91cc36a96d412372", "d53284c8dcd14b9390aad1c8f2baa51b", "e3879a4fe4914aa3a3eb81977b8d85bf", "41d0df2831114687937cd8c601006b77", "0e6fcb61bf02408eb6bdfaa00bc35f4c", "ec3e0454cb814b1991711056b51bf3f1", "62b05b7dec9e40229836f0c7125963e4", "a3d025fee186451a83fb8c69d59efca2", "350521650a3d46ce91c80a74036e6522", "4ab9212ce7e945dfa714e212c3203c79", "1478c65fa2db433e85479e80ca0edf1a", "c4c301ccd6184a8393eaf9b7200b8fc0", "7b49a91fabcc4ccab7a53bf90d2ff3ee", "ef5ca4e85fd84d66b416440fc4e93a1a", "ee37eb8f5a69425c929dde0d945b8e1c", "cbbeb89e584246a2b79ea9bb3896b766", "97fbd74292ea4250aae33e90fc9c5f12", "b0afc5b4436348d1a12265337cebeb6d", "d73d45eafd6e45a8b9bcd9f69863820f", "5293daa1efc64a18a1d421ce3c922932", "d81a469ab0e341818e0ba77539c6e9ae", "0360660e0ca243318cc19d79aca7d469", "585b9c1d27b947c682228091f6dbf13f", "f4d1be3572a04a6bb8d29e0026ec13b5", "1b643386556b44e4a244407165336202", "2ede2fedf6474bea9ae143c670fd10fc" ] }, "outputId": "73dfc078-35c1-4caa-ae45-d0e41cda6885" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "VBox(children=(HTML(value='
\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Login to Hugging Face within the notebook\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mnotebook_login\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msystem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'git config --global credential.helper store'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/google/colab/_shell.py\u001b[0m in \u001b[0;36msystem\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m'also_return_output'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_system_commands\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_system_compat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# pylint:disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mpip_warn\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/google/colab/_system_commands.py\u001b[0m in \u001b[0;36m_system_compat\u001b[0;34m(shell, cmd, also_return_output)\u001b[0m\n\u001b[1;32m 452\u001b[0m \u001b[0;31m# is expected to call this function, thus adding one level of nesting to the\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 453\u001b[0m \u001b[0;31m# stack.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 454\u001b[0;31m result = _run_command(\n\u001b[0m\u001b[1;32m 455\u001b[0m \u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvar_expand\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcmd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdepth\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclear_streamed_output\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 456\u001b[0m )\n", "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/google/colab/_system_commands.py\u001b[0m in \u001b[0;36m_run_command\u001b[0;34m(cmd, clear_streamed_output)\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0mlocale_encoding\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlocale\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetpreferredencoding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlocale_encoding\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0m_ENCODING\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 168\u001b[0;31m raise NotImplementedError(\n\u001b[0m\u001b[1;32m 169\u001b[0m \u001b[0;34m'A UTF-8 locale is required. Got {}'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlocale_encoding\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 170\u001b[0m )\n", "\u001b[0;31mNotImplementedError\u001b[0m: A UTF-8 locale is required. Got ANSI_X3.4-1968" ] } ], "source": [ "# Login to Hugging Face within the notebook\n", "notebook_login()\n", "!git config --global credential.helper store" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4BMQTQlJtqpj", "colab": { "base_uri": "https://localhost:8080/", "height": 153, "referenced_widgets": [ "1faadbaa1c784c73832ddde87d1e1ff6", "8c212475c37f4dcd9cd3dc343a17ea40", "dbd51cfbd3f245ccb8abe6d87389044e", "b103dcf2c0fd4b6dac63ad514f1f04f6", "dad9d0e3faa64c6098160e5fa56ed77f", "9a553633ac7445618dfea4ba6711ab73", "688092b1d169419aa3f2312a66ac728e", "f299507a827044d19f888e5f85a65f7f", "81827f9769ad47aca5b3fbe23e8d898b", "046b4f8ae7d241cb898d5629d4530f07", "7e74fd2b373144d9be3dd744d4314256" ] }, "outputId": "44dfdabf-a968-4a15-ba94-013669b0ca97" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "model.safetensors: 0%| | 0.00/498M [00:00
Copy a token from your Hugging Face\ntokens page and paste it below.
Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file.
" } }, "a71cb398750c4c5691b30243f7a80c80": { "model_module": "@jupyter-widgets/controls", "model_name": "PasswordModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "PasswordModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "PasswordView", "continuous_update": true, "description": "Token:", "description_tooltip": null, "disabled": false, "layout": "IPY_MODEL_41d0df2831114687937cd8c601006b77", "placeholder": "​", "style": "IPY_MODEL_0e6fcb61bf02408eb6bdfaa00bc35f4c", "value": "" } }, "cbaa0c4d6c6849f98e4b631a9611227b": { "model_module": "@jupyter-widgets/controls", "model_name": "CheckboxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "CheckboxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "CheckboxView", "description": "Add token as git credential?", "description_tooltip": null, "disabled": false, "indent": true, "layout": "IPY_MODEL_ec3e0454cb814b1991711056b51bf3f1", "style": "IPY_MODEL_62b05b7dec9e40229836f0c7125963e4", "value": true } }, "bff1b054f9b64676816ebb84827d9d77": { "model_module": "@jupyter-widgets/controls", "model_name": "ButtonModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ButtonModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ButtonView", "button_style": "", "description": "Login", "disabled": false, "icon": "", "layout": "IPY_MODEL_a3d025fee186451a83fb8c69d59efca2", "style": "IPY_MODEL_350521650a3d46ce91c80a74036e6522", "tooltip": "" } }, "755ee2ee027848c79dfbb836b87e4968": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_4ab9212ce7e945dfa714e212c3203c79", "placeholder": "​", "style": "IPY_MODEL_1478c65fa2db433e85479e80ca0edf1a", "value": "\nPro Tip: If you don't already have one, you can create a dedicated\n'notebooks' token with 'write' access, that you can then easily reuse for all\nnotebooks. " } }, "26bb5803b0ce442d91cc36a96d412372": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": "center", "align_self": null, "border": null, "bottom": null, "display": "flex", "flex": null, "flex_flow": "column", "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": "50%" } }, "d53284c8dcd14b9390aad1c8f2baa51b": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e3879a4fe4914aa3a3eb81977b8d85bf": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "41d0df2831114687937cd8c601006b77": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "0e6fcb61bf02408eb6bdfaa00bc35f4c": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "ec3e0454cb814b1991711056b51bf3f1": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "62b05b7dec9e40229836f0c7125963e4": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "a3d025fee186451a83fb8c69d59efca2": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "350521650a3d46ce91c80a74036e6522": { "model_module": "@jupyter-widgets/controls", "model_name": "ButtonStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ButtonStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "button_color": null, "font_weight": "" } }, "4ab9212ce7e945dfa714e212c3203c79": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "1478c65fa2db433e85479e80ca0edf1a": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "c4c301ccd6184a8393eaf9b7200b8fc0": { "model_module": "@jupyter-widgets/controls", "model_name": "LabelModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "LabelModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "LabelView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_7b49a91fabcc4ccab7a53bf90d2ff3ee", "placeholder": "​", "style": "IPY_MODEL_ef5ca4e85fd84d66b416440fc4e93a1a", "value": "Connecting..." } }, "7b49a91fabcc4ccab7a53bf90d2ff3ee": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ef5ca4e85fd84d66b416440fc4e93a1a": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "ee37eb8f5a69425c929dde0d945b8e1c": { "model_module": "@jupyter-widgets/controls", "model_name": "LabelModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "LabelModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "LabelView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_d73d45eafd6e45a8b9bcd9f69863820f", "placeholder": "​", "style": "IPY_MODEL_5293daa1efc64a18a1d421ce3c922932", "value": "Token is valid (permission: write)." } }, "cbbeb89e584246a2b79ea9bb3896b766": { "model_module": "@jupyter-widgets/controls", "model_name": "LabelModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "LabelModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "LabelView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_d81a469ab0e341818e0ba77539c6e9ae", "placeholder": "​", "style": "IPY_MODEL_0360660e0ca243318cc19d79aca7d469", "value": "Your token has been saved in your configured git credential helpers (store)." } }, "97fbd74292ea4250aae33e90fc9c5f12": { "model_module": "@jupyter-widgets/controls", "model_name": "LabelModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "LabelModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "LabelView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_585b9c1d27b947c682228091f6dbf13f", "placeholder": "​", "style": "IPY_MODEL_f4d1be3572a04a6bb8d29e0026ec13b5", "value": "Your token has been saved to /root/.cache/huggingface/token" } }, "b0afc5b4436348d1a12265337cebeb6d": { "model_module": "@jupyter-widgets/controls", "model_name": "LabelModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "LabelModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "LabelView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_1b643386556b44e4a244407165336202", "placeholder": "​", "style": "IPY_MODEL_2ede2fedf6474bea9ae143c670fd10fc", "value": "Login successful" } }, "d73d45eafd6e45a8b9bcd9f69863820f": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "5293daa1efc64a18a1d421ce3c922932": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "d81a469ab0e341818e0ba77539c6e9ae": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "0360660e0ca243318cc19d79aca7d469": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "585b9c1d27b947c682228091f6dbf13f": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f4d1be3572a04a6bb8d29e0026ec13b5": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "1b643386556b44e4a244407165336202": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2ede2fedf6474bea9ae143c670fd10fc": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "1faadbaa1c784c73832ddde87d1e1ff6": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_8c212475c37f4dcd9cd3dc343a17ea40", "IPY_MODEL_dbd51cfbd3f245ccb8abe6d87389044e", "IPY_MODEL_b103dcf2c0fd4b6dac63ad514f1f04f6" ], "layout": "IPY_MODEL_dad9d0e3faa64c6098160e5fa56ed77f" } }, "8c212475c37f4dcd9cd3dc343a17ea40": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_9a553633ac7445618dfea4ba6711ab73", "placeholder": "​", "style": "IPY_MODEL_688092b1d169419aa3f2312a66ac728e", "value": "model.safetensors: 100%" } }, "dbd51cfbd3f245ccb8abe6d87389044e": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_f299507a827044d19f888e5f85a65f7f", "max": 497777468, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_81827f9769ad47aca5b3fbe23e8d898b", "value": 497777468 } }, "b103dcf2c0fd4b6dac63ad514f1f04f6": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_046b4f8ae7d241cb898d5629d4530f07", "placeholder": "​", "style": "IPY_MODEL_7e74fd2b373144d9be3dd744d4314256", "value": " 498M/498M [00:11<00:00, 55.6MB/s]" } }, "dad9d0e3faa64c6098160e5fa56ed77f": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "9a553633ac7445618dfea4ba6711ab73": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "688092b1d169419aa3f2312a66ac728e": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "f299507a827044d19f888e5f85a65f7f": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "81827f9769ad47aca5b3fbe23e8d898b": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "046b4f8ae7d241cb898d5629d4530f07": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7e74fd2b373144d9be3dd744d4314256": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } } } } }, "nbformat": 4, "nbformat_minor": 0 }