Spaces:
Running
on
Zero
Running
on
Zero
Martín Santillán Cooper
commited on
Commit
•
6047b61
1
Parent(s):
a89905d
Continue adaptation
Browse files- .dockerignore +1 -1
- .env.example +1 -1
- .gitignore +1 -1
- catalog.json +1 -1
- convert_to_string.py +7 -13
- src/app.py +38 -30
- src/model.py +85 -84
- src/styles.css +1 -1
- src/utils.py +4 -3
.dockerignore
CHANGED
@@ -5,4 +5,4 @@
|
|
5 |
*.sh
|
6 |
*.md
|
7 |
__pycache__/
|
8 |
-
flagged/
|
|
|
5 |
*.sh
|
6 |
*.md
|
7 |
__pycache__/
|
8 |
+
flagged/
|
.env.example
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
MODEL_PATH='ibm-granite/granite-guardian-3.1-8b'
|
2 |
INFERENCE_ENGINE='VLLM' # one of [WATSONX, MOCK, VLLM]
|
3 |
WATSONX_API_KEY=''
|
4 |
-
WATSONX_PROJECT_ID=''
|
|
|
1 |
MODEL_PATH='ibm-granite/granite-guardian-3.1-8b'
|
2 |
INFERENCE_ENGINE='VLLM' # one of [WATSONX, MOCK, VLLM]
|
3 |
WATSONX_API_KEY=''
|
4 |
+
WATSONX_PROJECT_ID=''
|
.gitignore
CHANGED
@@ -4,4 +4,4 @@ parse.py
|
|
4 |
unparsed_catalog.json
|
5 |
__pycache__/
|
6 |
logs.txt
|
7 |
-
secrets.yaml
|
|
|
4 |
unparsed_catalog.json
|
5 |
__pycache__/
|
6 |
logs.txt
|
7 |
+
secrets.yaml
|
catalog.json
CHANGED
@@ -112,4 +112,4 @@
|
|
112 |
"context": null
|
113 |
}
|
114 |
]
|
115 |
-
}
|
|
|
112 |
"context": null
|
113 |
}
|
114 |
]
|
115 |
+
}
|
convert_to_string.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import json
|
2 |
|
|
|
3 |
def dict_to_json_with_newlines(data):
|
4 |
"""
|
5 |
Converts a dictionary into a JSON string with explicit newlines (\n) added.
|
@@ -12,28 +13,21 @@ def dict_to_json_with_newlines(data):
|
|
12 |
"""
|
13 |
# Convert the dictionary to a pretty-printed JSON string
|
14 |
pretty_json = json.dumps(data, indent=2)
|
15 |
-
|
16 |
# Replace actual newlines with escaped newlines (\n)
|
17 |
json_with_newlines = pretty_json.replace("\n", "\\n")
|
18 |
-
|
19 |
# Escape double quotes for embedding inside other JSON
|
20 |
json_with_newlines = json_with_newlines.replace('"', '\\"')
|
21 |
-
|
22 |
return json_with_newlines
|
23 |
|
|
|
24 |
# Example dictionary
|
25 |
-
example_dict =[
|
26 |
-
{
|
27 |
-
"name": "comment_list",
|
28 |
-
"arguments": {
|
29 |
-
"video_id": 456789123,
|
30 |
-
"count": 15
|
31 |
-
}
|
32 |
-
}
|
33 |
-
]
|
34 |
|
35 |
# Convert the dictionary
|
36 |
result = dict_to_json_with_newlines(example_dict)
|
37 |
|
38 |
print("Resulting JSON string:")
|
39 |
-
print(result)
|
|
|
1 |
import json
|
2 |
|
3 |
+
|
4 |
def dict_to_json_with_newlines(data):
|
5 |
"""
|
6 |
Converts a dictionary into a JSON string with explicit newlines (\n) added.
|
|
|
13 |
"""
|
14 |
# Convert the dictionary to a pretty-printed JSON string
|
15 |
pretty_json = json.dumps(data, indent=2)
|
16 |
+
|
17 |
# Replace actual newlines with escaped newlines (\n)
|
18 |
json_with_newlines = pretty_json.replace("\n", "\\n")
|
19 |
+
|
20 |
# Escape double quotes for embedding inside other JSON
|
21 |
json_with_newlines = json_with_newlines.replace('"', '\\"')
|
22 |
+
|
23 |
return json_with_newlines
|
24 |
|
25 |
+
|
26 |
# Example dictionary
|
27 |
+
example_dict = [{"name": "comment_list", "arguments": {"video_id": 456789123, "count": 15}}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
# Convert the dictionary
|
30 |
result = dict_to_json_with_newlines(example_dict)
|
31 |
|
32 |
print("Resulting JSON string:")
|
33 |
+
print(result)
|
src/app.py
CHANGED
@@ -6,7 +6,7 @@ from dotenv import load_dotenv
|
|
6 |
from gradio_modal import Modal
|
7 |
|
8 |
from logger import logger
|
9 |
-
from model import
|
10 |
from utils import (
|
11 |
get_messages,
|
12 |
get_result_description,
|
@@ -20,7 +20,7 @@ load_dotenv()
|
|
20 |
|
21 |
catalog = {}
|
22 |
|
23 |
-
toy_json =
|
24 |
|
25 |
with open("catalog.json") as f:
|
26 |
logger.debug("Loading catalog from json.")
|
@@ -63,12 +63,9 @@ def on_test_case_click(state: gr.State):
|
|
63 |
# update context field:
|
64 |
if is_context_editable:
|
65 |
context = gr.update(
|
66 |
-
value=selected_test_case["context"],
|
67 |
-
interactive=True,
|
68 |
-
visible=True,
|
69 |
-
elem_classes=["input-box"]
|
70 |
)
|
71 |
-
else:
|
72 |
context = gr.update(
|
73 |
visible=selected_test_case["context"] is not None,
|
74 |
value=selected_test_case["context"],
|
@@ -85,19 +82,13 @@ def on_test_case_click(state: gr.State):
|
|
85 |
# update user message field
|
86 |
if is_user_message_editable:
|
87 |
user_message = gr.update(
|
88 |
-
value=selected_test_case["user_message"],
|
89 |
-
visible=True,
|
90 |
-
interactive=True,
|
91 |
-
elem_classes=["input-box"]
|
92 |
)
|
93 |
else:
|
94 |
user_message = gr.update(
|
95 |
-
value=selected_test_case["user_message"],
|
96 |
-
interactive=False,
|
97 |
-
elem_classes=["read-only", "input-box"]
|
98 |
)
|
99 |
|
100 |
-
|
101 |
# update assistant message field
|
102 |
if is_tools_present:
|
103 |
assistant_message_json = gr.update(
|
@@ -124,8 +115,16 @@ def on_test_case_click(state: gr.State):
|
|
124 |
assistant_message_json = gr.update(visible=False)
|
125 |
|
126 |
result_text = gr.update(visible=False, value="")
|
127 |
-
return
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
|
131 |
def change_button_color(event: gr.EventData):
|
@@ -144,14 +143,15 @@ def on_submit(criteria, context, user_message, assistant_message_text, assistant
|
|
144 |
criteria_name = state["selected_criteria_name"]
|
145 |
if criteria_name == "function_calling_hallucination":
|
146 |
assistant_message = assistant_message_json
|
147 |
-
else:
|
|
|
148 |
test_case = {
|
149 |
"name": criteria_name,
|
150 |
"criteria": criteria,
|
151 |
"context": context,
|
152 |
"user_message": user_message,
|
153 |
"assistant_message": assistant_message,
|
154 |
-
"tools": tools
|
155 |
}
|
156 |
|
157 |
messages = get_messages(test_case=test_case, sub_catalog_name=state["selected_sub_catalog"])
|
@@ -160,7 +160,7 @@ def on_submit(criteria, context, user_message, assistant_message_text, assistant
|
|
160 |
f"Starting evaluation for subcatelog {state['selected_sub_catalog']} and criteria name {state['selected_criteria_name']}"
|
161 |
)
|
162 |
|
163 |
-
result_label =
|
164 |
|
165 |
html_str = f"<p>{get_result_description(state['selected_sub_catalog'], state['selected_criteria_name'])} <strong>{result_label}</strong></p>"
|
166 |
# html_str = f"{get_result_description(state['selected_sub_catalog'], state['selected_criteria_name'])} {result_label}"
|
@@ -171,7 +171,8 @@ def on_show_prompt_click(criteria, context, user_message, assistant_message_text
|
|
171 |
criteria_name = state["selected_criteria_name"]
|
172 |
if criteria_name == "function_calling_hallucination":
|
173 |
assistant_message = assistant_message_json
|
174 |
-
else:
|
|
|
175 |
test_case = {
|
176 |
"name": criteria_name,
|
177 |
"criteria": criteria,
|
@@ -247,7 +248,7 @@ with gr.Blocks(
|
|
247 |
gr.HTML("<h2>IBM Granite Guardian 3.1</h2>", elem_classes="title")
|
248 |
gr.HTML(
|
249 |
elem_classes="system-description",
|
250 |
-
value="<p>Granite Guardian models are specialized language models in the Granite family that can detect harms and risks in generative AI systems. They can be used with any large language model to make interactions with generative AI systems safe. Select an example in the left panel to see how the Granite Guardian model evaluates harms and risks in user prompts, assistant responses, and for hallucinations in
|
251 |
)
|
252 |
with gr.Row(elem_classes="column-gap"):
|
253 |
with gr.Column(scale=0, elem_classes="no-gap"):
|
@@ -301,11 +302,7 @@ with gr.Blocks(
|
|
301 |
elem_classes=["input-box"],
|
302 |
)
|
303 |
|
304 |
-
tools = gr.Code(
|
305 |
-
label="API Definition (Tools)",
|
306 |
-
visible=False,
|
307 |
-
language='json'
|
308 |
-
)
|
309 |
|
310 |
user_message = gr.Textbox(
|
311 |
label="User Prompt",
|
@@ -327,7 +324,7 @@ with gr.Blocks(
|
|
327 |
assistant_message_json = gr.Code(
|
328 |
label="Assistant Response",
|
329 |
visible=False,
|
330 |
-
language=
|
331 |
value=None,
|
332 |
elem_classes=["input-box"],
|
333 |
)
|
@@ -350,7 +347,9 @@ with gr.Blocks(
|
|
350 |
# events
|
351 |
|
352 |
show_propt_button.click(
|
353 |
-
on_show_prompt_click,
|
|
|
|
|
354 |
).then(lambda: gr.update(visible=True), None, modal)
|
355 |
|
356 |
submit_button.click(lambda: gr.update(visible=True, value=""), None, result_text).then(
|
@@ -368,7 +367,16 @@ with gr.Blocks(
|
|
368 |
).then(update_selected_test_case, inputs=[button, state], outputs=[state]).then(
|
369 |
on_test_case_click,
|
370 |
inputs=state,
|
371 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
)
|
373 |
|
374 |
demo.launch(server_name="0.0.0.0")
|
|
|
6 |
from gradio_modal import Modal
|
7 |
|
8 |
from logger import logger
|
9 |
+
from model import get_guardian_response, get_prompt
|
10 |
from utils import (
|
11 |
get_messages,
|
12 |
get_result_description,
|
|
|
20 |
|
21 |
catalog = {}
|
22 |
|
23 |
+
toy_json = '{"name": "John"}'
|
24 |
|
25 |
with open("catalog.json") as f:
|
26 |
logger.debug("Loading catalog from json.")
|
|
|
63 |
# update context field:
|
64 |
if is_context_editable:
|
65 |
context = gr.update(
|
66 |
+
value=selected_test_case["context"], interactive=True, visible=True, elem_classes=["input-box"]
|
|
|
|
|
|
|
67 |
)
|
68 |
+
else:
|
69 |
context = gr.update(
|
70 |
visible=selected_test_case["context"] is not None,
|
71 |
value=selected_test_case["context"],
|
|
|
82 |
# update user message field
|
83 |
if is_user_message_editable:
|
84 |
user_message = gr.update(
|
85 |
+
value=selected_test_case["user_message"], visible=True, interactive=True, elem_classes=["input-box"]
|
|
|
|
|
|
|
86 |
)
|
87 |
else:
|
88 |
user_message = gr.update(
|
89 |
+
value=selected_test_case["user_message"], interactive=False, elem_classes=["read-only", "input-box"]
|
|
|
|
|
90 |
)
|
91 |
|
|
|
92 |
# update assistant message field
|
93 |
if is_tools_present:
|
94 |
assistant_message_json = gr.update(
|
|
|
115 |
assistant_message_json = gr.update(visible=False)
|
116 |
|
117 |
result_text = gr.update(visible=False, value="")
|
118 |
+
return (
|
119 |
+
test_case_name,
|
120 |
+
criteria,
|
121 |
+
context,
|
122 |
+
user_message,
|
123 |
+
assistant_message_text,
|
124 |
+
assistant_message_json,
|
125 |
+
tools,
|
126 |
+
result_text,
|
127 |
+
)
|
128 |
|
129 |
|
130 |
def change_button_color(event: gr.EventData):
|
|
|
143 |
criteria_name = state["selected_criteria_name"]
|
144 |
if criteria_name == "function_calling_hallucination":
|
145 |
assistant_message = assistant_message_json
|
146 |
+
else:
|
147 |
+
assistant_message = assistant_message_text
|
148 |
test_case = {
|
149 |
"name": criteria_name,
|
150 |
"criteria": criteria,
|
151 |
"context": context,
|
152 |
"user_message": user_message,
|
153 |
"assistant_message": assistant_message,
|
154 |
+
"tools": tools,
|
155 |
}
|
156 |
|
157 |
messages = get_messages(test_case=test_case, sub_catalog_name=state["selected_sub_catalog"])
|
|
|
160 |
f"Starting evaluation for subcatelog {state['selected_sub_catalog']} and criteria name {state['selected_criteria_name']}"
|
161 |
)
|
162 |
|
163 |
+
result_label = get_guardian_response(messages=messages, criteria_name=criteria_name)["assessment"] # Yes or No
|
164 |
|
165 |
html_str = f"<p>{get_result_description(state['selected_sub_catalog'], state['selected_criteria_name'])} <strong>{result_label}</strong></p>"
|
166 |
# html_str = f"{get_result_description(state['selected_sub_catalog'], state['selected_criteria_name'])} {result_label}"
|
|
|
171 |
criteria_name = state["selected_criteria_name"]
|
172 |
if criteria_name == "function_calling_hallucination":
|
173 |
assistant_message = assistant_message_json
|
174 |
+
else:
|
175 |
+
assistant_message = assistant_message_text
|
176 |
test_case = {
|
177 |
"name": criteria_name,
|
178 |
"criteria": criteria,
|
|
|
248 |
gr.HTML("<h2>IBM Granite Guardian 3.1</h2>", elem_classes="title")
|
249 |
gr.HTML(
|
250 |
elem_classes="system-description",
|
251 |
+
value="<p>Granite Guardian models are specialized language models in the Granite family that can detect harms and risks in generative AI systems. They can be used with any large language model to make interactions with generative AI systems safe. Select an example in the left panel to see how the Granite Guardian model evaluates harms and risks in user prompts, assistant responses, and for hallucinations in retrival-augmented generation and function calling. In this demo, we use granite-guardian-3.1-8b.</p>",
|
252 |
)
|
253 |
with gr.Row(elem_classes="column-gap"):
|
254 |
with gr.Column(scale=0, elem_classes="no-gap"):
|
|
|
302 |
elem_classes=["input-box"],
|
303 |
)
|
304 |
|
305 |
+
tools = gr.Code(label="API Definition (Tools)", visible=False, language="json")
|
|
|
|
|
|
|
|
|
306 |
|
307 |
user_message = gr.Textbox(
|
308 |
label="User Prompt",
|
|
|
324 |
assistant_message_json = gr.Code(
|
325 |
label="Assistant Response",
|
326 |
visible=False,
|
327 |
+
language="json",
|
328 |
value=None,
|
329 |
elem_classes=["input-box"],
|
330 |
)
|
|
|
347 |
# events
|
348 |
|
349 |
show_propt_button.click(
|
350 |
+
on_show_prompt_click,
|
351 |
+
inputs=[criteria, context, user_message, assistant_message_text, assistant_message_json, tools, state],
|
352 |
+
outputs=prompt,
|
353 |
).then(lambda: gr.update(visible=True), None, modal)
|
354 |
|
355 |
submit_button.click(lambda: gr.update(visible=True, value=""), None, result_text).then(
|
|
|
367 |
).then(update_selected_test_case, inputs=[button, state], outputs=[state]).then(
|
368 |
on_test_case_click,
|
369 |
inputs=state,
|
370 |
+
outputs=[
|
371 |
+
test_case_name,
|
372 |
+
criteria,
|
373 |
+
context,
|
374 |
+
user_message,
|
375 |
+
assistant_message_text,
|
376 |
+
assistant_message_json,
|
377 |
+
tools,
|
378 |
+
result_text,
|
379 |
+
],
|
380 |
)
|
381 |
|
382 |
demo.launch(server_name="0.0.0.0")
|
src/model.py
CHANGED
@@ -9,9 +9,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
9 |
|
10 |
from logger import logger
|
11 |
|
12 |
-
# from vllm import LLM, SamplingParams
|
13 |
-
|
14 |
-
|
15 |
safe_token = "No"
|
16 |
risky_token = "Yes"
|
17 |
nlogprobs = 20
|
@@ -21,9 +18,10 @@ logger.debug(f"Inference engine is: '{inference_engine}'")
|
|
21 |
|
22 |
if inference_engine == "VLLM":
|
23 |
import torch
|
24 |
-
device = torch.device("cuda")
|
25 |
|
26 |
-
|
|
|
|
|
27 |
logger.debug(f"model_path is {model_path}")
|
28 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
29 |
model = AutoModelForCausalLM.from_pretrained(model_path)
|
@@ -41,17 +39,32 @@ elif inference_engine == "WATSONX":
|
|
41 |
model_id = "ibm/granite-guardian-3-8b" # 8B Model: "ibm/granite-guardian-3-8b"
|
42 |
model = ModelInference(model_id=model_id, api_client=client)
|
43 |
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
label, prob_of_risk = None, None
|
|
|
46 |
if nlogprobs > 0:
|
|
|
|
|
|
|
47 |
|
48 |
-
|
49 |
-
for token_i in list(output.scores)[:-1]]
|
50 |
-
if list_index_logprobs_i is not None:
|
51 |
-
prob = get_probablities(list_index_logprobs_i)
|
52 |
-
prob_of_risk = prob[1]
|
53 |
|
54 |
-
res = tokenizer.decode(output.sequences[:,input_len:][0],skip_special_tokens=True).strip()
|
55 |
if risky_token.lower() == res.lower():
|
56 |
label = risky_token
|
57 |
elif safe_token.lower() == res.lower():
|
@@ -59,7 +72,27 @@ def parse_output(output, input_len):
|
|
59 |
else:
|
60 |
label = "Failed"
|
61 |
|
62 |
-
return label, prob_of_risk
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
def get_probablities(logprobs):
|
65 |
safe_token_prob = 1e-50
|
@@ -72,87 +105,61 @@ def get_probablities(logprobs):
|
|
72 |
if decoded_token.strip().lower() == risky_token.lower():
|
73 |
unsafe_token_prob += math.exp(logprob)
|
74 |
|
75 |
-
probabilities = torch.softmax(
|
76 |
-
torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0
|
77 |
-
)
|
78 |
|
79 |
return probabilities
|
80 |
|
81 |
|
82 |
-
def
|
83 |
-
|
84 |
-
|
85 |
-
return [v / total for v in exp_values]
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
safe_token_prob += math.exp(token["logprob"])
|
94 |
-
if token["text"].strip().lower() == risky_token.lower():
|
95 |
-
risky_token_prob += math.exp(token["logprob"])
|
96 |
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
-
return
|
100 |
|
101 |
|
|
|
102 |
def get_prompt(messages, criteria_name, tokenize=False, add_generation_prompt=False, return_tensors=None):
|
|
|
|
|
|
|
103 |
if criteria_name == "general_harm":
|
104 |
criteria_name = "harm"
|
105 |
elif criteria_name == "function_calling_hallucination":
|
106 |
criteria_name = "function_call"
|
107 |
-
|
|
|
|
|
|
|
108 |
guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
|
|
|
109 |
prompt = tokenizer.apply_chat_template(
|
110 |
messages,
|
111 |
guardian_config=guardian_config,
|
112 |
tokenize=tokenize,
|
113 |
add_generation_prompt=add_generation_prompt,
|
114 |
-
return_tensors=return_tensors
|
115 |
)
|
116 |
-
logger.debug(f
|
117 |
return prompt
|
118 |
|
119 |
|
120 |
@spaces.GPU
|
121 |
-
def
|
122 |
-
result = model.generate(
|
123 |
-
prompt=[prompt],
|
124 |
-
params={
|
125 |
-
"decoding_method": "greedy",
|
126 |
-
"max_new_tokens": 20,
|
127 |
-
"temperature": 0,
|
128 |
-
"return_options": {"token_logprobs": True, "generated_tokens": True, "input_text": True, "top_n_tokens": 5},
|
129 |
-
},
|
130 |
-
)
|
131 |
-
return result[0]["results"][0]["generated_tokens"]
|
132 |
-
|
133 |
-
|
134 |
-
def parse_output_watsonx(generated_tokens_list):
|
135 |
-
label, prob_of_risk = None, None
|
136 |
-
|
137 |
-
if nlogprobs > 0:
|
138 |
-
top_tokens_list = [generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list]
|
139 |
-
prob = get_probablities_watsonx(top_tokens_list)
|
140 |
-
prob_of_risk = prob[1]
|
141 |
-
|
142 |
-
res = next(iter(generated_tokens_list))["text"].strip()
|
143 |
-
|
144 |
-
if risky_token.lower() == res.lower():
|
145 |
-
label = risky_token
|
146 |
-
elif safe_token.lower() == res.lower():
|
147 |
-
label = safe_token
|
148 |
-
else:
|
149 |
-
label = "Failed"
|
150 |
-
|
151 |
-
return label, prob_of_risk
|
152 |
-
|
153 |
-
|
154 |
-
@spaces.GPU
|
155 |
-
def generate_text(messages, criteria_name):
|
156 |
logger.debug(f"Messages used to create the prompt are: \n{messages}")
|
157 |
start = time()
|
158 |
if inference_engine == "MOCK":
|
@@ -163,27 +170,20 @@ def generate_text(messages, criteria_name):
|
|
163 |
elif inference_engine == "WATSONX":
|
164 |
chat = get_prompt(messages, criteria_name)
|
165 |
logger.debug(f"Prompt is \n{chat}")
|
166 |
-
generated_tokens =
|
167 |
label, prob_of_risk = parse_output_watsonx(generated_tokens)
|
168 |
|
169 |
elif inference_engine == "VLLM":
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
# add_generation_prompt=True,
|
175 |
-
# return_tensors="pt").to(model.device)
|
176 |
-
guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
|
177 |
-
logger.debug(f'guardian_config is: {guardian_config}')
|
178 |
-
input_ids = tokenizer.apply_chat_template(
|
179 |
-
messages,
|
180 |
-
guardian_config=guardian_config,
|
181 |
add_generation_prompt=True,
|
182 |
-
return_tensors=
|
183 |
).to(model.device)
|
184 |
logger.debug(f"input_ids are: {input_ids}")
|
185 |
input_len = input_ids.shape[1]
|
186 |
-
logger.debug(f"input_len
|
187 |
|
188 |
with torch.no_grad():
|
189 |
# output = model.generate(chat, sampling_params, use_tqdm=False)
|
@@ -192,7 +192,8 @@ def generate_text(messages, criteria_name):
|
|
192 |
do_sample=False,
|
193 |
max_new_tokens=nlogprobs,
|
194 |
return_dict_in_generate=True,
|
195 |
-
output_scores=True,
|
|
|
196 |
logger.debug(f"model output is:\n{output}")
|
197 |
|
198 |
label, prob_of_risk = parse_output(output, input_len)
|
|
|
9 |
|
10 |
from logger import logger
|
11 |
|
|
|
|
|
|
|
12 |
safe_token = "No"
|
13 |
risky_token = "Yes"
|
14 |
nlogprobs = 20
|
|
|
18 |
|
19 |
if inference_engine == "VLLM":
|
20 |
import torch
|
|
|
21 |
|
22 |
+
device = torch.device("cpu")
|
23 |
+
|
24 |
+
model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.1-2b")
|
25 |
logger.debug(f"model_path is {model_path}")
|
26 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
27 |
model = AutoModelForCausalLM.from_pretrained(model_path)
|
|
|
39 |
model_id = "ibm/granite-guardian-3-8b" # 8B Model: "ibm/granite-guardian-3-8b"
|
40 |
model = ModelInference(model_id=model_id, api_client=client)
|
41 |
|
42 |
+
|
43 |
+
def get_probablities_watsonx(top_tokens_list):
|
44 |
+
safe_token_prob = 1e-50
|
45 |
+
risky_token_prob = 1e-50
|
46 |
+
for top_tokens in top_tokens_list:
|
47 |
+
for token in top_tokens:
|
48 |
+
if token["text"].strip().lower() == safe_token.lower():
|
49 |
+
safe_token_prob += math.exp(token["logprob"])
|
50 |
+
if token["text"].strip().lower() == risky_token.lower():
|
51 |
+
risky_token_prob += math.exp(token["logprob"])
|
52 |
+
|
53 |
+
probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)])
|
54 |
+
|
55 |
+
return probabilities
|
56 |
+
|
57 |
+
|
58 |
+
def parse_output_watsonx(generated_tokens_list):
|
59 |
label, prob_of_risk = None, None
|
60 |
+
|
61 |
if nlogprobs > 0:
|
62 |
+
top_tokens_list = [generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list]
|
63 |
+
prob = get_probablities_watsonx(top_tokens_list)
|
64 |
+
prob_of_risk = prob[1]
|
65 |
|
66 |
+
res = next(iter(generated_tokens_list))["text"].strip()
|
|
|
|
|
|
|
|
|
67 |
|
|
|
68 |
if risky_token.lower() == res.lower():
|
69 |
label = risky_token
|
70 |
elif safe_token.lower() == res.lower():
|
|
|
72 |
else:
|
73 |
label = "Failed"
|
74 |
|
75 |
+
return label, prob_of_risk
|
76 |
+
|
77 |
+
|
78 |
+
def generate_tokens_watsonx(prompt):
|
79 |
+
result = model.generate(
|
80 |
+
prompt=[prompt],
|
81 |
+
params={
|
82 |
+
"decoding_method": "greedy",
|
83 |
+
"max_new_tokens": 20,
|
84 |
+
"temperature": 0,
|
85 |
+
"return_options": {"token_logprobs": True, "generated_tokens": True, "input_text": True, "top_n_tokens": 5},
|
86 |
+
},
|
87 |
+
)
|
88 |
+
return result[0]["results"][0]["generated_tokens"]
|
89 |
+
|
90 |
+
|
91 |
+
def softmax(values):
|
92 |
+
exp_values = [math.exp(v) for v in values]
|
93 |
+
total = sum(exp_values)
|
94 |
+
return [v / total for v in exp_values]
|
95 |
+
|
96 |
|
97 |
def get_probablities(logprobs):
|
98 |
safe_token_prob = 1e-50
|
|
|
105 |
if decoded_token.strip().lower() == risky_token.lower():
|
106 |
unsafe_token_prob += math.exp(logprob)
|
107 |
|
108 |
+
probabilities = torch.softmax(torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0)
|
|
|
|
|
109 |
|
110 |
return probabilities
|
111 |
|
112 |
|
113 |
+
def parse_output(output, input_len):
|
114 |
+
label, prob_of_risk = None, None
|
115 |
+
if nlogprobs > 0:
|
|
|
116 |
|
117 |
+
list_index_logprobs_i = [
|
118 |
+
torch.topk(token_i, k=nlogprobs, largest=True, sorted=True) for token_i in list(output.scores)[:-1]
|
119 |
+
]
|
120 |
+
if list_index_logprobs_i is not None:
|
121 |
+
prob = get_probablities(list_index_logprobs_i)
|
122 |
+
prob_of_risk = prob[1]
|
|
|
|
|
|
|
123 |
|
124 |
+
res = tokenizer.decode(output.sequences[:, input_len:][0], skip_special_tokens=True).strip()
|
125 |
+
if risky_token.lower() == res.lower():
|
126 |
+
label = risky_token
|
127 |
+
elif safe_token.lower() == res.lower():
|
128 |
+
label = safe_token
|
129 |
+
else:
|
130 |
+
label = "Failed"
|
131 |
|
132 |
+
return label, prob_of_risk.item()
|
133 |
|
134 |
|
135 |
+
@spaces.GPU
|
136 |
def get_prompt(messages, criteria_name, tokenize=False, add_generation_prompt=False, return_tensors=None):
|
137 |
+
logger.debug("Creating prompt for the model.")
|
138 |
+
logger.debug(f"Messages used to create the prompt are: \n{messages}")
|
139 |
+
logger.debug("Criteria name is: " + criteria_name)
|
140 |
if criteria_name == "general_harm":
|
141 |
criteria_name = "harm"
|
142 |
elif criteria_name == "function_calling_hallucination":
|
143 |
criteria_name = "function_call"
|
144 |
+
logger.debug("Criteria name was changed too: " + criteria_name)
|
145 |
+
logger.debug(f"Tokenize: {tokenize}")
|
146 |
+
logger.debug(f"add_generation_prompt: {add_generation_prompt}")
|
147 |
+
logger.debug(f"return_tensors: {return_tensors}")
|
148 |
guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
|
149 |
+
logger.debug(f"guardian_config is: {guardian_config}")
|
150 |
prompt = tokenizer.apply_chat_template(
|
151 |
messages,
|
152 |
guardian_config=guardian_config,
|
153 |
tokenize=tokenize,
|
154 |
add_generation_prompt=add_generation_prompt,
|
155 |
+
return_tensors=return_tensors,
|
156 |
)
|
157 |
+
logger.debug(f"Prompt (type {type(prompt)}) is: {prompt}")
|
158 |
return prompt
|
159 |
|
160 |
|
161 |
@spaces.GPU
|
162 |
+
def get_guardian_response(messages, criteria_name):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
logger.debug(f"Messages used to create the prompt are: \n{messages}")
|
164 |
start = time()
|
165 |
if inference_engine == "MOCK":
|
|
|
170 |
elif inference_engine == "WATSONX":
|
171 |
chat = get_prompt(messages, criteria_name)
|
172 |
logger.debug(f"Prompt is \n{chat}")
|
173 |
+
generated_tokens = generate_tokens_watsonx(chat)
|
174 |
label, prob_of_risk = parse_output_watsonx(generated_tokens)
|
175 |
|
176 |
elif inference_engine == "VLLM":
|
177 |
+
input_ids = get_prompt(
|
178 |
+
messages=messages,
|
179 |
+
criteria_name=criteria_name,
|
180 |
+
tokenize=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
add_generation_prompt=True,
|
182 |
+
return_tensors="pt",
|
183 |
).to(model.device)
|
184 |
logger.debug(f"input_ids are: {input_ids}")
|
185 |
input_len = input_ids.shape[1]
|
186 |
+
logger.debug(f"input_len is: {input_len}")
|
187 |
|
188 |
with torch.no_grad():
|
189 |
# output = model.generate(chat, sampling_params, use_tqdm=False)
|
|
|
192 |
do_sample=False,
|
193 |
max_new_tokens=nlogprobs,
|
194 |
return_dict_in_generate=True,
|
195 |
+
output_scores=True,
|
196 |
+
)
|
197 |
logger.debug(f"model output is:\n{output}")
|
198 |
|
199 |
label, prob_of_risk = parse_output(output, input_len)
|
src/styles.css
CHANGED
@@ -135,4 +135,4 @@
|
|
135 |
|
136 |
.no-stretch {
|
137 |
align-items: flex-start;
|
138 |
-
}
|
|
|
135 |
|
136 |
.no-stretch {
|
137 |
align-items: flex-start;
|
138 |
+
}
|
src/utils.py
CHANGED
@@ -5,6 +5,7 @@ import os
|
|
5 |
def create_message(role, content):
|
6 |
return [{"role": role, "content": content}]
|
7 |
|
|
|
8 |
def get_messages(test_case, sub_catalog_name) -> list[dict[str, str]]:
|
9 |
messages = []
|
10 |
|
@@ -27,8 +28,6 @@ def get_messages(test_case, sub_catalog_name) -> list[dict[str, str]]:
|
|
27 |
messages += create_message("tools", test_case["tools"])
|
28 |
messages += create_message("user", test_case["user_message"])
|
29 |
messages += create_message("assistant", test_case["assistant_message"])
|
30 |
-
print('Messages are')
|
31 |
-
print(messages)
|
32 |
return messages
|
33 |
|
34 |
|
@@ -53,7 +52,9 @@ def get_evaluated_component(sub_catalog_name, criteria_name):
|
|
53 |
component = None
|
54 |
if sub_catalog_name == "harmful_content_in_user_prompt":
|
55 |
component = "user"
|
56 |
-
elif
|
|
|
|
|
57 |
component = "assistant"
|
58 |
elif sub_catalog_name == "rag_hallucination_risks":
|
59 |
if criteria_name == "context_relevance":
|
|
|
5 |
def create_message(role, content):
|
6 |
return [{"role": role, "content": content}]
|
7 |
|
8 |
+
|
9 |
def get_messages(test_case, sub_catalog_name) -> list[dict[str, str]]:
|
10 |
messages = []
|
11 |
|
|
|
28 |
messages += create_message("tools", test_case["tools"])
|
29 |
messages += create_message("user", test_case["user_message"])
|
30 |
messages += create_message("assistant", test_case["assistant_message"])
|
|
|
|
|
31 |
return messages
|
32 |
|
33 |
|
|
|
52 |
component = None
|
53 |
if sub_catalog_name == "harmful_content_in_user_prompt":
|
54 |
component = "user"
|
55 |
+
elif (
|
56 |
+
sub_catalog_name == "harmful_content_in_assistant_response" or sub_catalog_name == "risks_in_agentic_workflows"
|
57 |
+
):
|
58 |
component = "assistant"
|
59 |
elif sub_catalog_name == "rag_hallucination_risks":
|
60 |
if criteria_name == "context_relevance":
|