Martín Santillán Cooper commited on
Commit
2b6005c
1 Parent(s): f466ad8

Adapt model to guardian cookbook

Browse files
Files changed (1) hide show
  1. model.py +38 -26
model.py CHANGED
@@ -15,19 +15,21 @@ from logger import logger
15
 
16
  safe_token = "No"
17
  risky_token = "Yes"
18
- nlogprobs = 5
19
 
20
  inference_engine = os.getenv("INFERENCE_ENGINE", "VLLM")
21
  logger.debug(f"Inference engine is: '{inference_engine}'")
22
 
23
  if inference_engine == "VLLM":
 
24
 
25
  model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-8b")
26
  logger.debug(f"model_path is {model_path}")
27
  tokenizer = AutoTokenizer.from_pretrained(model_path)
28
  # sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
29
  # model = LLM(model=model_path, tensor_parallel_size=1)
30
- model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
 
31
 
32
  elif inference_engine == "WATSONX":
33
  client = APIClient(
@@ -42,16 +44,17 @@ elif inference_engine == "WATSONX":
42
  model = ModelInference(model_id=model_id, api_client=client)
43
 
44
 
45
- def parse_output(output):
46
- label, prob = None, None
47
-
48
  if nlogprobs > 0:
49
- logprobs = next(iter(output.outputs)).logprobs
50
- if logprobs is not None:
51
- prob = get_probablities(logprobs)
 
 
52
  prob_of_risk = prob[1]
53
 
54
- res = next(iter(output.outputs)).text.strip()
55
  if risky_token.lower() == res.lower():
56
  label = risky_token
57
  elif safe_token.lower() == res.lower():
@@ -61,29 +64,29 @@ def parse_output(output):
61
 
62
  return label, prob_of_risk.item()
63
 
64
-
65
- def softmax(values):
66
- exp_values = [math.exp(v) for v in values]
67
- total = sum(exp_values)
68
- return [v / total for v in exp_values]
69
-
70
-
71
  def get_probablities(logprobs):
72
  safe_token_prob = 1e-50
73
  unsafe_token_prob = 1e-50
74
  for gen_token_i in logprobs:
75
- for token_prob in gen_token_i.values():
76
- decoded_token = token_prob.decoded_token
77
  if decoded_token.strip().lower() == safe_token.lower():
78
- safe_token_prob += math.exp(token_prob.logprob)
79
  if decoded_token.strip().lower() == risky_token.lower():
80
- unsafe_token_prob += math.exp(token_prob.logprob)
81
 
82
- probabilities = torch.softmax(torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0)
 
 
83
 
84
  return probabilities
85
 
86
 
 
 
 
 
 
87
  def get_probablities_watsonx(top_tokens_list):
88
  safe_token_prob = 1e-50
89
  risky_token_prob = 1e-50
@@ -99,10 +102,10 @@ def get_probablities_watsonx(top_tokens_list):
99
  return probabilities
100
 
101
 
102
- def get_prompt(messages, criteria_name):
103
  guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
104
  return tokenizer.apply_chat_template(
105
- messages, guardian_config=guardian_config, tokenize=False, add_generation_prompt=True
106
  )
107
 
108
 
@@ -155,15 +158,24 @@ def generate_text(messages, criteria_name):
155
  label, prob_of_risk = "Yes", 0.97
156
 
157
  elif inference_engine == "WATSONX":
 
158
  generated_tokens = generate_tokens(chat)
159
  label, prob_of_risk = parse_output_watsonx(generated_tokens)
160
 
161
  elif inference_engine == "VLLM":
 
 
 
162
  with torch.no_grad():
163
  # output = model.generate(chat, sampling_params, use_tqdm=False)
164
- output = model.generate(chat)
165
-
166
- label, prob_of_risk = parse_output(output[0])
 
 
 
 
 
167
  else:
168
  raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")
169
 
 
15
 
16
  safe_token = "No"
17
  risky_token = "Yes"
18
+ nlogprobs = 20
19
 
20
  inference_engine = os.getenv("INFERENCE_ENGINE", "VLLM")
21
  logger.debug(f"Inference engine is: '{inference_engine}'")
22
 
23
  if inference_engine == "VLLM":
24
+ device = torch.device("gpu")
25
 
26
  model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-8b")
27
  logger.debug(f"model_path is {model_path}")
28
  tokenizer = AutoTokenizer.from_pretrained(model_path)
29
  # sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
30
  # model = LLM(model=model_path, tensor_parallel_size=1)
31
+ model = AutoModelForCausalLM.from_pretrained(model_path)
32
+ model = model.to(device).eval()
33
 
34
  elif inference_engine == "WATSONX":
35
  client = APIClient(
 
44
  model = ModelInference(model_id=model_id, api_client=client)
45
 
46
 
47
+ def parse_output(output, input_len):
48
+ label, prob_of_risk = None, None
 
49
  if nlogprobs > 0:
50
+
51
+ list_index_logprobs_i = [torch.topk(token_i, k=nlogprobs, largest=True, sorted=True)
52
+ for token_i in list(output.scores)[:-1]]
53
+ if list_index_logprobs_i is not None:
54
+ prob = get_probablities(list_index_logprobs_i)
55
  prob_of_risk = prob[1]
56
 
57
+ res = tokenizer.decode(output.sequences[:,input_len:][0],skip_special_tokens=True).strip()
58
  if risky_token.lower() == res.lower():
59
  label = risky_token
60
  elif safe_token.lower() == res.lower():
 
64
 
65
  return label, prob_of_risk.item()
66
 
 
 
 
 
 
 
 
67
  def get_probablities(logprobs):
68
  safe_token_prob = 1e-50
69
  unsafe_token_prob = 1e-50
70
  for gen_token_i in logprobs:
71
+ for logprob, index in zip(gen_token_i.values.tolist()[0], gen_token_i.indices.tolist()[0]):
72
+ decoded_token = tokenizer.convert_ids_to_tokens(index)
73
  if decoded_token.strip().lower() == safe_token.lower():
74
+ safe_token_prob += math.exp(logprob)
75
  if decoded_token.strip().lower() == risky_token.lower():
76
+ unsafe_token_prob += math.exp(logprob)
77
 
78
+ probabilities = torch.softmax(
79
+ torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0
80
+ )
81
 
82
  return probabilities
83
 
84
 
85
+ def softmax(values):
86
+ exp_values = [math.exp(v) for v in values]
87
+ total = sum(exp_values)
88
+ return [v / total for v in exp_values]
89
+
90
  def get_probablities_watsonx(top_tokens_list):
91
  safe_token_prob = 1e-50
92
  risky_token_prob = 1e-50
 
102
  return probabilities
103
 
104
 
105
+ def get_prompt(messages, criteria_name, return_tensors=None):
106
  guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
107
  return tokenizer.apply_chat_template(
108
+ messages, guardian_config=guardian_config, tokenize=False, add_generation_prompt=True, return_tensors=return_tensors
109
  )
110
 
111
 
 
158
  label, prob_of_risk = "Yes", 0.97
159
 
160
  elif inference_engine == "WATSONX":
161
+ chat = get_prompt(messages, criteria_name)
162
  generated_tokens = generate_tokens(chat)
163
  label, prob_of_risk = parse_output_watsonx(generated_tokens)
164
 
165
  elif inference_engine == "VLLM":
166
+ input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
167
+ input_len = input_ids.shape[1]
168
+
169
  with torch.no_grad():
170
  # output = model.generate(chat, sampling_params, use_tqdm=False)
171
+ output = model.generate(
172
+ chat,
173
+ do_sample=False,
174
+ max_new_tokens=nlogprobs,
175
+ return_dict_in_generate=True,
176
+ output_scores=True,)
177
+
178
+ label, prob_of_risk = parse_output(output, input_len)
179
  else:
180
  raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")
181