Martín Santillán Cooper commited on
Commit
6047b61
1 Parent(s): a89905d

Continue adaptation

Browse files
Files changed (9) hide show
  1. .dockerignore +1 -1
  2. .env.example +1 -1
  3. .gitignore +1 -1
  4. catalog.json +1 -1
  5. convert_to_string.py +7 -13
  6. src/app.py +38 -30
  7. src/model.py +85 -84
  8. src/styles.css +1 -1
  9. src/utils.py +4 -3
.dockerignore CHANGED
@@ -5,4 +5,4 @@
5
  *.sh
6
  *.md
7
  __pycache__/
8
- flagged/
 
5
  *.sh
6
  *.md
7
  __pycache__/
8
+ flagged/
.env.example CHANGED
@@ -1,4 +1,4 @@
1
  MODEL_PATH='ibm-granite/granite-guardian-3.1-8b'
2
  INFERENCE_ENGINE='VLLM' # one of [WATSONX, MOCK, VLLM]
3
  WATSONX_API_KEY=''
4
- WATSONX_PROJECT_ID=''
 
1
  MODEL_PATH='ibm-granite/granite-guardian-3.1-8b'
2
  INFERENCE_ENGINE='VLLM' # one of [WATSONX, MOCK, VLLM]
3
  WATSONX_API_KEY=''
4
+ WATSONX_PROJECT_ID=''
.gitignore CHANGED
@@ -4,4 +4,4 @@ parse.py
4
  unparsed_catalog.json
5
  __pycache__/
6
  logs.txt
7
- secrets.yaml
 
4
  unparsed_catalog.json
5
  __pycache__/
6
  logs.txt
7
+ secrets.yaml
catalog.json CHANGED
@@ -112,4 +112,4 @@
112
  "context": null
113
  }
114
  ]
115
- }
 
112
  "context": null
113
  }
114
  ]
115
+ }
convert_to_string.py CHANGED
@@ -1,5 +1,6 @@
1
  import json
2
 
 
3
  def dict_to_json_with_newlines(data):
4
  """
5
  Converts a dictionary into a JSON string with explicit newlines (\n) added.
@@ -12,28 +13,21 @@ def dict_to_json_with_newlines(data):
12
  """
13
  # Convert the dictionary to a pretty-printed JSON string
14
  pretty_json = json.dumps(data, indent=2)
15
-
16
  # Replace actual newlines with escaped newlines (\n)
17
  json_with_newlines = pretty_json.replace("\n", "\\n")
18
-
19
  # Escape double quotes for embedding inside other JSON
20
  json_with_newlines = json_with_newlines.replace('"', '\\"')
21
-
22
  return json_with_newlines
23
 
 
24
  # Example dictionary
25
- example_dict =[
26
- {
27
- "name": "comment_list",
28
- "arguments": {
29
- "video_id": 456789123,
30
- "count": 15
31
- }
32
- }
33
- ]
34
 
35
  # Convert the dictionary
36
  result = dict_to_json_with_newlines(example_dict)
37
 
38
  print("Resulting JSON string:")
39
- print(result)
 
1
  import json
2
 
3
+
4
  def dict_to_json_with_newlines(data):
5
  """
6
  Converts a dictionary into a JSON string with explicit newlines (\n) added.
 
13
  """
14
  # Convert the dictionary to a pretty-printed JSON string
15
  pretty_json = json.dumps(data, indent=2)
16
+
17
  # Replace actual newlines with escaped newlines (\n)
18
  json_with_newlines = pretty_json.replace("\n", "\\n")
19
+
20
  # Escape double quotes for embedding inside other JSON
21
  json_with_newlines = json_with_newlines.replace('"', '\\"')
22
+
23
  return json_with_newlines
24
 
25
+
26
  # Example dictionary
27
+ example_dict = [{"name": "comment_list", "arguments": {"video_id": 456789123, "count": 15}}]
 
 
 
 
 
 
 
 
28
 
