File size: 5,250 Bytes
c61c6e7
4e0be52
 
 
c61c6e7
4e0be52
 
c61c6e7
4e0be52
 
 
c61c6e7
4e0be52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c61c6e7
4e0be52
 
 
c61c6e7
4e0be52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import streamlit as st
import pandas as pd
import os
import requests

st.set_page_config(page_icon='🦜', page_title='Text Generation Labeling Tool', layout='wide', initial_sidebar_state="collapsed")
st.markdown("<h1 style='text-align: center;'>Text Generation Labeling Tool</h1>", unsafe_allow_html=True)

def file_selector(folder_path=r'./Datasets'):
    filenames = os.listdir(folder_path)
    return filenames, folder_path

def revert_question_type_id(txt_question_type):
    if txt_question_type == 'What':
        return 0
    elif txt_question_type == 'Who':
        return 1
    elif txt_question_type == 'When':
        return 2
    elif txt_question_type == 'Where':
        return 3
    elif txt_question_type == 'Why':
        return 4
    elif txt_question_type == 'How':
        return 5
    elif txt_question_type == 'Others':
        return 6

filenames, folder_path = file_selector()
filename_input = st.sidebar.selectbox(label='Input dataset file:', options=filenames)
df = pd.read_csv(f'./{folder_path}/{filename_input}')

if 'idx' not in st.session_state:
    st.session_state.idx = 0

st.markdown(f"<h4 style='text-align: center;'>Sample {st.session_state.idx + 1}/{len(df)}</h4>", unsafe_allow_html=True)

col_1, col_2, col_3, col_4, col_5, col_6, col_7, col_8, col_9, col_10 = st.columns([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

btn_previous = col_1.button(label=':arrow_backward: Previous sample', use_container_width=True)
btn_next = col_2.button(label='Next sample :arrow_forward:', use_container_width=True)
btn_save = col_3.button(label=':heavy_check_mark: Save change', use_container_width=True)
txt_goto = col_4.selectbox(label='Sample', label_visibility='collapsed', options=list(range(1, len(df) + 1)))
btn_goto = col_5.button(label=':fast_forward: Move to', use_container_width=True)

if len(df) != 0:
    col_1, col_2, col_3, col_4, col_5, col_6, col_7, col_8, col_9, col_10 = st.columns(spec=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
    txt_context = st.text_area(height=200, label='Your context:', value=df['context'][st.session_state.idx])

    col_11, col_12, col_13 = st.columns([4.5, 1, 4.5])
    txt_question = col_11.text_area(height=90, label='Your question:', value=df['question'][st.session_state.idx])
    txt_question_type = col_12.selectbox(label='Your question type:', options=['What', 'Who', 'When', 'Where', 'Why', 'How', 'Others'], index=int(df['question_type'][st.session_state.idx]))
    txt_answer = col_13.text_area(height=90, label='Your answer:', value=df['answer'][st.session_state.idx])

    st.markdown(f"<p style='text-align: left; font-weight: normal; font-size: 14px'>Your distractors:</p>", unsafe_allow_html=True)

    col_21, col_22 = st.columns(spec=[9, 1])
    txt_distractors = col_21.text_area(height=90, label='Your distractors:', label_visibility='collapsed', value=df['distract'][st.session_state.idx])  
    btn_generate_distractor = col_22.button(label='Generate distractors', use_container_width=True)
    
    if btn_generate_distractor:
        if filename_input == 'BiologyQA.csv':
            expert = 'biologist'
        elif filename_input == 'GeographyQA.csv':
            expert = 'geographer'
        elif filename_input == 'HistoryQA.csv':
            expert = 'historian'
        url = 'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent'
        headers = {'Content-Type': 'application/json'}
        data = {
            'contents': [
                {
                    'parts': [
                        {
                            'text': f"You are a great {expert}, here is the following content: context: '{txt_context}', question: '{txt_question}', answer: '{txt_answer}' generate three distract answers. Distractor answers are separated by [SEP]. Example: Distract answer 1 [SEP] Distract answer 2 [SEP] Distract answer 3"
                        }
                    ]
                }
            ]
        }
        api_key = 'AIzaSyApFAbCUA1H-VHAidzqmyStHFe92ODeO1Y'
        params = {'key': api_key}
        response = requests.post(url, headers=headers, json=data, params=params)
        if response.status_code == 200:
            correct = response.json()['candidates'][0]['content']['parts'][0]['text']
            st.success(f'3 distraction answers: {correct}')
            st.cache_data.clear()
        else:
            st.error('Failed to generate distractors. Please check API and inputs.')
            st.rerun()

    if btn_previous:
        if st.session_state.idx > 0:
            st.session_state.idx -= 1
            st.rerun()
        else:
            pass

    if btn_next:
        if st.session_state.idx < (len(df) - 1):
            st.session_state.idx += 1
            st.rerun()
        else:
            pass

    if btn_save:
        df['context'][st.session_state.idx] = txt_context
        df['question'][st.session_state.idx] = txt_question
        df['answer'][st.session_state.idx] = txt_answer
        df['distract'][st.session_state.idx] = txt_distractors
        df['question_type'][st.session_state.idx] = revert_question_type_id(txt_question_type)

        df.to_csv(f'./Datasets/{filename_input}', index=None)

    if btn_goto:
        st.session_state.idx = txt_goto - 1
        st.rerun()