Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import os | |
from time import sleep, time | |
import spaces | |
import torch | |
from ibm_watsonx_ai.client import APIClient | |
from ibm_watsonx_ai.foundation_models import ModelInference | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from logger import logger | |
# from vllm import LLM, SamplingParams | |
safe_token = "No" | |
risky_token = "Yes" | |
nlogprobs = 20 | |
inference_engine = os.getenv("INFERENCE_ENGINE", "VLLM") | |
logger.debug(f"Inference engine is: '{inference_engine}'") | |
if inference_engine == "VLLM": | |
device = torch.device("cuda") | |
model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-8b") | |
logger.debug(f"model_path is {model_path}") | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
# sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs) | |
# model = LLM(model=model_path, tensor_parallel_size=1) | |
model = AutoModelForCausalLM.from_pretrained(model_path) | |
model = model.to(device).eval() | |
elif inference_engine == "WATSONX": | |
client = APIClient( | |
credentials={"api_key": os.getenv("WATSONX_API_KEY"), "url": "https://us-south.ml.cloud.ibm.com"} | |
) | |
client.set.default_project(os.getenv("WATSONX_PROJECT_ID")) | |
hf_model_path = "ibm-granite/granite-guardian-3.0-8b" | |
tokenizer = AutoTokenizer.from_pretrained(hf_model_path) | |
model_id = "ibm/granite-guardian-3-8b" # 8B Model: "ibm/granite-guardian-3-8b" | |
model = ModelInference(model_id=model_id, api_client=client) | |
def parse_output(output, input_len): | |
label, prob_of_risk = None, None | |
if nlogprobs > 0: | |
list_index_logprobs_i = [torch.topk(token_i, k=nlogprobs, largest=True, sorted=True) | |
for token_i in list(output.scores)[:-1]] | |
if list_index_logprobs_i is not None: | |
prob = get_probablities(list_index_logprobs_i) | |
prob_of_risk = prob[1] | |
res = tokenizer.decode(output.sequences[:,input_len:][0],skip_special_tokens=True).strip() | |
if risky_token.lower() == res.lower(): | |
label = risky_token | |
elif safe_token.lower() == res.lower(): | |
label = safe_token | |
else: | |
label = "Failed" | |
return label, prob_of_risk.item() | |
def get_probablities(logprobs): | |
safe_token_prob = 1e-50 | |
unsafe_token_prob = 1e-50 | |
for gen_token_i in logprobs: | |
for logprob, index in zip(gen_token_i.values.tolist()[0], gen_token_i.indices.tolist()[0]): | |
decoded_token = tokenizer.convert_ids_to_tokens(index) | |
if decoded_token.strip().lower() == safe_token.lower(): | |
safe_token_prob += math.exp(logprob) | |
if decoded_token.strip().lower() == risky_token.lower(): | |
unsafe_token_prob += math.exp(logprob) | |
probabilities = torch.softmax( | |
torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0 | |
) | |
return probabilities | |
def softmax(values): | |
exp_values = [math.exp(v) for v in values] | |
total = sum(exp_values) | |
return [v / total for v in exp_values] | |
def get_probablities_watsonx(top_tokens_list): | |
safe_token_prob = 1e-50 | |
risky_token_prob = 1e-50 | |
for top_tokens in top_tokens_list: | |
for token in top_tokens: | |
if token["text"].strip().lower() == safe_token.lower(): | |
safe_token_prob += math.exp(token["logprob"]) | |
if token["text"].strip().lower() == risky_token.lower(): | |
risky_token_prob += math.exp(token["logprob"]) | |
probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)]) | |
return probabilities | |
def get_prompt(messages, criteria_name, return_tensors=None): | |
guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"} | |
return tokenizer.apply_chat_template( | |
messages, guardian_config=guardian_config, tokenize=False, add_generation_prompt=True, return_tensors=return_tensors | |
) | |
def generate_tokens(prompt): | |
result = model.generate( | |
prompt=[prompt], | |
params={ | |
"decoding_method": "greedy", | |
"max_new_tokens": 20, | |
"temperature": 0, | |
"return_options": {"token_logprobs": True, "generated_tokens": True, "input_text": True, "top_n_tokens": 5}, | |
}, | |
) | |
return result[0]["results"][0]["generated_tokens"] | |
def parse_output_watsonx(generated_tokens_list): | |
label, prob_of_risk = None, None | |
if nlogprobs > 0: | |
top_tokens_list = [generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list] | |
prob = get_probablities_watsonx(top_tokens_list) | |
prob_of_risk = prob[1] | |
res = next(iter(generated_tokens_list))["text"].strip() | |
if risky_token.lower() == res.lower(): | |
label = risky_token | |
elif safe_token.lower() == res.lower(): | |
label = safe_token | |
else: | |
label = "Failed" | |
return label, prob_of_risk | |
def generate_text(messages, criteria_name): | |
logger.debug(f"Messages used to create the prompt are: \n{messages}") | |
start = time() | |
if inference_engine == "MOCK": | |
logger.debug("Returning mocked model result.") | |
sleep(1) | |
label, prob_of_risk = "Yes", 0.97 | |
elif inference_engine == "WATSONX": | |
chat = get_prompt(messages, criteria_name) | |
logger.debug(f"Prompt is \n{chat}") | |
generated_tokens = generate_tokens(chat) | |
label, prob_of_risk = parse_output_watsonx(generated_tokens) | |
elif inference_engine == "VLLM": | |
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device) | |
logger.debug(f"input_ids are: {input_ids}") | |
input_len = input_ids.shape[1] | |
logger.debug(f"input_len are: {input_len}") | |
with torch.no_grad(): | |
# output = model.generate(chat, sampling_params, use_tqdm=False) | |
output = model.generate( | |
input_ids, | |
do_sample=False, | |
max_new_tokens=nlogprobs, | |
return_dict_in_generate=True, | |
output_scores=True,) | |
logger.debug(f"model output is are: {output}") | |
label, prob_of_risk = parse_output(output, input_len) | |
logger.debug(f"label is are: {label}") | |
logger.debug(f"prob_of_risk is are: {prob_of_risk}") | |
else: | |
raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]") | |
logger.debug(f"Model generated label: \n{label}") | |
logger.debug(f"Model prob_of_risk: \n{prob_of_risk}") | |
end = time() | |
total = end - start | |
logger.debug(f"The evaluation took {total} secs") | |
return {"assessment": label, "certainty": prob_of_risk} | |