Martín Santillán Cooper commited on
Commit
0f55f50
1 Parent(s): 2e6c988

Continue adaption

Browse files
Files changed (4) hide show
  1. .env.example +1 -1
  2. src/logger.py +5 -4
  3. src/model.py +11 -12
  4. src/utils.py +3 -2
.env.example CHANGED
@@ -1,4 +1,4 @@
1
  MODEL_PATH='ibm-granite/granite-guardian-3.1-8b'
2
- INFERENCE_ENGINE='VLLM' # one of [WATSONX, MOCK, VLLM]
3
  WATSONX_API_KEY=''
4
  WATSONX_PROJECT_ID=''
 
1
  MODEL_PATH='ibm-granite/granite-guardian-3.1-8b'
2
+ INFERENCE_ENGINE='VLLM' # one of [WATSONX, MOCK, TORCH]
3
  WATSONX_API_KEY=''
4
  WATSONX_PROJECT_ID=''
src/logger.py CHANGED
@@ -1,14 +1,15 @@
1
  import logging
2
 
3
- logger = logging.getLogger("demo")
4
  logger.setLevel(logging.DEBUG)
5
 
 
 
6
  stream_handler = logging.StreamHandler()
7
  stream_handler.setLevel(logging.DEBUG)
 
8
  logger.addHandler(stream_handler)
9
 
10
  file_handler = logging.FileHandler("logs.txt")
