File size: 4,555 Bytes
491e087
 
a887cae
344f958
491e087
5344f77
491e087
344f958
491e087
 
 
 
 
 
 
 
 
5344f77
491e087
 
 
2704216
 
 
491e087
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5344f77
2704216
 
 
 
 
 
 
 
 
 
 
 
 
 
491e087
 
344f958
 
491e087
 
344f958
491e087
 
344f958
491e087
344f958
491e087
344f958
491e087
5344f77
344f958
a887cae
5344f77
a887cae
344f958
a887cae
491e087
344f958
 
491e087
2704216
344f958
a887cae
2704216
 
 
a887cae
344f958
2704216
 
 
 
344f958
2704216
344f958
2704216
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import streamlit as st
import pandas as pd
# from plms.language_model import TransformersQG
import time
import os
import numpy as np

st.set_page_config(page_icon='🧪', page_title='ViQAG for Vietnamese Education', layout='centered', initial_sidebar_state="collapsed")

with open(r"./static/styles.css") as f:
    st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)

st.markdown(f"""
    <div class=logo_area>
        <img src="./app/static/AlphaEdu_logo_trans.png"/>
    </div>
    """, unsafe_allow_html=True)
st.markdown("<h1 style='text-align: center;'>AlphaEdu</h1>", unsafe_allow_html=True)

# =====================================================================================================

if 'output' not in st.session_state:
    st.session_state.output = ''

def file_selector(folder_path=r'./Resources/'):
    filenames = os.listdir(folder_path)
    return filenames
filenames = file_selector()

def load_grades(file_name, folder_path=r'./Resources/'):
    file_path = f"{folder_path}{file_name}"
    df = pd.read_csv(file_path)
    list_grades = df['grade'].drop_duplicates().values
    return list_grades, df

def load_chapters(df, grade_name):
    df_raw = df[df['grade'] == grade_name]
    list_chapters = df_raw['chapter'].drop_duplicates().values
    return list_chapters, df

def load_lessons(df, grade_name, chapter_name):
    df_raw = df[(df['grade'] == grade_name) & (df['chapter'] == chapter_name)]
    return df_raw['lesson'].drop_duplicates().values

def load_context(df, grade_name, chapter_name, lesson_name):
    context = df[(df['grade'] == grade_name) & (df['chapter'] == chapter_name) & (df['lesson'] == lesson_name)]['context'].values
    return len(context), context

def generateQA(context, model_path = 'shnl/vit5-vinewsqa-qg-ae'):
    unique_qa_pairs = set()
    model = TransformersQG(model=model_path, max_length=512)
    output = model.generate_qa(context)
    qa_pairs = ''
    for item in output:
        question, answer = item
        if (question, answer) not in unique_qa_pairs:
            qa_pairs += f'question: {question} \nanswer: {answer} [SEP] '
            unique_qa_pairs.add((question, answer))
    qa = '\n\n'.join(qa_pairs.split(' [SEP] '))
    return qa
    
# =====================================================================================================

col_1, col_2, col_3 = st.columns(spec=[2.5, 1.5, 6])
subject = col_1.selectbox(label='Subject:', options=filenames, label_visibility='visible')

list_grades, df = load_grades(file_name=subject)
grade = col_2.selectbox(label='Grade:', options=list_grades, label_visibility='visible')

list_chapters, df = load_chapters(df=df, grade_name=grade)
chapter = col_3.selectbox(label='Chapter:', options=list_chapters, label_visibility='visible')

col_11, col_21 = st.columns(spec=[8, 2])
lesson_names = load_lessons(df=df, grade_name=grade, chapter_name=chapter)
lesson = col_11.selectbox(label='Lesson:', options=lesson_names, label_visibility='visible')

total_paragraph, context_values = load_context(df=df, grade_name=grade, chapter_name=chapter, lesson_name=lesson)
paragraph_idx = col_21.selectbox(label='Paragraph:', options=list(np.arange(1, total_paragraph + 1)), label_visibility='visible')
paragraph = st.text_area(label='Paragraph content', label_visibility='visible', height=200, value=context_values[paragraph_idx - 1])

col_13, col_23, col_33 = st.columns(spec=[3.6, 2.4, 3.6])
col_23.selectbox(label='QAG model:', options=['ViT5-ViNewsQA'], label_visibility='visible')
btn_show_answer = col_23.toggle(label='Show answers', disabled=False)

col_14, col_24, col_34, col_44, col_54 = st.columns(spec=[1, 1, 1, 1, 1])
btn_generate = col_34.button(label='Generate', use_container_width=True)

if btn_generate == True:
    with st.spinner(text='Generating QA pairs from the selected paragraph. Please wait ...'):
        st.session_state.output = generateQA(context=paragraph)

if btn_show_answer:
    if st.session_state.output != '':
        # st.markdown("<h8 style='text-align: left; font-weight: normal'>Generated QA pairs:</h8>", unsafe_allow_html=True)
        st.code(body=st.session_state.output, language='latex')
    else:
        pass
else:
    if st.session_state.output != '':
        st.markdown("<h8 style='text-align: left; font-weight: normal'>Generated QA pairs:</h8>", unsafe_allow_html=True)
        output_no_answer = st.session_state.output.split(' [SEP] ')[0].split(', answer: ')[0].replace('question: ', '')
        st.code(body=output_no_answer, language='latex')
    else:
        pass