{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y7Wdw0O6o2xa" }, "outputs": [], "source": [ "# FineTuning code " ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6-t3HhxN8joX", "outputId": "e570e7a5-fbb1-423b-8d42-f922319fc83c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n", "Requirement already satisfied: pip in /usr/local/lib/python3.10/dist-packages (24.3.1)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", "\u001b[0mLooking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n", "Requirement already satisfied: transformers==4.46.3 in /usr/local/lib/python3.10/dist-packages (4.46.3)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers==4.46.3) (3.12.2)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers==4.46.3) (0.26.2)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.46.3) (1.26.4)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.46.3) (24.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.46.3) (6.0.2)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.46.3) (2023.6.3)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers==4.46.3) (2.32.3)\n", "Requirement already satisfied: tokenizers<0.21,>=0.20 in /usr/local/lib/python3.10/dist-packages (from transformers==4.46.3) (0.20.3)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.46.3) (0.4.5)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers==4.46.3) (4.67.1)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers==4.46.3) (2023.6.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers==4.46.3) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.46.3) (3.1.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.46.3) (3.4)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.46.3) (1.26.16)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.46.3) (2023.5.7)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", "\u001b[0mLooking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n", "Requirement already satisfied: bitsandbytes in /usr/local/lib/python3.10/dist-packages (0.45.0)\n", "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from bitsandbytes) (2.5.1)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from bitsandbytes) (1.26.4)\n", "Requirement already satisfied: typing_extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from bitsandbytes) (4.12.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (3.12.2)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (3.4.2)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (3.1.2)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (2023.6.0)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (12.4.127)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (12.4.5.8)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (11.2.1.3)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (10.3.5.147)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (11.6.1.9)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (12.3.1.170)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (12.4.127)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (12.4.127)\n", "Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (3.1.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch->bitsandbytes) (1.3.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->bitsandbytes) (2.1.3)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", "\u001b[0mLooking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n", "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (1.2.1)\n", "Requirement already satisfied: numpy<3.0.0,>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate) (1.26.4)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (24.2)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.4)\n", "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate) (6.0.2)\n", "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.5.1)\n", "Requirement already satisfied: huggingface-hub>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.26.2)\n", "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.4.5)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (3.12.2)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (2023.6.0)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (4.67.1)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (4.12.2)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.4.2)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.2)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.4.127)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.4.5.8)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (11.2.1.3)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (10.3.5.147)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (11.6.1.9)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.3.1.170)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.4.127)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.4.127)\n", "Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch>=1.10.0->accelerate) (1.3.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.1.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.4)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (1.26.16)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2023.5.7)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", "\u001b[0mLooking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n", "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (3.2.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.12.2)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n", "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (18.0.0)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.2)\n", "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.67.1)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.5.0)\n", "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n", "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2023.6.0)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.11.7)\n", "Requirement already satisfied: huggingface-hub>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.26.2)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n", "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.3)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.2)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.3)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n", "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (0.2.0)\n", "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.18.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.23.0->datasets) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.1.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.4)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (1.26.16)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2023.5.7)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", "\u001b[0mLooking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n", "Requirement already satisfied: peft in /usr/local/lib/python3.10/dist-packages (0.14.0)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from peft) (1.26.4)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from peft) (24.2)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.4)\n", "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from peft) (6.0.2)\n", "Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.5.1)\n", "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from peft) (4.46.3)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from peft) (4.67.1)\n", "Requirement already satisfied: accelerate>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from peft) (1.2.1)\n", "Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from peft) (0.4.5)\n", "Requirement already satisfied: huggingface-hub>=0.25.0 in /usr/local/lib/python3.10/dist-packages (from peft) (0.26.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.25.0->peft) (3.12.2)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.25.0->peft) (2023.6.0)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.25.0->peft) (2.32.3)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.25.0->peft) (4.12.2)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.4.2)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.2)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.4.127)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.4.5.8)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (11.2.1.3)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (10.3.5.147)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (11.6.1.9)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.3.1.170)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.4.127)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.4.127)\n", "Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch>=1.13.0->peft) (1.3.0)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2023.6.3)\n", "Requirement already satisfied: tokenizers<0.21,>=0.20 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (0.20.3)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (2.1.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.25.0->peft) (3.1.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.25.0->peft) (3.4)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.25.0->peft) (1.26.16)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.25.0->peft) (2023.5.7)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", "\u001b[0mLooking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n", "Requirement already satisfied: trl==0.12.2 in /usr/local/lib/python3.10/dist-packages (0.12.2)\n", "Requirement already satisfied: accelerate>=0.34.0 in /usr/local/lib/python3.10/dist-packages (from trl==0.12.2) (1.2.1)\n", "Requirement already satisfied: datasets>=2.21.0 in /usr/local/lib/python3.10/dist-packages (from trl==0.12.2) (3.2.0)\n", "Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from trl==0.12.2) (13.9.4)\n", "Requirement already satisfied: transformers<4.47.0 in /usr/local/lib/python3.10/dist-packages (from trl==0.12.2) (4.46.3)\n", "Requirement already satisfied: numpy<3.0.0,>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.34.0->trl==0.12.2) (1.26.4)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.34.0->trl==0.12.2) (24.2)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.34.0->trl==0.12.2) (5.9.4)\n", "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.34.0->trl==0.12.2) (6.0.2)\n", "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.34.0->trl==0.12.2) (2.5.1)\n", "Requirement already satisfied: huggingface-hub>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.34.0->trl==0.12.2) (0.26.2)\n", "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.34.0->trl==0.12.2) (0.4.5)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets>=2.21.0->trl==0.12.2) (3.12.2)\n", "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.21.0->trl==0.12.2) (18.0.0)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.21.0->trl==0.12.2) (0.3.8)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets>=2.21.0->trl==0.12.2) (1.5.2)\n", "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.21.0->trl==0.12.2) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.21.0->trl==0.12.2) (4.67.1)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets>=2.21.0->trl==0.12.2) (3.5.0)\n", "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.21.0->trl==0.12.2) (0.70.16)\n", "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets>=2.21.0->trl==0.12.2) (2023.6.0)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets>=2.21.0->trl==0.12.2) (3.11.7)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers<4.47.0->trl==0.12.2) (2023.6.3)\n", "Requirement already satisfied: tokenizers<0.21,>=0.20 in /usr/local/lib/python3.10/dist-packages (from transformers<4.47.0->trl==0.12.2) (0.20.3)\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->trl==0.12.2) (3.0.0)\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->trl==0.12.2) (2.15.1)\n", "Requirement already satisfied: typing-extensions<5.0,>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from rich->trl==0.12.2) (4.12.2)\n", "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.21.0->trl==0.12.2) (2.4.3)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.21.0->trl==0.12.2) (1.3.1)\n", "Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.21.0->trl==0.12.2) (4.0.2)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.21.0->trl==0.12.2) (24.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.21.0->trl==0.12.2) (1.3.3)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.21.0->trl==0.12.2) (6.0.4)\n", "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.21.0->trl==0.12.2) (0.2.0)\n", "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.21.0->trl==0.12.2) (1.18.0)\n", "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->trl==0.12.2) (0.1.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets>=2.21.0->trl==0.12.2) (3.1.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets>=2.21.0->trl==0.12.2) (3.4)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets>=2.21.0->trl==0.12.2) (1.26.16)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets>=2.21.0->trl==0.12.2) (2023.5.7)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.34.0->trl==0.12.2) (3.4.2)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.34.0->trl==0.12.2) (3.1.2)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.34.0->trl==0.12.2) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.34.0->trl==0.12.2) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.34.0->trl==0.12.2) (12.4.127)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.34.0->trl==0.12.2) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.34.0->trl==0.12.2) (12.4.5.8)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.34.0->trl==0.12.2) (11.2.1.3)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.34.0->trl==0.12.2) (10.3.5.147)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.34.0->trl==0.12.2) (11.6.1.9)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.34.0->trl==0.12.2) (12.3.1.170)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.34.0->trl==0.12.2) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.34.0->trl==0.12.2) (12.4.127)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.34.0->trl==0.12.2) (12.4.127)\n", "Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.34.0->trl==0.12.2) (3.1.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.34.0->trl==0.12.2) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch>=1.10.0->accelerate>=0.34.0->trl==0.12.2) (1.3.0)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets>=2.21.0->trl==0.12.2) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets>=2.21.0->trl==0.12.2) (2023.3)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets>=2.21.0->trl==0.12.2) (1.16.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate>=0.34.0->trl==0.12.2) (2.1.3)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", "\u001b[0mLooking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n", "Requirement already satisfied: wandb in /usr/local/lib/python3.10/dist-packages (0.19.1)\n", "Requirement already satisfied: click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb) (8.1.7)\n", "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (0.4.0)\n", "Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.1.43)\n", "Requirement already satisfied: platformdirs in /usr/local/lib/python3.10/dist-packages (from wandb) (3.8.0)\n", "Requirement already satisfied: protobuf!=4.21.0,!=5.28.0,<6,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.20.3)\n", "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (5.9.4)\n", "Requirement already satisfied: pydantic<3,>=2.6 in /usr/local/lib/python3.10/dist-packages (from wandb) (2.9.2)\n", "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from wandb) (6.0.2)\n", "Requirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (2.32.3)\n", "Requirement already satisfied: sentry-sdk>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (2.19.0)\n", "Requirement already satisfied: setproctitle in /usr/local/lib/python3.10/dist-packages (from wandb) (1.3.4)\n", "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb) (68.0.0)\n", "Requirement already satisfied: typing-extensions<5,>=4.4 in /usr/local/lib/python3.10/dist-packages (from wandb) (4.12.2)\n", "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n", "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from gitpython!=3.1.29,>=1.0.0->wandb) (4.0.11)\n", "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=2.6->wandb) (0.7.0)\n", "Requirement already satisfied: pydantic-core==2.23.4 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=2.6->wandb) (2.23.4)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.1.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.4)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (1.26.16)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2023.5.7)\n", "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb) (5.0.1)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", "\u001b[0mLooking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n", "Requirement already satisfied: ipywidgets in /usr/local/lib/python3.10/dist-packages (8.1.5)\n", "Requirement already satisfied: comm>=0.1.3 in /usr/local/lib/python3.10/dist-packages (from ipywidgets) (0.1.3)\n", "Requirement already satisfied: ipython>=6.1.0 in /usr/local/lib/python3.10/dist-packages (from ipywidgets) (8.14.0)\n", "Requirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.10/dist-packages (from ipywidgets) (5.9.0)\n", "Requirement already satisfied: widgetsnbextension~=4.0.12 in /usr/local/lib/python3.10/dist-packages (from ipywidgets) (4.0.13)\n", "Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.10/dist-packages (from ipywidgets) (3.0.13)\n", "Requirement already satisfied: backcall in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets) (0.2.0)\n", "Requirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets) (5.1.1)\n", "Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets) (0.18.2)\n", "Requirement already satisfied: matplotlib-inline in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets) (0.1.6)\n", "Requirement already satisfied: pickleshare in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets) (0.7.5)\n", "Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets) (3.0.39)\n", "Requirement already satisfied: pygments>=2.4.0 in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets) (2.15.1)\n", "Requirement already satisfied: stack-data in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets) (0.6.2)\n", "Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets) (4.8.0)\n", "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /usr/local/lib/python3.10/dist-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets) (0.8.3)\n", "Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.10/dist-packages (from pexpect>4.3->ipython>=6.1.0->ipywidgets) (0.7.0)\n", "Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.1.0->ipywidgets) (0.2.6)\n", "Requirement already satisfied: executing>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from stack-data->ipython>=6.1.0->ipywidgets) (1.2.0)\n", "Requirement already satisfied: asttokens>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.2.1)\n", "Requirement already satisfied: pure-eval in /usr/local/lib/python3.10/dist-packages (from stack-data->ipython>=6.1.0->ipywidgets) (0.2.2)\n", "Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from asttokens>=2.1.0->stack-data->ipython>=6.1.0->ipywidgets) (1.16.0)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "# python 3.10.12\n", "!pip install -U pip\n", "!pip install transformers==4.46.3\n", "!pip install -U bitsandbytes\n", "!pip install -U accelerate\n", "!pip install -U datasets\n", "!pip install -U peft\n", "!pip install trl==0.12.2\n", "!pip install -U wandb\n", "!pip install ipywidgets --upgrade" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "LbmtYWUH8p_J" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4.46.3\n", "0.12.2\n" ] } ], "source": [ "from transformers import (\n", " AutoModelForCausalLM,\n", " AutoTokenizer,\n", " BitsAndBytesConfig,\n", " TrainingArguments,\n", " logging,\n", ")\n", "from peft import (\n", " LoraConfig,\n", " PeftModel,\n", " get_peft_model,\n", ")\n", "import os, torch, gc\n", "from datasets import load_dataset\n", "import bitsandbytes as bnb\n", "from trl import SFTTrainer\n", "import wandb\n", "\n", "# version check\n", "import transformers, trl\n", "print(transformers.__version__)\n", "print(trl.__version__)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "WAaS0RKXgG72" }, "outputs": [], "source": [ "# Hugging Face Token\n", "HF_TOKEN = \"Your Token\" #\"write権限のあるトークン\"\n", "\n", "# WANDB Token\n", "WB_TOKEN = \"Your Token\"" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33maoikazama\u001b[0m (\u001b[33mweblab-geniac-leaderboard\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.19.1" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /workspace/wandb/run-20241216_173153-f39d2amo" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run competition to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/weblab-geniac-leaderboard/llm-jp-3-13b-finetune" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/weblab-geniac-leaderboard/llm-jp-3-13b-finetune/runs/f39d2amo" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wandb.login(key=WB_TOKEN)\n", "\n", "wandb.init(\n", " # set the wandb project where this run will be logged\n", " project=\"llm-jp-3-13b-finetune\",\n", " name=\"competition\",\n", " entity=None\n", " # track hyperparameters and run metadata\n", "# config={\n", "# \"learning_rate\": 0.02,\n", "# \"architecture\": \"CNN\",\n", "# \"dataset\": \"CIFAR-100\",\n", "# \"epochs\": 10,\n", "# }\n", ")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "Oh-kvG8LQ2EZ" }, "outputs": [], "source": [ "# モデルを読み込み。\n", "base_model_id = \"models/models--llm-jp--llm-jp-3-13b/snapshots/cd3823f4c1fcbb0ad2e2af46036ab1b0ca13192a\" #Fine-Tuningするベースモデル\n", "# omnicampus以外の環境をご利用の方は以下をご利用ください。\n", "# base_model_id = \"llm-jp/llm-jp-3-13b\" \n", "new_model_id = \"llm-jp-3-13b-finetune-4bit\" #Fine-Tuningしたモデルにつけたい名前" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "HXXd9RiiQqZP" }, "outputs": [], "source": [ "\"\"\"\n", "bnb_config: 量子化の設定\n", "\n", " - load_in_4bit:\n", " - 4bit量子化形式でモデルをロード\n", "\n", " - bnb_4bit_quant_type:\n", " - 量子化の形式を指定\n", "\n", " - bnb_4bit_compute_dtype:\n", " - 量子化された重みを用いて計算する際のデータ型\n", "\n", "\"\"\"\n", "\n", "bnb_config = BitsAndBytesConfig(\n", "# load_in_8bit=True,\n", " load_in_4bit=True,\n", " bnb_4bit_quant_type=\"nf4\", # nf4は通常のINT4より精度が高く、ニューラルネットワークの分布に最適です\n", " bnb_4bit_compute_dtype=torch.bfloat16,\n", ")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 401, "referenced_widgets": [ "3fe703debbf1497caae9915117779e53", "fad1e4a67d534f7daff83ebd9c4ddcd5", "5c107055da3f4eab932c05a553b54138", "f7cb82147e3740a1b42765e2821b1406", "0df44d34108a4bc9b03e9e3d2d9c9f39", "74f02d0e9dfa4113a09fba63b91dc4d4", "14aa0af9f71f41a19d772be51ca497fc", "243e0ae2059348f6bd96b1e75d1ab10a", "a9272eb6817641ddab4215c093a188eb", "64c3605c035c4a05bd4f09fa8e57db27", "8e7ed64251944b04bacee1e4e5ed3f8f", "2cf93e5e79524a1389088ca61dece4a7", "6fb5a3e033894480a11ac1c84748f78a", "02728c0efbc1485d907e6e50e78f59e5", "905bfdabca504290bdc1650334055d01", "d191c300da464fc789584c511edfa0f4", "e88771cdaa364f3abc7c2ec435eb0f3b", "a8bf8e287d48469a88d62bd091d8c11b", "9f279e9b5fff467ca3b6e6cb118d6e50", "832dc49461d74069970fbc9e1222c6d6", "0667024824214c94b951a7dc70297d17", "d94325b5ec314b5a9f86e0f751751915", "dce159ea8d044df3a45a3e4f29a4ba72", "1f4f62895a934fb8915606f74d5a7ee5", "3990b59c86fe4b2aba553c75b1277236", "0dd3bc2741c5477584d36cf55418c6d3", "118acfc083b4451a94fd804bb072933c", "474ddb4f30184f6f85625887ca3d196e", "80bc28fd6b6a4fee981d9765c8b5cdc8", "803e25a0f6d54389b42af103e128c3dc", "3ff7a9bf9f7449e19b401cf013c6e8b1", "42cf0d0bf171456387e7673d749347b1", "1a57cd358d9340489db23102efb85ae5", "7f3e7d71089448fbbde8d2e97042a5c9", "f607dcaf1eb945749f47a6250ac045a6", "0c8c3448b89c4412999320b81d35da05", "f66cfefdaba24d2286c3ab4136a34de9", "af78632c384b4bd79c386d041e59538b", "6992a8e5871143bfb8d8780466fb873c", "2aea6652168049d6bdf7c3374d58e626", "3a6c152f32af4cc684615b5c8aff5e0a", "7f4032e451bd4ba3a6444586800c79b3", "6bb8b30b8bc043dc80081400766f6fd5", "65c1a99f4beb42f29fbd218aead6926e", "f93b0e8a93cb4bbfa23116ec2bcfaf9d", "3e6f50e9fee4422c9f9db9d154d06d7a", "9c0126041cbc46c3aa7a061f758c31fc", "16a58a0692244ca2919c4e59ac36414b", "6bcb1aa977c841eca32b5ad53cf5a831", "845dcfe9d54f411e94f86d93c0b01f31", "6cef8f1e30d9441798590fc14035ad7a", "8ad6ccc80a4143369f1db1b8c596d3e3", "738a1587d4224d4aadd7caf958aa1fa0", "aa1efa9e8b684f3c8b9dfd60fd1efa09", "241e4ea451ee4ae692bf56f6f7b6038c", "33f1f362a8ff4a54b24457e84577d3d1", "05c50d20f1a8443197f6cef646f53f2a", "db5ef37ca8ac4627b6904716b02f1886", "cac33376943145cd98117b048be98e1c", "2329b5e67a83484a83339043f15a3cbe", "c65e9a6ba744410ea9dde0483147333e", "00529ad1ecc743139470951bc9e3844b", "f89fd75f35d04155a160d7c8fd682caf", "a805ec68255d4eb7bbe86249c8c6df3a", "87da3e1776cf48e7931aba041a85658e", "8f383d44ec4f42579f5a95a7d3c42185", "ec05bd42d6ad497b8b6c80629e1510cf", "3d446fe4f0b141a0be8632914ee9fd43", "82572304ba064ea3bbdc68f3a4539b72", "0f6a46bdf781499686207e3c86492dd0", "b621e2c693bd473aace10ca918785ce9", "38d59da7d3454de1b36770e2cecfed86", "42626beb06554a91be23be268bc31d99", "62e68873093c45aa9bc25f4c4363e303", "90aa2c9da9194a66ba448a9ae45b4dd6", "1011a11ef9914fe19d6b1bc709f9fa6d", "7f6ed588fe4c4df392e483c16c1b1176", "d3af364de03f4802befe1075e1acc046", "1a97cefc463d4eb895a96dc18a8f3e56", "1fa736be8ecb454ea42b66d48e75835c", "01b4108a5b1e44528103f8174f534649", "d7a1f7044d6941d7a5f67bec1e218872", "f2313bcb866c47a496e2a4af3fea29c9", "817b06b936db4805acedfd9910a96cf9", "68b12c6beed44cc6ac4cb3d761f2d4a9", "e8eaf5c3b4b24689a169447e3046585c", "c4f3583bc2734209a98887b5cb351c19", "b4a42a693912466ea5bd97a43fd61dd4", "5e430b4726574824ba5f9edacac5ac90", "54f6a71dea464477b1b2638d01751397", "5e4c23a2b44d4251a73f945f8e19907d", "30c565e7bffd49a8b46c60875868e93c", "23a1aec80f87415e82b075ff8f4d2e9a", "7e6e242da8934733be06a683897448bc", "90ff6b7dfd5d47b1844a60fe7c43e85d", "c803aa096b5f40338e9d1252c243feb4", "08d80dc86d604066a8f0c03847915b4b", "0e3ca03f3a0244b8ab899ddbf2f7c5ac", "33927af459084290b3be6987a2ce10c9", "3dc260da4d4e49ea819be7714b945560", "af8a8cf2cf054553a0f83dc9a24fbd85", "3ea2f76d0a324ad2afcd9f3349aab629", "9948b19e079f4dc38f25d8ffd2c599cb", "d81f78eee45c45a8aea9090095d86fa5", "bcadff03c76b42909778a99d46158bfa", "07ae4654791d4ce69b6fba4f7e59591c", "6da6deb743104a89a528011146ad56f2", "1f31cca0ab024278999f9fc9dfb9197d", "504a61177ac945ac95eb1733d451926c", "808ac42289ef449aabeb0910f86f477b", "48a36332916b43829f92b02ab690bc05", "c112699bddc048bb823f1b98f0f11e88", "0d66a8d592b44637a88d3e0544112782", "41d20020ad724c9580a655f7d25495c8", "f1e09e3d03414cb0b7423ac482a7063f", "2640020b8fd94f53a7de055a66800f4f", "73aae4d1b2e44b8f88d07e0cc6b1311d", "8dd7ff468f474c8d8dd156f1cb50d59d", "30008d457d8440318c0912b450112e88", "a43b7e0343474e859c05a6104ca666cb", "e5f3be63c131463b94b44eb9ebf871bd", "fd5f0315bc7b408698047f061d8c1702", "2e7d1c67971c459a993ac305cb0e2c60", "f93ca9355a42406e941a9baa45868a5d", "d38a00663d534c36b9f18b0ea09e84f7", "08705d7203644a2e8c3d9045f2f8cc71", "b4d0e2955cd048c096769892c4149b6a", "60e22bf0e4f04f3fb570f2e1f1988136", "8fca9cc4ce594716bbfa0f146a29352b", "b83f5d8e63ca4c76844f18626f5ef3a9", "a5d6a0d0ca5946af8bfcc914729156f5", "e9a3f0df81424137ab5d7e0fd074469c" ] }, "id": "St-tJNuJQviq", "outputId": "42fc7dbb-376d-4e96-bfa5-5192fccf228e" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "287e5b777b5d4b6b811770e1b6c9bfad", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/6 [00:00\n" ] } ], "source": [ "# データを確認\n", "print(dataset[\"train\"][\"formatted_text\"][3])" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "tp9vHUYtTvly" }, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['ID', 'text', 'output', 'formatted_text'],\n", " num_rows: 6030\n", " })\n", " test: Dataset({\n", " features: ['ID', 'text', 'output', 'formatted_text'],\n", " num_rows: 671\n", " })\n", "})" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# データをtrainデータとtestデータに分割 (test_sizeの比率に)\n", "dataset = dataset[\"train\"].train_test_split(test_size=0.1)\n", "dataset" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "6gJAYhfCacf7" }, "outputs": [], "source": [ "\"\"\"\n", "training_arguments: 学習の設定\n", "\n", " - output_dir:\n", " -トレーニング後のモデルを保存するディレクトリ\n", "\n", " - per_device_train_batch_size:\n", " - デバイスごとのトレーニングバッチサイズ\n", "\n", " - per_device_\n", " _batch_size:\n", " - デバイスごとの評価バッチサイズ\n", "\n", " - gradient_accumulation_steps:\n", " - 勾配を更新する前にステップを積み重ねる回数\n", "\n", " - optim:\n", " - オプティマイザの設定\n", "\n", " - num_train_epochs:\n", " - エポック数\n", "\n", " - eval_strategy:\n", " - 評価の戦略 (\"no\"/\"steps\"/\"epoch\")\n", "\n", " - eval_steps:\n", " - eval_strategyが\"steps\"のとき、評価を行うstep間隔\n", "\n", " - logging_strategy:\n", " - ログ記録の戦略\n", "\n", " - logging_steps:\n", " - ログを出力するステップ間隔\n", "\n", " - warmup_steps:\n", " - 学習率のウォームアップステップ数\n", "\n", " - save_steps:\n", " - モデルを保存するステップ間隔\n", "\n", " - save_total_limit:\n", " - 保存しておくcheckpointの数\n", "\n", " - max_steps:\n", " - トレーニングの最大ステップ数\n", "\n", " - learning_rate:\n", " - 学習率\n", "\n", " - fp16:\n", " - 16bit浮動小数点の使用設定(第8回演習を参考にすると良いです)\n", "\n", " - bf16:\n", " - BFloat16の使用設定\n", "\n", " - group_by_length:\n", " - 入力シーケンスの長さによりバッチをグループ化 (トレーニングの効率化)\n", "\n", " - report_to:\n", " - ログの送信先 (\"wandb\"/\"tensorboard\"など)\n", "\"\"\"\n", "\n", "training_arguments = TrainingArguments(\n", " output_dir=new_model_id,\n", " per_device_train_batch_size=1,\n", " gradient_accumulation_steps=2,\n", " optim=\"paged_adamw_32bit\",\n", " num_train_epochs=2,\n", " logging_strategy=\"steps\",\n", " logging_steps=10,\n", " warmup_steps=10,\n", " save_steps=100,\n", " save_total_limit = 2,\n", " max_steps = -1,\n", " learning_rate=5e-5,\n", " fp16=False,\n", " bf16=True,\n", " seed = 1001,\n", " group_by_length=True,\n", " report_to=\"wandb\"\n", ")" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 716, "referenced_widgets": [ "78534a3b629e46e789d050ee831b36f0", "3a88aa5b8e63454296cd7d9c20a259e9", "fe1782bdb5184e038e35f15587e92a17", "4cc95ad598fc4b23896059eb93a6a20f", "0ae87864173144f28d794ab20299e8ed", "65d3374fee0743f3abf778eb335f264e", "44cbdf0fa68447e6bd1e76544fb0e325", "dcf4f14a7c534338832e563c4e4db4bc", "6087d4e855e6404fbae3fefb87c94f83", "1a1353e761034590ae75f6cc5e8eb7e0", "e99a962208eb4406902f4f1250c74df2", "50305082c58c4d9baad07b43feb5859a", "cda8cbcafb774162a296ad496321b132", "a93de36f379e477f90ddd537bc7029df", "3ec66c74ff4f477ca4b64f0b53a1ce84", "1a7869459eeb4fcdbc65a930e539a403", "528fc0fa43264c8998c53e7bd8e73c74", "1f68f4ad4b804530bab95e69bcf4300a", "ea667453bfff4a469a869666f1f3336f", "eb2207236a7f4901a53742b8383cd7e6", "93e39bbe8b9b4bc7b0233dbd6b72033c", "8fd295bbca4841adace1f06357c13b29" ] }, "id": "f3U8FUkwTx_K", "outputId": "69fcf52d-1c65-40b1-fc8d-f8c2e49e1f0c" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': max_seq_length, dataset_text_field. Will not be supported from version '0.13.0'.\n", "\n", "Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.\n", " warnings.warn(message, FutureWarning)\n", "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:300: UserWarning: You passed a `max_seq_length` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:328: UserWarning: You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cbaa010b00784c8ea9301fc1bf06c11d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/6030 [00:00\n", " \n", " \n", " [6030/6030 5:52:48, Epoch 2/2]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
101.961200
202.018200
301.944200
401.994000
501.829200
601.908400
701.927700
801.830200
901.793300
1001.872200
1101.858000
1201.873700
1301.848200
1401.818900
1501.610300
1601.748700
1701.695200
1801.766500
1901.800700
2001.797800
2101.852000
2202.021700
2301.873100
2401.845600
2501.700900
2601.787000
2701.900300
2801.801300
2901.834900
3001.769000
3101.752700
3201.699700
3301.835300
3401.904400
3501.711800
3601.820400
3701.872600
3801.791800
3901.688700
4001.648000
4101.834100
4201.704500
4301.828900
4401.772200
4501.869200
4601.853300
4701.828600
4801.782500
4901.797200
5001.554300
5101.968300
5201.819600
5301.748300
5401.672100
5501.667700
5601.740300
5701.844800
5801.700100
5901.923500
6001.602200
6101.721900
6201.845300
6301.755200
6401.871900
6501.516700
6601.660100
6701.825200
6801.887600
6901.952000
7001.821900
7101.651700
7201.784400
7301.908700
7401.583200
7501.779800
7601.776600
7701.781900
7801.735600
7901.779800
8001.892400
8101.694600
8201.948600
8301.804600
8401.712100
8501.777100
8601.892500
8701.789500
8801.629100
8901.756500
9001.671200
9101.809600
9201.814400
9301.882500
9401.875000
9501.691000
9601.694400
9701.674500
9801.839300
9901.777200
10001.662100
10101.658900
10201.728600
10301.684200
10401.658700
10501.718900
10601.796700
10701.897500
10801.756300
10901.719600
11001.867600
11101.719000
11201.688400
11301.793800
11401.791400
11501.608300
11601.682400
11701.729900
11801.799600
11901.737600
12001.648800
12101.713400
12201.892100
12301.803900
12401.773400
12501.651600
12601.670400
12701.872300
12801.692800
12901.653200
13001.794100
13101.700400
13201.837600
13301.678100
13401.712900
13501.707900
13601.700900
13701.789200
13801.674400
13901.677000
14001.688200
14101.734600
14201.768400
14301.723900
14401.847700
14501.719000
14601.767800
14701.854200
14801.641500
14901.792700
15001.653700
15101.706000
15201.676400
15301.557900
15401.736100
15501.557000
15601.728500
15701.866000
15801.702300
15901.796400
16001.582600
16101.704900
16201.813400
16301.785800
16401.701100
16501.734400
16601.861500
16701.881400
16801.767900
16901.538100
17001.550900
17101.729500
17201.705000
17301.766700
17401.614300
17501.735300
17601.793400
17701.802600
17801.799300
17901.595600
18001.536000
18101.719800
18201.718500
18301.700600
18401.666400
18501.595100
18601.668200
18701.876900
18801.751900
18901.771100
19001.654500
19101.691200
19201.967800
19301.680200
19401.729400
19501.673300
19601.740600
19701.796200
19801.667000
19901.565400
20001.627600
20101.837300
20201.796100
20301.731000
20401.629000
20501.561900
20601.812300
20701.585900
20801.748200
20901.617300
21001.739400
21101.884300
21201.806800
21301.765100
21401.830300
21501.570200
21601.744500
21701.829700
21801.654700
21901.659100
22001.545500
22101.729000
22201.717800
22301.723700
22401.688300
22501.451600
22601.767300
22701.793800
22801.764200
22901.877700
23001.528400
23101.831400
23201.842400
23301.755600
23401.823600
23501.649100
23601.788000
23701.777000
23801.661500
23901.841000
24001.497000
24101.840600
24201.729300
24301.674200
24401.667900
24501.635600
24601.774500
24701.701800
24801.687800
24901.624200
25001.714600
25101.619400
25201.788400
25301.591800
25401.589800
25501.684000
25601.713400
25701.838800
25801.673000
25901.821900
26001.729400
26101.844800
26201.715700
26301.612700
26401.731900
26501.522400
26601.666900
26701.767800
26801.556700
26901.753700
27001.343700
27101.813500
27201.776400
27301.669800
27401.658900
27501.709100
27601.753000
27701.799500
27801.723300
27901.672500
28001.729000
28101.884900
28201.759800
28301.601500
28401.621900
28501.640300
28601.851100
28701.697500
28801.733300
28901.825100
29001.670700
29101.594800
29201.740700
29301.679700
29401.707700
29501.494000
29601.802000
29701.703400
29801.650200
29901.715800
30001.746600
30101.636800
30201.581000
30301.838800
30401.473300
30501.555600
30601.371000
30701.365600
30801.564100
30901.515400
31001.591800
31101.248400
31201.434400
31301.647700
31401.592800
31501.428100
31601.308500
31701.501500
31801.441800
31901.543400
32001.538800
32101.310700
32201.401000
32301.642300
32401.674600
32501.439600
32601.314200
32701.433500
32801.639200
32901.617100
33001.594900
33101.227400
33201.401700
33301.691800
33401.648200
33501.446100
33601.328200
33701.400100
33801.674800
33901.439000
34001.471400
34101.264400
34201.451200
34301.695500
34401.476300
34501.334200
34601.146900
34701.206900
34801.646800
34901.527800
35001.325900
35101.383000
35201.422100
35301.675800
35401.593200
35501.451000
35601.318200
35701.317100
35801.646500
35901.425100
36001.492900
36101.374500
36201.366200
36301.645000
36401.498900
36501.419800
36601.365800
36701.386000
36801.630000
36901.610100
37001.444900
37101.358000
37201.309800
37301.676100
37401.500100
37501.356100
37601.173800
37701.240900
37801.617100
37901.408100
38001.497600
38101.469900
38201.338600
38301.414600
38401.403600
38501.474200
38601.174500
38701.353100
38801.700900
38901.702500
39001.335600
39101.339500
39201.539900
39301.705500
39401.523600
39501.427000
39601.283300
39701.233300
39801.701500
39901.599200
40001.455500
40101.127700
40201.341700
40301.620000
40401.666100
40501.263000
40601.485300
40701.515200
40801.675900
40901.496800
41001.472200
41101.456300
41201.518100
41301.567500
41401.452800
41501.441000
41601.321000
41701.391200
41801.677700
41901.518500
42001.336500
42101.259500
42201.260200
42301.538300
42401.489100
42501.293700
42601.300800
42701.350100
42801.590800
42901.500700
43001.474800
43101.297000
43201.349600
43301.582200
43401.664700
43501.290600
43601.387700
43701.459100
43801.558800
43901.661700
44001.438600
44101.240200
44201.332500
44301.549900
44401.349600
44501.423000
44601.392300
44701.295400
44801.644600
44901.553700
45001.286200
45101.198400
45201.241300
45301.563900
45401.535700
45501.346900
45601.211700
45701.248500
45801.656100
45901.600000
46001.442000
46101.294100
46201.346000
46301.664200
46401.375100
46501.421700
46601.260900
46701.318100
46801.634100
46901.265400
47001.379100
47101.228000
47201.357200
47301.578100
47401.523300
47501.502200
47601.377200
47701.517800
47801.591300
47901.620500
48001.480600
48101.332500
48201.476400
48301.628700
48401.477400
48501.462200
48601.340800
48701.370300
48801.426400
48901.446700
49001.270700
49101.275200
49201.393500
49301.549100
49401.565600
49501.437300
49601.251500
49701.345700
49801.677200
49901.319000
50001.401200
50101.184300
50201.305400
50301.586500
50401.441800
50501.479200
50601.326400
50701.168900
50801.750900
50901.554400
51001.435300
51101.300900
51201.417300
51301.726400
51401.492300
51501.444400
51601.381500
51701.382000
51801.429300
51901.564100
52001.242100
52101.415700
52201.309000
52301.633300
52401.510800
52501.481000
52601.179000
52701.190000
52801.707500
52901.450600
53001.524000
53101.306400
53201.310400
53301.541700
53401.616800
53501.553500
53601.359800
53701.323200
53801.609500
53901.508500
54001.353700
54101.342500
54201.263800
54301.632100
54401.400300
54501.490700
54601.329400
54701.431900
54801.657400
54901.496000
55001.313000
55101.232400
55201.261400
55301.515200
55401.505300
55501.301700
55601.243200
55701.408100
55801.575300
55901.388400
56001.420700
56101.066600
56201.307400
56301.646000
56401.505700
56501.271800
56601.226700
56701.280700
56801.501700
56901.557900
57001.223000
57101.182300
57201.256300
57301.639100
57401.533400
57501.487400
57601.212000
57701.373400
57801.542200
57901.547300
58001.379900
58101.119400
58201.227400
58301.566800
58401.474600
58501.313100
58601.253100
58701.355000
58801.580500
58901.460400
59001.326400
59101.366700
59201.366700
59301.508300
59401.497400
59501.304600
59601.263900
59701.214800
59801.570500
59901.511800
60001.379300
60101.313300
60201.183900
60301.064000

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=6030, training_loss=1.5882165481795125, metrics={'train_runtime': 21174.035, 'train_samples_per_second': 0.57, 'train_steps_per_second': 0.285, 'total_flos': 2.3816783854718976e+17, 'train_loss': 1.5882165481795125, 'epoch': 2.0})" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"\"\"\n", "SFTTrainer: Supervised Fine-Tuningに関する設定\n", "\n", " - model:\n", " - 読み込んだベースのモデル\n", "\n", " - train_dataset:\n", " - トレーニングに使用するデータセット\n", "\n", " - eval_dataset:\n", " - 評価に使用するデータセット\n", "\n", " - peft_config:\n", " - PEFT(Parameter-Efficient Fine-Tuning)の設定(LoRAを利用する場合に指定)\n", "\n", " - max_seq_length:\n", " - モデルに入力されるシーケンスの最大トークン長\n", "\n", " - dataset_text_field:\n", " - データセット内の学習に使うテキストを含むフィールド名\n", "\n", " - tokenizer:\n", " - モデルに対応するトークナイザー\n", "\n", " - args:\n", " - トレーニングに使用するハイパーパラメータ(TrainingArgumentsの設定を指定)\n", "\n", " - packing:\n", " - 入力シーケンスのパッキングを行うかどうかの設定 (False に設定することで、各入力を独立して扱う)\n", "\"\"\"\n", "trainer = SFTTrainer(\n", " model=model,\n", " train_dataset=dataset[\"train\"],\n", " peft_config=peft_config,\n", " max_seq_length=512,\n", " dataset_text_field=\"formatted_text\",\n", " tokenizer=tokenizer,\n", " args=training_arguments,\n", " packing= False,\n", ")\n", "\n", "model.config.use_cache = False # キャッシュ機能を無効化\n", "trainer.train() # トレーニングを実行" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "

