Martín Santillán Cooper commited on
Commit
51c0b7a
1 Parent(s): 42e2cdc

fix: guardian_config is not passed

Browse files
Files changed (1) hide show
  1. model.py +3 -3
model.py CHANGED
@@ -105,7 +105,7 @@ def get_probablities_watsonx(top_tokens_list):
105
  def get_prompt(messages, criteria_name, return_tensors=None):
106
  guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
107
  return tokenizer.apply_chat_template(
108
- messages, guardian_config=guardian_config, tokenize=False, add_generation_prompt=True, return_tensors=return_tensors
109
  )
110
 
111
 
@@ -159,7 +159,7 @@ def generate_text(messages, criteria_name):
159
  label, prob_of_risk = parse_output_watsonx(generated_tokens)
160
 
161
  elif inference_engine == "VLLM":
162
- input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
163
  logger.debug(f"input_ids are: {input_ids}")
164
  input_len = input_ids.shape[1]
165
  logger.debug(f"input_len are: {input_len}")
@@ -172,7 +172,7 @@ def generate_text(messages, criteria_name):
172
  max_new_tokens=nlogprobs,
173
  return_dict_in_generate=True,
174
  output_scores=True,)
175
- logger.debug(f"model output is are: {output}")
176
 
177
  label, prob_of_risk = parse_output(output, input_len)
178
  logger.debug(f"label is are: {label}")
 
105
  def get_prompt(messages, criteria_name, return_tensors=None):
106
  guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
107
  return tokenizer.apply_chat_template(
108
+ messages, guardian_config=guardian_config, tokenize=False, add_generation_prompt=return_tensors is not None, return_tensors=return_tensors
109
  )
110
 
111
 
 
159
  label, prob_of_risk = parse_output_watsonx(generated_tokens)
160
 
161
  elif inference_engine == "VLLM":
162
+ input_ids = get_prompt(messages=messages, criteria_name=criteria_name, return_tensors="pt").to(model.device)
163
  logger.debug(f"input_ids are: {input_ids}")
164
  input_len = input_ids.shape[1]
165
  logger.debug(f"input_len are: {input_len}")
 
172
  max_new_tokens=nlogprobs,
173
  return_dict_in_generate=True,
174
  output_scores=True,)
175
+ logger.debug(f"model output is:\n{output}")
176
 
177
  label, prob_of_risk = parse_output(output, input_len)
178
  logger.debug(f"label is are: {label}")