Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,80 @@
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
st.markdown("### Hello, world!")
|
4 |
st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
|
@@ -12,4 +88,4 @@ text = st.text_area("TEXT HERE")
|
|
12 |
# raw_predictions = pipe(text)
|
13 |
# тут уже знакомый вам код с huggingface.transformers -- его можно заменить на что угодно от fairseq до catboost
|
14 |
|
15 |
-
st.markdown("It's prediction:
|
|
|
1 |
import streamlit as st
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
import torch
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
from transformers import get_linear_schedule_with_warmup, AdamW
|
6 |
+
from torch.cuda.amp import autocast, GradScaler
|
7 |
+
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, \
|
8 |
+
BigBirdPegasusForSequenceClassification, BigBirdTokenizer
|
9 |
+
from transformers import pipeline
|
10 |
+
from torch.utils.data import TensorDataset, random_split, DataLoader, RandomSampler, SequentialSampler
|
11 |
+
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
12 |
+
import streamlit as st
|
13 |
+
|
14 |
+
import pandas as pd
|
15 |
+
import json
|
16 |
+
import ast
|
17 |
+
from scipy import stats
|
18 |
+
import numpy as np
|
19 |
+
import time
|
20 |
+
import datetime
|
21 |
+
|
22 |
+
#
|
23 |
+
def get_top95(y_predict, convert_target):
|
24 |
+
lst_labels = []
|
25 |
+
tuple_arr = tuple((idx, val) for idx, val in enumerate(y_predict[0]))
|
26 |
+
sort_y = sorted(tuple_arr, key=lambda x: x[1], reverse=True)
|
27 |
+
cumsum = 0
|
28 |
+
for key, prob in sort_y:
|
29 |
+
cumsum += prob
|
30 |
+
print(prob)
|
31 |
+
lst_labels.append(convert_target[str(key)])
|
32 |
+
if cumsum > 0.95:
|
33 |
+
break
|
34 |
+
return lst_labels
|
35 |
+
#
|
36 |
+
# model = MyModel()
|
37 |
+
model = torch.load("distilbert-model1.pt", map_location='cpu').eval()
|
38 |
+
# print(model)
|
39 |
+
# model = DistilBertForSequenceClassification.from_pretrained("model/distilbert-model1.pt", local_files_only=True)
|
40 |
+
# tokenizer = BigBirdTokenizer.from_pretrained('google/bigbird-pegasus-large-arxiv')
|
41 |
+
|
42 |
+
# model = BigBirdPegasusForSequenceClassification.from_pretrained('google/bigbird-pegasus-large-arxiv',
|
43 |
+
# num_labels=8,
|
44 |
+
# return_dict=False)
|
45 |
+
|
46 |
+
def get_predict(text):
|
47 |
+
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
|
48 |
+
# encoded_dict = tokenizer.encode_plus(
|
49 |
+
# text, # document to encode.
|
50 |
+
# add_special_tokens=True, # add '[CLS]' and '[SEP]'
|
51 |
+
# max_length=512, # set max length
|
52 |
+
# truncation=True, # truncate longer messages
|
53 |
+
# pad_to_max_length=True, # add padding
|
54 |
+
# return_attention_mask=True, # create attn. masks
|
55 |
+
# return_tensors='pt' # return pytorch tensors
|
56 |
+
# )
|
57 |
+
|
58 |
+
inputs = tokenizer(text, return_tensors="pt")
|
59 |
+
outputs = model(
|
60 |
+
input_ids=inputs['input_ids'],
|
61 |
+
attention_mask=inputs['attention_mask'],
|
62 |
+
)
|
63 |
+
logits = outputs[0]
|
64 |
+
y_predict = torch.nn.functional.softmax(logits).cpu().detach().numpy()
|
65 |
+
file_path = "decode_target (1).json"
|
66 |
+
|
67 |
+
with open(file_path, 'r') as json_file:
|
68 |
+
decode_target = json.load(json_file)
|
69 |
+
print(get_top95(y_predict, decode_target))
|
70 |
+
#
|
71 |
+
#
|
72 |
+
#
|
73 |
+
#
|
74 |
+
#
|
75 |
+
# get_predict('''physics physics physics physics physics
|
76 |
+
# physics physics physics physics''')
|
77 |
+
#
|
78 |
|
79 |
st.markdown("### Hello, world!")
|
80 |
st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
|
|
|
88 |
# raw_predictions = pipe(text)
|
89 |
# тут уже знакомый вам код с huggingface.transformers -- его можно заменить на что угодно от fairseq до catboost
|
90 |
|
91 |
+
st.markdown("It's prediction: {get_predict(text)}")
|