grahamwhiteuk's picture
fix: remove sampling params
bce909e verified
raw
history blame
5.75 kB
import math
import os
from time import sleep, time
import spaces
import torch
from ibm_watsonx_ai.client import APIClient
from ibm_watsonx_ai.foundation_models import ModelInference
from transformers import AutoModelForCausalLM, AutoTokenizer
from logger import logger
# from vllm import LLM, SamplingParams
safe_token = "No"
risky_token = "Yes"
nlogprobs = 5
inference_engine = os.getenv("INFERENCE_ENGINE", "VLLM")
logger.debug(f"Inference engine is: '{inference_engine}'")
if inference_engine == "VLLM":
model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-8b")
logger.debug(f"model_path is {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
# sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
# model = LLM(model=model_path, tensor_parallel_size=1)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
elif inference_engine == "WATSONX":
client = APIClient(
credentials={"api_key": os.getenv("WATSONX_API_KEY"), "url": "https://us-south.ml.cloud.ibm.com"}
)
client.set.default_project(os.getenv("WATSONX_PROJECT_ID"))
hf_model_path = "ibm-granite/granite-guardian-3.0-8b"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
model_id = "ibm/granite-guardian-3-8b" # 8B Model: "ibm/granite-guardian-3-8b"
model = ModelInference(model_id=model_id, api_client=client)
def parse_output(output):
label, prob = None, None
if nlogprobs > 0:
logprobs = next(iter(output.outputs)).logprobs
if logprobs is not None:
prob = get_probablities(logprobs)
prob_of_risk = prob[1]
res = next(iter(output.outputs)).text.strip()
if risky_token.lower() == res.lower():
label = risky_token
elif safe_token.lower() == res.lower():
label = safe_token
else:
label = "Failed"
return label, prob_of_risk.item()
def softmax(values):
exp_values = [math.exp(v) for v in values]
total = sum(exp_values)
return [v / total for v in exp_values]
def get_probablities(logprobs):
safe_token_prob = 1e-50
unsafe_token_prob = 1e-50
for gen_token_i in logprobs:
for token_prob in gen_token_i.values():
decoded_token = token_prob.decoded_token
if decoded_token.strip().lower() == safe_token.lower():
safe_token_prob += math.exp(token_prob.logprob)
if decoded_token.strip().lower() == risky_token.lower():
unsafe_token_prob += math.exp(token_prob.logprob)
probabilities = torch.softmax(torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0)
return probabilities
def get_probablities_watsonx(top_tokens_list):
safe_token_prob = 1e-50
risky_token_prob = 1e-50
for top_tokens in top_tokens_list:
for token in top_tokens:
if token["text"].strip().lower() == safe_token.lower():
safe_token_prob += math.exp(token["logprob"])
if token["text"].strip().lower() == risky_token.lower():
risky_token_prob += math.exp(token["logprob"])
probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)])
return probabilities
def get_prompt(messages, criteria_name):
guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
return tokenizer.apply_chat_template(
messages, guardian_config=guardian_config, tokenize=False, add_generation_prompt=True
)
@spaces.GPU
def generate_tokens(prompt):
result = model.generate(
prompt=[prompt],
params={
"decoding_method": "greedy",
"max_new_tokens": 20,
"temperature": 0,
"return_options": {"token_logprobs": True, "generated_tokens": True, "input_text": True, "top_n_tokens": 5},
},
)
return result[0]["results"][0]["generated_tokens"]
def parse_output_watsonx(generated_tokens_list):
label, prob_of_risk = None, None
if nlogprobs > 0:
top_tokens_list = [generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list]
prob = get_probablities_watsonx(top_tokens_list)
prob_of_risk = prob[1]
res = next(iter(generated_tokens_list))["text"].strip()
if risky_token.lower() == res.lower():
label = risky_token
elif safe_token.lower() == res.lower():
label = safe_token
else:
label = "Failed"
return label, prob_of_risk
@spaces.GPU
def generate_text(messages, criteria_name):
logger.debug(f"Messages used to create the prompt are: \n{messages}")
start = time()
chat = get_prompt(messages, criteria_name)
logger.debug(f"Prompt is \n{chat}")
if inference_engine == "MOCK":
logger.debug("Returning mocked model result.")
sleep(1)
label, prob_of_risk = "Yes", 0.97
elif inference_engine == "WATSONX":
generated_tokens = generate_tokens(chat)
label, prob_of_risk = parse_output_watsonx(generated_tokens)
elif inference_engine == "VLLM":
with torch.no_grad():
# output = model.generate(chat, sampling_params, use_tqdm=False)
output = model.generate(chat, use_tqdm=False)
label, prob_of_risk = parse_output(output[0])
else:
raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")
logger.debug(f"Model generated label: \n{label}")
logger.debug(f"Model prob_of_risk: \n{prob_of_risk}")
end = time()
total = end - start
logger.debug(f"The evaluation took {total} secs")
return {"assessment": label, "certainty": prob_of_risk}