Martín Santillán Cooper
use vllm
5269ad1
raw
history blame
2.83 kB
import os
from time import time, sleep
from logger import logger
import math
import torch
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
safe_token = "No"
unsafe_token = "Yes"
nlogprobs = 5
def parse_output(output):
label, prob = None, None
if nlogprobs > 0:
logprobs = next(iter(output.outputs)).logprobs
if logprobs is not None:
prob = get_probablities(logprobs)
prob_of_risk = prob[1]
res = next(iter(output.outputs)).text.strip()
if unsafe_token.lower() == res.lower():
label = unsafe_token
elif safe_token.lower() == res.lower():
label = safe_token
else:
label = "Failed"
return label, prob_of_risk.item()
def get_probablities(logprobs):
safe_token_prob = 1e-50
unsafe_token_prob = 1e-50
for gen_token_i in logprobs:
for token_prob in gen_token_i.values():
decoded_token = token_prob.decoded_token
if decoded_token.strip().lower() == safe_token.lower():
safe_token_prob += math.exp(token_prob.logprob)
if decoded_token.strip().lower() == unsafe_token.lower():
unsafe_token_prob += math.exp(token_prob.logprob)
probabilities = torch.softmax(
torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0
)
return probabilities
mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
if not mock_model_call:
use_conda = os.getenv('USE_CONDA', "false") == "true"
model_path = os.getenv('MODEL_PATH')#"granite-guardian-3b-pipecleaner-r241024a"
sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
model = LLM(model=model_path, tensor_parallel_size=1)
tokenizer = AutoTokenizer.from_pretrained(model_path)
def generate_text(prompt):
logger.debug('Starting evaluation...')
logger.debug(f'Prompts content is: \n{prompt["content"]}')
mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
if mock_model_call:
logger.debug('Returning mocked model result.')
sleep(3)
return {'assessment': 'Yes', 'certainty': 0.97}
else:
start = time()
tokenized_chat = tokenizer.apply_chat_template([prompt], tokenize=False, add_generation_prompt=True)
with torch.no_grad():
output = model.generate(tokenized_chat, sampling_params, use_tqdm=False)
predicted_label = output[0].outputs[0].text.strip()
label, prob_of_risk = parse_output(output[0])
logger.debug(f'Model generated label: \n{label}')
logger.debug(f'Model prob_of_risk: \n{prob_of_risk}')
end = time()
total = end - start
logger.debug(f'it took {round(total/60, 2)} mins')
return {'assessment': label, 'certainty': prob_of_risk}