klue_bert_layered / homework.py
winninglife's picture
initial commit
c344d9f
raw
history blame
3.55 kB
import streamlit as st
from pyparsing import empty
from mymodel import CustomBertForSequenceClassification
from transformers import BertTokenizer, Trainer
POSITIVE = 0
NEGATIVE = 1
NEUTRAL = 2
idx_target = {POSITIVE:'positive', NEGATIVE:'negative', NEUTRAL:'neutral'}
g_selected_model_type = None
g_input_text = ' '
def get_model_type(select_model, add_layers):
if (select_model == "BERT") & (add_layers == True):
return "klue/bert-base", "bert_layered"
elif (select_model == "BERT") & (add_layers == False):
return "klue/bert-base", "bert"
elif (select_model == "RoBERTa") & (add_layers == True):
return "klue/roberta-base", "roberta_layered"
elif (select_model == "RoBERTa") & (add_layers == False):
return "klue/roberta-base", "roberta"
def show_predict_result(model_type, input_text):
target = POSITIVE
show_response_img(target)
show_response_text(target)
def show_response_img(target = None):
if target == None:
st.con3.write("")
elif target == POSITIVE:
st.con3.write("μ›ƒλŠ” ν‘œμ •")
elif target == NEGATIVE:
st.con3.write("μš°λŠ” ν‘œμ •")
elif target == NEUTRAL:
st.con3.write("λ¬΄λ€λ€ν•œ ν‘œμ •")
def show_response_text(target = None):
if target == None:
st.con4.write("")
elif target == POSITIVE:
st.con4.write("긍정적인 λ°˜μ‘")
elif target == NEGATIVE:
st.con4.write("뢀정정인 λ°˜μ‘")
elif target == NEUTRAL:
st.con4.write("λ¬΄λ€λ€ν•œ λ°˜μ‘")
def show_data():
st.write("data")
def show_api_usage():
with st.container(border=True):
st.write("api μ‚¬μš©λ²•")
def tab1_page():
global g_selected_model_type
global g_input_text
st.con1,st.con2 = st.columns([0.3,0.7])
st.con3,st.con4 = st.columns([0.3,0.7])
st.con5,empty1 = st.columns([0.9999,0.0001])
with st.container():
with st.con1:
with st.con1.container(border=True):
selected_model = st.selectbox("예츑 λͺ¨λΈμ„ μ„ νƒν•˜μ„Έμš”.", ["BERT", "RoBERTa"])
add_layers = st.checkbox('Layer μΆ”κ°€')
g_selected_model_type = get_model_type(selected_model, add_layers)
show_predict_result(g_selected_model_type, g_input_text)
with st.con2:
with st.con2.container(border=True):
input_text = st.text_area("input_text")
submit_button = st.button('확인')
if submit_button :
g_input_text = input_text
show_predict_result(g_selected_model_type, g_input_text)
with st.con3:
with st.con3.container(border=True):
show_response_img()
with st.con4:
with st.con4.container(border=True):
show_response_text()
with st.con5:
with st.con5.container(border=True):
show_data()
with empty1:
empty()
def tab2_page():
show_api_usage()
#tokenizer = BertTokenizer()
#model = CustomBertForSequenceClassification()
#trainer = Trainer()
tokenizer = BertTokenizer.from_pretrained("bert_tokenizer_layered")
model = CustomBertForSequenceClassification.from_pretrained("bert_model_layered")
trainer = Trainer().load_model("bert_trainer_layered")
st.title("Semi Project - Sentiment analysis")
#st.subheader("일상 λŒ€ν™”λ₯Ό μž…λ ₯ν•˜λ©΄ ν•΄λ‹Ή λ¬Έμž₯이 λ‚˜νƒ€λ‚΄λŠ” λŠλ‚Œμ„ ν‘œμ‹œν•©λ‹ˆλ‹€.")
tab1, tab2 = st.tabs(['λ™μž‘ 확인', 'API'])
with tab1:
tab1_page()
with tab2:
tab2_page()