{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "private_outputs": true,
      "toc_visible": true,
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "gpuClass": "standard",
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/zetavg/LLaMA-LoRA-Tuner/blob/main/LLaMA_LoRA.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 🦙🎛️ LLaMA-LoRA Tuner\n",
        "\n",
        "TL;DR: **Runtime > Run All** (`⌘/Ctrl+F9`). Takes about 5 minutes to start. You will be promped to authorize Google Drive access."
      ],
      "metadata": {
        "id": "bb4nzBvLfZUj"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@markdown To prevent Colab from disconnecting you, here is a music player that will loop infinitely (it's silent):\n",
        "%%html\n",
        "<audio src=\"https://github.com/anars/blank-audio/raw/master/1-hour-of-silence.mp3\" autoplay muted loop controls />"
      ],
      "metadata": {
        "id": "DwarOgXbG77C",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# @title A small workaround { display-mode: \"form\" }\n",
        "# @markdown Don't panic if you see an error here. Just click the `RESTART RUNTIME` button in the output below, then Run All again.\n",
        "# @markdown The error will disappear on the next run.\n",
        "%pip install Pillow==9.3.0 numpy==1.23.5\n",
        "\n",
        "import pkg_resources as r\n",
        "import PIL\n",
        "import numpy\n",
        "for module, min_version in [(PIL, \"9.3\"), (numpy, \"1.23\")]:\n",
        "  lib_version = r.parse_version(module.__version__)\n",
        "  print(module.__name__, lib_version)\n",
        "  if lib_version < r.parse_version(min_version):\n",
        "    raise Exception(\"Restart the runtime by clicking the 'RESTART RUNTIME' button above (or Runtime > Restart Runtime).\")"
      ],
      "metadata": {
        "id": "XcJ4WO3KhOX1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Config\n",
        "\n",
        "Some configurations to run this notebook. "
      ],
      "metadata": {
        "id": "5uS5jJ8063f_"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# @title Git/Project { display-mode: \"form\", run: \"auto\" }\n",
        "# @markdown Project settings.\n",
        "\n",
        "# @markdown The URL of the LLaMA-LoRA-Tuner project<br>&nbsp;&nbsp;(default: `https://github.com/zetavg/LLaMA-LoRA-Tuner.git`):\n",
        "llama_lora_project_url = \"https://github.com/zetavg/LLaMA-LoRA-Tuner.git\" # @param {type:\"string\"}\n",
        "# @markdown The branch to use for LLaMA-LoRA-Tuner project:\n",
        "llama_lora_project_branch = \"main\" # @param {type:\"string\"}\n",
        "\n",
        "# # @markdown Forces the local directory to be updated by the remote branch:\n",
        "# force_update = True # @param {type:\"boolean\"}\n"
      ],
      "metadata": {
        "id": "v3ZCPW0JBCcH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# @title Google Drive { display-mode: \"form\", run: \"auto\" }\n",
        "# @markdown Google Drive will be used to store data that is used or outputed by this notebook. you will be prompted to authorize access while running this notebook. \n",
        "#\n",
        "# @markdown Currently, it's not possible to access only a specific folder of Google Drive, we have no choice to mount the entire Google Drive, but will only access given folder.\n",
        "#\n",
        "# @markdown You can customize the location of the stored data here.\n",
        "\n",
        "# @markdown The folder in Google Drive where Colab Notebook data are stored<br />&nbsp;&nbsp;**(WARNING: The content of this folder will be modified by this notebook)**:\n",
        "google_drive_folder = \"Colab Data/LLaMA-LoRA Tuner\" # @param {type:\"string\"}\n",
        "# google_drive_colab_data_folder = \"Colab Notebooks/Notebook Data\"\n",
        "\n",
        "# Where Google Drive will be mounted in the Colab runtime.\n",
        "google_drive_mount_path = \"/content/drive\"\n",
        "\n",
        "from requests import get\n",
        "from socket import gethostname, gethostbyname\n",
        "host_ip = gethostbyname(gethostname())\n",
        "colab_notebook_filename = get(f\"http://{host_ip}:9000/api/sessions\").json()[0][\"name\"]\n",
        "\n",
        "# def remove_ipynb_extension(filename: str) -> str:\n",
        "#     extension = \".ipynb\"\n",
        "#     if filename.endswith(extension):\n",
        "#         return filename[:-len(extension)]\n",
        "#     return filename\n",
        "\n",
        "# colab_notebook_name = remove_ipynb_extension(colab_notebook_filename)\n",
        "\n",
        "from google.colab import drive\n",
        "try:\n",
        "  drive.mount(google_drive_mount_path)\n",
        "\n",
        "  google_drive_data_directory_relative_path = google_drive_folder\n",
        "  google_drive_data_directory_path = f\"{google_drive_mount_path}/My Drive/{google_drive_data_directory_relative_path}\"\n",
        "  !mkdir -p \"{google_drive_data_directory_path}\"\n",
        "  !ln -nsf \"{google_drive_data_directory_path}\" ./data\n",
        "  !touch \"data/This folder is used by the Colab notebook \\\"{colab_notebook_filename}\\\".txt\"\n",
        "  !echo \"Data will be stored in Google Drive folder: \\\"{google_drive_data_directory_relative_path}\\\", which is mounted under \\\"{google_drive_data_directory_path}\\\"\"\n",
        "except Exception as e:\n",
        "  print(\"Drive won't be mounted!\")"
      ],
      "metadata": {
        "id": "iZmRtUY68U5f"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# @title Model/Training Settings { display-mode: \"form\", run: \"auto\" }\n",
        "\n",
        "base_model = \"decapoda-research/llama-7b-hf\" # @param {type:\"string\"}"
      ],
      "metadata": {
        "id": "Ep3Qhwj0Bzwf"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Runtime Info\n",
        "\n",
        "Print out some information about the Colab runtime. Code from https://colab.research.google.com/notebooks/pro.ipynb."
      ],
      "metadata": {
        "id": "Qg8tzrgb6ya_"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AHbRt8sK6YWy"
      },
      "outputs": [],
      "source": [
        "# @title GPU Info { display-mode: \"form\" }\n",
        "#\n",
        "# @markdown By running this cell, you can see what GPU you've been assigned.\n",
        "#\n",
        "# @markdown If the execution result of running the code cell below is \"Not connected to a GPU\", you can change the runtime by going to `Runtime / Change runtime type` in the menu to enable a GPU accelerator, and then re-execute the code cell.\n",
        "\n",
        "\n",
        "gpu_info = !nvidia-smi\n",
        "gpu_info = '\\n'.join(gpu_info)\n",
        "if gpu_info.find('failed') >= 0:\n",
        "  print('Not connected to a GPU')\n",
        "else:\n",
        "  print(gpu_info)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# @title RAM Info { display-mode: \"form\" }\n",
        "#\n",
        "# @markdown By running this code cell, You can see how much memory you have available.\n",
        "#\n",
        "# @markdown Normally, a high-RAM runtime is not needed, but if you need more RAM, you can enable a high-RAM runtime via `Runtime / Change runtime type` in the menu. Then select High-RAM in the Runtime shape dropdown. After, re-execute the code cell.\n",
        "\n",
        "from psutil import virtual_memory\n",
        "ram_gb = virtual_memory().total / 1e9\n",
        "print('Your runtime has {:.1f} gigabytes of available RAM\\n'.format(ram_gb))\n",
        "\n",
        "if ram_gb < 20:\n",
        "  print('Not using a high-RAM runtime')\n",
        "else:\n",
        "  print('You are using a high-RAM runtime!')"
      ],
      "metadata": {
        "id": "rGM5MYjR7yeS"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Prepare the Project\n",
        "\n",
        "Clone the project and install dependencies (takes about 5m on the first run)."
      ],
      "metadata": {
        "id": "8vSPMNtNAqOo"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "![ ! -d llama_lora ] && git clone -b {llama_lora_project_branch} --filter=tree:0 {llama_lora_project_url} llama_lora\n",
        "!cd llama_lora && git add --all && git stash && git fetch origin {llama_lora_project_branch} && git checkout {llama_lora_project_branch} && git reset origin/{llama_lora_project_branch} --hard\n",
        "![ ! -f llama-lora-requirements-installed ] && cd llama_lora && pip install -r requirements.lock.txt && touch ../llama-lora-requirements-installed"
      ],
      "metadata": {
        "id": "JGYz2VDoAzC8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Launch"
      ],
      "metadata": {
        "id": "o90p1eYQimyr"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# The following command will launch the app in one shot, but we will not do this here.\n",
        "# Instead, we will import and run Python code from the runtime, so that each part\n",
        "# can be reloaded easily in the Colab notebook and provide readable outputs.\n",
        "# It also resolves the GPU out-of-memory issue on training.\n",
        "# !python llama_lora/app.py --base_model='{base_model}' --data_dir='./data' --share"
      ],
      "metadata": {
        "id": "HYVjcvwXimB6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# @title Load the App (set config, prepare data dir, load base model)\n",
        "\n",
        "# @markdown For a LLaMA-7B model, it will take about ~5m to load for the first execution,\n",
        "# @markdown including download. Subsequent executions will take about 2m to load.\n",
        "\n",
        "# Set Configs\n",
        "from llama_lora.llama_lora.config import Config, process_config\n",
        "from llama_lora.llama_lora.globals import initialize_global\n",
        "Config.default_base_model_name = base_model\n",
        "Config.base_model_choices = [base_model]\n",
        "data_dir_realpath = !realpath ./data\n",
        "Config.data_dir = data_dir_realpath[0]\n",
        "Config.load_8bit = True\n",
        "process_config()\n",
        "initialize_global()\n",
        "\n",
        "# Prepare Data Dir\n",
        "from llama_lora.llama_lora.utils.data import init_data_dir\n",
        "init_data_dir()\n",
        "\n",
        "# Load the Base Model\n",
        "from llama_lora.llama_lora.models import prepare_base_model\n",
        "prepare_base_model()"
      ],
      "metadata": {
        "id": "Yf6g248ylteP"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Start Gradio UI 🚀 (open the app from the URL output-ed here)\n",
        "\n",
        "You will see `Running on public URL: https://...` in the output of the following code cell, click on it to open the Gradio UI."
      ],
      "metadata": {
        "id": "K-hCouBClKAe"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import gradio as gr\n",
        "from llama_lora.llama_lora.ui.main_page import main_page, get_page_title\n",
        "from llama_lora.llama_lora.ui.css_styles import get_css_styles\n",
        "\n",
        "with gr.Blocks(title=get_page_title(), css=get_css_styles()) as app:\n",
        "    main_page()\n",
        "\n",
        "app.queue(concurrency_count=1).launch(share=True, debug=True, server_name=\"127.0.0.1\")"
      ],
      "metadata": {
        "id": "iLygNTcHk0N8"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}