File size: 9,530 Bytes
cb4cfa0
 
 
 
 
 
 
0c2ae87
cb4cfa0
 
0c2ae87
 
cb4cfa0
 
 
 
 
 
20c5fed
cb4cfa0
 
 
20c5fed
cb4cfa0
 
 
20c5fed
cb4cfa0
 
 
 
 
 
594ab7b
cb4cfa0
 
 
 
112da8d
594ab7b
cb4cfa0
 
0c2ae87
cb4cfa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91aa7e9
cb4cfa0
 
 
 
 
 
 
91aa7e9
cb4cfa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Streamlit app for zero-shot classification."""
import streamlit as st
import pandas as pd
import altair as alt
from transformers import pipeline
from transformers import AutoTokenizer
import warnings
import os
warnings.filterwarnings('ignore')

hf_token = os.environ['HF_READ']

# set up altair theme
font = 'NotoSansSinhala.ttf'
font_color = '#858991'
font_title = '#858991'
font_axis = '#858991'

TARGETS = ['දේශපාලන', 'ආර්ථික', 'අනතුරු', 'අපරාධ', 'තාක්ෂණ', 'ක්රීඩා', 'කාලගුණ', 'සමාජ']
SIN_2_ENG = {
    'දේශපාලන':'Political',
    'ආර්ථික':'Economic',
    'අනතුරු':'Accident',
    'අපරාධ':'Crime',
    'තාක්ෂණ':'Technology',
    'ක්රීඩා':'Sports',
    'කාලගුණ':'Weather',
    'සමාජ':'Social'
}

st.set_page_config(page_title="Sinhala zero-shot classification demo", page_icon=":bar_chart:")
st.title("Sinhala zero-shot classification demo")
st.markdown("This is a demo of the zero-shot classification pipeline from the [HuggingFace Transformers library](https://huggingface.co/transformers/).")
st.markdown("The model used is [Ransaka/sinhala-bert-small](https://huggingface.co/Ransaka/sinhala-bert-small). However you can select a different model from the dropdown below.")

# select model
def get_model_id():
    st.subheader("Select a model to use")
    model_list = ["Ransaka/sinhala-bert-small","Ransaka/SinhalaRoberta","Ransaka/sinhala-bert-medium-v2","Ransaka/sinhala-bert-medium-v1"]#,"keshan/SinhalaBERTo"] #commented as tokenizer used here is differ from previous tow models.
    #TODO: Add "keshan/SinhalaBERTo" model as well
    selected_model = st.selectbox("Select Model", model_list)
    st.write(f"Selected model: {selected_model}")
    tokenizer = AutoTokenizer.from_pretrained(selected_model, token=hf_token)
    mask_token = tokenizer.mask_token
    return selected_model,mask_token

# get input text
def get_input_text():
    st.subheader("Input a sentence to classify")
    st.write("Remember: Longer sentences may produce better results and take longer to classify😊")   
    sentence = st.text_area("Input text", height=300)
    return sentence

def show_example():
    examples = [
        """ශ්‍රී ලංකාවේ චීන සංස්කෘතික මධ්‍යස්ථානය සහ නැන්ජින් සංචාරක හා සංස්කෘතික මණ්ඩලය විසින් “ගිම්හාන දිනය” සැමරීම සඳහා පවත්වන ලද සංස්කෘතික උත්සව මාලාවක් පසුගියදා කොළඹ සහ මහනුවර නගරවලදී පැවත්විණි. “ගිම්හාන දිනය” යනු චීන සංස්කෘතිය තුළ “චීන නව වසර” තරමටම වැදගත් සහ ඉතා ඉහළින් සමරනු ලබන වැදගත් දිනයකි. මෙම උත්සව මාලාව සැප්තැම්බර් 22 වැනි දින සිට 25 වැනිදා දක්වා පැවත්විණි.
        කොළඹ චීන සංස්කෘතික මධ්‍යස්ථානයේදී පැවත්වුනු ප්‍රධාන උත්සවය සාම්ප්‍රදායයික චීන සහ ශ්‍රී ලාංකික සංස්කෘතික සංදර්ශන සහ කලා සහ ඡායාරූප ප්‍රදර්ශන, සාම්ප්‍රදායයික චීන තේ පානෝත්සව  සමඟින් ඉතා වර්ණවත් අයුරින් පැවත්විණි. එහිදී චීන සංස්කෘතිය තුළ ‘ගිම්හාන දිනයේ’ ඇති වැදගත්කම සහ ඓතිහාසික චීන ශ්‍රී ලාංකික සබඳතාවයන් ගැන හරබර දේශන රැසක්ද ප්‍රකට කථිකයින් විසින් සිදු කරන ලදි.
        """,
        """මාලදිවයිනේ පැවති ජනාධිපතිවරණයෙන් චීන හිතවාදී අපේක්ෂක 45 හැවිරිදි මොහොමඩ් මුයිසු ජනාධිපති ධුරයට පත් වී තිබේ.

        ඉන්දියාව සමඟ සබඳතා ශක්තිමත් කළ වත්මන් ජනාධිපති ඊබ්‍රාහිම් මොහොමඩ් සෝලිහ් පරාජයට පත් කරමින් මොහොමඩ් මුයිසු ජනාධිපතිවරණය ජයග්‍රහණය කර ඇත.

        මොහොමඩ් මුයිසු 54%ක ඡන්ද ප්‍රතිශතයකින් ජනාධිපතිවරණය ජයග්‍රහණය කර තිබේ.

        'ඉන්දියාව ඉවතට' යන සටන් පාඨය ඔස්සේ මොහොමඩ් මුයිසු සිය ජනාධිපතිවරණ ව්‍යාපාරය සිදු කළේය.
        """,
        """ආසියානු ක්‍රීඩා උළෙලේ කාන්තා ක්‍රිකට් අවසාන තරගයේ කාසියේ වාසිය දිනාගැනීමට ඉන්දීය නායිකාව සමත්වුණි.

            ඒ අනුව ඇය පළමුවෙන් පන්දුවට පහරදීමට තීරණය කළාය.

            තරගය මෙරට වේලාවෙන් පෙරවරු 11.30ට චීනයේ හැන්ග්ෂු හිදී ආරම්භ වීමට නියමිතය.

            ඊයේ (24) පැවති දෙවන අවසන් පූර්ව තරගයෙන් පාකිස්තාන කණ්ඩායම පරදා කඩුලු 06ක ජයක් හිමිකරගනිමින් රන් පදක්කම සඳහා වූ අවසන් තරගයට සුදුසුකම් ලබාගැනීමට ශ්‍රී ලංකා කාන්තා කණ්ඩායම සමත්වුණි.

            අද පැවැත්වෙන තරගයෙන් ජයගතහොත් ශ්‍රී ලංකා කණ්ඩායමට රන් පදක්කම හිමිවන අතර පරාජය වුවහොත් තරග ඉසව්වේ රිදී පදක්කම හිමි වේ.

            ඒ අනුව 2014 වසරට පසුව එනම් වසර 9කට පසුව ආසියානු ක්‍රීඩා උළෙලකදී ශ්‍රී ලංකාවට පදක්කමක් හිමිවීමට නියමිතය."""
    ]
    st.subheader("Examples")
    st.table(pd.DataFrame(examples, columns=['Example']))

# get prompt
def get_prompt(mask_token):
    st.subheader("Input a prompt")
    # user can toggle between default prompt and custom prompt
    default_prompt = st.checkbox("Use default prompt",value=True)
    if default_prompt:
        prompt = f"මෙය {mask_token} ඝණයේ තොරතුරක්."
    else:
        prompt = st.text_input("Prompt", f"මෙය {mask_token} ඝණයේ තොරතුරක්.")
    return prompt

if __name__ == "__main__":
    model_id,mask_token = get_model_id()
    pipe = pipeline("fill-mask", model=model_id, token=hf_token)
    if st.checkbox("Show example"):
        show_example()
    sentence = get_input_text()
    # submit button
    if sentence:
        prompt = get_prompt(mask_token)
        if prompt and st.button("Classify"):
            # pipe = pipeline("fill-mask", model=model_id, token=hf_token)
            output = pipe(sentence + prompt, targets=TARGETS, top_k =len(TARGETS))
            output = pd.DataFrame(output)
            output['score'] = output['score'].apply(lambda x:x/sum(output['score']))
            output.rename(columns={'token_str':'label'}, inplace=True)
            # plot altair bar chart
            bar_chart = alt.Chart(output).mark_bar().encode(
                x='label:N',
                y='score:Q',
                # increase blue gradient as score increases
                color=alt.Color('score:Q', scale=alt.Scale(scheme='blues')),
                tooltip=['label:N', 'score:Q']
            ).properties(
                title='Zeroshot Classification Results',
                width=800,
                height=400
            )
            bar_chart.configure_axis(grid=False, labelFont=font, labelColor=font_color, titleColor= font_title).configure_view(strokeOpacity=0)
            bar_chart.configure_title(anchor='start')
            predicted_class = output.loc[output['score'].idxmax()]['label']
            predicted_class_en = SIN_2_ENG[predicted_class]
            st.altair_chart(bar_chart, use_container_width=True)
            st.markdown(
                "It seems this sentence belongs to the :green[{}]({}) category.".format(predicted_class,predicted_class_en)
            )
            # st.markdown(
            #     "This demo was created by [Ransaka Ravihara](https://www.linkedin.com/in/ransaka/)."
            #     )