import streamlit as st import json from streamlit_shortcuts import add_keyboard_shortcuts st.set_page_config(layout="wide") file_path = 'synth_toy_eval.json' # Load your data @st.cache_data() def load_data(): with open(file_path, 'r') as file: data = json.load(file) return data def save_data(data): print(file_path.split(".json")[0]) with open(f"{file_path.split('.json')[0]}_graded.json", 'w') as file: json.dump(data, file, indent=4) data = load_data() for query in data: for result in query['results']: if 'verified' not in result: result['verified'] = False # State management for current query index if 'current_query_index' not in st.session_state: st.session_state.current_query_index = 0 if 'data' not in st.session_state: st.session_state.data = data if 'graded_queries' not in st.session_state: st.session_state.graded_queries = 0 def truncate_text(text, length=250): return text if len(text) <= length else text[:length] + '...' result_box_style = """ """ # Navigation to next query def next_query(): if st.session_state.current_query_index < len(data) - 1: st.session_state.current_query_index += 1 st.rerun() # Display current query and its results def display_query(): # Navigation bar global current_query st.session_state.graded_queries = sum(query.get('status', None) is not None for query in st.session_state.data) print(f"Current Query Index: {st.session_state.current_query_index} | Graded Queries: {st.session_state.graded_queries} | Total Queries: {len(st.session_state.data)} | Current Query Status {current_query.get('status', None)}") col1, col2 = st.columns([5, 1], gap="small") with col1: if st.button('Previous'): if st.session_state.current_query_index > 0: st.session_state.current_query_index -= 1 % len(st.session_state.data) st.rerun() st.progress((st.session_state.current_query_index + 1) / len(st.session_state.data)) with col2: col1, col2, col3 = st.columns([1, 1, 1], gap = "small") with col1: if st.button('Next'): if st.session_state.current_query_index < len(st.session_state.data) - 1: st.session_state.current_query_index += 1 % len(st.session_state.data) st.rerun() with col2: if st.button('Skip'): current_query['status'] = 'skipped' next_query() with col3: if st.button('Junk'): current_query['status'] = 'nonsense' next_query() st.markdown(f"

At index {st.session_state.current_query_index + 1}. Graded Queries: {st.session_state.graded_queries}/{len(st.session_state.data)}

", unsafe_allow_html=True) if st.session_state.graded_queries >= len(data): save_data(st.session_state.data) st.success(f"{len(data)} Queries graded and data saved!") st.markdown(result_box_style, unsafe_allow_html=True) st.header(f"Query: {current_query['query']}") status_color = 'green' if current_query.get('status', None) is not None else 'red' st.markdown(f"{current_query['grid_pos_str']} | Query Grade: {'Graded' if status_color == 'green' else 'Ungraded'}", unsafe_allow_html = True) st.subheader("Results:") for index, result in enumerate(current_query['results']): st.markdown(f"
", unsafe_allow_html=True) col1, col2 = st.columns([3, 2], gap="small") with col1: # title_style = f"color: {'green' if result.get('verified') is True else 'red' if result.get('verified') is False else 'white'};" st.markdown(f"
{result['title']}
", unsafe_allow_html=True) st.markdown(f"[{truncate_text(result['url'], length = 50)}]({result['url']}) | {result['published_date']}", unsafe_allow_html=True) st.markdown(f"{truncate_text(result['text'], length = len(result['model_trace']))}") with col2: grade_color = 'green' if result['grade'].lower() == 'yes' else 'red' st.markdown(f"Model Grade: {result['grade']}", unsafe_allow_html=True) st.write(result['model_trace']) if st.checkbox("Accept", value=result.get('verified'), key=f'verify-{index}'): result['verified'] = True # btn_cols = st.columns([1, 1]) # with btn_cols[0]: # if st.button('Accept', key=f'accept-{index}'): # result['verified'] = True # if result.get('verified') is True: # st.write('Accepted') # with btn_cols[1]: # if st.button('Reject', key=f'reject-{index}'): # result['verified'] = False # if result.get('verified') is False: # st.write('Rejected') st.markdown("
", unsafe_allow_html=True) st.markdown(f"
", unsafe_allow_html=True) # Show current query and its results current_query = st.session_state.data[st.session_state.current_query_index] display_query() col1, col2 = st.columns([5, 1], gap="small") with col2: if st.button('Mark Done and Go to Next'): current_query['status'] = 'graded' next_query() add_keyboard_shortcuts({ 's': 'Skip', }) add_keyboard_shortcuts({ 'j': 'Junk', }) add_keyboard_shortcuts({ 'p': 'Previous', }) add_keyboard_shortcuts({ 'n': 'Next', })