Spaces:
Running
on
Zero
Running
on
Zero
Martín Santillán Cooper
commited on
Commit
•
0f55f50
1
Parent(s):
2e6c988
Continue adaption
Browse files- .env.example +1 -1
- src/logger.py +5 -4
- src/model.py +11 -12
- src/utils.py +3 -2
.env.example
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
MODEL_PATH='ibm-granite/granite-guardian-3.1-8b'
|
2 |
-
INFERENCE_ENGINE='VLLM' # one of [WATSONX, MOCK,
|
3 |
WATSONX_API_KEY=''
|
4 |
WATSONX_PROJECT_ID=''
|
|
|
1 |
MODEL_PATH='ibm-granite/granite-guardian-3.1-8b'
|
2 |
+
INFERENCE_ENGINE='VLLM' # one of [WATSONX, MOCK, TORCH]
|
3 |
WATSONX_API_KEY=''
|
4 |
WATSONX_PROJECT_ID=''
|
src/logger.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1 |
import logging
|
2 |
|
3 |
-
logger = logging.getLogger("demo")
|
4 |
logger.setLevel(logging.DEBUG)
|
5 |
|
|
|
|
|
6 |
stream_handler = logging.StreamHandler()
|
7 |
stream_handler.setLevel(logging.DEBUG)
|
|
|
8 |
logger.addHandler(stream_handler)
|
9 |
|
10 |
file_handler = logging.FileHandler("logs.txt")
|
11 |
-
file_handler.setFormatter(
|
12 |
-
logging.Formatter("%(asctime)s - %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
13 |
-
)
|
14 |
logger.addHandler(file_handler)
|
|
|
1 |
import logging
|
2 |
|
3 |
+
logger = logging.getLogger("guardian-demo")
|
4 |
logger.setLevel(logging.DEBUG)
|
5 |
|
6 |
+
formatter = logging.Formatter("%(asctime)s - %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
7 |
+
|
8 |
stream_handler = logging.StreamHandler()
|
9 |
stream_handler.setLevel(logging.DEBUG)
|
10 |
+
stream_handler.setFormatter(formatter)
|
11 |
logger.addHandler(stream_handler)
|
12 |
|
13 |
file_handler = logging.FileHandler("logs.txt")
|
14 |
+
file_handler.setFormatter(formatter)
|
|
|
|
|
15 |
logger.addHandler(file_handler)
|
src/model.py
CHANGED
@@ -13,15 +13,15 @@ safe_token = "No"
|
|
13 |
risky_token = "Yes"
|
14 |
nlogprobs = 20
|
15 |
|
16 |
-
inference_engine = os.getenv("INFERENCE_ENGINE", "
|
17 |
logger.debug(f"Inference engine is: '{inference_engine}'")
|
18 |
|
19 |
-
if inference_engine == "
|
20 |
import torch
|
21 |
|
22 |
device = torch.device("cpu")
|
23 |
|
24 |
-
model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.1-
|
25 |
logger.debug(f"model_path is {model_path}")
|
26 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
27 |
model = AutoModelForCausalLM.from_pretrained(model_path)
|
@@ -160,7 +160,6 @@ def get_prompt(messages, criteria_name, tokenize=False, add_generation_prompt=Fa
|
|
160 |
|
161 |
@spaces.GPU
|
162 |
def get_guardian_response(messages, criteria_name):
|
163 |
-
logger.debug(f"Messages used to create the prompt are: \n{messages}")
|
164 |
start = time()
|
165 |
if inference_engine == "MOCK":
|
166 |
logger.debug("Returning mocked model result.")
|
@@ -173,7 +172,7 @@ def get_guardian_response(messages, criteria_name):
|
|
173 |
generated_tokens = generate_tokens_watsonx(chat)
|
174 |
label, prob_of_risk = parse_output_watsonx(generated_tokens)
|
175 |
|
176 |
-
elif inference_engine == "
|
177 |
input_ids = get_prompt(
|
178 |
messages=messages,
|
179 |
criteria_name=criteria_name,
|
@@ -181,7 +180,7 @@ def get_guardian_response(messages, criteria_name):
|
|
181 |
add_generation_prompt=True,
|
182 |
return_tensors="pt",
|
183 |
).to(model.device)
|
184 |
-
logger.debug(f"input_ids are: {input_ids}")
|
185 |
input_len = input_ids.shape[1]
|
186 |
logger.debug(f"input_len is: {input_len}")
|
187 |
|
@@ -194,16 +193,16 @@ def get_guardian_response(messages, criteria_name):
|
|
194 |
return_dict_in_generate=True,
|
195 |
output_scores=True,
|
196 |
)
|
197 |
-
logger.debug(f"model output is:\n{output}")
|
198 |
|
199 |
label, prob_of_risk = parse_output(output, input_len)
|
200 |
-
logger.debug(f"
|
201 |
-
logger.debug(f"
|
202 |
else:
|
203 |
-
raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK,
|
204 |
|
205 |
-
logger.debug(f"Model generated label:
|
206 |
-
logger.debug(f"Model prob_of_risk:
|
207 |
|
208 |
end = time()
|
209 |
total = end - start
|
|
|
13 |
risky_token = "Yes"
|
14 |
nlogprobs = 20
|
15 |
|
16 |
+
inference_engine = os.getenv("INFERENCE_ENGINE", "TORCH")
|
17 |
logger.debug(f"Inference engine is: '{inference_engine}'")
|
18 |
|
19 |
+
if inference_engine == "TORCH":
|
20 |
import torch
|
21 |
|
22 |
device = torch.device("cpu")
|
23 |
|
24 |
+
model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.1-8b")
|
25 |
logger.debug(f"model_path is {model_path}")
|
26 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
27 |
model = AutoModelForCausalLM.from_pretrained(model_path)
|
|
|
160 |
|
161 |
@spaces.GPU
|
162 |
def get_guardian_response(messages, criteria_name):
|
|
|
163 |
start = time()
|
164 |
if inference_engine == "MOCK":
|
165 |
logger.debug("Returning mocked model result.")
|
|
|
172 |
generated_tokens = generate_tokens_watsonx(chat)
|
173 |
label, prob_of_risk = parse_output_watsonx(generated_tokens)
|
174 |
|
175 |
+
elif inference_engine == "TORCH":
|
176 |
input_ids = get_prompt(
|
177 |
messages=messages,
|
178 |
criteria_name=criteria_name,
|
|
|
180 |
add_generation_prompt=True,
|
181 |
return_tensors="pt",
|
182 |
).to(model.device)
|
183 |
+
# logger.debug(f"input_ids are: {input_ids}")
|
184 |
input_len = input_ids.shape[1]
|
185 |
logger.debug(f"input_len is: {input_len}")
|
186 |
|
|
|
193 |
return_dict_in_generate=True,
|
194 |
output_scores=True,
|
195 |
)
|
196 |
+
# logger.debug(f"model output is:\n{output}")
|
197 |
|
198 |
label, prob_of_risk = parse_output(output, input_len)
|
199 |
+
logger.debug(f"Label is: {label}")
|
200 |
+
logger.debug(f"Prob_of_risk is: {prob_of_risk}")
|
201 |
else:
|
202 |
+
raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, TORCH]")
|
203 |
|
204 |
+
logger.debug(f"Model generated label: {label}")
|
205 |
+
logger.debug(f"Model prob_of_risk: {prob_of_risk}")
|
206 |
|
207 |
end = time()
|
208 |
total = end - start
|
src/utils.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import argparse
|
|
|
2 |
import os
|
3 |
|
4 |
|
@@ -25,9 +26,9 @@ def get_messages(test_case, sub_catalog_name) -> list[dict[str, str]]:
|
|
25 |
messages += create_message("user", test_case["user_message"])
|
26 |
messages += create_message("assistant", test_case["assistant_message"])
|
27 |
elif sub_catalog_name == "risks_in_agentic_workflows":
|
28 |
-
messages += create_message("tools", test_case["tools"])
|
29 |
messages += create_message("user", test_case["user_message"])
|
30 |
-
messages += create_message("assistant", test_case["assistant_message"])
|
31 |
return messages
|
32 |
|
33 |
|
|
|
1 |
import argparse
|
2 |
+
import json
|
3 |
import os
|
4 |
|
5 |
|
|
|
26 |
messages += create_message("user", test_case["user_message"])
|
27 |
messages += create_message("assistant", test_case["assistant_message"])
|
28 |
elif sub_catalog_name == "risks_in_agentic_workflows":
|
29 |
+
messages += create_message("tools", json.loads(test_case["tools"]))
|
30 |
messages += create_message("user", test_case["user_message"])
|
31 |
+
messages += create_message("assistant", json.loads(test_case["assistant_message"]))
|
32 |
return messages
|
33 |
|
34 |
|