File size: 7,482 Bytes
5269ad1
f97dae7
5b7f169
 
 
f97dae7
 
026d799
5b7f169
 
182a21a
5269ad1
f97dae7
2b6005c
5269ad1
5b7f169
f97dae7
 
5b7f169
1eece35
5b7f169
6047b61
 
 
33193a0
2e81d77
2b6005c
 
2cb730a
f97dae7
5b7f169
 
 
 
 
1eece35
f97dae7
 
0caab14
5b7f169
 
6047b61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b6005c
6047b61
5269ad1
6047b61
 
 
2b6005c
6047b61
5269ad1
f97dae7
 
5269ad1
 
 
 
 
6047b61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5269ad1
 
 
 
 
2b6005c
 
5269ad1
2b6005c
f97dae7
2b6005c
5269ad1
6047b61
5269ad1
 
 
5b7f169
6047b61
 
 
2b6005c
6047b61
 
 
 
 
 
f97dae7
6047b61
 
 
 
 
 
 
f97dae7
6047b61
f97dae7
5b7f169
6047b61
4bd44f6
6047b61
 
 
1eece35
 
 
 
6047b61
 
 
 
5b7f169
6047b61
4bd44f6
 
 
 
 
6047b61
5b7f169
6047b61
4bd44f6
5b7f169
2e81d77
c786139
6047b61
5b7f169
2e81d77
5b7f169
 
f97dae7
5b7f169
 
 
2b6005c
fb6a6b8
6047b61
f97dae7
5269ad1
5b7f169
6047b61
 
 
 
0caab14
6047b61
4bd44f6
fb6a6b8
2b6005c
6047b61
2b6005c
f97dae7
bce909e
2b6005c
42e2cdc
2b6005c
 
 
6047b61
 
51c0b7a
2b6005c
 
fb6a6b8
 
f97dae7
 
d46878a
5b7f169
 
 
20a9c66
 
5b7f169
d46878a
5b7f169
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import math
import os
from time import sleep, time

import spaces
from ibm_watsonx_ai.client import APIClient
from ibm_watsonx_ai.foundation_models import ModelInference
from transformers import AutoModelForCausalLM, AutoTokenizer

from logger import logger

safe_token = "No"
risky_token = "Yes"
nlogprobs = 20

inference_engine = os.getenv("INFERENCE_ENGINE", "VLLM")
logger.debug(f"Inference engine is: '{inference_engine}'")

if inference_engine == "VLLM":
    import torch

    device = torch.device("cpu")

    model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.1-2b")
    logger.debug(f"model_path is {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path)
    model = model.to(device).eval()

elif inference_engine == "WATSONX":
    client = APIClient(
        credentials={"api_key": os.getenv("WATSONX_API_KEY"), "url": "https://us-south.ml.cloud.ibm.com"}
    )

    client.set.default_project(os.getenv("WATSONX_PROJECT_ID"))
    hf_model_path = "ibm-granite/granite-guardian-3.1-8b"
    tokenizer = AutoTokenizer.from_pretrained(hf_model_path)

    model_id = "ibm/granite-guardian-3-8b"  # 8B Model: "ibm/granite-guardian-3-8b"
    model = ModelInference(model_id=model_id, api_client=client)


def get_probablities_watsonx(top_tokens_list):
    safe_token_prob = 1e-50
    risky_token_prob = 1e-50
    for top_tokens in top_tokens_list:
        for token in top_tokens:
            if token["text"].strip().lower() == safe_token.lower():
                safe_token_prob += math.exp(token["logprob"])
            if token["text"].strip().lower() == risky_token.lower():
                risky_token_prob += math.exp(token["logprob"])

    probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)])

    return probabilities


def parse_output_watsonx(generated_tokens_list):
    label, prob_of_risk = None, None

    if nlogprobs > 0:
        top_tokens_list = [generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list]
        prob = get_probablities_watsonx(top_tokens_list)
        prob_of_risk = prob[1]

    res = next(iter(generated_tokens_list))["text"].strip()

    if risky_token.lower() == res.lower():
        label = risky_token
    elif safe_token.lower() == res.lower():
        label = safe_token
    else:
        label = "Failed"

    return label, prob_of_risk


def generate_tokens_watsonx(prompt):
    result = model.generate(
        prompt=[prompt],
        params={
            "decoding_method": "greedy",
            "max_new_tokens": 20,
            "temperature": 0,
            "return_options": {"token_logprobs": True, "generated_tokens": True, "input_text": True, "top_n_tokens": 5},
        },
    )
    return result[0]["results"][0]["generated_tokens"]


