{ "cells": [ { "cell_type": "code", "execution_count": 20, "id": "d136f503-bb1b-404e-8657-ce3168eae54b", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import torch\n", "from tqdm import tqdm\n", "from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\n", "from transformers import DistilBertTokenizer, DistilBertModel\n", "import streamlit as st\n", "\n", "\n", "\n", "\n", "MAX_LEN = 512\n", "TRAIN_BATCH_SIZE = 16\n", "VALID_BATCH_SIZE = 16\n", "EPOCHS = 3\n", "LEARNING_RATE = 1e-05\n", "DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n", "print(DEVICE)\n", "\n", "tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', truncation=True, do_lower_case=True)\n", "\n", "class MultiLabelDataset(Dataset):\n", "\n", " def __init__(self, dataframe, tokenizer, max_len, new_data=False):\n", " self.tokenizer = tokenizer\n", " self.data = dataframe\n", " self.text = dataframe.comment_text\n", " self.new_data = new_data\n", " \n", " if not new_data:\n", " self.targets = self.data.labels\n", " self.max_len = max_len\n", "\n", " def __len__(self):\n", " return len(self.text)\n", "\n", " def __getitem__(self, index):\n", " text = str(self.text[index])\n", " text = \" \".join(text.split())\n", "\n", " inputs = self.tokenizer.encode_plus(\n", " text,\n", " None,\n", " add_special_tokens=True,\n", " max_length=self.max_len,\n", " pad_to_max_length=True,\n", " return_token_type_ids=True\n", " )\n", " ids = inputs['input_ids']\n", " mask = inputs['attention_mask']\n", " token_type_ids = inputs[\"token_type_ids\"]\n", "\n", " out = {\n", " 'ids': torch.tensor(ids, dtype=torch.long),\n", " 'mask': torch.tensor(mask, dtype=torch.long),\n", " 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),\n", " }\n", " \n", " if not self.new_data:\n", " out['targets'] = torch.tensor(self.targets[index], dtype=torch.float)\n", "\n", " return out\n", "\n", "class DistilBERTClass(torch.nn.Module):\n", " def __init__(self):\n", " super(DistilBERTClass, self).__init__()\n", " \n", " self.bert = DistilBertModel.from_pretrained(\"distilbert-base-uncased\")\n", " self.classifier = torch.nn.Sequential(\n", " torch.nn.Linear(768, 768),\n", " torch.nn.ReLU(),\n", " torch.nn.Dropout(0.1),\n", " torch.nn.Linear(768, 6)\n", " )\n", "\n", " def forward(self, input_ids, attention_mask, token_type_ids):\n", " output_1 = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n", " hidden_state = output_1[0]\n", " out = hidden_state[:, 0]\n", " out = self.classifier(out)\n", " return out\n", "\n", "model = DistilBERTClass()\n", "model.to(DEVICE);\n", "\n", "model_loaded = torch.load('model/inference_models_output_4fold_distilbert_fold_best_model.pth',map_location=torch.device('cpu'))\n", "\n", "model.load_state_dict(model_loadede['model'])\n", "\n", "\n", "val_params = {'batch_size': VALID_BATCH_SIZE,\n", " 'shuffle': False,\n", " 'num_workers': 8\n", " }\n", "def give_toxic(text):\n", " text = \"You fucker \"\n", " test_data = pd.DataFrame([text],columns=['comment_text'])\n", " test_set = MultiLabelDataset(test_data, tokenizer, MAX_LEN, new_data=True)\n", " test_loader = DataLoader(test_set, **val_params)\n", "\n", " all_test_pred = []\n", "\n", " def test(epoch):\n", " model.eval()\n", "\n", " with torch.inference_mode():\n", "\n", " for _, data in tqdm(enumerate(test_loader, 0)):\n", "\n", "\n", " ids = data['ids'].to(DEVICE, dtype=torch.long)\n", " mask = data['mask'].to(DEVICE, dtype=torch.long)\n", " token_type_ids = data['token_type_ids'].to(DEVICE, dtype=torch.long)\n", " outputs = model(ids, mask, token_type_ids)\n", " probas = torch.sigmoid(outputs)\n", "\n", " all_test_pred.append(probas)\n", "\n", "\n", " probas = test(model)\n", "\n", " all_test_pred = torch.cat(all_test_pred)\n", "\n", " label_columns = [\"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"]\n", "\n", " preds = all_test_pred.detach().cpu().numpy()[0]\n", "\n", " final_dict = dict(zip(label_columns , preds))\n", " return final_dict\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "db651873-60cd-4cd7-8ba0-da6c62e22ca8", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.11" } }, "nbformat": 4, "nbformat_minor": 5 }