Truong-Phuc Nguyen
commited on
Commit
•
4e0be52
1
Parent(s):
6861876
Update app.py
Browse files
app.py
CHANGED
@@ -1,17 +1,117 @@
|
|
1 |
import streamlit as st
|
2 |
-
|
3 |
-
|
|
|
4 |
|
5 |
-
|
6 |
-
|
7 |
|
8 |
-
|
9 |
-
|
|
|
10 |
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
|
|
|
|
15 |
|
16 |
-
if
|
17 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import os
|
4 |
+
import requests
|
5 |
|
6 |
+
st.set_page_config(page_icon='🦜', page_title='Text Generation Labeling Tool', layout='wide', initial_sidebar_state="collapsed")
|
7 |
+
st.markdown("<h1 style='text-align: center;'>Text Generation Labeling Tool</h1>", unsafe_allow_html=True)
|
8 |
|
9 |
+
def file_selector(folder_path=r'./Datasets'):
|
10 |
+
filenames = os.listdir(folder_path)
|
11 |
+
return filenames, folder_path
|
12 |
|
13 |
+
def revert_question_type_id(txt_question_type):
|
14 |
+
if txt_question_type == 'What':
|
15 |
+
return 0
|
16 |
+
elif txt_question_type == 'Who':
|
17 |
+
return 1
|
18 |
+
elif txt_question_type == 'When':
|
19 |
+
return 2
|
20 |
+
elif txt_question_type == 'Where':
|
21 |
+
return 3
|
22 |
+
elif txt_question_type == 'Why':
|
23 |
+
return 4
|
24 |
+
elif txt_question_type == 'How':
|
25 |
+
return 5
|
26 |
+
elif txt_question_type == 'Others':
|
27 |
+
return 6
|
28 |
|
29 |
+
filenames, folder_path = file_selector()
|
30 |
+
filename_input = st.sidebar.selectbox(label='Input dataset file:', options=filenames)
|
31 |
+
df = pd.read_csv(f'./{folder_path}/{filename_input}')
|
32 |
|
33 |
+
if 'idx' not in st.session_state:
|
34 |
+
st.session_state.idx = 0
|
35 |
+
|
36 |
+
st.markdown(f"<h4 style='text-align: center;'>Sample {st.session_state.idx + 1}/{len(df)}</h4>", unsafe_allow_html=True)
|
37 |
+
|
38 |
+
col_1, col_2, col_3, col_4, col_5, col_6, col_7, col_8, col_9, col_10 = st.columns([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
|
39 |
+
|
40 |
+
btn_previous = col_1.button(label=':arrow_backward: Previous sample', use_container_width=True)
|
41 |
+
btn_next = col_2.button(label='Next sample :arrow_forward:', use_container_width=True)
|
42 |
+
btn_save = col_3.button(label=':heavy_check_mark: Save change', use_container_width=True)
|
43 |
+
txt_goto = col_4.selectbox(label='Sample', label_visibility='collapsed', options=list(range(1, len(df) + 1)))
|
44 |
+
btn_goto = col_5.button(label=':fast_forward: Move to', use_container_width=True)
|
45 |
+
|
46 |
+
if len(df) != 0:
|
47 |
+
col_1, col_2, col_3, col_4, col_5, col_6, col_7, col_8, col_9, col_10 = st.columns(spec=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
|
48 |
+
txt_context = st.text_area(height=200, label='Your context:', value=df['context'][st.session_state.idx])
|
49 |
+
|
50 |
+
col_11, col_12, col_13 = st.columns([4.5, 1, 4.5])
|
51 |
+
txt_question = col_11.text_area(height=90, label='Your question:', value=df['question'][st.session_state.idx])
|
52 |
+
txt_question_type = col_12.selectbox(label='Your question type:', options=['What', 'Who', 'When', 'Where', 'Why', 'How', 'Others'], index=int(df['question_type'][st.session_state.idx]))
|
53 |
+
txt_answer = col_13.text_area(height=90, label='Your answer:', value=df['answer'][st.session_state.idx])
|
54 |
+
|
55 |
+
st.markdown(f"<p style='text-align: left; font-weight: normal; font-size: 14px'>Your distractors:</p>", unsafe_allow_html=True)
|
56 |
+
|
57 |
+
col_21, col_22 = st.columns(spec=[9, 1])
|
58 |
+
txt_distractors = col_21.text_area(height=90, label='Your distractors:', label_visibility='collapsed', value=df['distract'][st.session_state.idx])
|
59 |
+
btn_generate_distractor = col_22.button(label='Generate distractors', use_container_width=True)
|
60 |
+
|
61 |
+
if btn_generate_distractor:
|
62 |
+
if filename_input == 'BiologyQA.csv':
|
63 |
+
expert = 'biologist'
|
64 |
+
elif filename_input == 'GeographyQA.csv':
|
65 |
+
expert = 'geographer'
|
66 |
+
elif filename_input == 'HistoryQA.csv':
|
67 |
+
expert = 'historian'
|
68 |
+
url = 'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent'
|
69 |
+
headers = {'Content-Type': 'application/json'}
|
70 |
+
data = {
|
71 |
+
'contents': [
|
72 |
+
{
|
73 |
+
'parts': [
|
74 |
+
{
|
75 |
+
'text': f"You are a great {expert}, here is the following content: context: '{txt_context}', question: '{txt_question}', answer: '{txt_answer}' generate three distract answers. Distractor answers are separated by [SEP]. Example: Distract answer 1 [SEP] Distract answer 2 [SEP] Distract answer 3"
|
76 |
+
}
|
77 |
+
]
|
78 |
+
}
|
79 |
+
]
|
80 |
+
}
|
81 |
+
api_key = 'AIzaSyApFAbCUA1H-VHAidzqmyStHFe92ODeO1Y'
|
82 |
+
params = {'key': api_key}
|
83 |
+
response = requests.post(url, headers=headers, json=data, params=params)
|
84 |
+
if response.status_code == 200:
|
85 |
+
correct = response.json()['candidates'][0]['content']['parts'][0]['text']
|
86 |
+
st.success(f'3 distraction answers: {correct}')
|
87 |
+
st.cache_data.clear()
|
88 |
+
else:
|
89 |
+
st.error('Failed to generate distractors. Please check API and inputs.')
|
90 |
+
st.rerun()
|
91 |
+
|
92 |
+
if btn_previous:
|
93 |
+
if st.session_state.idx > 0:
|
94 |
+
st.session_state.idx -= 1
|
95 |
+
st.rerun()
|
96 |
+
else:
|
97 |
+
pass
|
98 |
+
|
99 |
+
if btn_next:
|
100 |
+
if st.session_state.idx < (len(df) - 1):
|
101 |
+
st.session_state.idx += 1
|
102 |
+
st.rerun()
|
103 |
+
else:
|
104 |
+
pass
|
105 |
+
|
106 |
+
if btn_save:
|
107 |
+
df['context'][st.session_state.idx] = txt_context
|
108 |
+
df['question'][st.session_state.idx] = txt_question
|
109 |
+
df['answer'][st.session_state.idx] = txt_answer
|
110 |
+
df['distract'][st.session_state.idx] = txt_distractors
|
111 |
+
df['question_type'][st.session_state.idx] = revert_question_type_id(txt_question_type)
|
112 |
+
|
113 |
+
df.to_csv(f'./Datasets/{filename_input}', index=None)
|
114 |
+
|
115 |
+
if btn_goto:
|
116 |
+
st.session_state.idx = txt_goto - 1
|
117 |
+
st.rerun()
|