alex6095's picture
Update app.py
620a618
import torch
import torch.nn as nn
import re
import streamlit as st
from transformers import DistilBertModel
from tokenization_kobert import KoBertTokenizer
class SanctiMoly(nn.Module):
""" Holy Moly News BERT """
def __init__(self, bert_model, freeze_bert = True):
super(SanctiMoly, self).__init__()
self.encoder = bert_model
# FC-BN-Tanh
self.linear = nn.Sequential(nn.Linear(768, 1024),
nn.BatchNorm1d(1024),
nn.Tanh(),
nn.Dropout(),
nn.Linear(1024, 768),
nn.BatchNorm1d(768),
nn.Tanh(),
nn.Dropout(),
nn.Linear(768, 120)
)
# self.softmax = nn.LogSoftmax(dim=-1)
if freeze_bert == True:
for param in self.encoder.parameters():
param.requires_grad = False
else:
for param in self.encoder.parameters():
param.requires_grad = True
def forward(self, input_ids, input_length):
# calculate attention mask
attn_mask = torch.arange(input_ids.size(1))
attn_mask = attn_mask[None, :] < input_length[:, None]
enc_o = self.encoder(input_ids, attn_mask)
output = self.linear(enc_o.last_hidden_state[:, 0, :])
# print(output.shape)
return output
@st.cache(allow_output_mutation=True)
def get_model():
bert_model = DistilBertModel.from_pretrained('alex6095/SanctiMolyOH_Cpu')
tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert')
model = SanctiMoly(bert_model, freeze_bert=False)
device = torch.device('cpu')
checkpoint = torch.load("./model.pt", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model, tokenizer
model, tokenizer = get_model()
class RegexSubstitution(object):
"""Regex substitution class for transform"""
def __init__(self, regex, sub=''):
if isinstance(regex, re.Pattern):
self.regex = regex
else:
self.regex = re.compile(regex)
self.sub = sub
def __call__(self, target):
if isinstance(target, list):
return [self.regex.sub(self.sub, self.regex.sub(self.sub, string)) for string in target]
else:
return self.regex.sub(self.sub, self.regex.sub(self.sub, target))
def i2ym(fl):
return (str(fl // 12 + 2009), str(fl % 12 + 1))
default_text = '''ํ—Œ๋ฒ•์žฌํŒ์†Œ๊ฐ€ ๋ฐ•๊ทผํ˜œ ๋Œ€ํ†ต๋ น์˜ ํŒŒ๋ฉด์„ ๋งŒ์žฅ์ผ์น˜๋กœ ๊ฒฐ์ •ํ–ˆ๋‹ค. ํ˜„์ง ๋Œ€ํ†ต๋ น ํƒ„ํ•ต์ด ์ธ์šฉ๋œ ๊ฒƒ์€ ํ—Œ์ • ์‚ฌ์ƒ ์ตœ์ดˆ๋‹ค. ๋ฐ• ์ „ ๋Œ€ํ†ต๋ น์— ๋Œ€ํ•œ ํŒŒ๋ฉด์ด ๊ฒฐ์ •๋˜๋ฉด์„œ ํ—Œ๋ฒ•๊ณผ ๊ณต์ง์„ ๊ฑฐ๋ฒ•์— ๋”ฐ๋ผ ์•ž์œผ๋กœ 60์ผ ์ด๋‚ด์— ์ฐจ๊ธฐ ๋Œ€ํ†ต๋ น ์„ ๊ฑฐ๊ฐ€ ์น˜๋Ÿฌ์ง„๋‹ค.
์ด์ •๋ฏธ ํ—Œ์žฌ์†Œ์žฅ ๊ถŒํ•œ๋Œ€ํ–‰(์žฌํŒ๊ด€)์€ 10์ผ ์˜ค์ „ 11์‹œ 23๋ถ„ ์„œ์šธ ์ข…๋กœ๊ตฌ ํ—Œ๋ฒ•์žฌํŒ์†Œ ๋Œ€์‹ฌํŒ์ •์—์„œ โ€œํ”ผ์ฒญ๊ตฌ์ธ ๋Œ€ํ†ต๋ น ๋ฐ•๊ทผํ˜œ๋ฅผ ํŒŒ๋ฉดํ•œ๋‹คโ€๊ณ  ์ฃผ๋ฌธ์„ ์„ ๊ณ ํ–ˆ๋‹ค. ๊ทธ ์ˆœ๊ฐ„ ๋Œ€์‹ฌํŒ์ • ๊ณณ๊ณณ์—์„œ ๋ฌด๊ฒ๊ณ  ๋‚˜์งํ•œ ํƒ„์„ฑ์ด ํ„ฐ์ ธ ๋‚˜์™”๋‹ค. ์ด๋‚  ๋Œ€์‹ฌํŒ์ •์—์„  ๋ฐ•๊ทผํ˜œ ์ „ ๋Œ€ํ†ต๋ น ์ธก๊ณผ ๊ตญํšŒ์†Œ์ถ”์œ„์› ์ธก ๊ด€๊ณ„์ž๋“ค๊ณผ ์ทจ์žฌ์ง„ 80๋ช…, ์˜จ๋ผ์ธ ์ ‘์ˆ˜๋ฅผ ํ†ตํ•ด 795๋Œ€ 1์˜ ๊ฒฝ์Ÿ๋ฅ ์„ ๋šซ๊ณ  ์„ ์ •๋œ ์ผ๋ฐ˜๋ฐฉ์ฒญ๊ฐ 24๋ช…์ด ์ˆจ์„ ์ฃฝ์ด๊ณ  ์žˆ์—ˆ๋‹ค.
'''
st.title("Date prediction")
text = st.text_area("Input news :", value=default_text)
st.markdown("## Original News Data")
st.write(text)
st.markdown("## Predict Top 3 Date")
if text:
with st.spinner('processing..'):
text = RegexSubstitution(r'\([^()]+\)|[<>\'"โ–ณโ–ฒโ–กโ– ]')(text)
encoded_dict = tokenizer(
text=[text],
add_special_tokens=True,
max_length=512,
truncation=True,
return_tensors='pt',
return_length=True
)
input_ids = encoded_dict['input_ids']
input_ids_len = encoded_dict['length']
pred = model(input_ids, input_ids_len)
_, indices = torch.topk(pred, 3)
pred_print = []
for i in indices.squeeze(0):
year, month = i2ym(i.item())
pred_print.append(year+"-"+month)
st.write(", ".join(pred_print))