File size: 5,591 Bytes
46a13bb
 
 
 
d892a20
2dbe361
46a13bb
 
 
 
 
 
 
 
 
 
 
 
 
 
d892a20
 
46a13bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d892a20
 
46a13bb
 
 
 
 
 
d892a20
 
 
 
 
 
 
 
 
 
46a13bb
 
 
 
 
d892a20
46a13bb
 
 
d892a20
 
46a13bb
 
 
d892a20
46a13bb
 
 
 
 
 
d892a20
46a13bb
 
d892a20
46a13bb
 
 
 
 
d892a20
46a13bb
d892a20
 
 
46a13bb
 
 
 
 
149c109
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
from dotenv import load_dotenv
import json
from generate import generate_text, get_prompt_from_test_case
from time import sleep
# load_dotenv()

catalog = {}
all_test_cases = []

with open('catalog.json') as f:
    catalog = json.load(f)

starting_test_case = [t for sub_catalog_name, sub_catalog in catalog.items() for t in sub_catalog if t['name'] == 'Harmful' and sub_catalog_name == "Harmful content in user message"][0]
test_case_name = gr.HTML(f'<h2>{starting_test_case["name"]}</h2>')
criteria = gr.Textbox(label="Definition", lines=3, interactive=False, value=starting_test_case['criteria'])
context = gr.Textbox(label="Context", lines=3, interactive=True, value=starting_test_case['context'], visible=False)
user_message = gr.Textbox(label="User Message", lines=3, interactive=True, value=starting_test_case['user_message'])
assistant_message = gr.Textbox(label="Assistant Message", lines=3, interactive=True, value=starting_test_case['assistant_message'])
catalog_buttons: dict[str,dict[str,gr.Button]] = {}
result_text = gr.Textbox(label="Result", interactive=False)
result_certainty = gr.Number(label="Certainty", interactive=False, value='')

for sub_catalog_name, sub_catalog in catalog.items():
    catalog_buttons[sub_catalog_name] = {}
    for test_case in sub_catalog:
        elem_classes=['catalog-button']
        elem_id=f"{sub_catalog_name}_{test_case['name']}"
        if elem_id == "Harmful content in user message_Harmful":
            elem_classes.append('selected')
        catalog_buttons[sub_catalog_name][test_case['name']] = \
            gr.Button(test_case['name'], elem_classes=elem_classes, variant='secondary', size='sm', elem_id=elem_id)

# watsonx_api_url = os.getenv("WATSONX_URL", None) 
# watsonx_project_id = os.getenv("WATSONX_API_KEY", None) 
# watsonx_api_key = os.getenv("WATSONX_PROJECT_ID", None) 

# client = APIClient(credentials={
#     "url": watsonx_api_url,
#     "project_id": watsonx_project_id,
#     "api_key": watsonx_api_key
# })
# model = ModelInference(model_id=ModelTypes.LLAMA_3_8B_INSTRUCT, api_client=client)

# client.set.default_project(watsonx_project_id)

# bam_api_url = os.getenv("BAM_URL", None) 
# bam_api_key = os.getenv("BAM_API_KEY", None) 
# client = Client(credentials=Credentials(api_endpoint=bam_api_url, api_key=bam_api_key ))


def on_test_case_click(link):
    selected_test_case = [t for sub_catalog in catalog.values() for t in sub_catalog if t['name'] == link][0]
    return {
        test_case_name: f'<h2>{selected_test_case["name"]}</h2>',
        criteria: selected_test_case['criteria'],
        context: selected_test_case['context'] if selected_test_case['context'] is not None else gr.update(visible=False),
        user_message: selected_test_case['user_message'],
        assistant_message: selected_test_case['assistant_message'],
        result_text: gr.update(value=''),
        result_certainty: gr.update(value='')
    }

def change_button_color(event: gr.EventData):
    return [gr.update(elem_classes=['catalog-button', 'selected']) if v.elem_id == event.target.elem_id else gr.update(elem_classes=['catalog-button']) for c in catalog_buttons.values() for v in c.values()]

def on_submit(inputs):
    # prompt = get_prompt_from_test_case({
    #     'criteria': inputs[criteria],
    #     'context': inputs[context],
    #     'user_message': inputs[user_message],
    #     'assistant_message': inputs[assistant_message],
    # })
    # result = generate_text(prompt)
    # return result['assessment'], result['certainty']
    sleep(3)
    return 'Yes', 0.97

with gr.Blocks(
     theme=gr.themes.Soft(font=[gr.themes.GoogleFont("IBM Plex Sans")]), css='styles.css') as demo:
    with gr.Row():
        gr.HTML('<h1>Granite Guardian</h1>', elem_classes='title')
    with gr.Row(elem_classes='column-gap'):
        with gr.Column(scale=0):
            title_display_left = gr.HTML("<h2>Catalog - Harms & Risks</h2>")
            accordions = []
            for i, (sub_catalog_name, sub_catalog) in enumerate(catalog.items()):
                with gr.Accordion(sub_catalog_name, open=i==0, elem_classes='accordion-align') as accordion:
                    for test_case in sub_catalog:
                        link = catalog_buttons[sub_catalog_name][test_case['name']]
                        link.render()
                        link.click(on_test_case_click, link, {test_case_name, criteria, context, user_message, assistant_message, result_text, result_certainty}) \
                            .then(change_button_color, None, [v for c in catalog_buttons.values() for v in c.values()])
                    accordions.append(accordion)
                    def on_accordion_open(open):
                        pass
        with gr.Column(visible=True) as test_case_content:
            test_case_name.render()
            gr.HTML("Evaluation Criteria", elem_classes='subtitle')
            criteria.render()

            gr.HTML("Test Data", elem_classes='subtitle')
            context.render()
            user_message.render()
            assistant_message.render()

            submit_button = gr.Button("Evaluate", variant='primary')
            gr.HTML("Evaluation results", elem_classes='subtitle')
            with gr.Row():
                result_text.render()
                result_certainty.render()

            submit_button.click(
                on_submit, 
                inputs={test_case_name, criteria, context, user_message, assistant_message}, 
                outputs=[result_text, result_certainty])

demo.launch(server_name='0.0.0.0')