Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from dotenv import load_dotenv | |
from utils import to_title_case, get_prompt_from_test_case, to_snake_case | |
load_dotenv() | |
import json | |
from model import generate_text | |
from logger import logger | |
import os | |
from gradio_modal import Modal | |
catalog = {} | |
with open('catalog.json') as f: | |
logger.debug('Loading catalog from json.') | |
catalog = json.load(f) | |
def update_selected_test_case(button_name, state: gr.State, event: gr.EventData): | |
target_sub_catalog_name, target_test_case_name = event.target.elem_id.split('---') | |
state['selected_sub_catalog'] = target_sub_catalog_name | |
state['selected_criteria_name'] = target_test_case_name | |
state['selected_test_case'] = [t for sub_catalog_name, sub_catalog in catalog.items() for t in sub_catalog if t['name'] == to_snake_case(button_name) and to_snake_case(sub_catalog_name) == target_sub_catalog_name][0] | |
return state | |
def on_test_case_click(state: gr.State): | |
selected_sub_catalog = state['selected_sub_catalog'] | |
selected_criteria_name = state['selected_criteria_name'] | |
selected_test_case = state['selected_test_case'] | |
logger.debug(f'Changing to test case "{selected_criteria_name}" from catalog "{selected_sub_catalog}".') | |
return { | |
test_case_name: f'<h2>{to_title_case(selected_test_case["name"])}</h2>', | |
criteria: selected_test_case['criteria'], | |
context: gr.update(value=selected_test_case['context'], visible=True) if selected_test_case['context'] is not None else gr.update(visible=False, value=''), | |
user_message: gr.update(value=selected_test_case['user_message'], elem_classes=[], interactive=True) if selected_sub_catalog != 'harmful_content_in_assistant_message' else gr.update(value=selected_test_case['user_message'], interactive=False, elem_classes=['read-only']), | |
assistant_message: gr.update(value=selected_test_case['assistant_message'], visible=True) if selected_test_case['assistant_message'] is not None else gr.update(visible=False, value=''), | |
result_text: gr.update(value=''), | |
result_certainty: gr.update(value='') | |
} | |
def change_button_color(event: gr.EventData): | |
return [gr.update(elem_classes=['catalog-button', 'selected']) if v.elem_id == event.target.elem_id else gr.update(elem_classes=['catalog-button']) for c in catalog_buttons.values() for v in c.values()] | |
def on_submit(criteria, context, user_message, assistant_message, state): | |
prompt = get_prompt_from_test_case({ | |
'name': state['selected_criteria_name'], | |
'criteria': criteria, | |
'context': context, | |
'user_message': user_message, | |
'assistant_message': assistant_message, | |
}, state['selected_sub_catalog']) | |
result = generate_text(prompt) | |
return result['assessment'], result['certainty'] | |
def on_show_prompt_click(criteria, context, user_message, assistant_message, state): | |
prompt = get_prompt_from_test_case({ | |
'name': state['selected_criteria_name'], | |
'criteria': criteria, | |
'context': context, | |
'user_message': user_message, | |
'assistant_message': assistant_message, | |
}, state['selected_sub_catalog']) | |
prompt['content'] = prompt['content'].replace('<', '<').replace('>', '>').replace('\n', '<br>') | |
prompt = json.dumps(prompt, indent=4) | |
print(prompt) | |
return gr.Markdown(prompt) | |
with gr.Blocks( | |
title='Granite Guardian', | |
theme=gr.themes.Soft(font=[gr.themes.GoogleFont("IBM Plex Sans")]), css='styles.css') as demo: | |
state = gr.State(value={ | |
'selected_sub_catalog': 'harmful_content_in_user_message', | |
'selected_criteria_name': 'harmful' | |
}) | |
starting_test_case = [t for sub_catalog_name, sub_catalog in catalog.items() for t in sub_catalog if t['name'] == state.value['selected_criteria_name'] and sub_catalog_name == state.value['selected_sub_catalog']][0] | |
with gr.Row(): | |
gr.HTML('<h1>Granite Guardian</h1>', elem_classes='title') | |
with gr.Row(elem_classes='column-gap'): | |
with gr.Column(scale=0): | |
title_display_left = gr.HTML("<h2>Harms & Risks</h2>") | |
accordions = [] | |
catalog_buttons: dict[str,dict[str,gr.Button]] = {} | |
for i, (sub_catalog_name, sub_catalog) in enumerate(catalog.items()): | |
with gr.Accordion(to_title_case(sub_catalog_name), open=(i==0), elem_classes='accordion-align') as accordion: | |
for test_case in sub_catalog: | |
elem_classes=['catalog-button'] | |
elem_id=f"{sub_catalog_name}---{test_case['name']}" | |
if starting_test_case == test_case: | |
elem_classes.append('selected') | |
if not sub_catalog_name in catalog_buttons: | |
catalog_buttons[sub_catalog_name] = {} | |
catalog_buttons[sub_catalog_name][test_case['name']] = \ | |
gr.Button(to_title_case(test_case['name']), elem_classes=elem_classes, variant='secondary', size='sm', elem_id=elem_id) | |
accordions.append(accordion) | |
with gr.Column(visible=True) as test_case_content: | |
test_case_name = gr.HTML(f'<h2>{to_title_case(starting_test_case["name"])}</h2>') | |
gr.HTML("Evaluation Criteria", elem_classes='subtitle') | |
criteria = gr.Textbox(label="Definition", lines=3, interactive=False, value=starting_test_case['criteria'], elem_classes=['read-only']) | |
gr.HTML("Test Data", elem_classes='subtitle') | |
context = gr.Textbox(label="Context", lines=3, interactive=True, value=starting_test_case['context'], visible=False) | |
user_message = gr.Textbox(label="User Message", lines=3, interactive=True, value=starting_test_case['user_message']) | |
assistant_message = gr.Textbox(label="Assistant Message", lines=3, interactive=True, visible=False, value=starting_test_case['assistant_message']) | |
submit_button = gr.Button("Evaluate", variant='primary') | |
gr.HTML("Evaluation results", elem_classes='subtitle') | |
with gr.Row(): | |
result_text = gr.Textbox(label="Result", interactive=False, elem_classes=['read-only']) | |
result_certainty = gr.Number(label="Certainty", interactive=False, value='', elem_classes=['read-only']) | |
show_propt_button = gr.Button('Show prompt', size='sm', scale=0) | |
with Modal(visible=False) as modal: | |
prompt = gr.Markdown("Hello world!") | |
show_propt_button.click( | |
on_show_prompt_click, | |
inputs=[criteria, context, user_message, assistant_message, state], | |
outputs=prompt | |
).then(lambda: gr.update(visible=True), None, modal) | |
submit_button.click( | |
on_submit, | |
inputs=[criteria, context, user_message, assistant_message, state], | |
outputs=[result_text, result_certainty]) | |
for button in [t for sub_catalog_name, sub_catalog_buttons in catalog_buttons.items() for t in sub_catalog_buttons.values()]: | |
button.click(update_selected_test_case, inputs=[button, state], outputs=[state])\ | |
.then(on_test_case_click, inputs=state, outputs={test_case_name, criteria, context, user_message, assistant_message, result_text, result_certainty}) \ | |
.then(change_button_color, None, [v for c in catalog_buttons.values() for v in c.values()]) | |
demo.launch(server_name='0.0.0.0') | |