File size: 3,548 Bytes
c344d9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import streamlit as st
from pyparsing import empty
from mymodel import CustomBertForSequenceClassification
from transformers import BertTokenizer, Trainer

POSITIVE = 0
NEGATIVE = 1
NEUTRAL = 2

idx_target = {POSITIVE:'positive', NEGATIVE:'negative', NEUTRAL:'neutral'}

g_selected_model_type = None
g_input_text = ' '

def get_model_type(select_model, add_layers):
    if (select_model == "BERT") & (add_layers == True):
        return "klue/bert-base", "bert_layered"
    elif (select_model == "BERT") & (add_layers == False):
        return "klue/bert-base", "bert"
    elif (select_model == "RoBERTa") & (add_layers == True):
        return "klue/roberta-base", "roberta_layered"
    elif (select_model == "RoBERTa") & (add_layers == False):
        return "klue/roberta-base", "roberta"

def show_predict_result(model_type, input_text):
    target = POSITIVE
    show_response_img(target)
    show_response_text(target)

def show_response_img(target = None):
    if target == None:
        st.con3.write("")
    elif target == POSITIVE:
        st.con3.write("μ›ƒλŠ” ν‘œμ •")
    elif target == NEGATIVE:
        st.con3.write("μš°λŠ” ν‘œμ •")
    elif target == NEUTRAL:
        st.con3.write("λ¬΄λ€λ€ν•œ ν‘œμ •")

def show_response_text(target = None):
    if target == None:
        st.con4.write("")
    elif target == POSITIVE:
        st.con4.write("긍정적인 λ°˜μ‘")
    elif target == NEGATIVE:
        st.con4.write("뢀정정인 λ°˜μ‘")
    elif target == NEUTRAL:
        st.con4.write("λ¬΄λ€λ€ν•œ λ°˜μ‘")
    

def show_data():
    st.write("data")

def show_api_usage():
    with st.container(border=True):
        st.write("api μ‚¬μš©λ²•")

def tab1_page():
    global g_selected_model_type
    global g_input_text
    st.con1,st.con2 = st.columns([0.3,0.7])
    st.con3,st.con4 = st.columns([0.3,0.7])
    st.con5,empty1 = st.columns([0.9999,0.0001])

    with st.container():
        with st.con1:
            with st.con1.container(border=True):
                selected_model = st.selectbox("예츑 λͺ¨λΈμ„ μ„ νƒν•˜μ„Έμš”.", ["BERT", "RoBERTa"])
                add_layers = st.checkbox('Layer μΆ”κ°€')
                g_selected_model_type = get_model_type(selected_model, add_layers)
                show_predict_result(g_selected_model_type, g_input_text)
        with st.con2:
            with st.con2.container(border=True):
                input_text = st.text_area("input_text")
                submit_button = st.button('확인')
                if submit_button :
                    g_input_text = input_text
                    show_predict_result(g_selected_model_type, g_input_text)
    with st.con3:
        with st.con3.container(border=True):
            show_response_img()
    with st.con4:
        with st.con4.container(border=True):
            show_response_text()
    with st.con5:
        with st.con5.container(border=True):
            show_data()
    with empty1:
        empty()


def tab2_page():
    show_api_usage()


#tokenizer = BertTokenizer()
#model = CustomBertForSequenceClassification()
#trainer = Trainer()
tokenizer = BertTokenizer.from_pretrained("bert_tokenizer_layered")
model = CustomBertForSequenceClassification.from_pretrained("bert_model_layered")
trainer = Trainer().load_model("bert_trainer_layered")



st.title("Semi Project - Sentiment analysis")
#st.subheader("일상 λŒ€ν™”λ₯Ό μž…λ ₯ν•˜λ©΄ ν•΄λ‹Ή λ¬Έμž₯이 λ‚˜νƒ€λ‚΄λŠ” λŠλ‚Œμ„ ν‘œμ‹œν•©λ‹ˆλ‹€.")


tab1, tab2 = st.tabs(['λ™μž‘ 확인', 'API'])

with tab1:
    tab1_page()
    
with tab2:
    tab2_page()