def softmax(values):
    exp_values = [math.exp(v) for v in values]
    total = sum(exp_values)
    return [v / total for v in exp_values]


def get_probablities(logprobs):
    safe_token_prob = 1e-50
    unsafe_token_prob = 1e-50
    for gen_token_i in logprobs:
        for logprob, index in zip(gen_token_i.values.tolist()[0], gen_token_i.indices.tolist()[0]):
            decoded_token = tokenizer.convert_ids_to_tokens(index)
            if decoded_token.strip().lower() == safe_token.lower():
                safe_token_prob += math.exp(logprob)
            if decoded_token.strip().lower() == risky_token.lower():
                unsafe_token_prob += math.exp(logprob)

    probabilities = torch.softmax(torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0)

    return probabilities


def parse_output(output, input_len):
    label, prob_of_risk = None, None
    if nlogprobs > 0:

        list_index_logprobs_i = [
            torch.topk(token_i, k=nlogprobs, largest=True, sorted=True) for token_i in list(output.scores)[:-1]
        ]
        if list_index_logprobs_i is not None:
            prob = get_probablities(list_index_logprobs_i)
            prob_of_risk = prob[1]

    res = tokenizer.decode(output.sequences[:, input_len:][0], skip_special_tokens=True).strip()
    if risky_token.lower() == res.lower():
        label = risky_token
    elif safe_token.lower() == res.lower():
        label = safe_token
    else:
        label = "Failed"

    return label, prob_of_risk.item()


@spaces.GPU
def get_prompt(messages, criteria_name, tokenize=False, add_generation_prompt=False, return_tensors=None):
    logger.debug("Creating prompt for the model.")
    logger.debug(f"Messages used to create the prompt are: \n{messages}")
    logger.debug("Criteria name is: " + criteria_name)
    if criteria_name == "general_harm":
        criteria_name = "harm"
    elif criteria_name == "function_calling_hallucination":
        criteria_name = "function_call"
    logger.debug("Criteria name was changed too: " + criteria_name)
    logger.debug(f"Tokenize: {tokenize}")
    logger.debug(f"add_generation_prompt: {add_generation_prompt}")
    logger.debug(f"return_tensors: {return_tensors}")
    guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
    logger.debug(f"guardian_config is: {guardian_config}")
    prompt = tokenizer.apply_chat_template(
        messages,
        guardian_config=guardian_config,
        tokenize=tokenize,
        add_generation_prompt=add_generation_prompt,
        return_tensors=return_tensors,
    )
    logger.debug(f"Prompt (type {type(prompt)}) is: {prompt}")
    return prompt


@spaces.GPU
def get_guardian_response(messages, criteria_name):
    logger.debug(f"Messages used to create the prompt are: \n{messages}")
    start = time()
    if inference_engine == "MOCK":
        logger.debug("Returning mocked model result.")
        sleep(1)
        label, prob_of_risk = "Yes", 0.97

    elif inference_engine == "WATSONX":
        chat = get_prompt(messages, criteria_name)
        logger.debug(f"Prompt is \n{chat}")
        generated_tokens = generate_tokens_watsonx(chat)
        label, prob_of_risk = parse_output_watsonx(generated_tokens)

    elif inference_engine == "VLLM":
        input_ids = get_prompt(
            messages=messages,
            criteria_name=criteria_name,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
        ).to(model.device)
        logger.debug(f"input_ids are: {input_ids}")
        input_len = input_ids.shape[1]
        logger.debug(f"input_len is: {input_len}")

        with torch.no_grad():
            # output = model.generate(chat, sampling_params, use_tqdm=False)
            output = model.generate(
                input_ids,
                do_sample=False,
                max_new_tokens=nlogprobs,
                return_dict_in_generate=True,
                output_scores=True,
            )
            logger.debug(f"model output is:\n{output}")

            label, prob_of_risk = parse_output(output, input_len)
            logger.debug(f"label is are: {label}")
            logger.debug(f"prob_of_risk is are: {prob_of_risk}")
    else:
        raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")

    logger.debug(f"Model generated label: \n{label}")
    logger.debug(f"Model prob_of_risk: \n{prob_of_risk}")

    end = time()
    total = end - start
    logger.debug(f"The evaluation took {total} secs")

    return {"assessment": label, "certainty": prob_of_risk}