Spaces:
Running
on
Zero
Running
on
Zero
Martín Santillán Cooper
commited on
Commit
•
51c0b7a
1
Parent(s):
42e2cdc
fix: guardian_config is not passed
Browse files
model.py
CHANGED
@@ -105,7 +105,7 @@ def get_probablities_watsonx(top_tokens_list):
|
|
105 |
def get_prompt(messages, criteria_name, return_tensors=None):
|
106 |
guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
|
107 |
return tokenizer.apply_chat_template(
|
108 |
-
messages, guardian_config=guardian_config, tokenize=False, add_generation_prompt=
|
109 |
)
|
110 |
|
111 |
|
@@ -159,7 +159,7 @@ def generate_text(messages, criteria_name):
|
|
159 |
label, prob_of_risk = parse_output_watsonx(generated_tokens)
|
160 |
|
161 |
elif inference_engine == "VLLM":
|
162 |
-
input_ids =
|
163 |
logger.debug(f"input_ids are: {input_ids}")
|
164 |
input_len = input_ids.shape[1]
|
165 |
logger.debug(f"input_len are: {input_len}")
|
@@ -172,7 +172,7 @@ def generate_text(messages, criteria_name):
|
|
172 |
max_new_tokens=nlogprobs,
|
173 |
return_dict_in_generate=True,
|
174 |
output_scores=True,)
|
175 |
-
logger.debug(f"model output is
|
176 |
|
177 |
label, prob_of_risk = parse_output(output, input_len)
|
178 |
logger.debug(f"label is are: {label}")
|
|
|
105 |
def get_prompt(messages, criteria_name, return_tensors=None):
|
106 |
guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
|
107 |
return tokenizer.apply_chat_template(
|
108 |
+
messages, guardian_config=guardian_config, tokenize=False, add_generation_prompt=return_tensors is not None, return_tensors=return_tensors
|
109 |
)
|
110 |
|
111 |
|
|
|
159 |
label, prob_of_risk = parse_output_watsonx(generated_tokens)
|
160 |
|
161 |
elif inference_engine == "VLLM":
|
162 |
+
input_ids = get_prompt(messages=messages, criteria_name=criteria_name, return_tensors="pt").to(model.device)
|
163 |
logger.debug(f"input_ids are: {input_ids}")
|
164 |
input_len = input_ids.shape[1]
|
165 |
logger.debug(f"input_len are: {input_len}")
|
|
|
172 |
max_new_tokens=nlogprobs,
|
173 |
return_dict_in_generate=True,
|
174 |
output_scores=True,)
|
175 |
+
logger.debug(f"model output is:\n{output}")
|
176 |
|
177 |
label, prob_of_risk = parse_output(output, input_len)
|
178 |
logger.debug(f"label is are: {label}")
|