Run history:


train/epoch▁▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▇███
train/global_step▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇██
train/grad_norm▂▁▁▂▁▃▁▃▂▂▂▂▃▂▄▂▂▂▂▂▂▂▃▃▇▄▅▃▄▃▃▃▇▅█▅▃▄▅▄
train/learning_rate███▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▁
train/loss██▇▇▆▅▆▆▇▆▅▆▆▆▇▃▃▂▅▅▂▃▃▅▅▁▂▃▄▂▂▁▁▆▂▂▂▄▄▃

Run summary:


total_flos2.3816783854718976e+17
train/epoch2
train/global_step6030
train/grad_norm5.02325
train/learning_rate0
train/loss1.064
train_loss1.58822
train_runtime21174.035
train_samples_per_second0.57
train_steps_per_second0.285

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run competition at: https://wandb.ai/weblab-geniac-leaderboard/llm-jp-3-13b-finetune/runs/f39d2amo
View project at: https://wandb.ai/weblab-geniac-leaderboard/llm-jp-3-13b-finetune
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Find logs at: ./wandb/run-20241216_173153-f39d2amo/logs" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "wandb.finish()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "# タスクとなるデータの読み込み。\n", "# omnicampusの開発環境では、左にタスクのjsonlをドラッグアンドドロップしてから実行。\n", "import json\n", "datasets = []\n", "with open(\"./elyza-tasks-100-TV_0.jsonl\", \"r\") as f:\n", " item = \"\"\n", " for line in f:\n", " line = line.strip()\n", " item += line\n", " if item.endswith(\"}\"):\n", " datasets.append(json.loads(item))\n", " item = \"\"" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 100/100 [28:10<00:00, 16.91s/it]\n" ] } ], "source": [ "# モデルによるタスクの推論。\n", "from tqdm import tqdm\n", "\n", "results = []\n", "for data in tqdm(datasets):\n", "\n", " input = data[\"input\"]\n", "\n", " prompt = f\"\"\"### 指示\n", " {input}\n", " ### 回答\n", " \"\"\"\n", " \n", " tokenized_input = tokenizer.encode(prompt, add_special_tokens=False, return_tensors=\"pt\").to(model.device)\n", " attention_mask = torch.ones_like(tokenized_input)\n", "\n", " with torch.no_grad():\n", " outputs = model.generate(\n", " tokenized_input,\n", " attention_mask=attention_mask,\n", " max_new_tokens=512,\n", " do_sample=True,\n", " repetition_penalty=1.2,\n", " top_k=4,\n", " num_beams=1,\n", " pad_token_id=tokenizer.eos_token_id\n", " )[0]\n", " output = tokenizer.decode(outputs[tokenized_input.size(1):], skip_special_tokens=True)\n", "\n", " results.append({\"task_id\": data[\"task_id\"], \"input\": input, \"output\": output})" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "# こちらで生成されたjsolを提出してください。\n", "# 本コードではinputとeval_aspectも含んでいますが、なくても問題ありません。\n", "# 必須なのはtask_idとoutputとなります。\n", "import re\n", "jsonl_id = re.sub(\".*/\", \"\", new_model_id)\n", "with open(f\"./{jsonl_id}-outputs.jsonl\", 'w', encoding='utf-8') as f:\n", " for result in results:\n", " json.dump(result, f, ensure_ascii=False) # ensure_ascii=False for handling non-ASCII characters\n", " f.write('\\n')" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 310, "referenced_widgets": [ "9fbdc638b16d42aa933b76ddd0aeb0d2", "8040ebc72b98422eb850ca74faa5010c", "8b524ed252fd48beb2ea120289255c4c", "9a1bfe71fe4447d7aa6dd5b0e2c76323", "bc2a31be20c448f482d90e14c4dd4217", "f6aca67c93084ccfa69e46ff3458c2c1", "25a574c44a4b40988e923cde73c7ab98", "dc8158ad183d4ed5999a1d3cbed374dc", "5ee3034b9229440d8f32caeab790c18c", "6ddd54fceed548c9b9952a0fae3235e2", "9f609b4e3e6940739da8e2c96cad556e", "ed486c3edc3a4956bbcff1e277616eb8", "138532a62d6d4871881aff833b4b7952", "f8ac748a9fce40a2b7f596fc71d7d562", "81935de0c7fe47ba821e74a7ca2e9d71", "f964e93ca65a4a97bf45e67508ecd046", "69c3365d22e04899a3de30545ddcbcf0", "cd983f1b0cc24d5ca3fb6e51dcb94581", "c6e94ed442d74945ae5a4cfdee57105a", "2d069ed4a5ed459e8953b68909d839c9", "2b9c44446bf941e9b005ae4b0d649c06", "74454e323b6e4430a87400f217ff1e92", "7a14519d06c9451891a4a912a657d48f", "0cbbd35186f34dbaba91dc380ce238c1", "3b6dac137ccb4c18a3507a5169942587", "77abad15b55e410cad7849c01b9fea57", "37d8c7779c9a44f4a81307d2f0de36f4", "fe8f8fa5feb446cda268299922eee583", "891caf0eafd64152aa598ce6ddc10d0b", "877e90cb92af40ba99a0eebc68b4c44d", "6f8ec884fb364cf8a75c4c3040071b26", "7655b84283d44186b6ab817ed4047680", "99feab4c24314d1dbd54d0afb1d66e40", "5088748592e149558900927fc5432333", "b1330a0fcccc4dbdaf67b4875a942b74", "8983f82330da4c489c0f376ae3cb6628", "7523b32b860d46c182cde9e2d31b20d8", "1bb23d81d42142128b2acbce907e92a4", "1474e70adb91432aa7ac671dd1136cc7", "4609b95fbe5341a1b64e746f459e0527", "b730a06833584f2bb5120496595cea8c", "5e1a33f2f3a741faba291158c64e375d", "d6816b614a2b46b3934f349e49cb4bf2", "13394cc0087146768536e655e09b957d", "a328ac71be8b4a6ea2bed286e511954b", "ec2433cabbee4461acb7d4b0912b4a85", "472675ee6ae549528940eb5d035498cd", "d402e1af963942958d870d0b40418ce9", "c751fc8cf55845d79cc34ca3887946a6", "41e53d2b2cdf43df826daf7d70afeaeb", "4e0859adc5c44a0eb2ca2ca535bf61a0", "c2b2e2153e4d46e08ca16493a6981cc4", "eb02a59ffb68430388c7e5f3b6b8b500", "c223a1fc63564a6da61712411402ff80", "95e76e766c7d4d3ea14d45acf5cfaa68", "f8513cccb925451a99df5a2dbdccfec5", "90dc91a7cc9144499e5986f226ecf820", "dde5aab1876b4d6e8ebbb164f7a4344a", "9826004c024743a8a387077651f527e1", "a3f7b3472b354ac7a08e85cbb4600c76", "7647bde4c9d544e7beb01614b648dbab", "c9d39178a137478bbc6309be489dad34", "cefed373abfc4eb08bd28b80e8926287", "7072597723e14c2e876942b4d40b0c31", "7f6bb847d0ba446586fd20f6d65ed849", "10555a1ffa0748daa432377971a547d6", "246fa3a7ee464046b646aa2c7e90e54b", "a797ac8cf5fb4a8ea7999ea129e2e0a5", "74207d399112401b89304ce251b11082", "8ec531982ce54767840b8a1a6fd4f3f2", "106cd9d0f87f4582b64b6e0531bf5cad", "ddbcb5a9304e4e3a855963d06966b5ec", "f62357d48e3b4ca88bb846bbe14dcb67", "673bccf1a944493bb221246fbc4505c9", "1ae2437d91334084bd2a9f709e770739", "59fe3fb9e27e4f13a7ecdbbf7b021bab", "0bfaaff4cf9a4bd5b713a522871d10df" ] }, "id": "zq4Ko1FWakX9", "outputId": "48922ec2-6a85-460b-8486-8627fef91972" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "17106f90b1424d0b939c4533b98fb5bd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "adapter_model.safetensors: 0%| | 0.00/501M [00:00