{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"T4","authorship_tag":"ABX9TyPWZ8Jb4Kxe+LPy00eQFSll"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","source":["!pip install transformers datasets trl bitsandbytes peft\n","!pip install datasets\n","!pip install -U accelerate"],"metadata":{"collapsed":true,"colab":{"base_uri":"https://localhost:8080/"},"id":"-84y3O9audxh","executionInfo":{"status":"ok","timestamp":1730211323839,"user_tz":-540,"elapsed":24581,"user":{"displayName":"‍구원정[ 학부재학 / 수학과 ]","userId":"15682121601729926510"}},"outputId":"62401782-8e8b-4215-efdd-41331def3f6d"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.44.2)\n","Collecting datasets\n"," Downloading datasets-3.0.2-py3-none-any.whl.metadata (20 kB)\n","Collecting trl\n"," Downloading trl-0.11.4-py3-none-any.whl.metadata (12 kB)\n","Collecting bitsandbytes\n"," Downloading bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl.metadata (3.5 kB)\n","Collecting peft\n"," Downloading peft-0.13.2-py3-none-any.whl.metadata (13 kB)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.16.1)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.24.7)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.4)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.1)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.2)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.9.11)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n","Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.5)\n","Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n","Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.5)\n","Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (16.1.0)\n","Collecting dill<0.3.9,>=0.3.0 (from datasets)\n"," Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n","Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n","Collecting xxhash (from datasets)\n"," Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n","Collecting multiprocess<0.70.17 (from datasets)\n"," Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n","Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.6.1)\n","Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.10)\n","Requirement already satisfied: torch>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from trl) (2.5.0+cu121)\n","Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (from trl) (0.34.2)\n","Collecting tyro>=0.5.11 (from trl)\n"," Downloading tyro-0.8.14-py3-none-any.whl.metadata (8.4 kB)\n","Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n","Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.3)\n","Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n","Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n","Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n","Requirement already satisfied: yarl<2.0,>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.16.0)\n","Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (4.12.2)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4.0)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.10)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.2.3)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.8.30)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (3.4.2)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (3.1.4)\n","Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (1.13.1)\n","Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch>=1.4.0->trl) (1.3.0)\n","Requirement already satisfied: docstring-parser>=0.16 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl) (0.16)\n","Requirement already satisfied: rich>=11.1.0 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl) (13.9.3)\n","Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl)\n"," Downloading shtab-1.7.1-py3-none-any.whl.metadata (7.3 kB)\n","Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n","Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n","Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (3.0.0)\n","Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (2.18.0)\n","Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from yarl<2.0,>=1.12.0->aiohttp->datasets) (0.2.0)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.4.0->trl) (3.0.2)\n","Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl) (0.1.2)\n","Downloading datasets-3.0.2-py3-none-any.whl (472 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m472.7/472.7 kB\u001b[0m \u001b[31m11.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading trl-0.11.4-py3-none-any.whl (316 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m316.6/316.6 kB\u001b[0m \u001b[31m16.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl (122.4 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m122.4/122.4 MB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading peft-0.13.2-py3-none-any.whl (320 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m320.7/320.7 kB\u001b[0m \u001b[31m13.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m8.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m12.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading tyro-0.8.14-py3-none-any.whl (109 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m109.8/109.8 kB\u001b[0m \u001b[31m10.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m11.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading shtab-1.7.1-py3-none-any.whl (14 kB)\n","Installing collected packages: xxhash, shtab, dill, multiprocess, tyro, bitsandbytes, peft, datasets, trl\n","Successfully installed bitsandbytes-0.44.1 datasets-3.0.2 dill-0.3.8 multiprocess-0.70.16 peft-0.13.2 shtab-1.7.1 trl-0.11.4 tyro-0.8.14 xxhash-3.5.0\n","Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (3.0.2)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.1)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n","Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (16.1.0)\n","Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n","Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n","Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n","Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.5)\n","Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.5.0)\n","Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n","Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.6.1)\n","Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.10)\n","Requirement already satisfied: huggingface-hub>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.24.7)\n","Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.1)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n","Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.3)\n","Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n","Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n","Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n","Requirement already satisfied: yarl<2.0,>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.16.0)\n","Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.23.0->datasets) (4.12.2)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.4.0)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.10)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.2.3)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.8.30)\n","Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n","Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n","Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from yarl<2.0,>=1.12.0->aiohttp->datasets) (0.2.0)\n","Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.34.2)\n","Collecting accelerate\n"," Downloading accelerate-1.0.1-py3-none-any.whl.metadata (19 kB)\n","Requirement already satisfied: numpy<3.0.0,>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate) (1.26.4)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (24.1)\n","Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n","Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate) (6.0.2)\n","Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.5.0+cu121)\n","Requirement already satisfied: huggingface-hub>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.24.7)\n","Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.4.5)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (3.16.1)\n","Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (2024.6.1)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (2.32.3)\n","Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (4.66.5)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (4.12.2)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.4.2)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.4)\n","Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.13.1)\n","Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch>=1.10.0->accelerate) (1.3.0)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (3.0.2)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.4.0)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.10)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2.2.3)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2024.8.30)\n","Downloading accelerate-1.0.1-py3-none-any.whl (330 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m330.9/330.9 kB\u001b[0m \u001b[31m11.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hInstalling collected packages: accelerate\n"," Attempting uninstall: accelerate\n"," Found existing installation: accelerate 0.34.2\n"," Uninstalling accelerate-0.34.2:\n"," Successfully uninstalled accelerate-0.34.2\n","Successfully installed accelerate-1.0.1\n"]}]},{"cell_type":"code","execution_count":null,"metadata":{"collapsed":true,"id":"dCTa8Ekcs2ZB"},"outputs":[],"source":["import os\n","import torch\n","import torchvision\n","torchvision.disable_beta_transforms_warning()\n","from datasets import load_dataset, concatenate_datasets\n","from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig\n","from trl import SFTTrainer\n","from peft import LoraConfig, get_peft_model"]},{"cell_type":"markdown","source":["# W&B 비활성화"],"metadata":{"id":"3QxeXUHqyZSL"}},{"cell_type":"code","source":["import os\n","os.environ[\"WANDB_MODE\"] = \"disabled\""],"metadata":{"id":"G5KNSMgQyZJs"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["torch.cuda.empty_cache()\n","torch.cuda.memory_summary(device=None, abbreviated=False)\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":227},"collapsed":true,"id":"dvrZxfhRoSX_","executionInfo":{"status":"ok","timestamp":1730211361971,"user_tz":-540,"elapsed":375,"user":{"displayName":"‍구원정[ 학부재학 / 수학과 ]","userId":"15682121601729926510"}},"outputId":"27bcfaf8-f420-451f-824c-d6e329218838"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'|===========================================================================|\\n| PyTorch CUDA memory summary, device ID 0 |\\n|---------------------------------------------------------------------------|\\n| CUDA OOMs: 0 | cudaMalloc retries: 0 |\\n|===========================================================================|\\n| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |\\n|---------------------------------------------------------------------------|\\n| Allocated memory | 0 B | 0 B | 0 B | 0 B |\\n| from large pool | 0 B | 0 B | 0 B | 0 B |\\n| from small pool | 0 B | 0 B | 0 B | 0 B |\\n|---------------------------------------------------------------------------|\\n| Active memory | 0 B | 0 B | 0 B | 0 B |\\n| from large pool | 0 B | 0 B | 0 B | 0 B |\\n| from small pool | 0 B | 0 B | 0 B | 0 B |\\n|---------------------------------------------------------------------------|\\n| Requested memory | 0 B | 0 B | 0 B | 0 B |\\n| from large pool | 0 B | 0 B | 0 B | 0 B |\\n| from small pool | 0 B | 0 B | 0 B | 0 B |\\n|---------------------------------------------------------------------------|\\n| GPU reserved memory | 0 B | 0 B | 0 B | 0 B |\\n| from large pool | 0 B | 0 B | 0 B | 0 B |\\n| from small pool | 0 B | 0 B | 0 B | 0 B |\\n|---------------------------------------------------------------------------|\\n| Non-releasable memory | 0 B | 0 B | 0 B | 0 B |\\n| from large pool | 0 B | 0 B | 0 B | 0 B |\\n| from small pool | 0 B | 0 B | 0 B | 0 B |\\n|---------------------------------------------------------------------------|\\n| Allocations | 0 | 0 | 0 | 0 |\\n| from large pool | 0 | 0 | 0 | 0 |\\n| from small pool | 0 | 0 | 0 | 0 |\\n|---------------------------------------------------------------------------|\\n| Active allocs | 0 | 0 | 0 | 0 |\\n| from large pool | 0 | 0 | 0 | 0 |\\n| from small pool | 0 | 0 | 0 | 0 |\\n|---------------------------------------------------------------------------|\\n| GPU reserved segments | 0 | 0 | 0 | 0 |\\n| from large pool | 0 | 0 | 0 | 0 |\\n| from small pool | 0 | 0 | 0 | 0 |\\n|---------------------------------------------------------------------------|\\n| Non-releasable allocs | 0 | 0 | 0 | 0 |\\n| from large pool | 0 | 0 | 0 | 0 |\\n| from small pool | 0 | 0 | 0 | 0 |\\n|---------------------------------------------------------------------------|\\n| Oversize allocations | 0 | 0 | 0 | 0 |\\n|---------------------------------------------------------------------------|\\n| Oversize GPU segments | 0 | 0 | 0 | 0 |\\n|===========================================================================|\\n'"],"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"}},"metadata":{},"execution_count":5}]},{"cell_type":"markdown","source":["# 환경 변수 로드 및 Google Colab 환경 설정"],"metadata":{"id":"hvemiqXbtPVo"}},{"cell_type":"code","source":["if os.path.exists('C:/Users/yd170/OneDrive/바탕 화면/Coding/KRX.env'):\n"," load_dotenv('C:/Users/yd170/OneDrive/바탕 화면/Coding/KRX.env')\n"," hf_token = os.getenv(\"HF_TOKEN\")\n","else:\n"," hf_token = \"YOUR_HF_TOKEN\"\n","\n","model_name = \"Qwen/Qwen2-1.5B\"\n","max_seq_length = 2048"],"metadata":{"id":"2nEoHW7itR3t"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# 양자화 설정 (4비트 양자화 사용)"],"metadata":{"id":"XPOupITbtTe1"}},{"cell_type":"code","source":["bnb_config = BitsAndBytesConfig(\n"," load_in_4bit=True, # 4비트 양자화 적용\n"," bnb_4bit_use_double_quant=True,\n"," bnb_4bit_quant_type=\"nf4\",\n"," bnb_4bit_compute_dtype=torch.float16\n",")"],"metadata":{"id":"LxQopNuFtTOy"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# 모델 및 토크나이저 로드 (GPU 사용하도록 설정)"],"metadata":{"id":"dcD5t1OstWbx"}},{"cell_type":"code","source":["tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)\n","model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map=\"auto\")"],"metadata":{"collapsed":true,"id":"uXH2KLTytX0A"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# LoRA 설정 추가"],"metadata":{"id":"230hvGJmtZU6"}},{"cell_type":"code","source":["lora_config = LoraConfig(\n"," r=16,\n"," lora_alpha=32,\n"," target_modules=[\"q_proj\", \"v_proj\"],\n"," lora_dropout=0.05,\n"," bias=\"none\"\n",")\n","model = get_peft_model(model, lora_config)"],"metadata":{"id":"RM4Xvc74tZD8"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# 두 개의 데이터셋 로드 및 병합"],"metadata":{"id":"Y-jCFuCdtb7i"}},{"cell_type":"code","source":["first_dataset = load_dataset(\"amphora/krx-sample-instructions\", split=\"train\")\n","second_dataset = load_dataset(\"Cartinoe5930/web_text_synthetic_dataset_50k\", split=\"train\")\n","\n","# 데이터셋 병합\n","dataset = concatenate_datasets([first_dataset, second_dataset])\n","\n","# 프롬프트 포맷 설정\n","prompt_format = \"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n","\n","### Instruction:\n","{}\n","\n","### Response:\n","{}\"\"\"\n","\n","EOS_TOKEN = tokenizer.eos_token\n","\n","def formatting_prompts_func(examples):\n"," instructions = []\n"," outputs = []\n","\n"," # 데이터셋 내 필드 확인 후 각 필드에 따라 처리\n"," if \"prompt\" in examples and \"response\" in examples:\n"," instructions = examples[\"prompt\"]\n"," outputs = examples[\"response\"]\n"," elif \"question\" in examples and \"response\" in examples:\n"," instructions = examples[\"question\"]\n"," outputs = examples[\"response\"]\n"," else:\n"," raise KeyError(\"The dataset fields do not match the expected format.\")\n","\n"," texts = [prompt_format.format(instr, output) + EOS_TOKEN for instr, output in zip(instructions, outputs)]\n"," return {\"formatted_text\": texts}\n","\n","# 데이터셋 가공\n","dataset = dataset.map(formatting_prompts_func, batched=True)"],"metadata":{"collapsed":true,"id":"FNBsMfZWtdYm"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# 모델 학습 설정"],"metadata":{"id":"yhoxZxuBtmuu"}},{"cell_type":"code","source":["training_args = TrainingArguments(\n"," output_dir=\"./output\",\n"," per_device_train_batch_size=1, d\n"," gradient_accumulation_steps=8,\n"," max_steps=100,\n"," logging_steps=10,\n"," learning_rate=2e-5,\n"," seed=42,\n"," save_steps=100,\n"," fp16=True, # 혼합 정밀도 사용으로 메모리 최적화\n"," report_to=\"none\",\n",")"],"metadata":{"id":"ijdYp1Wxtmgi"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# GPU 메모리 관리 최적화 환경 변수 설정"],"metadata":{"id":"O7iV-HEyxagR"}},{"cell_type":"code","source":["os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\""],"metadata":{"id":"WqNvG7vwxboU"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# SFTTrainer 초기화"],"metadata":{"id":"bPlQN84Nto4j"}},{"cell_type":"code","source":["trainer = SFTTrainer(\n"," model=model,\n"," tokenizer=tokenizer,\n"," train_dataset=dataset, # 병합된 데이터셋 사용\n"," dataset_text_field=\"formatted_text\",\n"," max_seq_length=1024,\n"," args=training_args,\n",")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"collapsed":true,"id":"O4MNZHjNtp11","executionInfo":{"status":"ok","timestamp":1730212244752,"user_tz":-540,"elapsed":765,"user":{"displayName":"‍구원정[ 학부재학 / 수학과 ]","userId":"15682121601729926510"}},"outputId":"0e779eee-4195-4ef5-93bc-ce3096a1e2a9"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': dataset_text_field, max_seq_length. Will not be supported from version '1.0.0'.\n","\n","Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.\n"," warnings.warn(message, FutureWarning)\n","/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:283: UserWarning: You passed a `max_seq_length` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n"," warnings.warn(\n","/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:321: UserWarning: You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n"," warnings.warn(\n","max_steps is given, it will override any value given in num_train_epochs\n"]}]},{"cell_type":"markdown","source":["# 학습"],"metadata":{"id":"qbLwnuqYtqYh"}},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"t586zh7owDHf","executionInfo":{"status":"ok","timestamp":1730211648880,"user_tz":-540,"elapsed":18263,"user":{"displayName":"‍구원정[ 학부재학 / 수학과 ]","userId":"15682121601729926510"}},"outputId":"0409cadd-f76f-41aa-a2ff-a738e6d575f5"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}]},{"cell_type":"code","source":["print(\"모델 학습 시작...\")\n","trainer.train()\n","print(\"모델 학습 완료.\")\n","\n","# 학습된 모델 저장\n","print(\"모델 저장 중...\")\n","model.save_pretrained(\"/content/drive/My Drive/KRX_Qwen2_1_5B\")\n","print(\"모델 저장 완료.\")\n","\n","print(\"토크나이저 저장 중...\")\n","tokenizer.save_pretrained(\"/content/drive/My Drive/KRX_Qwen2_1_5B\")\n","print(\"토크나이저 저장 완료.\")\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":493},"collapsed":true,"id":"42X2FaIPtsdj","executionInfo":{"status":"ok","timestamp":1730212538132,"user_tz":-540,"elapsed":287807,"user":{"displayName":"‍구원정[ 학부재학 / 수학과 ]","userId":"15682121601729926510"}},"outputId":"8e271ef3-1f6b-4754-841d-fe7797d070fb"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["모델 학습 시작...\n"]},{"output_type":"display_data","data":{"text/plain":[""],"text/html":["\n","
\n"," \n"," \n"," [100/100 04:37, Epoch 0/1]\n","
\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
StepTraining Loss
101.821300
201.804400
301.777700
401.827900
501.770500
601.734600
701.803100
801.693200
901.727900
1001.709200

"]},"metadata":{}},{"output_type":"stream","name":"stdout","text":["모델 학습 완료.\n","모델 저장 중...\n","모델 저장 완료.\n","토크나이저 저장 중...\n","토크나이저 저장 완료.\n"]}]},{"cell_type":"markdown","source":["# 학습된 모델 로드"],"metadata":{"id":"TiLfumsA1TVr"}},{"cell_type":"code","source":["from transformers import AutoTokenizer, AutoModelForCausalLM\n","from google.colab import drive\n","\n","# Google Drive 마운트\n","drive.mount('/content/drive')\n","\n","# 학습된 모델 로드 (Google Drive에서 저장된 경로 지정)\n","model_name = \"/content/drive/My Drive/KRX_Qwen2_1_5B\"\n","tokenizer = AutoTokenizer.from_pretrained(model_name)\n","model = AutoModelForCausalLM.from_pretrained(model_name)\n","\n","print(\"모델과 토크나이저가 성공적으로 로드되었습니다.\")"],"metadata":{"collapsed":true,"id":"onqwou0O1VU8","executionInfo":{"status":"ok","timestamp":1730212718308,"user_tz":-540,"elapsed":41590,"user":{"displayName":"‍구원정[ 학부재학 / 수학과 ]","userId":"15682121601729926510"}},"colab":{"base_uri":"https://localhost:8080/"},"outputId":"07c75506-35c0-4647-8a83-d78bf82f4a39"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"]},{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n","The secret `HF_TOKEN` does not exist in your Colab secrets.\n","To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n","You will be able to reuse this secret in all of your notebooks.\n","Please note that authentication is recommended but still optional to access public models or datasets.\n"," warnings.warn(\n"]},{"output_type":"stream","name":"stdout","text":["모델과 토크나이저가 성공적으로 로드되었습니다.\n"]}]},{"cell_type":"markdown","source":["# 추론을 위한 프롬프트 설정"],"metadata":{"id":"J3YP2p4b1YMF"}},{"cell_type":"code","source":["prompt_format = \"\"\"The following is a detailed financial question or instruction, and the corresponding answer is expected to be precise and informative. Use relevant financial terms and provide a comprehensive explanation.\n","\n","### Instruction:\n","{}\n","\n","### Response:\"\"\""],"metadata":{"id":"bEuDvZ9N1Z-l"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# 예제 프롬프트"],"metadata":{"id":"WAcjaaz91av1"}},{"cell_type":"code","source":["instruction = \"선물옵션에 대해 설명해줘.\"\n","prompt = prompt_format.format(instruction)\n","inputs = tokenizer(prompt, return_tensors=\"pt\")"],"metadata":{"id":"5rCUMNMT1cKw"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# 텍스트 생성"],"metadata":{"id":"OT6wiZby1dvy"}},{"cell_type":"code","source":["outputs = model.generate(\n"," **inputs,\n"," max_new_tokens=256,\n"," temperature=0.7, # 다양성 조절\n"," top_k=50, # 최상위 K개의 단어만 고려\n"," repetition_penalty=1.2, # 반복을 줄이기 위한 패널티\n"," use_cache=True\n",")"],"metadata":{"id":"ZQEsBUq21eyc","executionInfo":{"status":"ok","timestamp":1730212993303,"user_tz":-540,"elapsed":24179,"user":{"displayName":"‍구원정[ 학부재학 / 수학과 ]","userId":"15682121601729926510"}},"colab":{"base_uri":"https://localhost:8080/"},"outputId":"08c63ca0-9b2f-4437-9a80-5fc0d814a7a0"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:567: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.7` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n"," warnings.warn(\n","Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.\n"]}]},{"cell_type":"markdown","source":["# 결과 출력"],"metadata":{"id":"RG3dENu01gNC"}},{"cell_type":"code","source":["response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n","\n","# 입력된 프롬프트 이후의 응답 부분만 출력\n","print(response[len(prompt):].strip())"],"metadata":{"id":"IgPVmSqQ1hlC","executionInfo":{"status":"ok","timestamp":1730213030713,"user_tz":-540,"elapsed":385,"user":{"displayName":"‍구원정[ 학부재학 / 수학과 ]","userId":"15682121601729926510"}},"colab":{"base_uri":"https://localhost:8080/"},"outputId":"1f5ee5aa-8144-4b70-a47d-953e81782543"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Orange price at the port of origin?\n","\n","Yes, that's correct! The term \"orange\" in this context refers to the orange fruit itself rather than its value as an investment asset.\n"]}]}]}