{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "AKjdG7tbTb-n" }, "source": [ "# Example notebook for running Axolotl on google colab" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RcbNpOgWRcii" }, "outputs": [], "source": [ "import torch\n", "# Check so there is a gpu available, a T4(free tier) is enough to run this notebook\n", "assert (torch.cuda.is_available()==True)" ] }, { "cell_type": "markdown", "metadata": { "id": "h3nLav8oTRA5" }, "source": [ "## Install Axolotl and dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3c3yGAwnOIdi", "outputId": "e3777b5a-40ef-424f-e181-62dfecd1dd01" }, "outputs": [], "source": [ "!pip install torch==\"2.1.2\"\n", "!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n", "!pip install flash-attn==\"2.5.0\"\n", "!pip install deepspeed==\"0.13.1\"" ] }, { "cell_type": "markdown", "metadata": { "id": "BW2MFr7HTjub" }, "source": [ "## Create an yaml config file" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9pkF2dSoQEUN" }, "outputs": [], "source": [ "import yaml\n", "\n", "# Your YAML string\n", "yaml_string = \"\"\"\n", "base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n", "model_type: LlamaForCausalLM\n", "tokenizer_type: LlamaTokenizer\n", "is_llama_derived_model: true\n", "\n", "load_in_8bit: false\n", "load_in_4bit: true\n", "strict: false\n", "\n", "datasets:\n", " - path: mhenrichsen/alpaca_2k_test\n", " type: alpaca\n", "dataset_prepared_path:\n", "val_set_size: 0.05\n", "output_dir: ./qlora-out\n", "\n", "adapter: qlora\n", "lora_model_dir:\n", "\n", "sequence_len: 1096\n", "sample_packing: true\n", "pad_to_sequence_len: true\n", "\n", "lora_r: 32\n", "lora_alpha: 16\n", "lora_dropout: 0.05\n", "lora_target_modules:\n", "lora_target_linear: true\n", "lora_fan_in_fan_out:\n", "\n", "wandb_project:\n", "wandb_entity:\n", "wandb_watch:\n", "wandb_name:\n", "wandb_log_model:\n", "\n", "mlflow_experiment_name: colab-example\n", "\n", "gradient_accumulation_steps: 1\n", "micro_batch_size: 1\n", "num_epochs: 4\n", "max_steps: 20\n", "optimizer: paged_adamw_32bit\n", "lr_scheduler: cosine\n", "learning_rate: 0.0002\n", "\n", "train_on_inputs: false\n", "group_by_length: false\n", "bf16: false\n", "fp16: true\n", "tf32: false\n", "\n", "gradient_checkpointing: true\n", "early_stopping_patience:\n", "resume_from_checkpoint:\n", "local_rank:\n", "logging_steps: 1\n", "xformers_attention:\n", "flash_attention: false\n", "\n", "warmup_steps: 10\n", "evals_per_epoch:\n", "saves_per_epoch:\n", "debug:\n", "deepspeed:\n", "weight_decay: 0.0\n", "fsdp:\n", "fsdp_config:\n", "special_tokens:\n", "\n", "\"\"\"\n", "\n", "# Convert the YAML string to a Python dictionary\n", "yaml_dict = yaml.safe_load(yaml_string)\n", "\n", "# Specify your file path\n", "file_path = 'test_axolotl.yaml'\n", "\n", "# Write the YAML file\n", "with open(file_path, 'w') as file:\n", " yaml.dump(yaml_dict, file)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "bidoj8YLTusD" }, "source": [ "## Launch the training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ydTI2Jk2RStU", "outputId": "d6d0df17-4b53-439c-c802-22c0456d301b" }, "outputs": [], "source": [ "# Buy using the ! the comand will be executed as a bash command\n", "!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Play with inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Buy using the ! the comand will be executed as a bash command\n", "!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n", " --qlora_model_dir=\"./qlora-out\" --gradio" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }