import os from time import time, sleep from logger import logger import math safe_token = "No" unsafe_token = "Yes" nlogprobs = 5 mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true' if not mock_model_call: import torch from vllm import LLM, SamplingParams from transformers import AutoTokenizer model_path = os.getenv('MODEL_PATH') #"granite-guardian-3b-pipecleaner-r241024a" tokenizer = AutoTokenizer.from_pretrained(model_path) sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs) model = LLM(model=model_path, tensor_parallel_size=1) def parse_output(output): label, prob = None, None if nlogprobs > 0: logprobs = next(iter(output.outputs)).logprobs if logprobs is not None: prob = get_probablities(logprobs) prob_of_risk = prob[1] res = next(iter(output.outputs)).text.strip() if unsafe_token.lower() == res.lower(): label = unsafe_token elif safe_token.lower() == res.lower(): label = safe_token else: label = "Failed" return label, prob_of_risk.item() def get_probablities(logprobs): safe_token_prob = 1e-50 unsafe_token_prob = 1e-50 for gen_token_i in logprobs: for token_prob in gen_token_i.values(): decoded_token = token_prob.decoded_token if decoded_token.strip().lower() == safe_token.lower(): safe_token_prob += math.exp(token_prob.logprob) if decoded_token.strip().lower() == unsafe_token.lower(): unsafe_token_prob += math.exp(token_prob.logprob) probabilities = torch.softmax( torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0 ) return probabilities def get_prompt(messages, criteria_name): guardian_config = {"risk_name": criteria_name if criteria_name != 'general_harm' else 'harm'} return tokenizer.apply_chat_template( messages, guardian_config=guardian_config, tokenize=False, add_generation_prompt=True) def generate_text(messages, criteria_name): logger.debug(f'Messages are: \n{messages}') mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true' if mock_model_call: logger.debug('Returning mocked model result.') sleep(1) return {'assessment': 'Yes', 'certainty': 0.97} start = time() chat = get_prompt(messages, criteria_name) logger.debug(f'Prompt is \n{chat}') with torch.no_grad(): output = model.generate(chat, sampling_params, use_tqdm=False) # predicted_label = output[0].outputs[0].text.strip() label, prob_of_risk = parse_output(output[0]) logger.debug(f'Model generated label: \n{label}') logger.debug(f'Model prob_of_risk: \n{prob_of_risk}') end = time() total = end - start logger.debug(f'The evaluation took {total} secs') return {'assessment': label, 'certainty': prob_of_risk}