Truong-Phuc Nguyen commited on
Commit
4e0be52
1 Parent(s): 6861876

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -11
app.py CHANGED
@@ -1,17 +1,117 @@
1
  import streamlit as st
2
- q1 = st.text_area(placeholder='Question 1 ...', key='q1', label='Question 1')
3
- a1 = st.text_area(placeholder='Answer for question 1 ...', key='a1', label='Answer for question 1')
 
4
 
5
- q2 = st.text_area(placeholder='Question 2 ...', key='q2', label='Question 2')
6
- a2 = st.text_area(placeholder='Answer for question 2 ...', key='a2', label='Answer for question 2')
7
 
8
- q3 = st.text_area(placeholder='Question 3 ...', key='q3', label='Question 3')
9
- a3 = st.text_area(placeholder='Answer for question 3 ...', key='a3', label='Answer for question 3')
 
10
 
11
- q4 = st.text_area(placeholder='Question 4 ...', key='q4', label='Question 4')
12
- a4 = st.text_area(placeholder='Answer for question 4 ...', key='a4', label='Answer for question 4')
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- btn_gen_str = st.button('Generate')
 
 
15
 
16
- if btn_gen_str:
17
- st.success("question: " + q1 + ", answer: " + a1 + "[SEP] " + "question: " + q2 + ", answer: " + a2 + "[SEP] " + "question: " + q3 + ", answer: " + a3 + "[SEP] " +"question: " + q4 + ", answer: " + a4 + "[SEP]")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ import os
4
+ import requests
5
 
6
+ st.set_page_config(page_icon='🦜', page_title='Text Generation Labeling Tool', layout='wide', initial_sidebar_state="collapsed")
7
+ st.markdown("<h1 style='text-align: center;'>Text Generation Labeling Tool</h1>", unsafe_allow_html=True)
8
 
9
+ def file_selector(folder_path=r'./Datasets'):
10
+ filenames = os.listdir(folder_path)
11
+ return filenames, folder_path
12
 
13
+ def revert_question_type_id(txt_question_type):
14
+ if txt_question_type == 'What':
15
+ return 0
16
+ elif txt_question_type == 'Who':
17
+ return 1
18
+ elif txt_question_type == 'When':
19
+ return 2
20
+ elif txt_question_type == 'Where':
21
+ return 3
22
+ elif txt_question_type == 'Why':
23
+ return 4
24
+ elif txt_question_type == 'How':
25
+ return 5
26
+ elif txt_question_type == 'Others':
27
+ return 6
28
 
29
+ filenames, folder_path = file_selector()
30
+ filename_input = st.sidebar.selectbox(label='Input dataset file:', options=filenames)
31
+ df = pd.read_csv(f'./{folder_path}/{filename_input}')
32
 
33
+ if 'idx' not in st.session_state:
34
+ st.session_state.idx = 0
35
+
36
+ st.markdown(f"<h4 style='text-align: center;'>Sample {st.session_state.idx + 1}/{len(df)}</h4>", unsafe_allow_html=True)
37
+
38
+ col_1, col_2, col_3, col_4, col_5, col_6, col_7, col_8, col_9, col_10 = st.columns([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
39
+
40
+ btn_previous = col_1.button(label=':arrow_backward: Previous sample', use_container_width=True)
41
+ btn_next = col_2.button(label='Next sample :arrow_forward:', use_container_width=True)
42
+ btn_save = col_3.button(label=':heavy_check_mark: Save change', use_container_width=True)
43
+ txt_goto = col_4.selectbox(label='Sample', label_visibility='collapsed', options=list(range(1, len(df) + 1)))
44
+ btn_goto = col_5.button(label=':fast_forward: Move to', use_container_width=True)
45
+
46
+ if len(df) != 0:
47
+ col_1, col_2, col_3, col_4, col_5, col_6, col_7, col_8, col_9, col_10 = st.columns(spec=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
48
+ txt_context = st.text_area(height=200, label='Your context:', value=df['context'][st.session_state.idx])
49
+
50
+ col_11, col_12, col_13 = st.columns([4.5, 1, 4.5])
51
+ txt_question = col_11.text_area(height=90, label='Your question:', value=df['question'][st.session_state.idx])
52
+ txt_question_type = col_12.selectbox(label='Your question type:', options=['What', 'Who', 'When', 'Where', 'Why', 'How', 'Others'], index=int(df['question_type'][st.session_state.idx]))
53
+ txt_answer = col_13.text_area(height=90, label='Your answer:', value=df['answer'][st.session_state.idx])
54
+
55
+ st.markdown(f"<p style='text-align: left; font-weight: normal; font-size: 14px'>Your distractors:</p>", unsafe_allow_html=True)
56
+
57
+ col_21, col_22 = st.columns(spec=[9, 1])
58
+ txt_distractors = col_21.text_area(height=90, label='Your distractors:', label_visibility='collapsed', value=df['distract'][st.session_state.idx])
59
+ btn_generate_distractor = col_22.button(label='Generate distractors', use_container_width=True)
60
+
61
+ if btn_generate_distractor:
62
+ if filename_input == 'BiologyQA.csv':
63
+ expert = 'biologist'
64
+ elif filename_input == 'GeographyQA.csv':
65
+ expert = 'geographer'
66
+ elif filename_input == 'HistoryQA.csv':
67
+ expert = 'historian'
68
+ url = 'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent'
69
+ headers = {'Content-Type': 'application/json'}
70
+ data = {
71
+ 'contents': [
72
+ {
73
+ 'parts': [
74
+ {
75
+ 'text': f"You are a great {expert}, here is the following content: context: '{txt_context}', question: '{txt_question}', answer: '{txt_answer}' generate three distract answers. Distractor answers are separated by [SEP]. Example: Distract answer 1 [SEP] Distract answer 2 [SEP] Distract answer 3"
76
+ }
77
+ ]
78
+ }
79
+ ]
80
+ }
81
+ api_key = 'AIzaSyApFAbCUA1H-VHAidzqmyStHFe92ODeO1Y'
82
+ params = {'key': api_key}
83
+ response = requests.post(url, headers=headers, json=data, params=params)
84
+ if response.status_code == 200:
85
+ correct = response.json()['candidates'][0]['content']['parts'][0]['text']
86
+ st.success(f'3 distraction answers: {correct}')
87
+ st.cache_data.clear()
88
+ else:
89
+ st.error('Failed to generate distractors. Please check API and inputs.')
90
+ st.rerun()
91
+
92
+ if btn_previous:
93
+ if st.session_state.idx > 0:
94
+ st.session_state.idx -= 1
95
+ st.rerun()
96
+ else:
97
+ pass
98
+
99
+ if btn_next:
100
+ if st.session_state.idx < (len(df) - 1):
101
+ st.session_state.idx += 1
102
+ st.rerun()
103
+ else:
104
+ pass
105
+
106
+ if btn_save:
107
+ df['context'][st.session_state.idx] = txt_context
108
+ df['question'][st.session_state.idx] = txt_question
109
+ df['answer'][st.session_state.idx] = txt_answer
110
+ df['distract'][st.session_state.idx] = txt_distractors
111
+ df['question_type'][st.session_state.idx] = revert_question_type_id(txt_question_type)
112
+
113
+ df.to_csv(f'./Datasets/{filename_input}', index=None)
114
+
115
+ if btn_goto:
116
+ st.session_state.idx = txt_goto - 1
117
+ st.rerun()