29
  # Convert the dictionary
30
  result = dict_to_json_with_newlines(example_dict)
31
 
32
  print("Resulting JSON string:")
33
+ print(result)
src/app.py CHANGED
@@ -6,7 +6,7 @@ from dotenv import load_dotenv
6
  from gradio_modal import Modal
7
 
8
  from logger import logger
9
- from model import generate_text, get_prompt
10
  from utils import (
11
  get_messages,
12
  get_result_description,
@@ -20,7 +20,7 @@ load_dotenv()
20
 
21
  catalog = {}
22
 
23
- toy_json = '{"name": "John"}'
24
 
25
  with open("catalog.json") as f:
26
  logger.debug("Loading catalog from json.")
@@ -63,12 +63,9 @@ def on_test_case_click(state: gr.State):
63
  # update context field:
64
  if is_context_editable:
65
  context = gr.update(
66
- value=selected_test_case["context"],
67
- interactive=True,
68
- visible=True,
69
- elem_classes=["input-box"]
70
  )
71
- else:
72
  context = gr.update(
73
  visible=selected_test_case["context"] is not None,
74
  value=selected_test_case["context"],
@@ -85,19 +82,13 @@ def on_test_case_click(state: gr.State):
85
  # update user message field
86
  if is_user_message_editable:
87
  user_message = gr.update(
88
- value=selected_test_case["user_message"],
89
- visible=True,
90
- interactive=True,
91
- elem_classes=["input-box"]
92
  )
93
  else:
94
  user_message = gr.update(
95
- value=selected_test_case["user_message"],
96
- interactive=False,
97
- elem_classes=["read-only", "input-box"]
98
  )
99
 
100
-
101
  # update assistant message field
102
  if is_tools_present:
103
  assistant_message_json = gr.update(
@@ -124,8 +115,16 @@ def on_test_case_click(state: gr.State):
124
  assistant_message_json = gr.update(visible=False)
125
 
126
  result_text = gr.update(visible=False, value="")
127
- return test_case_name,criteria,context,user_message,assistant_message_text,assistant_message_json,tools,result_text
128
-
 
 
 
 
 
 
 
 
129
 
130
 
131
  def change_button_color(event: gr.EventData):
@@ -144,14 +143,15 @@ def on_submit(criteria, context, user_message, assistant_message_text, assistant
144
  criteria_name = state["selected_criteria_name"]
145
  if criteria_name == "function_calling_hallucination":
146
  assistant_message = assistant_message_json
147
- else: assistant_message = assistant_message_text
 
148
  test_case = {
149
  "name": criteria_name,
150
  "criteria": criteria,
151
  "context": context,
152
  "user_message": user_message,
153
  "assistant_message": assistant_message,
154
- "tools": tools
155
  }
156
 
157
  messages = get_messages(test_case=test_case, sub_catalog_name=state["selected_sub_catalog"])
@@ -160,7 +160,7 @@ def on_submit(criteria, context, user_message, assistant_message_text, assistant
160
  f"Starting evaluation for subcatelog {state['selected_sub_catalog']} and criteria name {state['selected_criteria_name']}"
161
  )
162
 
163
- result_label = generate_text(messages=messages, criteria_name=criteria_name)["assessment"] # Yes or No
164
 
165
  html_str = f"<p>{get_result_description(state['selected_sub_catalog'], state['selected_criteria_name'])} <strong>{result_label}</strong></p>"
166
  # html_str = f"{get_result_description(state['selected_sub_catalog'], state['selected_criteria_name'])} {result_label}"
@@ -171,7 +171,8 @@ def on_show_prompt_click(criteria, context, user_message, assistant_message_text
171
  criteria_name = state["selected_criteria_name"]
172
  if criteria_name == "function_calling_hallucination":
173
  assistant_message = assistant_message_json
174
- else: assistant_message = assistant_message_text
 
175
  test_case = {
176
  "name": criteria_name,
177
  "criteria": criteria,
@@ -247,7 +248,7 @@ with gr.Blocks(
247
  gr.HTML("<h2>IBM Granite Guardian 3.1</h2>", elem_classes="title")
248
  gr.HTML(
249
  elem_classes="system-description",
250
- value="<p>Granite Guardian models are specialized language models in the Granite family that can detect harms and risks in generative AI systems. They can be used with any large language model to make interactions with generative AI systems safe. Select an example in the left panel to see how the Granite Guardian model evaluates harms and risks in user prompts, assistant responses, and for hallucinations in retrieval-augmented generation. In this demo, we use granite-guardian-3.1-8b.</p>",
251
  )
252
  with gr.Row(elem_classes="column-gap"):
253
  with gr.Column(scale=0, elem_classes="no-gap"):
@@ -301,11 +302,7 @@ with gr.Blocks(
301
  elem_classes=["input-box"],
302
  )
303
 
304
- tools = gr.Code(
305
- label="API Definition (Tools)",
306
- visible=False,
307
- language='json'
308
- )
309
 
310
  user_message = gr.Textbox(
311
  label="User Prompt",
@@ -327,7 +324,7 @@ with gr.Blocks(
327
  assistant_message_json = gr.Code(
328
  label="Assistant Response",
329
  visible=False,
330
- language='json',
331
  value=None,
332
  elem_classes=["input-box"],
333
  )
@@ -350,7 +347,9 @@ with gr.Blocks(
350
  # events
351
 
352
  show_propt_button.click(
353
- on_show_prompt_click, inputs=[criteria, context, user_message, assistant_message_text, assistant_message_json, tools, state], outputs=prompt
 
 
354
  ).then(lambda: gr.update(visible=True), None, modal)
355
 
356
  submit_button.click(lambda: gr.update(visible=True, value=""), None, result_text).then(
@@ -368,7 +367,16 @@ with gr.Blocks(
368
  ).then(update_selected_test_case, inputs=[button, state], outputs=[state]).then(
369
  on_test_case_click,
370
  inputs=state,
371
- outputs=[test_case_name, criteria, context, user_message, assistant_message_text, assistant_message_json, tools, result_text],
 
 
 
 
 
 
 
 
 
372
  )
373
 
374
  demo.launch(server_name="0.0.0.0")
 
6
  from gradio_modal import Modal
7
 
8
  from logger import logger
9
+ from model import get_guardian_response, get_prompt
10
  from utils import (
11
  get_messages,
12
  get_result_description,
 
20
 
21
  catalog = {}
22
 
23
+ toy_json = '{"name": "John"}'
24
 
25
  with open("catalog.json") as f:
26
  logger.debug("Loading catalog from json.")
 
63
  # update context field:
64
  if is_context_editable:
65
  context = gr.update(
66
+ value=selected_test_case["context"], interactive=True, visible=True, elem_classes=["input-box"]
 
 
 
67
  )
68
+ else:
69
  context = gr.update(
70
  visible=selected_test_case["context"] is not None,
71
  value=selected_test_case["context"],
 
82
  # update user message field
83
  if is_user_message_editable:
84
  user_message = gr.update(
85
+ value=selected_test_case["user_message"], visible=True, interactive=True, elem_classes=["input-box"]
 
 
 
86
  )
87
  else:
88
  user_message = gr.update(
89
+ value=selected_test_case["user_message"], interactive=False, elem_classes=["read-only", "input-box"]
 
 
90
  )
91
 
 
92
  # update assistant message field
93
  if is_tools_present:
94
  assistant_message_json = gr.update(
 
115
  assistant_message_json = gr.update(visible=False)
116
 
117
  result_text = gr.update(visible=False, value="")
118
+ return (
119
+ test_case_name,
120
+ criteria,
121
+ context,
122
+ user_message,
123
+ assistant_message_text,
124
+ assistant_message_json,
125
+ tools,
126
+ result_text,
127
+ )
128
 
129
 
130
  def change_button_color(event: gr.EventData):
 
143
  criteria_name = state["selected_criteria_name"]
144
  if criteria_name == "function_calling_hallucination":
145
  assistant_message = assistant_message_json
146
+ else:
147
+ assistant_message = assistant_message_text
148
  test_case = {
149
  "name": criteria_name,
150
  "criteria": criteria,
151
  "context": context,
152
  "user_message": user_message,
153
  "assistant_message": assistant_message,
154
+ "tools": tools,
155
  }
156
 
157
  messages = get_messages(test_case=test_case, sub_catalog_name=state["selected_sub_catalog"])
 
160
  f"Starting evaluation for subcatelog {state['selected_sub_catalog']} and criteria name {state['selected_criteria_name']}"
161
  )
162
 
163
+ result_label = get_guardian_response(messages=messages, criteria_name=criteria_name)["assessment"] # Yes or No
164
 
165
  html_str = f"<p>{get_result_description(state['selected_sub_catalog'], state['selected_criteria_name'])} <strong>{result_label}</strong></p>"
166
  # html_str = f"{get_result_description(state['selected_sub_catalog'], state['selected_criteria_name'])} {result_label}"
 
171
  criteria_name = state["selected_criteria_name"]
172
  if criteria_name == "function_calling_hallucination":
173
  assistant_message = assistant_message_json
174
+ else:
175
+ assistant_message = assistant_message_text
176
  test_case = {
177
  "name": criteria_name,
178
  "criteria": criteria,
 
248
  gr.HTML("<h2>IBM Granite Guardian 3.1</h2>", elem_classes="title")
249
  gr.HTML(
250
  elem_classes="system-description",
251
+ value="<p>Granite Guardian models are specialized language models in the Granite family that can detect harms and risks in generative AI systems. They can be used with any large language model to make interactions with generative AI systems safe. Select an example in the left panel to see how the Granite Guardian model evaluates harms and risks in user prompts, assistant responses, and for hallucinations in retrival-augmented generation and function calling. In this demo, we use granite-guardian-3.1-8b.</p>",
252
  )
253
  with gr.Row(elem_classes="column-gap"):
254
  with gr.Column(scale=0, elem_classes="no-gap"):
 
302
  elem_classes=["input-box"],
303
  )
304
 
305
+ tools = gr.Code(label="API Definition (Tools)", visible=False, language="json")
 
 
 
 
306
 
307
  user_message = gr.Textbox(
308
  label="User Prompt",
 
324
  assistant_message_json = gr.Code(
325
  label="Assistant Response",
326
  visible=False,
327
+ language="json",
328
  value=None,
329
  elem_classes=["input-box"],
330
  )
 
347
  # events
348
 
349
  show_propt_button.click(
350
+ on_show_prompt_click,
351
+ inputs=[criteria, context, user_message, assistant_message_text, assistant_message_json, tools, state],
352
+ outputs=prompt,
353
  ).then(lambda: gr.update(visible=True), None, modal)
354
 
355
  submit_button.click(lambda: gr.update(visible=True, value=""), None, result_text).then(
 
367
  ).then(update_selected_test_case, inputs=[button, state], outputs=[state]).then(
368
  on_test_case_click,
369
  inputs=state,
370
+ outputs=[
371
+ test_case_name,
372
+ criteria,
373
+ context,
374
+ user_message,
375
+ assistant_message_text,
376
+ assistant_message_json,
377
+ tools,
378
+ result_text,
379
+ ],
380
  )
381
 
382
  demo.launch(server_name="0.0.0.0")
src/model.py CHANGED
@@ -9,9 +9,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
  from logger import logger
11
 
12
- # from vllm import LLM, SamplingParams
13
-
14
-
15
  safe_token = "No"
16
  risky_token = "Yes"
17
  nlogprobs = 20
@@ -21,9 +18,10 @@ logger.debug(f"Inference engine is: '{inference_engine}'")
21
 
22
  if inference_engine == "VLLM":
23
  import torch
24
- device = torch.device("cuda")
25
 
26
- model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.1-8b")
 
 
27
  logger.debug(f"model_path is {model_path}")
28
  tokenizer = AutoTokenizer.from_pretrained(model_path)
29
  model = AutoModelForCausalLM.from_pretrained(model_path)
@@ -41,17 +39,32 @@ elif inference_engine == "WATSONX":
41
  model_id = "ibm/granite-guardian-3-8b" # 8B Model: "ibm/granite-guardian-3-8b"
42
  model = ModelInference(model_id=model_id, api_client=client)
43
 
44
- def parse_output(output, input_len):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  label, prob_of_risk = None, None
 
46
  if nlogprobs > 0:
 
 
 
47
 
48
- list_index_logprobs_i = [torch.topk(token_i, k=nlogprobs, largest=True, sorted=True)
49
- for token_i in list(output.scores)[:-1]]
50
- if list_index_logprobs_i is not None:
51
- prob = get_probablities(list_index_logprobs_i)
52
- prob_of_risk = prob[1]
53
 
54
- res = tokenizer.decode(output.sequences[:,input_len:][0],skip_special_tokens=True).strip()
55
  if risky_token.lower() == res.lower():
56
  label = risky_token
57
  elif safe_token.lower() == res.lower():
@@ -59,7 +72,27 @@ def parse_output(output, input_len):
59
  else:
60
  label = "Failed"
61
 
62
- return label, prob_of_risk.item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def get_probablities(logprobs):
65
  safe_token_prob = 1e-50
@@ -72,87 +105,61 @@ def get_probablities(logprobs):
72
  if decoded_token.strip().lower() == risky_token.lower():
73
  unsafe_token_prob += math.exp(logprob)
74
 
75
- probabilities = torch.softmax(
76
- torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0
77
- )
78
 
79
  return probabilities
80
 
81
 
82
- def softmax(values):
83
- exp_values = [math.exp(v) for v in values]
84
- total = sum(exp_values)
85
- return [v / total for v in exp_values]
86
 
87
- def get_probablities_watsonx(top_tokens_list):
88
- safe_token_prob = 1e-50
89
- risky_token_prob = 1e-50
90
- for top_tokens in top_tokens_list:
91
- for token in top_tokens:
92
- if token["text"].strip().lower() == safe_token.lower():
93
- safe_token_prob += math.exp(token["logprob"])
94
- if token["text"].strip().lower() == risky_token.lower():
95
- risky_token_prob += math.exp(token["logprob"])
96
 
97
- probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)])
 
 
 
 
 
 
98
 
99
- return probabilities
100
 
101
 
 
102
  def get_prompt(messages, criteria_name, tokenize=False, add_generation_prompt=False, return_tensors=None):
 
 
 
103
  if criteria_name == "general_harm":
104
  criteria_name = "harm"
105
  elif criteria_name == "function_calling_hallucination":
106
  criteria_name = "function_call"
107
-
 
 
 
108
  guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
 
109
  prompt = tokenizer.apply_chat_template(
110
  messages,
111
  guardian_config=guardian_config,
112
  tokenize=tokenize,
113
  add_generation_prompt=add_generation_prompt,
114
- return_tensors=return_tensors
115
  )
116
- logger.debug(f'prompt is\n{prompt}')
117
  return prompt
118
 
119
 
120
  @spaces.GPU
121
- def generate_tokens(prompt):
122
- result = model.generate(
123
- prompt=[prompt],
124
- params={
125
- "decoding_method": "greedy",
126
- "max_new_tokens": 20,
127
- "temperature": 0,
128
- "return_options": {"token_logprobs": True, "generated_tokens": True, "input_text": True, "top_n_tokens": 5},
129
- },
130
- )
131
- return result[0]["results"][0]["generated_tokens"]
132
-
133
-
134
- def parse_output_watsonx(generated_tokens_list):
135
- label, prob_of_risk = None, None
136
-
137
- if nlogprobs > 0:
138
- top_tokens_list = [generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list]
139
- prob = get_probablities_watsonx(top_tokens_list)
140
- prob_of_risk = prob[1]
141
-
142
- res = next(iter(generated_tokens_list))["text"].strip()
143
-
144
- if risky_token.lower() == res.lower():
145
- label = risky_token
146
- elif safe_token.lower() == res.lower():
147
- label = safe_token
148
- else:
149
- label = "Failed"
150
-
151
- return label, prob_of_risk
152
-
153
-
154
- @spaces.GPU
155
- def generate_text(messages, criteria_name):
156
  logger.debug(f"Messages used to create the prompt are: \n{messages}")
157
  start = time()
158
  if inference_engine == "MOCK":
@@ -163,27 +170,20 @@ def generate_text(messages, criteria_name):
163
  elif inference_engine == "WATSONX":
164
  chat = get_prompt(messages, criteria_name)
165
  logger.debug(f"Prompt is \n{chat}")
166
- generated_tokens = generate_tokens(chat)
167
  label, prob_of_risk = parse_output_watsonx(generated_tokens)
168
 
169
  elif inference_engine == "VLLM":
170
- # input_ids = get_prompt(
171
- # messages=messages,
172
- # criteria_name=criteria_name,
173
- # tokenize=True,
174
- # add_generation_prompt=True,
175
- # return_tensors="pt").to(model.device)
176
- guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
177
- logger.debug(f'guardian_config is: {guardian_config}')
178
- input_ids = tokenizer.apply_chat_template(
179
- messages,
180
- guardian_config=guardian_config,
181
  add_generation_prompt=True,
182
- return_tensors='pt'
183
  ).to(model.device)
184
  logger.debug(f"input_ids are: {input_ids}")
185
  input_len = input_ids.shape[1]
186
- logger.debug(f"input_len are: {input_len}")
187
 
188
  with torch.no_grad():
189
  # output = model.generate(chat, sampling_params, use_tqdm=False)
@@ -192,7 +192,8 @@ def generate_text(messages, criteria_name):
192
  do_sample=False,
193
  max_new_tokens=nlogprobs,
194
  return_dict_in_generate=True,
195
- output_scores=True,)
 
196
  logger.debug(f"model output is:\n{output}")
197
 
198
  label, prob_of_risk = parse_output(output, input_len)
 
9
 
10
  from logger import logger
11
 
 
 
 
12
  safe_token = "No"
13
  risky_token = "Yes"
14
  nlogprobs = 20
 
18
 
19
  if inference_engine == "VLLM":
20
  import torch
 
21
 
22
+ device = torch.device("cpu")
23
+
24
+ model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.1-2b")
25
  logger.debug(f"model_path is {model_path}")
26
  tokenizer = AutoTokenizer.from_pretrained(model_path)
27
  model = AutoModelForCausalLM.from_pretrained(model_path)
 
39
  model_id = "ibm/granite-guardian-3-8b" # 8B Model: "ibm/granite-guardian-3-8b"
40
  model = ModelInference(model_id=model_id, api_client=client)
41
 
42
+
43
+ def get_probablities_watsonx(top_tokens_list):
44
+ safe_token_prob = 1e-50
45
+ risky_token_prob = 1e-50
46
+ for top_tokens in top_tokens_list:
47
+ for token in top_tokens:
48
+ if token["text"].strip().lower() == safe_token.lower():
49
+ safe_token_prob += math.exp(token["logprob"])
50
+ if token["text"].strip().lower() == risky_token.lower():
51
+ risky_token_prob += math.exp(token["logprob"])
52
+
53
+ probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)])
54
+
55
+ return probabilities
56
+
57
+
58
+ def parse_output_watsonx(generated_tokens_list):
59
  label, prob_of_risk = None, None
60
+
61
  if nlogprobs > 0:
62
+ top_tokens_list = [generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list]
63
+ prob = get_probablities_watsonx(top_tokens_list)
64
+ prob_of_risk = prob[1]
65
 
66
+ res = next(iter(generated_tokens_list))["text"].strip()
 
 
 
 
67
 
 
68
  if risky_token.lower() == res.lower():
69
  label = risky_token
70
  elif safe_token.lower() == res.lower():
 
72
  else:
73
  label = "Failed"
74
 
75
+ return label, prob_of_risk
76
+
77
+
78
+ def generate_tokens_watsonx(prompt):
79
+ result = model.generate(
80
+ prompt=[prompt],
81
+ params={
82
+ "decoding_method": "greedy",
83
+ "max_new_tokens": 20,
84
+ "temperature": 0,
85
+ "return_options": {"token_logprobs": True, "generated_tokens": True, "input_text": True, "top_n_tokens": 5},
86
+ },
87
+ )
88
+ return result[0]["results"][0]["generated_tokens"]
89
+
90
+
91
+ def softmax(values):
92
+ exp_values = [math.exp(v) for v in values]
93
+ total = sum(exp_values)
94
+ return [v / total for v in exp_values]
95
+
96
 
97
  def get_probablities(logprobs):
98
  safe_token_prob = 1e-50
 
105
  if decoded_token.strip().lower() == risky_token.lower():
106
  unsafe_token_prob += math.exp(logprob)
107
 
108
+ probabilities = torch.softmax(torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0)
 
 
109
 
110
  return probabilities
111
 
112
 
113
+ def parse_output(output, input_len):
114
+ label, prob_of_risk = None, None
115
+ if nlogprobs > 0:
 
116
 
117
+ list_index_logprobs_i = [
118
+ torch.topk(token_i, k=nlogprobs, largest=True, sorted=True) for token_i in list(output.scores)[:-1]
119
+ ]
120
+ if list_index_logprobs_i is not None:
121
+ prob = get_probablities(list_index_logprobs_i)
122
+ prob_of_risk = prob[1]
 
 
 
123
 
124
+ res = tokenizer.decode(output.sequences[:, input_len:][0], skip_special_tokens=True).strip()
125
+ if risky_token.lower() == res.lower():
126
+ label = risky_token
127
+ elif safe_token.lower() == res.lower():
128
+ label = safe_token
129
+ else:
130
+ label = "Failed"
131
 
132
+ return label, prob_of_risk.item()
133
 
134
 
135
+ @spaces.GPU
136
  def get_prompt(messages, criteria_name, tokenize=False, add_generation_prompt=False, return_tensors=None):
137
+ logger.debug("Creating prompt for the model.")
138
+ logger.debug(f"Messages used to create the prompt are: \n{messages}")
139
+ logger.debug("Criteria name is: " + criteria_name)
140
  if criteria_name == "general_harm":
141
  criteria_name = "harm"
142
  elif criteria_name == "function_calling_hallucination":
143
  criteria_name = "function_call"
144
+ logger.debug("Criteria name was changed too: " + criteria_name)
145
+ logger.debug(f"Tokenize: {tokenize}")
146
+ logger.debug(f"add_generation_prompt: {add_generation_prompt}")
147
+ logger.debug(f"return_tensors: {return_tensors}")
148
  guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
149
+ logger.debug(f"guardian_config is: {guardian_config}")
150
  prompt = tokenizer.apply_chat_template(
151
  messages,
152
  guardian_config=guardian_config,
153
  tokenize=tokenize,
154
  add_generation_prompt=add_generation_prompt,
155
+ return_tensors=return_tensors,
156
  )
157
+ logger.debug(f"Prompt (type {type(prompt)}) is: {prompt}")
158
  return prompt
159
 
160
 
161
  @spaces.GPU
162
+ def get_guardian_response(messages, criteria_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  logger.debug(f"Messages used to create the prompt are: \n{messages}")
164
  start = time()
165
  if inference_engine == "MOCK":
 
170
  elif inference_engine == "WATSONX":
171
  chat = get_prompt(messages, criteria_name)
172
  logger.debug(f"Prompt is \n{chat}")
173
+ generated_tokens = generate_tokens_watsonx(chat)
174
  label, prob_of_risk = parse_output_watsonx(generated_tokens)
175
 
176
  elif inference_engine == "VLLM":
177
+ input_ids = get_prompt(
178
+ messages=messages,
179
+ criteria_name=criteria_name,
180
+ tokenize=True,
 
 
 
 
 
 
 
181
  add_generation_prompt=True,
182
+ return_tensors="pt",
183
  ).to(model.device)
184
  logger.debug(f"input_ids are: {input_ids}")
185
  input_len = input_ids.shape[1]
186
+ logger.debug(f"input_len is: {input_len}")
187
 
188
  with torch.no_grad():
189
  # output = model.generate(chat, sampling_params, use_tqdm=False)
 
192
  do_sample=False,
193
  max_new_tokens=nlogprobs,
194
  return_dict_in_generate=True,
195
+ output_scores=True,
196
+ )
197
  logger.debug(f"model output is:\n{output}")
198
 
199
  label, prob_of_risk = parse_output(output, input_len)
src/styles.css CHANGED
@@ -135,4 +135,4 @@
135
 
136
  .no-stretch {
137
  align-items: flex-start;
138
- }
 
135
 
136
  .no-stretch {
137
  align-items: flex-start;
138
+ }
src/utils.py CHANGED
@@ -5,6 +5,7 @@ import os
5
  def create_message(role, content):
6
  return [{"role": role, "content": content}]
7
 
 
8
  def get_messages(test_case, sub_catalog_name) -> list[dict[str, str]]:
9
  messages = []
10
 
@@ -27,8 +28,6 @@ def get_messages(test_case, sub_catalog_name) -> list[dict[str, str]]:
27
  messages += create_message("tools", test_case["tools"])
28
  messages += create_message("user", test_case["user_message"])
29
  messages += create_message("assistant", test_case["assistant_message"])
30
- print('Messages are')
31
- print(messages)
32
  return messages
33
 
34
 
@@ -53,7 +52,9 @@ def get_evaluated_component(sub_catalog_name, criteria_name):
53
  component = None
54
  if sub_catalog_name == "harmful_content_in_user_prompt":
55
  component = "user"
56
- elif sub_catalog_name == "harmful_content_in_assistant_response" or sub_catalog_name == "risks_in_agentic_workflows":
 
 
57
  component = "assistant"
58
  elif sub_catalog_name == "rag_hallucination_risks":
59
  if criteria_name == "context_relevance":
 
5
  def create_message(role, content):
6
  return [{"role": role, "content": content}]
7
 
8
+
9
  def get_messages(test_case, sub_catalog_name) -> list[dict[str, str]]:
10
  messages = []
11
 
 
28
  messages += create_message("tools", test_case["tools"])
29
  messages += create_message("user", test_case["user_message"])
30
  messages += create_message("assistant", test_case["assistant_message"])
 
 
31
  return messages
32
 
33
 
 
52
  component = None
53
  if sub_catalog_name == "harmful_content_in_user_prompt":
54
  component = "user"
55
+ elif (
56
+ sub_catalog_name == "harmful_content_in_assistant_response" or sub_catalog_name == "risks_in_agentic_workflows"
57
+ ):
58
  component = "assistant"
59
  elif sub_catalog_name == "rag_hallucination_risks":
60
  if criteria_name == "context_relevance":