File size: 4,555 Bytes
491e087 a887cae 344f958 491e087 5344f77 491e087 344f958 491e087 5344f77 491e087 2704216 491e087 5344f77 2704216 491e087 344f958 491e087 344f958 491e087 344f958 491e087 344f958 491e087 344f958 491e087 5344f77 344f958 a887cae 5344f77 a887cae 344f958 a887cae 491e087 344f958 491e087 2704216 344f958 a887cae 2704216 a887cae 344f958 2704216 344f958 2704216 344f958 2704216 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import streamlit as st
import pandas as pd
# from plms.language_model import TransformersQG
import time
import os
import numpy as np
st.set_page_config(page_icon='🧪', page_title='ViQAG for Vietnamese Education', layout='centered', initial_sidebar_state="collapsed")
with open(r"./static/styles.css") as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
st.markdown(f"""
<div class=logo_area>
<img src="./app/static/AlphaEdu_logo_trans.png"/>
</div>
""", unsafe_allow_html=True)
st.markdown("<h1 style='text-align: center;'>AlphaEdu</h1>", unsafe_allow_html=True)
# =====================================================================================================
if 'output' not in st.session_state:
st.session_state.output = ''
def file_selector(folder_path=r'./Resources/'):
filenames = os.listdir(folder_path)
return filenames
filenames = file_selector()
def load_grades(file_name, folder_path=r'./Resources/'):
file_path = f"{folder_path}{file_name}"
df = pd.read_csv(file_path)
list_grades = df['grade'].drop_duplicates().values
return list_grades, df
def load_chapters(df, grade_name):
df_raw = df[df['grade'] == grade_name]
list_chapters = df_raw['chapter'].drop_duplicates().values
return list_chapters, df
def load_lessons(df, grade_name, chapter_name):
df_raw = df[(df['grade'] == grade_name) & (df['chapter'] == chapter_name)]
return df_raw['lesson'].drop_duplicates().values
def load_context(df, grade_name, chapter_name, lesson_name):
context = df[(df['grade'] == grade_name) & (df['chapter'] == chapter_name) & (df['lesson'] == lesson_name)]['context'].values
return len(context), context
def generateQA(context, model_path = 'shnl/vit5-vinewsqa-qg-ae'):
unique_qa_pairs = set()
model = TransformersQG(model=model_path, max_length=512)
output = model.generate_qa(context)
qa_pairs = ''
for item in output:
question, answer = item
if (question, answer) not in unique_qa_pairs:
qa_pairs += f'question: {question} \nanswer: {answer} [SEP] '
unique_qa_pairs.add((question, answer))
qa = '\n\n'.join(qa_pairs.split(' [SEP] '))
return qa
# =====================================================================================================
col_1, col_2, col_3 = st.columns(spec=[2.5, 1.5, 6])
subject = col_1.selectbox(label='Subject:', options=filenames, label_visibility='visible')
list_grades, df = load_grades(file_name=subject)
grade = col_2.selectbox(label='Grade:', options=list_grades, label_visibility='visible')
list_chapters, df = load_chapters(df=df, grade_name=grade)
chapter = col_3.selectbox(label='Chapter:', options=list_chapters, label_visibility='visible')
col_11, col_21 = st.columns(spec=[8, 2])
lesson_names = load_lessons(df=df, grade_name=grade, chapter_name=chapter)
lesson = col_11.selectbox(label='Lesson:', options=lesson_names, label_visibility='visible')
total_paragraph, context_values = load_context(df=df, grade_name=grade, chapter_name=chapter, lesson_name=lesson)
paragraph_idx = col_21.selectbox(label='Paragraph:', options=list(np.arange(1, total_paragraph + 1)), label_visibility='visible')
paragraph = st.text_area(label='Paragraph content', label_visibility='visible', height=200, value=context_values[paragraph_idx - 1])
col_13, col_23, col_33 = st.columns(spec=[3.6, 2.4, 3.6])
col_23.selectbox(label='QAG model:', options=['ViT5-ViNewsQA'], label_visibility='visible')
btn_show_answer = col_23.toggle(label='Show answers', disabled=False)
col_14, col_24, col_34, col_44, col_54 = st.columns(spec=[1, 1, 1, 1, 1])
btn_generate = col_34.button(label='Generate', use_container_width=True)
if btn_generate == True:
with st.spinner(text='Generating QA pairs from the selected paragraph. Please wait ...'):
st.session_state.output = generateQA(context=paragraph)
if btn_show_answer:
if st.session_state.output != '':
# st.markdown("<h8 style='text-align: left; font-weight: normal'>Generated QA pairs:</h8>", unsafe_allow_html=True)
st.code(body=st.session_state.output, language='latex')
else:
pass
else:
if st.session_state.output != '':
st.markdown("<h8 style='text-align: left; font-weight: normal'>Generated QA pairs:</h8>", unsafe_allow_html=True)
output_no_answer = st.session_state.output.split(' [SEP] ')[0].split(', answer: ')[0].replace('question: ', '')
st.code(body=output_no_answer, language='latex')
else:
pass |