{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71fbfca2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM\n",
    "from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "import os\n",
    "from transformers import AutoTokenizer\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import default_data_collator, get_linear_schedule_with_warmup\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "\n",
    "device = \"cuda\"\n",
    "model_name_or_path = \"bigscience/bloomz-560m\"\n",
    "tokenizer_name_or_path = \"bigscience/bloomz-560m\"\n",
    "peft_config = PromptTuningConfig(\n",
    "    task_type=TaskType.CAUSAL_LM,\n",
    "    prompt_tuning_init=PromptTuningInit.TEXT,\n",
    "    num_virtual_tokens=8,\n",
    "    prompt_tuning_init_text=\"Classify if the tweet is a complaint or not:\",\n",
    "    tokenizer_name_or_path=model_name_or_path,\n",
    ")\n",
    "\n",
    "dataset_name = \"twitter_complaints\"\n",
    "checkpoint_name = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}_v1.pt\".replace(\n",
    "    \"/\", \"_\"\n",
    ")\n",
    "text_column = \"Tweet text\"\n",
    "label_column = \"text_label\"\n",
    "max_length = 64\n",
    "lr = 3e-2\n",
    "num_epochs = 50\n",
    "batch_size = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1a3648b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"ought/raft\", dataset_name)\n",
    "\n",
    "classes = [k.replace(\"_\", \" \") for k in dataset[\"train\"].features[\"Label\"].names]\n",
    "print(classes)\n",
    "dataset = dataset.map(\n",
    "    lambda x: {\"text_label\": [classes[label] for label in x[\"Label\"]]},\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    ")\n",
    "print(dataset)\n",
    "dataset[\"train\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe12d4d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data preprocessing\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
    "if tokenizer.pad_token_id is None:\n",
    "    tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "target_max_length = max([len(tokenizer(class_label)[\"input_ids\"]) for class_label in classes])\n",
    "print(target_max_length)\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    batch_size = len(examples[text_column])\n",
    "    inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
    "    targets = [str(x) for x in examples[label_column]]\n",
    "    model_inputs = tokenizer(inputs)\n",
    "    labels = tokenizer(targets, add_special_tokens=False)  # don't add bos token because we concatenate with inputs\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i] + [tokenizer.eos_token_id]\n",
    "        # print(i, sample_input_ids, label_input_ids)\n",
    "        model_inputs[\"input_ids\"][i] = sample_input_ids + label_input_ids\n",
    "        labels[\"input_ids\"][i] = [-100] * len(sample_input_ids) + label_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
    "    # print(model_inputs)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i]\n",
    "        model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
    "            max_length - len(sample_input_ids)\n",
    "        ) + sample_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
    "            \"attention_mask\"\n",
    "        ][i]\n",
    "        labels[\"input_ids\"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids\n",
    "        model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
    "        model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
    "        labels[\"input_ids\"][i] = torch.tensor(labels[\"input_ids\"][i][:max_length])\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "processed_datasets = dataset.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "train_dataset = processed_datasets[\"train\"]\n",
    "eval_dataset = processed_datasets[\"train\"]\n",
    "\n",
    "\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True\n",
    ")\n",
    "eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "641b21fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_preprocess_function(examples):\n",
    "    batch_size = len(examples[text_column])\n",
    "    inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
    "    model_inputs = tokenizer(inputs)\n",
    "    # print(model_inputs)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
    "            max_length - len(sample_input_ids)\n",
    "        ) + sample_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
    "            \"attention_mask\"\n",
    "        ][i]\n",
    "        model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
    "        model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "test_dataset = dataset[\"test\"].map(\n",
    "    test_preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "test_dataloader = DataLoader(test_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)\n",
    "next(iter(test_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "accc5012",
   "metadata": {},
   "outputs": [],
   "source": [
    "next(iter(train_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "218df807",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(test_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47d1fedf",
   "metadata": {},
   "outputs": [],
   "source": [
    "next(iter(test_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a773e092",
   "metadata": {},
   "outputs": [],
   "source": [
    "# creating model\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name_or_path)\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b2f91568",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model\n",
    "# optimizer and lr scheduler\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
    "lr_scheduler = get_linear_schedule_with_warmup(\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=(len(train_dataloader) * num_epochs),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e4fb69fc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  5.68it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=0: train_ppl=tensor(2.2720e+13, device='cuda:0') train_epoch_loss=tensor(30.7543, device='cuda:0') eval_ppl=tensor(483597.5625, device='cuda:0') eval_epoch_loss=tensor(13.0890, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.91it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 20.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=1: train_ppl=tensor(452658.3750, device='cuda:0') train_epoch_loss=tensor(13.0229, device='cuda:0') eval_ppl=tensor(275088.1875, device='cuda:0') eval_epoch_loss=tensor(12.5248, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.90it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=2: train_ppl=tensor(199203.3906, device='cuda:0') train_epoch_loss=tensor(12.2021, device='cuda:0') eval_ppl=tensor(143637.0312, device='cuda:0') eval_epoch_loss=tensor(11.8750, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.92it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=3: train_ppl=tensor(114743.9531, device='cuda:0') train_epoch_loss=tensor(11.6505, device='cuda:0') eval_ppl=tensor(54962., device='cuda:0') eval_epoch_loss=tensor(10.9144, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.81it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.34it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=4: train_ppl=tensor(40786.5977, device='cuda:0') train_epoch_loss=tensor(10.6161, device='cuda:0') eval_ppl=tensor(18342.5430, device='cuda:0') eval_epoch_loss=tensor(9.8170, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.89it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.34it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=5: train_ppl=tensor(14023.0830, device='cuda:0') train_epoch_loss=tensor(9.5485, device='cuda:0') eval_ppl=tensor(6316.8540, device='cuda:0') eval_epoch_loss=tensor(8.7510, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.84it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.32it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=6: train_ppl=tensor(5635.3262, device='cuda:0') train_epoch_loss=tensor(8.6368, device='cuda:0') eval_ppl=tensor(2476.5776, device='cuda:0') eval_epoch_loss=tensor(7.8146, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.88it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=7: train_ppl=tensor(1818.4940, device='cuda:0') train_epoch_loss=tensor(7.5058, device='cuda:0') eval_ppl=tensor(934.1146, device='cuda:0') eval_epoch_loss=tensor(6.8396, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.05it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 18.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=8: train_ppl=tensor(645.2143, device='cuda:0') train_epoch_loss=tensor(6.4696, device='cuda:0') eval_ppl=tensor(361.9093, device='cuda:0') eval_epoch_loss=tensor(5.8914, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.67it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 19.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=9: train_ppl=tensor(293.8047, device='cuda:0') train_epoch_loss=tensor(5.6829, device='cuda:0') eval_ppl=tensor(215.8185, device='cuda:0') eval_epoch_loss=tensor(5.3744, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.54it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 20.83it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=10: train_ppl=tensor(191.2377, device='cuda:0') train_epoch_loss=tensor(5.2535, device='cuda:0') eval_ppl=tensor(177.1512, device='cuda:0') eval_epoch_loss=tensor(5.1770, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.02it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 18.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=11: train_ppl=tensor(153.6052, device='cuda:0') train_epoch_loss=tensor(5.0344, device='cuda:0') eval_ppl=tensor(126.6154, device='cuda:0') eval_epoch_loss=tensor(4.8412, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.54it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 18.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=12: train_ppl=tensor(122.8925, device='cuda:0') train_epoch_loss=tensor(4.8113, device='cuda:0') eval_ppl=tensor(97.3331, device='cuda:0') eval_epoch_loss=tensor(4.5781, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.66it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 19.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=13: train_ppl=tensor(84.8845, device='cuda:0') train_epoch_loss=tensor(4.4413, device='cuda:0') eval_ppl=tensor(70.3213, device='cuda:0') eval_epoch_loss=tensor(4.2531, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  6.73it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=14: train_ppl=tensor(64.6705, device='cuda:0') train_epoch_loss=tensor(4.1693, device='cuda:0') eval_ppl=tensor(50.4688, device='cuda:0') eval_epoch_loss=tensor(3.9214, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.41it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=15: train_ppl=tensor(44.2937, device='cuda:0') train_epoch_loss=tensor(3.7908, device='cuda:0') eval_ppl=tensor(34.8210, device='cuda:0') eval_epoch_loss=tensor(3.5502, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.31it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=16: train_ppl=tensor(30.0995, device='cuda:0') train_epoch_loss=tensor(3.4045, device='cuda:0') eval_ppl=tensor(24.7703, device='cuda:0') eval_epoch_loss=tensor(3.2096, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.31it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=17: train_ppl=tensor(23.3086, device='cuda:0') train_epoch_loss=tensor(3.1488, device='cuda:0') eval_ppl=tensor(20.8131, device='cuda:0') eval_epoch_loss=tensor(3.0356, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.29it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=18: train_ppl=tensor(16.4479, device='cuda:0') train_epoch_loss=tensor(2.8002, device='cuda:0') eval_ppl=tensor(12.0876, device='cuda:0') eval_epoch_loss=tensor(2.4922, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.37it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.37it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=19: train_ppl=tensor(11.1977, device='cuda:0') train_epoch_loss=tensor(2.4157, device='cuda:0') eval_ppl=tensor(9.0399, device='cuda:0') eval_epoch_loss=tensor(2.2016, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.23it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 17.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=20: train_ppl=tensor(8.1847, device='cuda:0') train_epoch_loss=tensor(2.1023, device='cuda:0') eval_ppl=tensor(6.7486, device='cuda:0') eval_epoch_loss=tensor(1.9093, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.30it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=21: train_ppl=tensor(6.1145, device='cuda:0') train_epoch_loss=tensor(1.8107, device='cuda:0') eval_ppl=tensor(5.5931, device='cuda:0') eval_epoch_loss=tensor(1.7215, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.34it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=22: train_ppl=tensor(5.2963, device='cuda:0') train_epoch_loss=tensor(1.6670, device='cuda:0') eval_ppl=tensor(5.0573, device='cuda:0') eval_epoch_loss=tensor(1.6208, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.84it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=23: train_ppl=tensor(4.7485, device='cuda:0') train_epoch_loss=tensor(1.5578, device='cuda:0') eval_ppl=tensor(3.6277, device='cuda:0') eval_epoch_loss=tensor(1.2886, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.84it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=24: train_ppl=tensor(3.4080, device='cuda:0') train_epoch_loss=tensor(1.2261, device='cuda:0') eval_ppl=tensor(3.0467, device='cuda:0') eval_epoch_loss=tensor(1.1141, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.88it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=25: train_ppl=tensor(3.3052, device='cuda:0') train_epoch_loss=tensor(1.1955, device='cuda:0') eval_ppl=tensor(2.7784, device='cuda:0') eval_epoch_loss=tensor(1.0219, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.86it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=26: train_ppl=tensor(2.9487, device='cuda:0') train_epoch_loss=tensor(1.0814, device='cuda:0') eval_ppl=tensor(2.9471, device='cuda:0') eval_epoch_loss=tensor(1.0808, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.85it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=27: train_ppl=tensor(2.8738, device='cuda:0') train_epoch_loss=tensor(1.0556, device='cuda:0') eval_ppl=tensor(2.5801, device='cuda:0') eval_epoch_loss=tensor(0.9478, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.84it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=28: train_ppl=tensor(2.3241, device='cuda:0') train_epoch_loss=tensor(0.8433, device='cuda:0') eval_ppl=tensor(2.2198, device='cuda:0') eval_epoch_loss=tensor(0.7974, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.84it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 20.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=29: train_ppl=tensor(2.0376, device='cuda:0') train_epoch_loss=tensor(0.7118, device='cuda:0') eval_ppl=tensor(1.8572, device='cuda:0') eval_epoch_loss=tensor(0.6191, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.76it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 18.83it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=30: train_ppl=tensor(1.8301, device='cuda:0') train_epoch_loss=tensor(0.6044, device='cuda:0') eval_ppl=tensor(1.8864, device='cuda:0') eval_epoch_loss=tensor(0.6347, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.80it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 19.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=31: train_ppl=tensor(1.7301, device='cuda:0') train_epoch_loss=tensor(0.5482, device='cuda:0') eval_ppl=tensor(1.6340, device='cuda:0') eval_epoch_loss=tensor(0.4910, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.60it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 19.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=32: train_ppl=tensor(1.5842, device='cuda:0') train_epoch_loss=tensor(0.4601, device='cuda:0') eval_ppl=tensor(1.6179, device='cuda:0') eval_epoch_loss=tensor(0.4811, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.11it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 18.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=33: train_ppl=tensor(1.5193, device='cuda:0') train_epoch_loss=tensor(0.4183, device='cuda:0') eval_ppl=tensor(1.5543, device='cuda:0') eval_epoch_loss=tensor(0.4410, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.59it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 18.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=34: train_ppl=tensor(1.5402, device='cuda:0') train_epoch_loss=tensor(0.4319, device='cuda:0') eval_ppl=tensor(1.4924, device='cuda:0') eval_epoch_loss=tensor(0.4004, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.80it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 19.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=35: train_ppl=tensor(1.4410, device='cuda:0') train_epoch_loss=tensor(0.3654, device='cuda:0') eval_ppl=tensor(1.3888, device='cuda:0') eval_epoch_loss=tensor(0.3284, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  6.60it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=36: train_ppl=tensor(1.3675, device='cuda:0') train_epoch_loss=tensor(0.3130, device='cuda:0') eval_ppl=tensor(1.4001, device='cuda:0') eval_epoch_loss=tensor(0.3366, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.40it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=37: train_ppl=tensor(1.4197, device='cuda:0') train_epoch_loss=tensor(0.3505, device='cuda:0') eval_ppl=tensor(1.3214, device='cuda:0') eval_epoch_loss=tensor(0.2787, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.27it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=38: train_ppl=tensor(1.3855, device='cuda:0') train_epoch_loss=tensor(0.3261, device='cuda:0') eval_ppl=tensor(1.3501, device='cuda:0') eval_epoch_loss=tensor(0.3001, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.25it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=39: train_ppl=tensor(1.3643, device='cuda:0') train_epoch_loss=tensor(0.3107, device='cuda:0') eval_ppl=tensor(1.3549, device='cuda:0') eval_epoch_loss=tensor(0.3037, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.28it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=40: train_ppl=tensor(1.3093, device='cuda:0') train_epoch_loss=tensor(0.2695, device='cuda:0') eval_ppl=tensor(1.3233, device='cuda:0') eval_epoch_loss=tensor(0.2801, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.24it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=41: train_ppl=tensor(1.3108, device='cuda:0') train_epoch_loss=tensor(0.2706, device='cuda:0') eval_ppl=tensor(1.3440, device='cuda:0') eval_epoch_loss=tensor(0.2957, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.78it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=42: train_ppl=tensor(1.2944, device='cuda:0') train_epoch_loss=tensor(0.2581, device='cuda:0') eval_ppl=tensor(1.2711, device='cuda:0') eval_epoch_loss=tensor(0.2399, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.29it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=43: train_ppl=tensor(1.2616, device='cuda:0') train_epoch_loss=tensor(0.2323, device='cuda:0') eval_ppl=tensor(1.2449, device='cuda:0') eval_epoch_loss=tensor(0.2190, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.85it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=44: train_ppl=tensor(1.2478, device='cuda:0') train_epoch_loss=tensor(0.2214, device='cuda:0') eval_ppl=tensor(1.2202, device='cuda:0') eval_epoch_loss=tensor(0.1990, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.85it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=45: train_ppl=tensor(1.2350, device='cuda:0') train_epoch_loss=tensor(0.2111, device='cuda:0') eval_ppl=tensor(1.2180, device='cuda:0') eval_epoch_loss=tensor(0.1972, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.86it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=46: train_ppl=tensor(1.2277, device='cuda:0') train_epoch_loss=tensor(0.2052, device='cuda:0') eval_ppl=tensor(1.2077, device='cuda:0') eval_epoch_loss=tensor(0.1887, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.87it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=47: train_ppl=tensor(1.2037, device='cuda:0') train_epoch_loss=tensor(0.1854, device='cuda:0') eval_ppl=tensor(1.2041, device='cuda:0') eval_epoch_loss=tensor(0.1857, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.83it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=48: train_ppl=tensor(1.2026, device='cuda:0') train_epoch_loss=tensor(0.1845, device='cuda:0') eval_ppl=tensor(1.1982, device='cuda:0') eval_epoch_loss=tensor(0.1808, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.86it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=49: train_ppl=tensor(1.2005, device='cuda:0') train_epoch_loss=tensor(0.1827, device='cuda:0') eval_ppl=tensor(1.1968, device='cuda:0') eval_epoch_loss=tensor(0.1796, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# training and evaluation\n",
    "model = model.to(device)\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for step, batch in enumerate(tqdm(train_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        #         print(batch)\n",
    "        #         print(batch[\"input_ids\"].shape)\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        total_loss += loss.detach().float()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "    model.eval()\n",
    "    eval_loss = 0\n",
    "    eval_preds = []\n",
    "    for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        eval_loss += loss.detach().float()\n",
    "        eval_preds.extend(\n",
    "            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)\n",
    "        )\n",
    "\n",
    "    eval_epoch_loss = eval_loss / len(eval_dataloader)\n",
    "    eval_ppl = torch.exp(eval_epoch_loss)\n",
    "    train_epoch_loss = total_loss / len(train_dataloader)\n",
    "    train_ppl = torch.exp(train_epoch_loss)\n",
    "    print(f\"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "53752a7b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "@TommyHilfiger Dramatic shopping exp. ordered 6 jeans same size (30/32) 2 fits / 2 too large / 2 too slim : same brand &gt; different sizing\n",
      "{'input_ids': tensor([[227985,   5484,    915,   2566, 226154, 126015,   5385,    259, 239364,\n",
      "           3396,  70823,   5853,     17,  57247,   1231, 191040,   5025,   7869,\n",
      "            375,   2324, 149349,     12,    415, 122321,    897,    415,  10136,\n",
      "          10021,    897,    415,  10136,   6497,    381,    915,   5025,  51950,\n",
      "          66869,   5955,    272,  20311,  77658,    915,    210]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n",
      "tensor([[227985,   5484,    915,   2566, 226154, 126015,   5385,    259, 239364,\n",
      "           3396,  70823,   5853,     17,  57247,   1231, 191040,   5025,   7869,\n",
      "            375,   2324, 149349,     12,    415, 122321,    897,    415,  10136,\n",
      "          10021,    897,    415,  10136,   6497,    381,    915,   5025,  51950,\n",
      "          66869,   5955,    272,  20311,  77658,    915,    210,  16449,   5952,\n",
      "              3]], device='cuda:0')\n",
      "['Tweet text : @TommyHilfiger Dramatic shopping exp. ordered 6 jeans same size (30/32) 2 fits / 2 too large / 2 too slim : same brand &gt; different sizing Label : complaint']\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "i = 33\n",
    "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
    "print(dataset[\"test\"][i][\"Tweet text\"])\n",
    "print(inputs)\n",
    "\n",
    "with torch.no_grad():\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    outputs = model.generate(\n",
    "        input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_new_tokens=10, eos_token_id=3\n",
    "    )\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8f35152",
   "metadata": {},
   "source": [
    "You can push model to hub or save model locally. \n",
    "\n",
    "- Option1: Pushing the model to Hugging Face Hub\n",
    "```python\n",
    "model.push_to_hub(\n",
    "    f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\"/\", \"_\"),\n",
    "    token = \"hf_...\"\n",
    ")\n",
    "```\n",
    "token (`bool` or `str`, *optional*):\n",
    "    `token` is to be used for HTTP Bearer authorization when accessing remote files. If `True`, will use the token generated\n",
    "    when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`\n",
    "    is not specified.\n",
    "    Or you can get your token from https://huggingface.co/settings/token\n",
    "```\n",
    "- Or save model locally\n",
    "```python\n",
    "peft_model_id = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\"/\", \"_\")\n",
    "model.save_pretrained(peft_model_id)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d8ba1f8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# saving model\n",
    "peft_model_id = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\n",
    "    \"/\", \"_\"\n",
    ")\n",
    "model.save_pretrained(peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "4928c7f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "36K\tbigscience/bloomz-560m_PROMPT_TUNING_CAUSAL_LM/adapter_model.bin\n"
     ]
    }
   ],
   "source": [
    "ckpt = f\"{peft_model_id}/adapter_model.bin\"\n",
    "!du -h $ckpt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "4d9476e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import PeftModel, PeftConfig\n",
    "\n",
    "peft_model_id = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\n",
    "    \"/\", \"_\"\n",
    ")\n",
    "\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "ebe174a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "@greateranglia Ok thanks...\n",
      "{'input_ids': tensor([[227985,   5484,    915,   2566,  14173,   2960,  29906,    387,  20706,\n",
      "          49337,   1369,  77658,    915,    210]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n",
      "tensor([[227985,   5484,    915,   2566,  14173,   2960,  29906,    387,  20706,\n",
      "          49337,   1369,  77658,    915,    210,   1936, 106863,      3]],\n",
      "       device='cuda:0')\n",
      "['Tweet text : @greateranglia Ok thanks... Label : no complaint']\n"
     ]
    }
   ],
   "source": [
    "model.to(device)\n",
    "model.eval()\n",
    "i = 4\n",
    "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
    "print(dataset[\"test\"][i][\"Tweet text\"])\n",
    "print(inputs)\n",
    "\n",
    "with torch.no_grad():\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    outputs = model.generate(\n",
    "        input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_new_tokens=10, eos_token_id=3\n",
    "    )\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24041ee1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}