{"cells":[{"cell_type":"markdown","metadata":{"id":"28e4c4d1-a73f-437b-a1bd-c2cc3874924a"},"source":["# 강의 11주차: midm-food-order-understanding\n","\n","1. KT-AI/midm-bitext-S-7B-inst-v1 를 주문 문장 이해에 미세 튜닝\n","\n","- food-order-understanding-small-3200.json (학습)\n","- food-order-understanding-small-800.json (검증)\n","\n","\n","종속적인 필요 내용\n","- huggingface 계정 설정 및 llama-2 사용 승인\n","- 로깅을 위한 wandb\n","\n","\n","history\n","\n","v1.2\n","- KT-AI/midm-bitext-S-7B-inst-v1 에 safetensors 포맷이 올라왔기에, 해당 리포에서 받도록 설정 변경\n","- 전체 과정 재검증"],"id":"28e4c4d1-a73f-437b-a1bd-c2cc3874924a"},{"cell_type":"code","execution_count":2,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"nDZe_wqKU6J3","outputId":"031e0ee2-9385-44c0-ab12-97cb3c95ffc9","executionInfo":{"status":"ok","timestamp":1702304409865,"user_tz":-540,"elapsed":14624,"user":{"displayName":"조수연","userId":"03810862007552836948"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.35.2)\n","Requirement already satisfied: peft in /usr/local/lib/python3.10/dist-packages (0.7.0)\n","Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.25.0)\n","Requirement already satisfied: optimum in /usr/local/lib/python3.10/dist-packages (1.15.0)\n","Requirement already satisfied: bitsandbytes in /usr/local/lib/python3.10/dist-packages (0.41.3.post1)\n","Requirement already satisfied: trl in /usr/local/lib/python3.10/dist-packages (0.7.4)\n","Requirement already satisfied: wandb in /usr/local/lib/python3.10/dist-packages (0.16.1)\n","Requirement already satisfied: einops in /usr/local/lib/python3.10/dist-packages (0.7.0)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.13.1)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.4)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n","Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.0)\n","Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.1)\n","Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n","Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n","Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.1.0+cu118)\n","Requirement already satisfied: coloredlogs in /usr/local/lib/python3.10/dist-packages (from optimum) (15.0.1)\n","Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from optimum) (1.12)\n","Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (from optimum) (2.15.0)\n","Requirement already satisfied: tyro>=0.5.11 in /usr/local/lib/python3.10/dist-packages (from trl) (0.6.0)\n","Requirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb) (8.1.7)\n","Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.1.40)\n","Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (1.38.0)\n","Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (0.4.0)\n","Requirement already satisfied: setproctitle in /usr/local/lib/python3.10/dist-packages (from wandb) (1.3.3)\n","Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb) (67.7.2)\n","Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb) (1.4.4)\n","Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.20.3)\n","Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n","Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.11)\n","Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (2023.6.0)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (4.5.0)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.6)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.11.17)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.2.1)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.2)\n","Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.1.0)\n","Requirement already satisfied: sentencepiece!=0.1.92,>=0.1.91 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.1.99)\n","Requirement already satisfied: docstring-parser>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl) (0.15)\n","Requirement already satisfied: rich>=11.1.0 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl) (13.7.0)\n","Requirement already satisfied: shtab>=1.5.6 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl) (1.6.5)\n","Requirement already satisfied: humanfriendly>=9.1 in /usr/local/lib/python3.10/dist-packages (from coloredlogs->optimum) (10.0)\n","Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (9.0.0)\n","Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (0.6)\n","Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (0.3.7)\n","Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (1.5.3)\n","Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (3.4.1)\n","Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (0.70.15)\n","Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (3.9.1)\n","Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->optimum) (1.3.0)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (23.1.0)\n","Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (6.0.4)\n","Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (1.9.3)\n","Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (1.4.0)\n","Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (1.3.1)\n","Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (4.0.3)\n","Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.1)\n","Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (3.0.0)\n","Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (2.16.1)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (2.1.3)\n","Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->optimum) (2.8.2)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->optimum) (2023.3.post1)\n","Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl) (0.1.2)\n"]}],"source":["pip install transformers peft accelerate optimum bitsandbytes trl wandb einops"],"id":"nDZe_wqKU6J3"},{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"51eb00d7-2928-41ad-9ae9-7f0da7d64d6d","outputId":"e7e31196-fa10-4589-e5e8-c4086486db5f","executionInfo":{"status":"ok","timestamp":1702304447771,"user_tz":-540,"elapsed":30386,"user":{"displayName":"조수연","userId":"03810862007552836948"}}},"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n"," warnings.warn(\n"]}],"source":["import os\n","from dataclasses import dataclass, field\n","from typing import Optional\n","import re\n","\n","import torch\n","import tyro\n","from accelerate import Accelerator\n","from datasets import load_dataset, Dataset\n","from peft import AutoPeftModelForCausalLM, LoraConfig\n","from tqdm import tqdm\n","from transformers import (\n"," AutoModelForCausalLM,\n"," AutoTokenizer,\n"," BitsAndBytesConfig,\n"," TrainingArguments,\n",")\n","\n","from trl import SFTTrainer\n","\n","from trl.trainer import ConstantLengthDataset"],"id":"51eb00d7-2928-41ad-9ae9-7f0da7d64d6d"},{"cell_type":"code","execution_count":4,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":162,"referenced_widgets":["dbe8b80107f646fca9ce17fc6898688e","25bab324b2b9446bad5f3a73eed40e68","1e5df26c96974f9e80ec411cc2efb005","726bbc9eda2647089f64254e9afc18a6","730a80d2060d4c0d9ddd2e17f2da0045","cd2ea8d1f93c436c8045979227f28f39","e520cbc12c7f45809976dfbfcf56dd64","cacc47dd52114b3caa6a0a420f748793","435d3880497f437fbe82c5c5aea4723b","f2c6a7c598a2446d980e5b099f8b0504","380d699b391e443594c77e0618acc1e6","81c738cb1572429fad029c865af5864e","1dbd9abdfd9f441a9a2a92797469029f","bdff58ba27c74f89acc6ce2fa028b322","a8d2283aa6d44f1ab1549f4311e88e2d","ff6ee54fece6482fa4908c5bd6f35331","4552475fe488474e98941eb5bc34fe1e","349de155fbbb411b98558636e5b363e5","29721702addc4325b2d6578e51ad6212","ff3d0f971a534f23928c1c9b133ade05","38d4d232d70d49dd8c3ab620e6cfb96c","7dcd8bfea49a447390fd3d693ce473f8","a827efea829546b7b7e5e42a465849e4","fee5d6bf794f4cb7962ef9985fbf4348","bb9ba62e3cd74e5d965fd6d7cbfffcdb","6d01340c7ea248da9b089906ddb0743f","520fd7520fe4457f88e1e7bdcbff3e99","66775e202d174977937a2bb33552e08d","ab2576b47a964778a4fb23a0177c2372","a99d5e99af0748a289fa755b80c2ceaf","129d75c4582a42b98245c5a79ea22525","92fdf3c90389449595e1d7b3605f6953"]},"id":"tX7gYxZaVhYL","outputId":"368e5df8-8976-47c1-a8be-d407e4e16a4d","executionInfo":{"status":"ok","timestamp":1702304450076,"user_tz":-540,"elapsed":364,"user":{"displayName":"조수연","userId":"03810862007552836948"}}},"outputs":[{"output_type":"display_data","data":{"text/plain":["VBox(children=(HTML(value='
/content/wandb/run-20231211_142441-q0brniqd
"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["Step | \n","Training Loss | \n","
---|---|
50 | \n","1.040400 | \n","
100 | \n","0.548100 | \n","
150 | \n","0.504600 | \n","
200 | \n","0.495700 | \n","
250 | \n","0.518000 | \n","
300 | \n","0.497100 | \n","
"]},"metadata":{}},{"output_type":"execute_result","data":{"text/plain":["TrainOutput(global_step=300, training_loss=0.6006682777404785, metrics={'train_runtime': 940.0842, 'train_samples_per_second': 0.638, 'train_steps_per_second': 0.319, 'total_flos': 9315508499251200.0, 'train_loss': 0.6006682777404785, 'epoch': 0.19})"]},"metadata":{},"execution_count":28}],"source":["trainer.train()"],"id":"14019fa9-0c6f-4729-ac99-0d407af375b8"},{"cell_type":"code","execution_count":29,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":35},"id":"3Y4FQSyRghQt","outputId":"60b008f1-1e1c-42f3-bd0c-1157fa7412b7","executionInfo":{"status":"ok","timestamp":1702305626226,"user_tz":-540,"elapsed":412,"user":{"displayName":"조수연","userId":"03810862007552836948"}}},"outputs":[{"output_type":"execute_result","data":{"text/plain":["'/gdrive/MyDrive/lora-midm-7b-food-order-understanding'"],"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"}},"metadata":{},"execution_count":29}],"source":["script_args.training_args.output_dir"],"id":"3Y4FQSyRghQt"},{"cell_type":"code","execution_count":30,"metadata":{"id":"49f05450-da2a-4edd-9db2-63836a0ec73a","executionInfo":{"status":"ok","timestamp":1702305629228,"user_tz":-540,"elapsed":851,"user":{"displayName":"조수연","userId":"03810862007552836948"}}},"outputs":[],"source":["trainer.save_model(script_args.training_args.output_dir)"],"id":"49f05450-da2a-4edd-9db2-63836a0ec73a"},{"cell_type":"markdown","metadata":{"id":"652f307e-e1d7-43ae-b083-dba2d94c2296"},"source":["# 추론 테스트"],"id":"652f307e-e1d7-43ae-b083-dba2d94c2296"},{"cell_type":"code","execution_count":31,"metadata":{"id":"ea8a1fea-7499-4386-9dea-0509110f61af","executionInfo":{"status":"ok","timestamp":1702305631310,"user_tz":-540,"elapsed":857,"user":{"displayName":"조수연","userId":"03810862007552836948"}}},"outputs":[],"source":["from transformers import pipeline, TextStreamer"],"id":"ea8a1fea-7499-4386-9dea-0509110f61af"},{"cell_type":"code","execution_count":32,"metadata":{"id":"52626888-1f6e-46b6-a8dd-836622149ff5","executionInfo":{"status":"ok","timestamp":1702305633700,"user_tz":-540,"elapsed":481,"user":{"displayName":"조수연","userId":"03810862007552836948"}}},"outputs":[],"source":["instruction_prompt_template = \"\"\"###System;다음은 매장에서 고객이 음식을 주문하는 주문 문장이다. 이를 분석하여 음식명, 옵션명, 수량을 추출하여 고객의 의도를 이해하고자 한다.\n","분석 결과를 완성해주기 바란다.\n","\n","### 주문 문장: {0} ### 분석 결과:\n","\"\"\"\n","\n","prompt_template = \"\"\"###System;{System}\n","###User;{User}\n","###Midm;\"\"\"\n","\n","default_system_msg = (\n"," \"너는 먼저 사용자가 입력한 주문 문장을 분석하는 에이전트이다. 이로부터 주문을 구성하는 음식명, 옵션명, 수량을 차례대로 추출해야 한다.\"\n",")"],"id":"52626888-1f6e-46b6-a8dd-836622149ff5"},{"cell_type":"code","execution_count":33,"metadata":{"id":"46e844fa-8f63-4359-a4fb-df66e8171796","executionInfo":{"status":"ok","timestamp":1702305636576,"user_tz":-540,"elapsed":1,"user":{"displayName":"조수연","userId":"03810862007552836948"}}},"outputs":[],"source":["evaluation_queries = [\n"," \"오늘은 비가오니깐 이거 먹자. 삼선짬뽕 곱배기 하나하구요, 사천 탕수육 중짜 한그릇 주세요.\",\n"," \"아이스아메리카노 톨사이즈 한잔 하고요. 딸기스무디 한잔 주세요. 또, 콜드브루라떼 하나요.\",\n"," \"참이슬 한병, 코카콜라 1.5리터 한병, 테슬라 한병이요.\",\n"," \"꼬막무침 1인분하고요, 닭도리탕 중자 주세요. 그리고 소주도 한병 주세요.\",\n"," \"김치찌개 3인분하고요, 계란말이 주세요.\",\n"," \"불고기버거세트 1개하고요 감자튀김 추가해주세요.\",\n"," \"불닭볶음면 1개랑 사리곰탕면 2개 주세요.\",\n"," \"카페라떼 아이스 샷추가 한잔하구요. 스콘 하나 주세요\",\n"," \"여기요 춘천닭갈비 4인분하고요. 라면사리 추가하겠습니다. 콜라 300ml 두캔주세요.\",\n"," \"있잖아요 조랭이떡국 3인분하고요. 떡만두 한세트 주세요.\",\n"," \"깐풍탕수 2인분 하고요 콜라 1.5리터 한병이요.\",\n","]"],"id":"46e844fa-8f63-4359-a4fb-df66e8171796"},{"cell_type":"code","execution_count":34,"metadata":{"id":"1919cf1f-482e-4185-9d06-e3cea1918416","executionInfo":{"status":"ok","timestamp":1702305639801,"user_tz":-540,"elapsed":344,"user":{"displayName":"조수연","userId":"03810862007552836948"}}},"outputs":[],"source":["def wrapper_generate(model, input_prompt, do_stream=False):\n"," data = tokenizer(input_prompt, return_tensors=\"pt\")\n"," streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n"," input_ids = data.input_ids[..., :-1]\n"," with torch.no_grad():\n"," pred = model.generate(\n"," input_ids=input_ids.cuda(),\n"," streamer=streamer if do_stream else None,\n"," use_cache=True,\n"," max_new_tokens=float('inf'),\n"," do_sample=False\n"," )\n"," decoded_text = tokenizer.batch_decode(pred, skip_special_tokens=True)\n"," decoded_text = decoded_text[0].replace(\"<[!newline]>\", \"\\n\")\n"," return (decoded_text[len(input_prompt):])"],"id":"1919cf1f-482e-4185-9d06-e3cea1918416"},{"cell_type":"code","execution_count":35,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"eaac1f6f-c823-4488-8edb-2f931ddf0daa","outputId":"c632e94d-faad-4244-b32d-139ace8783f8","executionInfo":{"status":"ok","timestamp":1702306195075,"user_tz":-540,"elapsed":552708,"user":{"displayName":"조수연","userId":"03810862007552836948"}}},"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1473: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use and modify the model generation configuration (see https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )\n"," warnings.warn(\n"]}],"source":["eval_dic = {i:wrapper_generate(model=base_model, input_prompt=prompt_template.format(System=default_system_msg, User=evaluation_queries[i]))for i, query in enumerate(evaluation_queries)}"],"id":"eaac1f6f-c823-4488-8edb-2f931ddf0daa"},{"cell_type":"code","execution_count":36,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"fefd04ba-2ed8-4f84-bdd0-86d52b3f39f6","outputId":"0d52da0b-d64c-4d60-a624-81d094fbbb13","executionInfo":{"status":"ok","timestamp":1702306195075,"user_tz":-540,"elapsed":18,"user":{"displayName":"조수연","userId":"03810862007552836948"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["- 분석 결과 0: 음식명:삼선짬뽕, 옵션:곱배기, 수량:하나\n","- 분석 결과 1: 음식명:사천 탕수육, 옵션:중짜, 수량:한그릇\n"]}],"source":["print(eval_dic[0])"],"id":"fefd04ba-2ed8-4f84-bdd0-86d52b3f39f6"},{"cell_type":"markdown","metadata":{"id":"3f471e3a-723b-4df5-aa72-46f571f6bab6"},"source":["# 미세튜닝된 모델 로딩 후 테스트"],"id":"3f471e3a-723b-4df5-aa72-46f571f6bab6"},{"cell_type":"code","execution_count":37,"metadata":{"id":"a43bdd07-7555-42b2-9888-a614afec892f","executionInfo":{"status":"ok","timestamp":1702306199550,"user_tz":-540,"elapsed":368,"user":{"displayName":"조수연","userId":"03810862007552836948"}}},"outputs":[],"source":["bnb_config = BitsAndBytesConfig(\n"," load_in_4bit=True,\n"," bnb_4bit_quant_type=\"nf4\",\n"," bnb_4bit_compute_dtype=torch.bfloat16,\n",")"],"id":"a43bdd07-7555-42b2-9888-a614afec892f"},{"cell_type":"code","execution_count":39,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":705},"id":"39db2ee4-23c8-471f-89b2-bca34964bf81","outputId":"d00d2dc2-cd2f-480c-85a2-33cf265314b2","executionInfo":{"status":"error","timestamp":1702306279779,"user_tz":-540,"elapsed":15084,"user":{"displayName":"조수연","userId":"03810862007552836948"}}},"outputs":[{"output_type":"error","ename":"ValueError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)","\u001b[0;32m'))\n","\n","trained_model.config.pad_token_id = tokenizer.pad_token_id\n","trained_model.config.bos_token_id = tokenizer.bos_token_id"],"id":"b0b75ca4-730d-4bde-88bb-a86462a76d52"},{"cell_type":"markdown","metadata":{"id":"X1tRCa4EiYXp"},"source":["추론 과정에서는 GPU 메모리를 약 5.5 GB 활용"],"id":"X1tRCa4EiYXp"},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"background_save":true,"base_uri":"https://localhost:8080/"},"id":"e374555b-9f8a-4617-8ea7-c1e6ee1b2999","outputId":"526d2827-6422-4399-d7ed-107b822b2bb2"},"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1473: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use and modify the model generation configuration (see https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )\n"," warnings.warn(\n"]},{"output_type":"stream","name":"stdout","text":["- 분석 결과 0: 음식명:삼선짬뽕, 옵션:곱배기, 수량:하나<[!newline]>- 분석 결과 1: 음식명:사천 탕수육, 옵션:중짜, 수량:한그릇\n","- 분석 결과 0: 음식명:아이스아메리카노,옵션:톨사이즈,수량:한잔<[!newline]>- 분석 결과 1: 음식명:딸기스무디,수량:한잔<[!newline]>- 분석 결과 2: 음식명:콜드브루라떼,수량:하나\n","- 분석 결과 0: 음식명:참이슬,수량:한병<[!newline]>- 분석 결과 1: 음식명:코카콜라,옵션:1.5리터,수량:한병<[!newline]>- 분석 결과 2: 음식명:테슬라,수량:한병\n","- 분석 결과 0: 음식명:꼬막무침, 수량:1인분<[!newline]>- 분석 결과 1: 음식명:닭도리탕, 옵션:중자<[!newline]>- 분석 결과 2: 음식명:소주, 수량:한병\n","- 분석 결과 0: 음식명:김치찌개, 수량:3인분<[!newline]>- 분석 결과 1: 음식명:계란말이\n","- 분석 결과 0: 음식명:불고기버거세트, 수량:1개<[!newline]>- 분석 결과 1: 음식명:감자튀김, 수량:추가\n","- 분석 결과 0: "]}],"source":["eval_dic = {i:wrapper_generate(model=trained_model, do_stream=True, input_prompt=prompt_template.format(System=default_system_msg, User=evaluation_queries[i]))for i, query in enumerate(evaluation_queries)}"],"id":"e374555b-9f8a-4617-8ea7-c1e6ee1b2999"},{"cell_type":"code","execution_count":null,"metadata":{"id":"5d055bb0-5e5f-4221-a634-45d903c0f3b5"},"outputs":[],"source":["print(eval_dic[0])"],"id":"5d055bb0-5e5f-4221-a634-45d903c0f3b5"}],"metadata":{"accelerator":"GPU","colab":{"provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.13"},"widgets":{"application/vnd.jupyter.widget-state+json":{"dbe8b80107f646fca9ce17fc6898688e":{"model_module":"@jupyter-widgets/controls","model_name":"VBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"VBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"VBoxView","box_style":"","children":["IPY_MODEL_38d4d232d70d49dd8c3ab620e6cfb96c","IPY_MODEL_7dcd8bfea49a447390fd3d693ce473f8","IPY_MODEL_a827efea829546b7b7e5e42a465849e4","IPY_MODEL_fee5d6bf794f4cb7962ef9985fbf4348"],"layout":"IPY_MODEL_e520cbc12c7f45809976dfbfcf56dd64"}},"25bab324b2b9446bad5f3a73eed40e68":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_cacc47dd52114b3caa6a0a420f748793","placeholder":"","style":"IPY_MODEL_435d3880497f437fbe82c5c5aea4723b","value":"
Copy a token from your Hugging Face\ntokens page and paste it below.
Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file.