Spaces:
Sleeping
Sleeping
File size: 5,867 Bytes
f396b8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import streamlit as st
import json
from streamlit_shortcuts import add_keyboard_shortcuts
st.set_page_config(layout="wide")
file_path = 'synth_toy_eval.json'
# Load your data
@st.cache_data()
def load_data():
with open(file_path, 'r') as file:
data = json.load(file)
return data
def save_data(data):
print(file_path.split(".json")[0])
with open(f"{file_path.split('.json')[0]}_graded.json", 'w') as file:
json.dump(data, file, indent=4)
data = load_data()
for query in data:
for result in query['results']:
if 'verified' not in result:
result['verified'] = False
# State management for current query index
if 'current_query_index' not in st.session_state:
st.session_state.current_query_index = 0
if 'data' not in st.session_state:
st.session_state.data = data
if 'graded_queries' not in st.session_state:
st.session_state.graded_queries = 0
def truncate_text(text, length=250):
return text if len(text) <= length else text[:length] + '...'
result_box_style = """
<style>
.rounded-box {
border: 1px solid #ddd;
border-radius: 10px;
padding: 0.01px;
margin-bottom: 10px;
}
</style>
"""
# Navigation to next query
def next_query():
if st.session_state.current_query_index < len(data) - 1:
st.session_state.current_query_index += 1
st.rerun()
# Display current query and its results
def display_query():
# Navigation bar
global current_query
st.session_state.graded_queries = sum(query.get('status', None) is not None for query in st.session_state.data)
print(f"Current Query Index: {st.session_state.current_query_index} | Graded Queries: {st.session_state.graded_queries} | Total Queries: {len(st.session_state.data)} | Current Query Status {current_query.get('status', None)}")
col1, col2 = st.columns([5, 1], gap="small")
with col1:
if st.button('Previous'):
if st.session_state.current_query_index > 0:
st.session_state.current_query_index -= 1 % len(st.session_state.data)
st.rerun()
st.progress((st.session_state.current_query_index + 1) / len(st.session_state.data))
with col2:
col1, col2, col3 = st.columns([1, 1, 1], gap = "small")
with col1:
if st.button('Next'):
if st.session_state.current_query_index < len(st.session_state.data) - 1:
st.session_state.current_query_index += 1 % len(st.session_state.data)
st.rerun()
with col2:
if st.button('Skip'):
current_query['status'] = 'skipped'
next_query()
with col3:
if st.button('Junk'):
current_query['status'] = 'nonsense'
next_query()
st.markdown(f"<p>At index {st.session_state.current_query_index + 1}. Graded Queries: {st.session_state.graded_queries}/{len(st.session_state.data)}</p>", unsafe_allow_html=True)
if st.session_state.graded_queries >= len(data):
save_data(st.session_state.data)
st.success(f"{len(data)} Queries graded and data saved!")
st.markdown(result_box_style, unsafe_allow_html=True)
st.header(f"Query: {current_query['query']}")
status_color = 'green' if current_query.get('status', None) is not None else 'red'
st.markdown(f"{current_query['grid_pos_str']} | Query Grade: <b style='color: {status_color};'>{'Graded' if status_color == 'green' else 'Ungraded'}</b>", unsafe_allow_html = True)
st.subheader("Results:")
for index, result in enumerate(current_query['results']):
st.markdown(f"<div class='rounded-box'>", unsafe_allow_html=True)
col1, col2 = st.columns([3, 2], gap="small")
with col1:
# title_style = f"color: {'green' if result.get('verified') is True else 'red' if result.get('verified') is False else 'white'};"
st.markdown(f"<h5>{result['title']}</h5>", unsafe_allow_html=True)
st.markdown(f"[<span style='font-size: 0.8em;'>{truncate_text(result['url'], length = 50)}</span>]({result['url']}) | {result['published_date']}", unsafe_allow_html=True)
st.markdown(f"{truncate_text(result['text'], length = len(result['model_trace']))}")
with col2:
grade_color = 'green' if result['grade'].lower() == 'yes' else 'red'
st.markdown(f"<b style='color: {grade_color};'>Model Grade: {result['grade']}</b>", unsafe_allow_html=True)
st.write(result['model_trace'])
if st.checkbox("Accept", value=result.get('verified'), key=f'verify-{index}'):
result['verified'] = True
# btn_cols = st.columns([1, 1])
# with btn_cols[0]:
# if st.button('Accept', key=f'accept-{index}'):
# result['verified'] = True
# if result.get('verified') is True:
# st.write('Accepted')
# with btn_cols[1]:
# if st.button('Reject', key=f'reject-{index}'):
# result['verified'] = False
# if result.get('verified') is False:
# st.write('Rejected')
st.markdown("</div>", unsafe_allow_html=True)
st.markdown(f"<div class='rounded-box'>", unsafe_allow_html=True)
# Show current query and its results
current_query = st.session_state.data[st.session_state.current_query_index]
display_query()
col1, col2 = st.columns([5, 1], gap="small")
with col2:
if st.button('Mark Done and Go to Next'):
current_query['status'] = 'graded'
next_query()
add_keyboard_shortcuts({
's': 'Skip',
})
add_keyboard_shortcuts({
'j': 'Junk',
})
add_keyboard_shortcuts({
'p': 'Previous',
})
add_keyboard_shortcuts({
'n': 'Next',
})
|