Spaces:
Running
on
Zero
Running
on
Zero
Martín Santillán Cooper
commited on
Commit
•
2b6005c
1
Parent(s):
f466ad8
Adapt model to guardian cookbook
Browse files
model.py
CHANGED
@@ -15,19 +15,21 @@ from logger import logger
|
|
15 |
|
16 |
safe_token = "No"
|
17 |
risky_token = "Yes"
|
18 |
-
nlogprobs =
|
19 |
|
20 |
inference_engine = os.getenv("INFERENCE_ENGINE", "VLLM")
|
21 |
logger.debug(f"Inference engine is: '{inference_engine}'")
|
22 |
|
23 |
if inference_engine == "VLLM":
|
|
|
24 |
|
25 |
model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-8b")
|
26 |
logger.debug(f"model_path is {model_path}")
|
27 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
28 |
# sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
|
29 |
# model = LLM(model=model_path, tensor_parallel_size=1)
|
30 |
-
model = AutoModelForCausalLM.from_pretrained(model_path
|
|
|
31 |
|
32 |
elif inference_engine == "WATSONX":
|
33 |
client = APIClient(
|
@@ -42,16 +44,17 @@ elif inference_engine == "WATSONX":
|
|
42 |
model = ModelInference(model_id=model_id, api_client=client)
|
43 |
|
44 |
|
45 |
-
def parse_output(output):
|
46 |
-
label,
|
47 |
-
|
48 |
if nlogprobs > 0:
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
52 |
prob_of_risk = prob[1]
|
53 |
|
54 |
-
res =
|
55 |
if risky_token.lower() == res.lower():
|
56 |
label = risky_token
|
57 |
elif safe_token.lower() == res.lower():
|
@@ -61,29 +64,29 @@ def parse_output(output):
|
|
61 |
|
62 |
return label, prob_of_risk.item()
|
63 |
|
64 |
-
|
65 |
-
def softmax(values):
|
66 |
-
exp_values = [math.exp(v) for v in values]
|
67 |
-
total = sum(exp_values)
|
68 |
-
return [v / total for v in exp_values]
|
69 |
-
|
70 |
-
|
71 |
def get_probablities(logprobs):
|
72 |
safe_token_prob = 1e-50
|
73 |
unsafe_token_prob = 1e-50
|
74 |
for gen_token_i in logprobs:
|
75 |
-
for
|
76 |
-
decoded_token =
|
77 |
if decoded_token.strip().lower() == safe_token.lower():
|
78 |
-
safe_token_prob += math.exp(
|
79 |
if decoded_token.strip().lower() == risky_token.lower():
|
80 |
-
unsafe_token_prob += math.exp(
|
81 |
|
82 |
-
probabilities = torch.softmax(
|
|
|
|
|
83 |
|
84 |
return probabilities
|
85 |
|
86 |
|
|
|
|
|
|
|
|
|
|
|
87 |
def get_probablities_watsonx(top_tokens_list):
|
88 |
safe_token_prob = 1e-50
|
89 |
risky_token_prob = 1e-50
|
@@ -99,10 +102,10 @@ def get_probablities_watsonx(top_tokens_list):
|
|
99 |
return probabilities
|
100 |
|
101 |
|
102 |
-
def get_prompt(messages, criteria_name):
|
103 |
guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
|
104 |
return tokenizer.apply_chat_template(
|
105 |
-
messages, guardian_config=guardian_config, tokenize=False, add_generation_prompt=True
|
106 |
)
|
107 |
|
108 |
|
@@ -155,15 +158,24 @@ def generate_text(messages, criteria_name):
|
|
155 |
label, prob_of_risk = "Yes", 0.97
|
156 |
|
157 |
elif inference_engine == "WATSONX":
|
|
|
158 |
generated_tokens = generate_tokens(chat)
|
159 |
label, prob_of_risk = parse_output_watsonx(generated_tokens)
|
160 |
|
161 |
elif inference_engine == "VLLM":
|
|
|
|
|
|
|
162 |
with torch.no_grad():
|
163 |
# output = model.generate(chat, sampling_params, use_tqdm=False)
|
164 |
-
output = model.generate(
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
167 |
else:
|
168 |
raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")
|
169 |
|
|
|
15 |
|
16 |
safe_token = "No"
|
17 |
risky_token = "Yes"
|
18 |
+
nlogprobs = 20
|
19 |
|
20 |
inference_engine = os.getenv("INFERENCE_ENGINE", "VLLM")
|
21 |
logger.debug(f"Inference engine is: '{inference_engine}'")
|
22 |
|
23 |
if inference_engine == "VLLM":
|
24 |
+
device = torch.device("gpu")
|
25 |
|
26 |
model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-8b")
|
27 |
logger.debug(f"model_path is {model_path}")
|
28 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
29 |
# sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
|
30 |
# model = LLM(model=model_path, tensor_parallel_size=1)
|
31 |
+
model = AutoModelForCausalLM.from_pretrained(model_path)
|
32 |
+
model = model.to(device).eval()
|
33 |
|
34 |
elif inference_engine == "WATSONX":
|
35 |
client = APIClient(
|
|
|
44 |
model = ModelInference(model_id=model_id, api_client=client)
|
45 |
|
46 |
|
47 |
+
def parse_output(output, input_len):
|
48 |
+
label, prob_of_risk = None, None
|
|
|
49 |
if nlogprobs > 0:
|
50 |
+
|
51 |
+
list_index_logprobs_i = [torch.topk(token_i, k=nlogprobs, largest=True, sorted=True)
|
52 |
+
for token_i in list(output.scores)[:-1]]
|
53 |
+
if list_index_logprobs_i is not None:
|
54 |
+
prob = get_probablities(list_index_logprobs_i)
|
55 |
prob_of_risk = prob[1]
|
56 |
|
57 |
+
res = tokenizer.decode(output.sequences[:,input_len:][0],skip_special_tokens=True).strip()
|
58 |
if risky_token.lower() == res.lower():
|
59 |
label = risky_token
|
60 |
elif safe_token.lower() == res.lower():
|
|
|
64 |
|
65 |
return label, prob_of_risk.item()
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
def get_probablities(logprobs):
|
68 |
safe_token_prob = 1e-50
|
69 |
unsafe_token_prob = 1e-50
|
70 |
for gen_token_i in logprobs:
|
71 |
+
for logprob, index in zip(gen_token_i.values.tolist()[0], gen_token_i.indices.tolist()[0]):
|
72 |
+
decoded_token = tokenizer.convert_ids_to_tokens(index)
|
73 |
if decoded_token.strip().lower() == safe_token.lower():
|
74 |
+
safe_token_prob += math.exp(logprob)
|
75 |
if decoded_token.strip().lower() == risky_token.lower():
|
76 |
+
unsafe_token_prob += math.exp(logprob)
|
77 |
|
78 |
+
probabilities = torch.softmax(
|
79 |
+
torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0
|
80 |
+
)
|
81 |
|
82 |
return probabilities
|
83 |
|
84 |
|
85 |
+
def softmax(values):
|
86 |
+
exp_values = [math.exp(v) for v in values]
|
87 |
+
total = sum(exp_values)
|
88 |
+
return [v / total for v in exp_values]
|
89 |
+
|
90 |
def get_probablities_watsonx(top_tokens_list):
|
91 |
safe_token_prob = 1e-50
|
92 |
risky_token_prob = 1e-50
|
|
|
102 |
return probabilities
|
103 |
|
104 |
|
105 |
+
def get_prompt(messages, criteria_name, return_tensors=None):
|
106 |
guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
|
107 |
return tokenizer.apply_chat_template(
|
108 |
+
messages, guardian_config=guardian_config, tokenize=False, add_generation_prompt=True, return_tensors=return_tensors
|
109 |
)
|
110 |
|
111 |
|
|
|
158 |
label, prob_of_risk = "Yes", 0.97
|
159 |
|
160 |
elif inference_engine == "WATSONX":
|
161 |
+
chat = get_prompt(messages, criteria_name)
|
162 |
generated_tokens = generate_tokens(chat)
|
163 |
label, prob_of_risk = parse_output_watsonx(generated_tokens)
|
164 |
|
165 |
elif inference_engine == "VLLM":
|
166 |
+
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
167 |
+
input_len = input_ids.shape[1]
|
168 |
+
|
169 |
with torch.no_grad():
|
170 |
# output = model.generate(chat, sampling_params, use_tqdm=False)
|
171 |
+
output = model.generate(
|
172 |
+
chat,
|
173 |
+
do_sample=False,
|
174 |
+
max_new_tokens=nlogprobs,
|
175 |
+
return_dict_in_generate=True,
|
176 |
+
output_scores=True,)
|
177 |
+
|
178 |
+
label, prob_of_risk = parse_output(output, input_len)
|
179 |
else:
|
180 |
raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")
|
181 |
|