Martín Santillán Cooper
start using local granite guardian model
d46878a
raw
history blame
5.38 kB
import gradio as gr
from dotenv import load_dotenv
from utils import get_prompt_from_test_case
load_dotenv()
import json
from model import generate_text
import logging
logging.getLogger('demo')
catalog = {}
all_test_cases = []
with open('catalog.json') as f:
catalog = json.load(f)
starting_test_case = [t for sub_catalog_name, sub_catalog in catalog.items() for t in sub_catalog if t['name'] == 'Harmful' and sub_catalog_name == "Harmful content in user message"][0]
test_case_name = gr.HTML(f'<h2>{starting_test_case["name"]}</h2>')
criteria = gr.Textbox(label="Definition", lines=3, interactive=False, value=starting_test_case['criteria'])
context = gr.Textbox(label="Context", lines=3, interactive=True, value=starting_test_case['context'], visible=False)
user_message = gr.Textbox(label="User Message", lines=3, interactive=True, value=starting_test_case['user_message'])
assistant_message = gr.Textbox(label="Assistant Message", lines=3, interactive=True, visible=False, value=starting_test_case['assistant_message'])
catalog_buttons: dict[str,dict[str,gr.Button]] = {}
result_text = gr.Textbox(label="Result", interactive=False)
result_certainty = gr.Number(label="Certainty", interactive=False, value='')
for sub_catalog_name, sub_catalog in catalog.items():
catalog_buttons[sub_catalog_name] = {}
for test_case in sub_catalog:
elem_classes=['catalog-button']
elem_id=f"{sub_catalog_name}_{test_case['name']}"
if elem_id == "Harmful content in user message_Harmful":
elem_classes.append('selected')
catalog_buttons[sub_catalog_name][test_case['name']] = \
gr.Button(test_case['name'], elem_classes=elem_classes, variant='secondary', size='sm', elem_id=elem_id)
def on_test_case_click(link, event: gr.EventData):
target_sub_catalog_name, target_test_case_name = event.target.elem_id.split('_')
selected_test_case = [t for sub_catalog_name, sub_catalog in catalog.items() for t in sub_catalog if t['name'] == link and sub_catalog_name == target_sub_catalog_name][0]
print(selected_test_case['assistant_message'])
return {
test_case_name: f'<h2>{selected_test_case["name"]}</h2>',
criteria: selected_test_case['criteria'],
context: selected_test_case['context'] if selected_test_case['context'] is not None else gr.update(visible=False, value=''),
user_message: selected_test_case['user_message'],
assistant_message: gr.update(value=selected_test_case['assistant_message'], visible=True) if selected_test_case['assistant_message'] is not None else gr.update(visible=False, value=''),
result_text: gr.update(value=''),
result_certainty: gr.update(value='')
}
def change_button_color(event: gr.EventData):
return [gr.update(elem_classes=['catalog-button', 'selected']) if v.elem_id == event.target.elem_id else gr.update(elem_classes=['catalog-button']) for c in catalog_buttons.values() for v in c.values()]
def on_submit(inputs):
prompt = get_prompt_from_test_case({
'criteria': inputs[criteria],
'context': inputs[context],
'user_message': inputs[user_message],
'assistant_message': inputs[assistant_message],
})
result = generate_text(prompt)
return result['assessment'], result['certainty']
# sleep(3)
# return 'Yes', 0.97
with gr.Blocks(
title='Granite Guardian',
theme=gr.themes.Soft(font=[gr.themes.GoogleFont("IBM Plex Sans")]), css='styles.css') as demo:
with gr.Row():
gr.HTML('<h1>Granite Guardian</h1>', elem_classes='title')
with gr.Row(elem_classes='column-gap'):
with gr.Column(scale=0):
title_display_left = gr.HTML("<h2>Catalog - Harms & Risks</h2>")
accordions = []
for i, (sub_catalog_name, sub_catalog) in enumerate(catalog.items()):
with gr.Accordion(sub_catalog_name, open=i==0, elem_classes='accordion-align') as accordion:
for test_case in sub_catalog:
link = catalog_buttons[sub_catalog_name][test_case['name']]
link.render()
link.click(on_test_case_click, link, {test_case_name, criteria, context, user_message, assistant_message, result_text, result_certainty}) \
.then(change_button_color, None, [v for c in catalog_buttons.values() for v in c.values()])
accordions.append(accordion)
def on_accordion_open(open):
pass
with gr.Column(visible=True) as test_case_content:
test_case_name.render()
gr.HTML("Evaluation Criteria", elem_classes='subtitle')
criteria.render()
gr.HTML("Test Data", elem_classes='subtitle')
context.render()
user_message.render()
assistant_message.render()
submit_button = gr.Button("Evaluate", variant='primary')
gr.HTML("Evaluation results", elem_classes='subtitle')
with gr.Row():
result_text.render()
result_certainty.render()
submit_button.click(
on_submit,
inputs={test_case_name, criteria, context, user_message, assistant_message},
outputs=[result_text, result_certainty])
demo.launch(server_name='0.0.0.0')