File size: 2,962 Bytes
d46878a
5269ad1
d46878a
5269ad1
182a21a
5269ad1
 
 
 
2cb730a
 
 
 
 
2cecaad
2e81d77
2cb730a
 
 
5269ad1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e81d77
 
 
 
 
 
 
 
 
 
34f1382
2e81d77
912f740
 
 
b022d45
912f740
20a9c66
2e81d77
 
34f1382
 
20a9c66
2e81d77
5269ad1
20a9c66
5269ad1
20a9c66
d46878a
20a9c66
 
 
 
 
 
d46878a
20a9c66
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
from time import time, sleep
from logger import logger
import math

safe_token = "No"
unsafe_token = "Yes"
nlogprobs = 5

mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
if not mock_model_call:
    import torch
    from vllm import LLM, SamplingParams
    from transformers import AutoTokenizer
    model_path = os.getenv('MODEL_PATH') #"granite-guardian-3b-pipecleaner-r241024a"
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
    model = LLM(model=model_path, tensor_parallel_size=1)

def parse_output(output):
    label, prob = None, None

    if nlogprobs > 0:
        logprobs = next(iter(output.outputs)).logprobs
        if logprobs is not None:
            prob = get_probablities(logprobs)
            prob_of_risk = prob[1]

    res = next(iter(output.outputs)).text.strip()
    if unsafe_token.lower() == res.lower():
        label = unsafe_token
    elif safe_token.lower() == res.lower():
        label = safe_token
    else:
        label = "Failed"

    return label, prob_of_risk.item()

def get_probablities(logprobs):
    safe_token_prob = 1e-50
    unsafe_token_prob = 1e-50
    for gen_token_i in logprobs:
        for token_prob in gen_token_i.values():
            decoded_token = token_prob.decoded_token
            if decoded_token.strip().lower() == safe_token.lower():
                safe_token_prob += math.exp(token_prob.logprob)
            if decoded_token.strip().lower() == unsafe_token.lower():
                unsafe_token_prob += math.exp(token_prob.logprob)

    probabilities = torch.softmax(
        torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0
    )

    return probabilities

def get_prompt(messages, criteria_name):
    guardian_config = {"risk_name": criteria_name if criteria_name != 'general_harm' else 'harm'}
    return tokenizer.apply_chat_template(
        messages,
        guardian_config=guardian_config,
        tokenize=False,
        add_generation_prompt=True)


def generate_text(messages, criteria_name):
    logger.debug(f'Messages are: \n{messages}')
    
    mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
    if mock_model_call:
        logger.debug('Returning mocked model result.')
        sleep(1)
        return {'assessment': 'Yes', 'certainty': 0.97}
    
    start = time()
    chat = get_prompt(messages, criteria_name)
    logger.debug(f'Prompt is \n{chat}')
    
    with torch.no_grad():
        output = model.generate(chat, sampling_params, use_tqdm=False)

    # predicted_label = output[0].outputs[0].text.strip()

    label, prob_of_risk = parse_output(output[0])

    logger.debug(f'Model generated label: \n{label}')
    logger.debug(f'Model prob_of_risk: \n{prob_of_risk}')
    
    end = time()
    total = end - start
    logger.debug(f'The evaluation took {total} secs')

    return {'assessment': label, 'certainty': prob_of_risk}