alex6095 commited on
Commit
355910d
โ€ข
1 Parent(s): f8d7d27

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+ import streamlit as st
5
+
6
+ from transformers import DistilBertModel
7
+ from tokenization_kobert import KoBertTokenizer
8
+
9
+ class SanctiMoly(nn.Module):
10
+ """ Holy Moly News BERT """
11
+
12
+ def __init__(self, freeze_bert = True):
13
+ super(SanctiMoly, self).__init__()
14
+ self.encoder = bert_model
15
+ # FC-BN-Tanh
16
+ self.linear = nn.Sequential(nn.Linear(768, 1024),
17
+ nn.BatchNorm1d(1024),
18
+ nn.Tanh(),
19
+ nn.Dropout(),
20
+ nn.Linear(1024, 768),
21
+ nn.BatchNorm1d(768),
22
+ nn.Tanh(),
23
+ nn.Dropout(),
24
+ nn.Linear(768, 120)
25
+ )
26
+ # self.softmax = nn.LogSoftmax(dim=-1)
27
+
28
+ if freeze_bert == True:
29
+ for param in self.encoder.parameters():
30
+ param.requires_grad = False
31
+ else:
32
+ for param in self.encoder.parameters():
33
+ param.requires_grad = True
34
+
35
+
36
+ def forward(self, input_ids, input_length):
37
+ # calculate attention mask
38
+ attn_mask = torch.arange(input_ids.size(1)).to(device)
39
+ attn_mask = attn_mask[None, :] < input_length[:, None]
40
+
41
+ enc_o = self.encoder(input_ids, attn_mask)
42
+
43
+ output = self.linear(enc_o.last_hidden_state[:, 0, :])
44
+ # print(output.shape)
45
+ return output
46
+
47
+ @st.cache(allow_output_mutation=True)
48
+ def get_model():
49
+ bert_model = DistilBertModel.from_pretrained('monologg/distilkobert')
50
+ tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert')
51
+
52
+ model = SanctiMoly(freeze_bert=False)
53
+ checkpoint = torch.load("./model.pt", map_location=device)
54
+ model.load_state_dict(checkpoint['model_state_dict'])
55
+
56
+ return model, tokenizer
57
+
58
+ model, tokenizer = get_model()
59
+
60
+
61
+
62
+
63
+ class RegexSubstitution(object):
64
+ """Regex substitution class for transform"""
65
+ def __init__(self, regex, sub=''):
66
+ if isinstance(regex, re.Pattern):
67
+ self.regex = regex
68
+ else:
69
+ self.regex = re.compile(regex)
70
+ self.sub = sub
71
+ def __call__(self, target):
72
+ if isinstance(target, list):
73
+ return [self.regex.sub(self.sub, self.regex.sub(self.sub, string)) for string in target]
74
+ else:
75
+ return self.regex.sub(self.sub, self.regex.sub(self.sub, target))
76
+ def i2ym(fl):
77
+ return (str(fl // 12 + 2009), str(fl % 12 + 1))
78
+
79
+ default_text = '''์งˆ๋ณ‘๊ด€๋ฆฌ์ฒญ์€ 23์ผ ์ง€๋ฐฉ์ž์น˜๋‹จ์ฒด๊ฐ€ ๋ณด๊ฑด๋‹น๊ตญ๊ณผ ํ˜‘์˜ ์—†์ด ๋‹จ๋…์œผ๋กœ ์ธํ”Œ๋ฃจ์—”์ž(๋…๊ฐ) ๋ฐฑ์‹  ์ ‘์ข… ์ค‘๋‹จ์„ ๊ฒฐ์ •ํ•ด์„œ๋Š” ์•ˆ ๋œ๋‹ค๋Š” ์ž…์žฅ์„ ๋ฐํ˜”๋‹ค.
80
+ ์งˆ๋ณ‘์ฒญ์€ ์ด๋‚  ์ฐธ๊ณ ์ž๋ฃŒ๋ฅผ ๋ฐฐํฌํ•˜๊ณ  โ€œํ–ฅํ›„ ์ „์ฒด ๊ตญ๊ฐ€ ์˜ˆ๋ฐฉ์ ‘์ข…์‚ฌ์—…์ด ์ฐจ์งˆ ์—†์ด ์ง„ํ–‰๋˜๋„๋ก ์ง€์ž์ฒด๊ฐ€ ์ž์ฒด์ ์œผ๋กœ ์ ‘์ข… ์œ ๋ณด ์—ฌ๋ถ€๋ฅผ ๊ฒฐ์ •ํ•˜์ง€ ์•Š๋„๋ก ์•ˆ๋‚ด๋ฅผ ํ–ˆ๋‹คโ€๊ณ  ์„ค๋ช…ํ–ˆ๋‹ค.
81
+ ๋…๊ฐ๋ฐฑ์‹ ์„ ์ ‘์ข…ํ•œ ํ›„ ๊ณ ๋ น์ธต์„ ์ค‘์‹ฌ์œผ๋กœ ์ „๊ตญ์—์„œ ์‚ฌ๋ง์ž๊ฐ€ ์ž‡๋”ฐ๋ฅด์ž ์„œ์šธ ์˜๋“ฑํฌ๊ตฌ๋ณด๊ฑด์†Œ๋Š” ์ „๋‚ , ๊ฒฝ๋ถ ํฌํ•ญ์‹œ๋Š” ์ด๋‚  ๊ด€๋‚ด ์˜๋ฃŒ๊ธฐ๊ด€์— ์ ‘์ข…์„ ๋ณด๋ฅ˜ํ•ด๋‹ฌ๋ผ๋Š” ๊ณต๋ฌธ์„ ๋‚ด๋ ค๋ณด๋ƒˆ๋‹ค. ์ด๋Š” ์˜ˆ๋ฐฉ์ ‘์ข…๊ณผ ์‚ฌ๋ง ๊ฐ„ ์ง์ ‘์  ์—ฐ๊ด€์„ฑ์ด ๋‚ฎ์•„ ์ ‘์ข…์„ ์ค‘๋‹จํ•  ์ƒํ™ฉ์€ ์•„๋‹ˆ๋ผ๋Š” ์งˆ๋ณ‘์ฒญ์˜ ํŒ๋‹จ๊ณผ๋Š” ๋‹ค๋ฅธ ๊ฒƒ์ด๋‹ค.
82
+ ์งˆ๋ณ‘์ฒญ์€ ์ง€๋‚œ 21์ผ ์ „๋ฌธ๊ฐ€ ๋“ฑ์ด ์ฐธ์—ฌํ•œ โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ํ”ผํ•ด์กฐ์‚ฌ๋ฐ˜โ€™์˜ ๋ถ„์„ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๋…๊ฐ ์˜ˆ๋ฐฉ์ ‘์ข… ์‚ฌ์—…์„ ์ผ์ •๋Œ€๋กœ ์ง„ํ–‰ํ•˜๊ธฐ๋กœ ํ–ˆ๋‹ค. ํŠนํžˆ ๊ณ ๋ น ์–ด๋ฅด์‹ ๊ณผ ์–ด๋ฆฐ์ด, ์ž„์‹ ๋ถ€ ๋“ฑ ๋…๊ฐ ๊ณ ์œ„ํ—˜๊ตฐ์€ ๋ฐฑ์‹ ์„ ์ ‘์ข…ํ•˜์ง€ ์•Š์•˜์„ ๋•Œ ํ•ฉ๋ณ‘์ฆ ํ”ผํ•ด๊ฐ€ ํด ์ˆ˜ ์žˆ๋‹ค๋ฉด์„œ ์ ‘์ข…์„ ๋…๋ คํ–ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ ‘์ข…์‚ฌ์—… ์œ ์ง€ ๋ฐœํ‘œ ์ดํ›„์—๋„ ์‚ฌ๋ง ๋ณด๊ณ ๊ฐ€ ์ž‡๋”ฐ๋ฅด์ž ์งˆ๋ณ‘์ฒญ์€ ์ด๋‚  โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ํ”ผํ•ด์กฐ์‚ฌ๋ฐ˜ ํšŒ์˜โ€™์™€ โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ์ „๋ฌธ์œ„์›ํšŒโ€™๋ฅผ ๊ฐœ์ตœํ•ด ๋…๊ฐ๋ฐฑ์‹ ๊ณผ ์‚ฌ๋ง ๊ฐ„ ๊ด€๋ จ์„ฑ, ์ ‘์ข…์‚ฌ์—… ์œ ์ง€ ์—ฌ๋ถ€ ๋“ฑ์— ๋Œ€ํ•ด ๋‹ค์‹œ ๊ฒฐ๋ก  ๋‚ด๋ฆฌ๊ธฐ๋กœ ํ–ˆ๋‹ค. ํšŒ์˜ ๊ฒฐ๊ณผ๋Š” ์ด๋‚  ์˜คํ›„ 7์‹œ ๋„˜์–ด ๋ฐœํ‘œ๋  ์˜ˆ์ •์ด๋‹ค.
83
+ '''
84
+
85
+
86
+ st.title("Date prediction")
87
+ text = st.text_area("Input news :", value=default_text)
88
+ st.markdown("## Original News Data")
89
+ st.write(text)
90
+ st.markdown("## Predict Date")
91
+ col1, col2 = st.columns(2)
92
+ if text:
93
+ with st.spinner('processing..'):
94
+ text = RegexSubstitution(r'\([^()]+\)|[<>\'"โ–ณโ–ฒโ–กโ– ]')(text)
95
+ encoded_dict = tokenizer(
96
+ text=[text],
97
+ add_special_tokens=True,
98
+ max_length=512,
99
+ truncation=True,
100
+ return_tensors='pt',
101
+ return_length=True
102
+ )
103
+ input_ids = encoded_dict['input_ids']
104
+ input_ids_len = encoded_dict['length']
105
+
106
+ pred = model(input_ids, input_ids_len)
107
+
108
+ _, indices = torch.topk(pred, 3)
109
+ pred_print = []
110
+ for i in indices.squeeze(0):
111
+ year, month = i2ym(i.item()))
112
+ pred_print.append(year+"-"+month)
113
+ st.write(", ".join(pred_print))