File size: 2,742 Bytes
d46878a
5269ad1
d46878a
5269ad1
182a21a
5269ad1
 
 
 
2cb730a
 
 
 
 
 
 
 
 
 
5269ad1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d46878a
 
912f740
 
 
aeabe15
912f740
 
 
5269ad1
 
 
912f740
5269ad1
 
2cb730a
5269ad1
 
d46878a
5269ad1
 
 
912f740
 
2cb730a
d46878a
5269ad1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
from time import time, sleep
from logger import logger
import math

safe_token = "No"
unsafe_token = "Yes"
nlogprobs = 5

mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
if not mock_model_call:
    import torch
    from vllm import LLM, SamplingParams
    from transformers import AutoTokenizer
    model_path = os.getenv('MODEL_PATH')#"granite-guardian-3b-pipecleaner-r241024a"
    sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
    model = LLM(model=model_path, tensor_parallel_size=1)
    tokenizer = AutoTokenizer.from_pretrained(model_path)

def parse_output(output):
    label, prob = None, None

    if nlogprobs > 0:
        logprobs = next(iter(output.outputs)).logprobs
        if logprobs is not None:
            prob = get_probablities(logprobs)
            prob_of_risk = prob[1]

    res = next(iter(output.outputs)).text.strip()
    if unsafe_token.lower() == res.lower():
        label = unsafe_token
    elif safe_token.lower() == res.lower():
        label = safe_token
    else:
        label = "Failed"

    return label, prob_of_risk.item()

def get_probablities(logprobs):
    safe_token_prob = 1e-50
    unsafe_token_prob = 1e-50
    for gen_token_i in logprobs:
        for token_prob in gen_token_i.values():
            decoded_token = token_prob.decoded_token
            if decoded_token.strip().lower() == safe_token.lower():
                safe_token_prob += math.exp(token_prob.logprob)
            if decoded_token.strip().lower() == unsafe_token.lower():
                unsafe_token_prob += math.exp(token_prob.logprob)

    probabilities = torch.softmax(
        torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0
    )

    return probabilities

def generate_text(prompt):
    logger.debug(f'Prompts content is: \n{prompt["content"]}')
    mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
    if mock_model_call:
        logger.debug('Returning mocked model result.')
        sleep(2)
        return {'assessment': 'Yes', 'certainty': 0.97}
    else:
        start = time()
        
        tokenized_chat = tokenizer.apply_chat_template([prompt], tokenize=False, add_generation_prompt=True)

        with torch.no_grad():
            output = model.generate(tokenized_chat, sampling_params, use_tqdm=False)

        # predicted_label = output[0].outputs[0].text.strip()

        label, prob_of_risk = parse_output(output[0])

        logger.debug(f'Model generated label: \n{label}')
        logger.debug(f'Model prob_of_risk: \n{prob_of_risk}')
        
        end = time()
        total = end - start
        logger.debug(f'The evaluation took {total} secs')

        return {'assessment': label, 'certainty': prob_of_risk}