Martín Santillán Cooper
General updates
aeabe15
raw
history blame
2.74 kB
import os
from time import time, sleep
from logger import logger
import math
safe_token = "No"
unsafe_token = "Yes"
nlogprobs = 5
mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
if not mock_model_call:
import torch
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
model_path = os.getenv('MODEL_PATH')#"granite-guardian-3b-pipecleaner-r241024a"
sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
model = LLM(model=model_path, tensor_parallel_size=1)
tokenizer = AutoTokenizer.from_pretrained(model_path)
def parse_output(output):
label, prob = None, None
if nlogprobs > 0:
logprobs = next(iter(output.outputs)).logprobs
if logprobs is not None:
prob = get_probablities(logprobs)
prob_of_risk = prob[1]
res = next(iter(output.outputs)).text.strip()
if unsafe_token.lower() == res.lower():
label = unsafe_token
elif safe_token.lower() == res.lower():
label = safe_token
else:
label = "Failed"
return label, prob_of_risk.item()
def get_probablities(logprobs):
safe_token_prob = 1e-50
unsafe_token_prob = 1e-50
for gen_token_i in logprobs:
for token_prob in gen_token_i.values():
decoded_token = token_prob.decoded_token
if decoded_token.strip().lower() == safe_token.lower():
safe_token_prob += math.exp(token_prob.logprob)
if decoded_token.strip().lower() == unsafe_token.lower():
unsafe_token_prob += math.exp(token_prob.logprob)
probabilities = torch.softmax(
torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0
)
return probabilities
def generate_text(prompt):
logger.debug(f'Prompts content is: \n{prompt["content"]}')
mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
if mock_model_call:
logger.debug('Returning mocked model result.')
sleep(2)
return {'assessment': 'Yes', 'certainty': 0.97}
else:
start = time()
tokenized_chat = tokenizer.apply_chat_template([prompt], tokenize=False, add_generation_prompt=True)
with torch.no_grad():
output = model.generate(tokenized_chat, sampling_params, use_tqdm=False)
# predicted_label = output[0].outputs[0].text.strip()
label, prob_of_risk = parse_output(output[0])
logger.debug(f'Model generated label: \n{label}')
logger.debug(f'Model prob_of_risk: \n{prob_of_risk}')
end = time()
total = end - start
logger.debug(f'The evaluation took {total} secs')
return {'assessment': label, 'certainty': prob_of_risk}