kanav0183 commited on
Commit
e1a5a68
·
1 Parent(s): 06f079f

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +132 -0
  2. infer.ipynb +171 -0
  3. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from tqdm import tqdm
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from transformers import DistilBertTokenizer, DistilBertModel
6
+ import os
7
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
8
+
9
+
10
+
11
+ MAX_LEN = 512
12
+ TRAIN_BATCH_SIZE = 16
13
+ VALID_BATCH_SIZE = 16
14
+ EPOCHS = 3
15
+ LEARNING_RATE = 1e-05
16
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
17
+
18
+
19
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', truncation=True, do_lower_case=True)
20
+
21
+ class MultiLabelDataset(Dataset):
22
+
23
+ def __init__(self, dataframe, tokenizer, max_len, new_data=False):
24
+ self.tokenizer = tokenizer
25
+ self.data = dataframe
26
+ self.text = dataframe.comment_text
27
+ self.new_data = new_data
28
+
29
+ if not new_data:
30
+ self.targets = self.data.labels
31
+ self.max_len = max_len
32
+
33
+ def __len__(self):
34
+ return len(self.text)
35
+
36
+ def __getitem__(self, index):
37
+ text = str(self.text[index])
38
+ text = " ".join(text.split())
39
+
40
+ inputs = self.tokenizer.encode_plus(
41
+ text,
42
+ None,
43
+ add_special_tokens=True,
44
+ max_length=self.max_len,
45
+ pad_to_max_length=True,
46
+ return_token_type_ids=True
47
+ )
48
+ ids = inputs['input_ids']
49
+ mask = inputs['attention_mask']
50
+ token_type_ids = inputs["token_type_ids"]
51
+
52
+ out = {
53
+ 'ids': torch.tensor(ids, dtype=torch.long),
54
+ 'mask': torch.tensor(mask, dtype=torch.long),
55
+ 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
56
+ }
57
+
58
+ if not self.new_data:
59
+ out['targets'] = torch.tensor(self.targets[index], dtype=torch.float)
60
+
61
+ return out
62
+
63
+ class DistilBERTClass(torch.nn.Module):
64
+ def __init__(self):
65
+ super(DistilBERTClass, self).__init__()
66
+
67
+ self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
68
+ self.classifier = torch.nn.Sequential(
69
+ torch.nn.Linear(768, 768),
70
+ torch.nn.ReLU(),
71
+ torch.nn.Dropout(0.1),
72
+ torch.nn.Linear(768, 6)
73
+ )
74
+
75
+ def forward(self, input_ids, attention_mask, token_type_ids):
76
+ output_1 = self.bert(input_ids=input_ids, attention_mask=attention_mask)
77
+ hidden_state = output_1[0]
78
+ out = hidden_state[:, 0]
79
+ out = self.classifier(out)
80
+ return out
81
+
82
+ model = DistilBERTClass()
83
+ model.to(DEVICE);
84
+
85
+ model_loaded = torch.load('model/inference_models_output_4fold_distilbert_fold_best_model.pth')
86
+
87
+ model.load_state_dict(model_loaded['model'])
88
+
89
+
90
+ val_params = {'batch_size': VALID_BATCH_SIZE,
91
+ 'shuffle': False,
92
+
93
+ }
94
+ def give_toxic(text):
95
+ # text = "You fucker "
96
+ test_data = pd.DataFrame([text],columns=['comment_text'])
97
+ test_set = MultiLabelDataset(test_data, tokenizer, MAX_LEN, new_data=True)
98
+ test_loader = DataLoader(test_set, **val_params)
99
+
100
+ all_test_pred = []
101
+
102
+ def test(epoch):
103
+ model.eval()
104
+
105
+ with torch.inference_mode():
106
+
107
+ for _, data in tqdm(enumerate(test_loader, 0)):
108
+
109
+
110
+ ids = data['ids'].to(DEVICE, dtype=torch.long)
111
+ mask = data['mask'].to(DEVICE, dtype=torch.long)
112
+ token_type_ids = data['token_type_ids'].to(DEVICE, dtype=torch.long)
113
+ outputs = model(ids, mask, token_type_ids)
114
+ probas = torch.sigmoid(outputs)
115
+
116
+ all_test_pred.append(probas)
117
+
118
+ probas = test(model)
119
+
120
+ all_test_pred = torch.cat(all_test_pred)
121
+
122
+ label_columns = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
123
+
124
+ preds = all_test_pred.detach().cpu().numpy()[0]
125
+
126
+ final_dict = dict(zip(label_columns , preds))
127
+ return final_dict
128
+
129
+ def device():
130
+ return DEVICE
131
+
132
+ print(give_toxic("fuck"))
infer.ipynb ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 20,
6
+ "id": "d136f503-bb1b-404e-8657-ce3168eae54b",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import pandas as pd\n",
11
+ "import torch\n",
12
+ "from tqdm import tqdm\n",
13
+ "from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\n",
14
+ "from transformers import DistilBertTokenizer, DistilBertModel\n",
15
+ "import streamlit as st\n",
16
+ "\n",
17
+ "\n",
18
+ "\n",
19
+ "\n",
20
+ "MAX_LEN = 512\n",
21
+ "TRAIN_BATCH_SIZE = 16\n",
22
+ "VALID_BATCH_SIZE = 16\n",
23
+ "EPOCHS = 3\n",
24
+ "LEARNING_RATE = 1e-05\n",
25
+ "DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
26
+ "print(DEVICE)\n",
27
+ "\n",
28
+ "tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', truncation=True, do_lower_case=True)\n",
29
+ "\n",
30
+ "class MultiLabelDataset(Dataset):\n",
31
+ "\n",
32
+ " def __init__(self, dataframe, tokenizer, max_len, new_data=False):\n",
33
+ " self.tokenizer = tokenizer\n",
34
+ " self.data = dataframe\n",
35
+ " self.text = dataframe.comment_text\n",
36
+ " self.new_data = new_data\n",
37
+ " \n",
38
+ " if not new_data:\n",
39
+ " self.targets = self.data.labels\n",
40
+ " self.max_len = max_len\n",
41
+ "\n",
42
+ " def __len__(self):\n",
43
+ " return len(self.text)\n",
44
+ "\n",
45
+ " def __getitem__(self, index):\n",
46
+ " text = str(self.text[index])\n",
47
+ " text = \" \".join(text.split())\n",
48
+ "\n",
49
+ " inputs = self.tokenizer.encode_plus(\n",
50
+ " text,\n",
51
+ " None,\n",
52
+ " add_special_tokens=True,\n",
53
+ " max_length=self.max_len,\n",
54
+ " pad_to_max_length=True,\n",
55
+ " return_token_type_ids=True\n",
56
+ " )\n",
57
+ " ids = inputs['input_ids']\n",
58
+ " mask = inputs['attention_mask']\n",
59
+ " token_type_ids = inputs[\"token_type_ids\"]\n",
60
+ "\n",
61
+ " out = {\n",
62
+ " 'ids': torch.tensor(ids, dtype=torch.long),\n",
63
+ " 'mask': torch.tensor(mask, dtype=torch.long),\n",
64
+ " 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),\n",
65
+ " }\n",
66
+ " \n",
67
+ " if not self.new_data:\n",
68
+ " out['targets'] = torch.tensor(self.targets[index], dtype=torch.float)\n",
69
+ "\n",
70
+ " return out\n",
71
+ "\n",
72
+ "class DistilBERTClass(torch.nn.Module):\n",
73
+ " def __init__(self):\n",
74
+ " super(DistilBERTClass, self).__init__()\n",
75
+ " \n",
76
+ " self.bert = DistilBertModel.from_pretrained(\"distilbert-base-uncased\")\n",
77
+ " self.classifier = torch.nn.Sequential(\n",
78
+ " torch.nn.Linear(768, 768),\n",
79
+ " torch.nn.ReLU(),\n",
80
+ " torch.nn.Dropout(0.1),\n",
81
+ " torch.nn.Linear(768, 6)\n",
82
+ " )\n",
83
+ "\n",
84
+ " def forward(self, input_ids, attention_mask, token_type_ids):\n",
85
+ " output_1 = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
86
+ " hidden_state = output_1[0]\n",
87
+ " out = hidden_state[:, 0]\n",
88
+ " out = self.classifier(out)\n",
89
+ " return out\n",
90
+ "\n",
91
+ "model = DistilBERTClass()\n",
92
+ "model.to(DEVICE);\n",
93
+ "\n",
94
+ "model_loaded = torch.load('model/inference_models_output_4fold_distilbert_fold_best_model.pth',map_location=torch.device('cpu'))\n",
95
+ "\n",
96
+ "model.load_state_dict(model_loadede['model'])\n",
97
+ "\n",
98
+ "\n",
99
+ "val_params = {'batch_size': VALID_BATCH_SIZE,\n",
100
+ " 'shuffle': False,\n",
101
+ " 'num_workers': 8\n",
102
+ " }\n",
103
+ "def give_toxic(text):\n",
104
+ " text = \"You fucker \"\n",
105
+ " test_data = pd.DataFrame([text],columns=['comment_text'])\n",
106
+ " test_set = MultiLabelDataset(test_data, tokenizer, MAX_LEN, new_data=True)\n",
107
+ " test_loader = DataLoader(test_set, **val_params)\n",
108
+ "\n",
109
+ " all_test_pred = []\n",
110
+ "\n",
111
+ " def test(epoch):\n",
112
+ " model.eval()\n",
113
+ "\n",
114
+ " with torch.inference_mode():\n",
115
+ "\n",
116
+ " for _, data in tqdm(enumerate(test_loader, 0)):\n",
117
+ "\n",
118
+ "\n",
119
+ " ids = data['ids'].to(DEVICE, dtype=torch.long)\n",
120
+ " mask = data['mask'].to(DEVICE, dtype=torch.long)\n",
121
+ " token_type_ids = data['token_type_ids'].to(DEVICE, dtype=torch.long)\n",
122
+ " outputs = model(ids, mask, token_type_ids)\n",
123
+ " probas = torch.sigmoid(outputs)\n",
124
+ "\n",
125
+ " all_test_pred.append(probas)\n",
126
+ "\n",
127
+ "\n",
128
+ " probas = test(model)\n",
129
+ "\n",
130
+ " all_test_pred = torch.cat(all_test_pred)\n",
131
+ "\n",
132
+ " label_columns = [\"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"]\n",
133
+ "\n",
134
+ " preds = all_test_pred.detach().cpu().numpy()[0]\n",
135
+ "\n",
136
+ " final_dict = dict(zip(label_columns , preds))\n",
137
+ " return final_dict\n",
138
+ "\n"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "id": "db651873-60cd-4cd7-8ba0-da6c62e22ca8",
145
+ "metadata": {},
146
+ "outputs": [],
147
+ "source": []
148
+ }
149
+ ],
150
+ "metadata": {
151
+ "kernelspec": {
152
+ "display_name": "Python 3 (ipykernel)",
153
+ "language": "python",
154
+ "name": "python3"
155
+ },
156
+ "language_info": {
157
+ "codemirror_mode": {
158
+ "name": "ipython",
159
+ "version": 3
160
+ },
161
+ "file_extension": ".py",
162
+ "mimetype": "text/x-python",
163
+ "name": "python",
164
+ "nbconvert_exporter": "python",
165
+ "pygments_lexer": "ipython3",
166
+ "version": "3.9.11"
167
+ }
168
+ },
169
+ "nbformat": 4,
170
+ "nbformat_minor": 5
171
+ }
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ pandas
3
+ transformers