Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from dotenv import load_dotenv | |
import json | |
from generate import generate_text, get_prompt_from_test_case | |
# load_dotenv() | |
catalog = {} | |
all_test_cases = [] | |
with open('catalog.json') as f: | |
catalog = json.load(f) | |
starting_test_case = [t for sub_catalog_name, sub_catalog in catalog.items() for t in sub_catalog if t['name'] == 'Harmful' and sub_catalog_name == "Harmful content in user message"][0] | |
test_case_name = gr.HTML(f'<h2>{starting_test_case["name"]}</h2>') | |
criteria = gr.Textbox(label="Definition", lines=3, interactive=False, value=starting_test_case['criteria']) | |
context = gr.Textbox(label="Context", lines=3, interactive=True, value=starting_test_case['context'], visible=False) | |
user_message = gr.Textbox(label="User Message", lines=3, interactive=True, value=starting_test_case['user_message']) | |
assistant_message = gr.Textbox(label="Assistant Message", lines=3, interactive=True, value=starting_test_case['assistant_message']) | |
catalog_buttons: dict[str,dict[str,gr.Button]] = {} | |
for sub_catalog_name, sub_catalog in catalog.items(): | |
catalog_buttons[sub_catalog_name] = {} | |
for test_case in sub_catalog: | |
elem_classes=['catalog-button'] | |
elem_id=f"{sub_catalog_name}_{test_case['name']}" | |
if elem_id == "Harmful content in user message_Harmful": | |
elem_classes.append('selected') | |
catalog_buttons[sub_catalog_name][test_case['name']] = \ | |
gr.Button(test_case['name'], elem_classes=elem_classes, variant='secondary', size='sm', elem_id=elem_id) | |
# watsonx_api_url = os.getenv("WATSONX_URL", None) | |
# watsonx_project_id = os.getenv("WATSONX_API_KEY", None) | |
# watsonx_api_key = os.getenv("WATSONX_PROJECT_ID", None) | |
# client = APIClient(credentials={ | |
# "url": watsonx_api_url, | |
# "project_id": watsonx_project_id, | |
# "api_key": watsonx_api_key | |
# }) | |
# model = ModelInference(model_id=ModelTypes.LLAMA_3_8B_INSTRUCT, api_client=client) | |
# client.set.default_project(watsonx_project_id) | |
# bam_api_url = os.getenv("BAM_URL", None) | |
# bam_api_key = os.getenv("BAM_API_KEY", None) | |
# client = Client(credentials=Credentials(api_endpoint=bam_api_url, api_key=bam_api_key )) | |
def on_test_case_click(link): | |
selected_test_case = [t for sub_catalog in catalog.values() for t in sub_catalog if t['name'] == link][0] | |
return { | |
test_case_name: f'<h2>{selected_test_case["name"]}</h2>', | |
criteria: selected_test_case['criteria'], | |
context: selected_test_case['context'] if selected_test_case['context'] is not None else gr.update(visible=False), | |
user_message: selected_test_case['user_message'], | |
assistant_message: selected_test_case['assistant_message'], | |
} | |
def change_button_color(event: gr.EventData): | |
return [gr.update(elem_classes=['catalog-button', 'selected']) if v.elem_id == event.target.elem_id else gr.update(elem_classes=['catalog-button']) for c in catalog_buttons.values() for v in c.values()] | |
def on_submit(inputs): | |
prompt = get_prompt_from_test_case({ | |
'criteria': inputs[criteria], | |
'context': inputs[context], | |
'user_message': inputs[user_message], | |
'assistant_message': inputs[assistant_message], | |
}) | |
result = generate_text(prompt) | |
return result['assessment'], result['certainty'] | |
with gr.Blocks( | |
theme=gr.themes.Soft(font=[gr.themes.GoogleFont("IBM Plex Sans")]), css='styles.css') as demo: | |
with gr.Row(): | |
gr.HTML('<h1>Granite Guardian</h1>', elem_classes='title') | |
with gr.Row(): | |
with gr.Column(scale=0): | |
title_display_left = gr.HTML("<h2>Catalog - Harms & Risks</h2>") | |
accordions = [] | |
for sub_catalog_name, sub_catalog in catalog.items(): | |
with gr.Accordion(sub_catalog_name, open=True) as accordion: | |
for test_case in sub_catalog: | |
link = catalog_buttons[sub_catalog_name][test_case['name']] | |
link.render() | |
link.click(on_test_case_click, link, {test_case_name, criteria, context, user_message, assistant_message}) \ | |
.then(change_button_color, None, [v for c in catalog_buttons.values() for v in c.values()]) | |
accordions.append(accordion) | |
def on_accordion_open(open): | |
pass | |
with gr.Column(visible=True) as test_case_content: | |
test_case_name.render() | |
gr.HTML("Evaluation Criteria") | |
criteria.render() | |
gr.Markdown("Test Data") | |
context.render() | |
user_message.render() | |
assistant_message.render() | |
submit_button = gr.Button("Evaluate", variant='primary') | |
with gr.Row(): | |
result_text = gr.Textbox(label="Result", interactive=False) | |
result_certainty = gr.Number(label="Certainty", interactive=False, value='') | |
submit_button.click( | |
on_submit, | |
inputs={test_case_name, criteria, context, user_message, assistant_message}, | |
outputs=[result_text, result_certainty]) | |
demo.launch() | |