Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import re | |
import streamlit as st | |
from transformers import DistilBertModel | |
from tokenization_kobert import KoBertTokenizer | |
class SanctiMoly(nn.Module): | |
""" Holy Moly News BERT """ | |
def __init__(self, bert_model, freeze_bert = True): | |
super(SanctiMoly, self).__init__() | |
self.encoder = bert_model | |
# FC-BN-Tanh | |
self.linear = nn.Sequential(nn.Linear(768, 1024), | |
nn.BatchNorm1d(1024), | |
nn.Tanh(), | |
nn.Dropout(), | |
nn.Linear(1024, 768), | |
nn.BatchNorm1d(768), | |
nn.Tanh(), | |
nn.Dropout(), | |
nn.Linear(768, 120) | |
) | |
# self.softmax = nn.LogSoftmax(dim=-1) | |
if freeze_bert == True: | |
for param in self.encoder.parameters(): | |
param.requires_grad = False | |
else: | |
for param in self.encoder.parameters(): | |
param.requires_grad = True | |
def forward(self, input_ids, input_length): | |
# calculate attention mask | |
attn_mask = torch.arange(input_ids.size(1)) | |
attn_mask = attn_mask[None, :] < input_length[:, None] | |
enc_o = self.encoder(input_ids, attn_mask) | |
output = self.linear(enc_o.last_hidden_state[:, 0, :]) | |
# print(output.shape) | |
return output | |
def get_model(): | |
bert_model = DistilBertModel.from_pretrained('alex6095/SanctiMolyOH_Cpu') | |
tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert') | |
model = SanctiMoly(bert_model, freeze_bert=False) | |
device = torch.device('cpu') | |
checkpoint = torch.load("./model.pt", map_location=device) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.eval() | |
return model, tokenizer | |
model, tokenizer = get_model() | |
class RegexSubstitution(object): | |
"""Regex substitution class for transform""" | |
def __init__(self, regex, sub=''): | |
if isinstance(regex, re.Pattern): | |
self.regex = regex | |
else: | |
self.regex = re.compile(regex) | |
self.sub = sub | |
def __call__(self, target): | |
if isinstance(target, list): | |
return [self.regex.sub(self.sub, self.regex.sub(self.sub, string)) for string in target] | |
else: | |
return self.regex.sub(self.sub, self.regex.sub(self.sub, target)) | |
def i2ym(fl): | |
return (str(fl // 12 + 2009), str(fl % 12 + 1)) | |
default_text = '''ํ๋ฒ์ฌํ์๊ฐ ๋ฐ๊ทผํ ๋ํต๋ น์ ํ๋ฉด์ ๋ง์ฅ์ผ์น๋ก ๊ฒฐ์ ํ๋ค. ํ์ง ๋ํต๋ น ํํต์ด ์ธ์ฉ๋ ๊ฒ์ ํ์ ์ฌ์ ์ต์ด๋ค. ๋ฐ ์ ๋ํต๋ น์ ๋ํ ํ๋ฉด์ด ๊ฒฐ์ ๋๋ฉด์ ํ๋ฒ๊ณผ ๊ณต์ง์ ๊ฑฐ๋ฒ์ ๋ฐ๋ผ ์์ผ๋ก 60์ผ ์ด๋ด์ ์ฐจ๊ธฐ ๋ํต๋ น ์ ๊ฑฐ๊ฐ ์น๋ฌ์ง๋ค. | |
์ด์ ๋ฏธ ํ์ฌ์์ฅ ๊ถํ๋ํ(์ฌํ๊ด)์ 10์ผ ์ค์ 11์ 23๋ถ ์์ธ ์ข ๋ก๊ตฌ ํ๋ฒ์ฌํ์ ๋์ฌํ์ ์์ โํผ์ฒญ๊ตฌ์ธ ๋ํต๋ น ๋ฐ๊ทผํ๋ฅผ ํ๋ฉดํ๋คโ๊ณ ์ฃผ๋ฌธ์ ์ ๊ณ ํ๋ค. ๊ทธ ์๊ฐ ๋์ฌํ์ ๊ณณ๊ณณ์์ ๋ฌด๊ฒ๊ณ ๋์งํ ํ์ฑ์ด ํฐ์ ธ ๋์๋ค. ์ด๋ ๋์ฌํ์ ์์ ๋ฐ๊ทผํ ์ ๋ํต๋ น ์ธก๊ณผ ๊ตญํ์์ถ์์ ์ธก ๊ด๊ณ์๋ค๊ณผ ์ทจ์ฌ์ง 80๋ช , ์จ๋ผ์ธ ์ ์๋ฅผ ํตํด 795๋ 1์ ๊ฒฝ์๋ฅ ์ ๋ซ๊ณ ์ ์ ๋ ์ผ๋ฐ๋ฐฉ์ฒญ๊ฐ 24๋ช ์ด ์จ์ ์ฃฝ์ด๊ณ ์์๋ค. | |
''' | |
st.title("Date prediction") | |
text = st.text_area("Input news :", value=default_text) | |
st.markdown("## Original News Data") | |
st.write(text) | |
st.markdown("## Predict Top 3 Date") | |
if text: | |
with st.spinner('processing..'): | |
text = RegexSubstitution(r'\([^()]+\)|[<>\'"โณโฒโกโ ]')(text) | |
encoded_dict = tokenizer( | |
text=[text], | |
add_special_tokens=True, | |
max_length=512, | |
truncation=True, | |
return_tensors='pt', | |
return_length=True | |
) | |
input_ids = encoded_dict['input_ids'] | |
input_ids_len = encoded_dict['length'] | |
pred = model(input_ids, input_ids_len) | |
_, indices = torch.topk(pred, 3) | |
pred_print = [] | |
for i in indices.squeeze(0): | |
year, month = i2ym(i.item()) | |
pred_print.append(year+"-"+month) | |
st.write(", ".join(pred_print)) |