File size: 5,867 Bytes
f396b8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import streamlit as st
import json
from streamlit_shortcuts import add_keyboard_shortcuts

st.set_page_config(layout="wide")

file_path = 'synth_toy_eval.json'

# Load your data
@st.cache_data()
def load_data():
    with open(file_path, 'r') as file:
        data = json.load(file)
    return data

def save_data(data):
    print(file_path.split(".json")[0])
    with open(f"{file_path.split('.json')[0]}_graded.json", 'w') as file:
        json.dump(data, file, indent=4)

data = load_data()

for query in data:
    for result in query['results']:
        if 'verified' not in result:
            result['verified'] = False
            
# State management for current query index
if 'current_query_index' not in st.session_state:
    st.session_state.current_query_index = 0
if 'data' not in st.session_state:
    st.session_state.data = data
if 'graded_queries' not in st.session_state:
    st.session_state.graded_queries = 0
    
def truncate_text(text, length=250):
    return text if len(text) <= length else text[:length] + '...'

result_box_style = """
<style>
.rounded-box {
    border: 1px solid #ddd;
    border-radius: 10px;
    padding: 0.01px;
    margin-bottom: 10px;
}
</style>
"""
# Navigation to next query
def next_query():
    if st.session_state.current_query_index < len(data) - 1:
        st.session_state.current_query_index += 1
    st.rerun()
    
# Display current query and its results
def display_query():
    # Navigation bar
    global current_query    

    st.session_state.graded_queries = sum(query.get('status', None) is not None for query in st.session_state.data)
    print(f"Current Query Index: {st.session_state.current_query_index} | Graded Queries: {st.session_state.graded_queries} | Total Queries: {len(st.session_state.data)} | Current Query Status {current_query.get('status', None)}")
    col1, col2 = st.columns([5, 1], gap="small")
    with col1:
        if st.button('Previous'):
            if st.session_state.current_query_index > 0:
                st.session_state.current_query_index -= 1 % len(st.session_state.data)
                st.rerun()
        st.progress((st.session_state.current_query_index + 1) / len(st.session_state.data))
    with col2:
        col1, col2, col3 = st.columns([1, 1, 1], gap = "small")
        with col1:
            if st.button('Next'):
                if st.session_state.current_query_index < len(st.session_state.data) - 1:
                    st.session_state.current_query_index += 1 % len(st.session_state.data)
                    st.rerun()
        with col2:
            if st.button('Skip'):
                current_query['status'] = 'skipped'
                next_query()
        with col3:
            if st.button('Junk'):
                current_query['status'] = 'nonsense'
                next_query()

        st.markdown(f"<p>At index {st.session_state.current_query_index + 1}. Graded Queries: {st.session_state.graded_queries}/{len(st.session_state.data)}</p>", unsafe_allow_html=True)

    if st.session_state.graded_queries >= len(data):
        save_data(st.session_state.data)
        st.success(f"{len(data)} Queries graded and data saved!")
    
    st.markdown(result_box_style, unsafe_allow_html=True)

    st.header(f"Query: {current_query['query']}")
    status_color = 'green' if current_query.get('status', None) is not None else 'red'
    st.markdown(f"{current_query['grid_pos_str']} | Query Grade: <b style='color: {status_color};'>{'Graded' if status_color == 'green' else 'Ungraded'}</b>", unsafe_allow_html = True)

    st.subheader("Results:")
    for index, result in enumerate(current_query['results']):
        st.markdown(f"<div class='rounded-box'>", unsafe_allow_html=True)
        col1, col2 = st.columns([3, 2], gap="small")
        with col1:
            # title_style = f"color: {'green' if result.get('verified') is True else 'red' if result.get('verified') is False else 'white'};"
            
            st.markdown(f"<h5>{result['title']}</h5>", unsafe_allow_html=True)
            st.markdown(f"[<span style='font-size: 0.8em;'>{truncate_text(result['url'], length = 50)}</span>]({result['url']})  |  {result['published_date']}", unsafe_allow_html=True)
            st.markdown(f"{truncate_text(result['text'], length = len(result['model_trace']))}")
        with col2:
            grade_color = 'green' if result['grade'].lower() == 'yes' else 'red'
            st.markdown(f"<b style='color: {grade_color};'>Model Grade: {result['grade']}</b>", unsafe_allow_html=True)
            st.write(result['model_trace'])
            
            if st.checkbox("Accept", value=result.get('verified'), key=f'verify-{index}'):
                result['verified'] = True

            # btn_cols = st.columns([1, 1])
            # with btn_cols[0]:
            #     if st.button('Accept', key=f'accept-{index}'):
            #         result['verified'] = True
            #     if result.get('verified') is True:
            #         st.write('Accepted')
            # with btn_cols[1]:
            #     if st.button('Reject', key=f'reject-{index}'):
            #         result['verified'] = False
            #     if result.get('verified') is False:
            #         st.write('Rejected')
                
        st.markdown("</div>", unsafe_allow_html=True)
    st.markdown(f"<div class='rounded-box'>", unsafe_allow_html=True)

# Show current query and its results
current_query = st.session_state.data[st.session_state.current_query_index]



display_query()

col1, col2 = st.columns([5, 1], gap="small")
with col2:
    if st.button('Mark Done and Go to Next'):
        current_query['status'] = 'graded'
        next_query()
    

add_keyboard_shortcuts({
    's': 'Skip',
})
add_keyboard_shortcuts({
    'j': 'Junk',
})
add_keyboard_shortcuts({
    'p': 'Previous',
})
add_keyboard_shortcuts({
    'n': 'Next',
})