import os from time import time, sleep from logger import logger import math import os from ibm_watsonx_ai.client import APIClient from ibm_watsonx_ai.foundation_models import ModelInference from transformers import AutoTokenizer import math import spaces safe_token = "No" risky_token = "Yes" nlogprobs = 5 inference_engine = os.getenv('INFERENCE_ENGINE', 'VLLM') logger.debug(f"Inference engine is: '{inference_engine}'") if inference_engine == 'VLLM': import torch from vllm import LLM, SamplingParams from transformers import AutoTokenizer model_path = os.getenv('MODEL_PATH', 'ibm-granite/granite-guardian-3.0-8b') logger.debug(f"model_path is {model_path}") tokenizer = AutoTokenizer.from_pretrained(model_path) sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs) model = LLM(model=model_path, tensor_parallel_size=1) elif inference_engine == "WATSONX": client = APIClient(credentials={ 'api_key': os.getenv('WATSONX_API_KEY'), 'url': 'https://us-south.ml.cloud.ibm.com'}) client.set.default_project(os.getenv('WATSONX_PROJECT_ID')) hf_model_path = "ibm-granite/granite-guardian-3.0-8b" tokenizer = AutoTokenizer.from_pretrained(hf_model_path) model_id = "ibm/granite-guardian-3-8b" # 8B Model: "ibm/granite-guardian-3-8b" model = ModelInference( model_id=model_id, api_client=client ) def parse_output(output): label, prob = None, None if nlogprobs > 0: logprobs = next(iter(output.outputs)).logprobs if logprobs is not None: prob = get_probablities(logprobs) prob_of_risk = prob[1] res = next(iter(output.outputs)).text.strip() if risky_token.lower() == res.lower(): label = risky_token elif safe_token.lower() == res.lower(): label = safe_token else: label = "Failed" return label, prob_of_risk.item() def softmax(values): exp_values = [math.exp(v) for v in values] total = sum(exp_values) return [v / total for v in exp_values] def get_probablities(logprobs): safe_token_prob = 1e-50 unsafe_token_prob = 1e-50 for gen_token_i in logprobs: for token_prob in gen_token_i.values(): decoded_token = token_prob.decoded_token if decoded_token.strip().lower() == safe_token.lower(): safe_token_prob += math.exp(token_prob.logprob) if decoded_token.strip().lower() == risky_token.lower(): unsafe_token_prob += math.exp(token_prob.logprob) probabilities = torch.softmax( torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0 ) return probabilities def get_probablities_watsonx(top_tokens_list): safe_token_prob = 1e-50 risky_token_prob = 1e-50 for top_tokens in top_tokens_list: for token in top_tokens: if token['text'].strip().lower() == safe_token.lower(): safe_token_prob += math.exp(token['logprob']) if token['text'].strip().lower() == risky_token.lower(): risky_token_prob += math.exp(token['logprob']) probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)]) return probabilities def get_prompt(messages, criteria_name): guardian_config = {"risk_name": criteria_name if criteria_name != 'general_harm' else 'harm'} return tokenizer.apply_chat_template( messages, guardian_config=guardian_config, tokenize=False, add_generation_prompt=True) def generate_tokens(prompt): result = model.generate( prompt=[prompt], params={ 'decoding_method':'greedy', 'max_new_tokens': 20, "temperature": 0, "return_options": { "token_logprobs": True, "generated_tokens": True, "input_text": True, "top_n_tokens": 5 } }) return result[0]['results'][0]['generated_tokens'] def parse_output_watsonx(generated_tokens_list): label, prob_of_risk = None, None if nlogprobs > 0: top_tokens_list = [generated_tokens['top_tokens'] for generated_tokens in generated_tokens_list] prob = get_probablities_watsonx(top_tokens_list) prob_of_risk = prob[1] res = next(iter(generated_tokens_list))['text'].strip() if risky_token.lower() == res.lower(): label = risky_token elif safe_token.lower() == res.lower(): label = safe_token else: label = "Failed" return label, prob_of_risk @spaces.GPU def generate_text(messages, criteria_name): logger.debug(f'Messages used to create the prompt are: \n{messages}') start = time() chat = get_prompt(messages, criteria_name) logger.debug(f'Prompt is \n{chat}') if inference_engine=="MOCK": logger.debug('Returning mocked model result.') sleep(1) label, prob_of_risk = 'Yes', 0.97 elif inference_engine=="WATSONX": generated_tokens = generate_tokens(chat) label, prob_of_risk = parse_output_watsonx(generated_tokens) elif inference_engine=="VLLM": with torch.no_grad(): output = model.generate(chat, sampling_params, use_tqdm=False) label, prob_of_risk = parse_output(output[0]) else: raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]") logger.debug(f'Model generated label: \n{label}') logger.debug(f'Model prob_of_risk: \n{prob_of_risk}') end = time() total = end - start logger.debug(f'The evaluation took {total} secs') return {'assessment': label, 'certainty': prob_of_risk}