Update app.py
Browse files
app.py
CHANGED
@@ -3,14 +3,33 @@ import torch
|
|
3 |
|
4 |
@st.cache
|
5 |
def Model():
|
6 |
-
from transformers import
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
bn_state_dict = torch.load('model_w.pt')
|
11 |
model.load_state_dict(bn_state_dict)
|
12 |
-
return model
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
st.title('Классификация статьи по названию и описанию')
|
15 |
# ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
|
16 |
|
@@ -19,13 +38,13 @@ title = st.text_area("Введите название статьи:")
|
|
19 |
abstract = st.text_area("Введите описание статьи:")
|
20 |
# ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
st.markdown(f"{raw_predictions}")
|
28 |
|
29 |
-
model = Model()
|
30 |
-
st.markdown(f"{model}")
|
31 |
-
# выводим результаты модели в текстовое поле, на потеху пользователю
|
|
|
3 |
|
4 |
@st.cache
|
5 |
def Model():
|
6 |
+
from transformers import DebertaTokenizer, DebertaForSequenceClassification
|
7 |
+
tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
|
8 |
+
model = DebertaForSequenceClassification.from_pretrained("microsoft/deberta-base", num_labels=8)
|
9 |
+
bn_state_dict = torch.load('model_weights.pt')
|
|
|
10 |
model.load_state_dict(bn_state_dict)
|
11 |
+
return model, tokenizer
|
12 |
|
13 |
+
def Predict(model, tokenizer, text):
|
14 |
+
res = tokenizer(s, padding=True, truncation=True, return_tensors="pt", max_length=512)
|
15 |
+
#var.to("cuda:0")
|
16 |
+
res = model(**res)
|
17 |
+
logits = res.logits.softmax(dim=1)
|
18 |
+
logits = logits.numpy()[0]#logits.cpu().detach().numpy()[0]
|
19 |
+
return logits
|
20 |
+
|
21 |
+
def Print(logits, dictionary):
|
22 |
+
z = zip(logits, np.arange(0, 8))
|
23 |
+
z = sorted(z, key=lambda x: x[0], reverse=True)
|
24 |
+
sum, idx = 0, 0
|
25 |
+
while sum < 0.95:
|
26 |
+
st.markdown(f"{idx + 1}. ", dictionary[z[idx][1]])
|
27 |
+
sum += z[idx][0]
|
28 |
+
idx += 1
|
29 |
+
|
30 |
+
def filter(title, abstract):
|
31 |
+
return True
|
32 |
+
|
33 |
st.title('Классификация статьи по названию и описанию')
|
34 |
# ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
|
35 |
|
|
|
38 |
abstract = st.text_area("Введите описание статьи:")
|
39 |
# ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
|
40 |
|
41 |
+
text = title + '. ' + abstract
|
42 |
+
dictionary = ['computer science', 'economics', 'Electrical Engineering and Systems Science',
|
43 |
+
'math', 'physics', 'quantitative biology', 'quantitative finance',
|
44 |
+
'statistics']
|
45 |
+
if filter(title, abstract):
|
46 |
+
model, tokenizer = Model()
|
47 |
+
logits = Predict(model, tokenizer, text)
|
48 |
+
Print(logits, dictionary)
|
49 |
|
|
|
50 |
|
|
|
|
|
|