File size: 1,822 Bytes
d46878a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import logging.handlers
import torch
from torch.nn.functional import softmax
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
import jinja2
import os
from time import time
from logger import logger

use_conda = os.getenv('USE_CONDA', "false") == "true"
device = "cuda"
model_path = os.getenv('MODEL_PATH')#"granite-guardian-3b-pipecleaner-r241024a"
logger.info(f'Model path is "{model_path}"')
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map=device if use_conda else None
)


def generate_text(prompt):
    logger.debug('Starting evaluation...')
    logger.debug(f'Prompts content is: \n{prompt["content"]}')
    start = time()
    tokenized_chat = tokenizer.apply_chat_template(
        [prompt], 
        tokenize=True, 
        add_generation_prompt=True, 
        return_tensors="pt")#.to(device)
    if use_conda:
        tokenized_chat.to(device)
    with torch.no_grad():
        logits = model(tokenized_chat).logits
        gen_outputs = model.generate(tokenized_chat, max_new_tokens=128) 

    generated_text = tokenizer.decode(gen_outputs[0])
    logger.debug(f'Model generated text: \n{generated_text}')
    vocab = tokenizer.get_vocab()
    selected_logits = logits[0, -1, [vocab['No'], vocab['Yes']]]
    probabilities = softmax(selected_logits, dim=0)

    prob = probabilities[1].item()
    logger.debug(f'Certainty is: {prob} from probabilities {probabilities}')
    certainty = prob
    assessment = 'Yes' if certainty > 0.5 else 'No'
    certainty = 1 - certainty if certainty < 0.5 else certainty
    certainty = f'{round(certainty,3)}'

    end = time()
    total = end - start
    logger.debug(f'it took {round(total/60, 2)} mins')

    return {'assessment': assessment, 'certainty': certainty}