Martín Santillán Cooper
UX improvemnts
e5f0735
raw
history blame
8.44 kB
import gradio as gr
from dotenv import load_dotenv
from utils import get_evaluated_component, get_evaluated_component_adjective, to_title_case, get_prompt_from_test_case, to_snake_case
load_dotenv()
import json
from model import generate_text
from logger import logger
import os
from gradio_modal import Modal
catalog = {}
with open('catalog.json') as f:
logger.debug('Loading catalog from json.')
catalog = json.load(f)
def update_selected_test_case(button_name, state: gr.State, event: gr.EventData):
target_sub_catalog_name, target_test_case_name = event.target.elem_id.split('---')
state['selected_sub_catalog'] = target_sub_catalog_name
state['selected_criteria_name'] = target_test_case_name
state['selected_test_case'] = [t for sub_catalog_name, sub_catalog in catalog.items() for t in sub_catalog if t['name'] == to_snake_case(button_name) and to_snake_case(sub_catalog_name) == target_sub_catalog_name][0]
return state
def on_test_case_click(state: gr.State):
selected_sub_catalog = state['selected_sub_catalog']
selected_criteria_name = state['selected_criteria_name']
selected_test_case = state['selected_test_case']
logger.debug(f'Changing to test case "{selected_criteria_name}" from catalog "{selected_sub_catalog}".')
return {
test_case_name: f'<h2>{to_title_case(selected_test_case["name"])}</h2>',
criteria: selected_test_case['criteria'],
context: gr.update(value=selected_test_case['context'], visible=True) if selected_test_case['context'] is not None else gr.update(visible=False, value=''),
user_message: gr.update(value=selected_test_case['user_message'], elem_classes=[], interactive=True),# if selected_sub_catalog != 'harmful_content_in_assistant_message' else gr.update(value=selected_test_case['user_message'], interactive=False, elem_classes=['read-only']),
assistant_message: gr.update(value=selected_test_case['assistant_message'], visible=True) if selected_test_case['assistant_message'] is not None else gr.update(visible=False, value=''),
result_text: gr.update(value='', visible=False),
result_container: gr.update(visible=False)
}
def change_button_color(event: gr.EventData):
return [gr.update(elem_classes=['catalog-button', 'selected']) if v.elem_id == event.target.elem_id else gr.update(elem_classes=['catalog-button']) for c in catalog_buttons.values() for v in c.values()]
def on_submit(criteria, context, user_message, assistant_message, state):
prompt = get_prompt_from_test_case({
'name': state['selected_criteria_name'],
'criteria': criteria,
'context': context,
'user_message': user_message,
'assistant_message': assistant_message,
}, state['selected_sub_catalog'])
evaluated_component = get_evaluated_component(state['selected_sub_catalog'], state['selected_criteria_name'])
evaluated_component_adjective = get_evaluated_component_adjective(state['selected_sub_catalog'], state['selected_criteria_name'])
logger.debug(f"Starting evaluation for subcatelog {state['selected_sub_catalog']} and criteria name {state['selected_criteria_name']}")
result_label = generate_text(prompt)['assessment'] # Yes or No
# result_str = f"Is the {evaluated_component} message harmful based on the risk definition: {result_label}"
html_str = f"<p>Is the {evaluated_component} {evaluated_component_adjective}: <strong>{result_label}</strong></p>"
return gr.update(value=html_str)
def on_show_prompt_click(criteria, context, user_message, assistant_message, state):
prompt = get_prompt_from_test_case({
'name': state['selected_criteria_name'],
'criteria': criteria,
'context': context,
'user_message': user_message,
'assistant_message': assistant_message,
}, state['selected_sub_catalog'])
prompt['content'] = prompt['content'].replace('<', '&lt;').replace('>', '&gt;').replace('\n', '<br>')
prompt = json.dumps(prompt, indent=4)
return gr.Markdown(prompt)
with gr.Blocks(
title='Granite Guardian',
theme=gr.themes.Soft(font=[gr.themes.GoogleFont("IBM Plex Sans")]), css='styles.css') as demo:
state = gr.State(value={
'selected_sub_catalog': 'harmful_content_in_user_message',
'selected_criteria_name': 'harmful'
})
starting_test_case = [t for sub_catalog_name, sub_catalog in catalog.items() for t in sub_catalog if t['name'] == state.value['selected_criteria_name'] and sub_catalog_name == state.value['selected_sub_catalog']][0]
with gr.Row():
gr.HTML('<h1>Granite Guardian</h1>', elem_classes='title')
with gr.Row(elem_classes='column-gap'):
with gr.Column(scale=0):
title_display_left = gr.HTML("<h2>Harms & Risks</h2>")
accordions = []
catalog_buttons: dict[str,dict[str,gr.Button]] = {}
for i, (sub_catalog_name, sub_catalog) in enumerate(catalog.items()):
with gr.Accordion(to_title_case(sub_catalog_name), open=(i==0), elem_classes='accordion-align') as accordion:
for test_case in sub_catalog:
elem_classes=['catalog-button']
elem_id=f"{sub_catalog_name}---{test_case['name']}"
if starting_test_case == test_case:
elem_classes.append('selected')
if not sub_catalog_name in catalog_buttons:
catalog_buttons[sub_catalog_name] = {}
catalog_buttons[sub_catalog_name][test_case['name']] = \
gr.Button(to_title_case(test_case['name']), elem_classes=elem_classes, variant='secondary', size='sm', elem_id=elem_id)
accordions.append(accordion)
with gr.Column(visible=True) as test_case_content:
with gr.Row():
test_case_name = gr.HTML(f'<h2>{to_title_case(starting_test_case["name"])}</h2>')
show_propt_button = gr.Button('Show prompt', size='sm', scale=0, min_width=110)
gr.HTML("Evaluation Criteria", elem_classes='subtitle')
criteria = gr.Textbox(label="Definition", lines=3, interactive=False, value=starting_test_case['criteria'], elem_classes=['read-only'])
gr.HTML("Test Data", elem_classes='subtitle')
context = gr.Textbox(label="Context", lines=3, interactive=True, value=starting_test_case['context'], visible=False)
user_message = gr.Textbox(label="User Message", lines=3, interactive=True, value=starting_test_case['user_message'])
assistant_message = gr.Textbox(label="Assistant Message", lines=3, interactive=True, visible=False, value=starting_test_case['assistant_message'])
submit_button = gr.Button("Evaluate", variant='primary')
with gr.Column(elem_classes="result-container", visible=False) as result_container:
evaluation_results_label = gr.HTML("<span>Results</span>", elem_classes='result-title', visible=False)
result_text = gr.HTML(label="Result", elem_classes=['read-only', "result-text"], visible=False)
with Modal(visible=False, elem_classes='modal') as modal:
prompt = gr.Markdown("Hello world!")
show_propt_button.click(
on_show_prompt_click,
inputs=[criteria, context, user_message, assistant_message, state],
outputs=prompt
).then(lambda: gr.update(visible=True), None, modal)
submit_button.click(lambda: [gr.update(visible=True, value=''), gr.update(visible=True), gr.update(visible=True)], inputs=None, outputs=[result_text, evaluation_results_label, result_container]).then(
on_submit,
inputs=[criteria, context, user_message, assistant_message, state],
outputs=result_text)
for button in [t for sub_catalog_name, sub_catalog_buttons in catalog_buttons.items() for t in sub_catalog_buttons.values()]:
button.click(update_selected_test_case, inputs=[button, state], outputs=[state])\
.then(on_test_case_click, inputs=state, outputs={test_case_name, criteria, context, user_message, assistant_message, result_text, result_container}) \
.then(change_button_color, None, [v for c in catalog_buttons.values() for v in c.values()])
demo.launch(server_name='0.0.0.0')