11
- file_handler.setFormatter(
12
- logging.Formatter("%(asctime)s - %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
13
- )
14
  logger.addHandler(file_handler)
 
1
  import logging
2
 
3
+ logger = logging.getLogger("guardian-demo")
4
  logger.setLevel(logging.DEBUG)
5
 
6
+ formatter = logging.Formatter("%(asctime)s - %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
7
+
8
  stream_handler = logging.StreamHandler()
9
  stream_handler.setLevel(logging.DEBUG)
10
+ stream_handler.setFormatter(formatter)
11
  logger.addHandler(stream_handler)
12
 
13
  file_handler = logging.FileHandler("logs.txt")
14
+ file_handler.setFormatter(formatter)
 
 
15
  logger.addHandler(file_handler)
src/model.py CHANGED
@@ -13,15 +13,15 @@ safe_token = "No"
13
  risky_token = "Yes"
14
  nlogprobs = 20
15
 
16
- inference_engine = os.getenv("INFERENCE_ENGINE", "VLLM")
17
  logger.debug(f"Inference engine is: '{inference_engine}'")
18
 
19
- if inference_engine == "VLLM":
20
  import torch
21
 
22
  device = torch.device("cpu")
23
 
24
- model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.1-2b")
25
  logger.debug(f"model_path is {model_path}")
26
  tokenizer = AutoTokenizer.from_pretrained(model_path)
27
  model = AutoModelForCausalLM.from_pretrained(model_path)
@@ -160,7 +160,6 @@ def get_prompt(messages, criteria_name, tokenize=False, add_generation_prompt=Fa
160
 
161
  @spaces.GPU
162
  def get_guardian_response(messages, criteria_name):
163
- logger.debug(f"Messages used to create the prompt are: \n{messages}")
164
  start = time()
165
  if inference_engine == "MOCK":
166
  logger.debug("Returning mocked model result.")
@@ -173,7 +172,7 @@ def get_guardian_response(messages, criteria_name):
173
  generated_tokens = generate_tokens_watsonx(chat)
174
  label, prob_of_risk = parse_output_watsonx(generated_tokens)
175
 
176
- elif inference_engine == "VLLM":
177
  input_ids = get_prompt(
178
  messages=messages,
179
  criteria_name=criteria_name,
@@ -181,7 +180,7 @@ def get_guardian_response(messages, criteria_name):
181
  add_generation_prompt=True,
182
  return_tensors="pt",
183
  ).to(model.device)
184
- logger.debug(f"input_ids are: {input_ids}")
185
  input_len = input_ids.shape[1]
186
  logger.debug(f"input_len is: {input_len}")
187
 
@@ -194,16 +193,16 @@ def get_guardian_response(messages, criteria_name):
194
  return_dict_in_generate=True,
195
  output_scores=True,
196
  )
197
- logger.debug(f"model output is:\n{output}")
198
 
199
  label, prob_of_risk = parse_output(output, input_len)
200
- logger.debug(f"label is are: {label}")
201
- logger.debug(f"prob_of_risk is are: {prob_of_risk}")
202
  else:
203
- raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")
204
 
205
- logger.debug(f"Model generated label: \n{label}")
206
- logger.debug(f"Model prob_of_risk: \n{prob_of_risk}")
207
 
208
  end = time()
209
  total = end - start
 
13
  risky_token = "Yes"
14
  nlogprobs = 20
15
 
16
+ inference_engine = os.getenv("INFERENCE_ENGINE", "TORCH")
17
  logger.debug(f"Inference engine is: '{inference_engine}'")
18
 
19
+ if inference_engine == "TORCH":
20
  import torch
21
 
22
  device = torch.device("cpu")
23
 
24
+ model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.1-8b")
25
  logger.debug(f"model_path is {model_path}")
26
  tokenizer = AutoTokenizer.from_pretrained(model_path)
27
  model = AutoModelForCausalLM.from_pretrained(model_path)
 
160
 
161
  @spaces.GPU
162
  def get_guardian_response(messages, criteria_name):
 
163
  start = time()
164
  if inference_engine == "MOCK":
165
  logger.debug("Returning mocked model result.")
 
172
  generated_tokens = generate_tokens_watsonx(chat)
173
  label, prob_of_risk = parse_output_watsonx(generated_tokens)
174
 
175
+ elif inference_engine == "TORCH":
176
  input_ids = get_prompt(
177
  messages=messages,
178
  criteria_name=criteria_name,
 
180
  add_generation_prompt=True,
181
  return_tensors="pt",
182
  ).to(model.device)
183
+ # logger.debug(f"input_ids are: {input_ids}")
184
  input_len = input_ids.shape[1]
185
  logger.debug(f"input_len is: {input_len}")
186
 
 
193
  return_dict_in_generate=True,
194
  output_scores=True,
195
  )
196
+ # logger.debug(f"model output is:\n{output}")
197
 
198
  label, prob_of_risk = parse_output(output, input_len)
199
+ logger.debug(f"Label is: {label}")
200
+ logger.debug(f"Prob_of_risk is: {prob_of_risk}")
201
  else:
202
+ raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, TORCH]")
203
 
204
+ logger.debug(f"Model generated label: {label}")
205
+ logger.debug(f"Model prob_of_risk: {prob_of_risk}")
206
 
207
  end = time()
208
  total = end - start
src/utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import argparse
 
2
  import os
3
 
4
 
@@ -25,9 +26,9 @@ def get_messages(test_case, sub_catalog_name) -> list[dict[str, str]]:
25
  messages += create_message("user", test_case["user_message"])
26
  messages += create_message("assistant", test_case["assistant_message"])
27
  elif sub_catalog_name == "risks_in_agentic_workflows":
28
- messages += create_message("tools", test_case["tools"])
29
  messages += create_message("user", test_case["user_message"])
30
- messages += create_message("assistant", test_case["assistant_message"])
31
  return messages
32
 
33
 
 
1
  import argparse
2
+ import json
3
  import os
4
 
5
 
 
26
  messages += create_message("user", test_case["user_message"])
27
  messages += create_message("assistant", test_case["assistant_message"])
28
  elif sub_catalog_name == "risks_in_agentic_workflows":
29
+ messages += create_message("tools", json.loads(test_case["tools"]))
30
  messages += create_message("user", test_case["user_message"])
31
+ messages += create_message("assistant", json.loads(test_case["assistant_message"]))
32
  return messages
33
 
34