stal76 commited on
Commit
14ca625
1 Parent(s): 80b88cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -1
app.py CHANGED
@@ -3,6 +3,70 @@ import torch
3
  import torch.nn as nn
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  st.markdown("### Hello, world!")
7
  st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
8
  # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
@@ -17,4 +81,4 @@ abstract = st.text_area("Abstract")
17
  #raw_predictions = pipe(text)
18
  # тут уже знакомый вам код с huggingface.transformers -- его можно заменить на что угодно от fairseq до catboost
19
 
20
- st.markdown(f"{title + ' ' + abstract}")
 
3
  import torch.nn as nn
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
+ class Net(nn.Module):
7
+ def __init__(self):
8
+ super(Net,self).__init__()
9
+ self.layer = nn.Sequential(
10
+ nn.Linear(768, 512),
11
+ nn.ReLU(),
12
+ nn.Linear(512, 256),
13
+ nn.ReLU(),
14
+ nn.Linear(256, 128),
15
+ nn.ReLU(),
16
+ nn.Linear(128, 8),
17
+ )
18
+
19
+ def forward(self,x):
20
+ return self.layer(x)
21
+
22
+ model = Net()
23
+ model.load_state_dict(torch.load('model.dat'))
24
+ tokenizer = AutoTokenizer.from_pretrained("Callidior/bert2bert-base-arxiv-titlegen")
25
+ model_emb = AutoModelForSeq2SeqLM.from_pretrained("Callidior/bert2bert-base-arxiv-titlegen")
26
+
27
+ def BuildAnswer(txt):
28
+ def get_hidden_states(encoded, model):
29
+ with torch.no_grad():
30
+ output = model(decoder_input_ids=encoded['input_ids'], output_hidden_states=True, **encoded)
31
+
32
+ layers = [-4, -3, -2, -1]
33
+ states = output['decoder_hidden_states']
34
+ output = torch.stack([states[i] for i in layers]).sum(0).squeeze()
35
+
36
+ return output.mean(dim=0)
37
+
38
+ def get_word_vector(sent, tokenizer, model):
39
+ encoded = tokenizer.encode_plus(sent, return_tensors="pt", truncation=True)
40
+ return get_hidden_states(encoded, model)
41
+
42
+ labels_articles = {
43
+ 1: 'Computer Science',
44
+ 2: 'Economics',
45
+ 3: "Electrical Engineering And Systems Science",
46
+ 4: "Mathematics",
47
+ 5: "Physics",
48
+ 6: "Quantitative Biology",
49
+ 7: "Quantitative Finance",
50
+ 8: "Statistics"
51
+ }
52
+
53
+ embed = get_word_vector(txt, tokenizer, model_emb)
54
+ logits = torch.nn.functional.softmax(model(embed), dim=0)
55
+ best_tags = torch.argsort(logits, descending=True)
56
+
57
+ sum = 0
58
+ result = []
59
+ for tag in best_tags:
60
+ if sum > 0.95:
61
+ break
62
+ sum += logits[tag.item()]
63
+ res = round(float(logits[tag.item()].cpu()) * 100)
64
+ label = labels_articles[tag.item() + 1]
65
+ result.append(f'{res:3d}% - {label}')
66
+ return '\n'.join(result)
67
+
68
+
69
+
70
  st.markdown("### Hello, world!")
71
  st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
72
  # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
 
81
  #raw_predictions = pipe(text)
82
  # тут уже знакомый вам код с huggingface.transformers -- его можно заменить на что угодно от fairseq до catboost
83
 
84
+ st.markdown(f"{BuildAnswer(title + ' ' + abstract)}")