File size: 5,580 Bytes
56523b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from langchain.document_loaders.unstructured import UnstructuredFileLoader 
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.schema import AIMessage, HumanMessage, SystemMessage, Document
from langchain.document_loaders import PyPDFLoader

from transformers import AutoTokenizer, T5ForConditionalGeneration
from retrieval.retrieval import Retrieval, BM25
from datetime import datetime
import os, time, torch
from torch.nn import Softmax
import requests

API_URL = "https://api-inference.huggingface.co/models/CreatorPhan/ViQA-small"
headers = {"Authorization": "Bearer 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"}

class Agent:
    def __init__(self, args=None) -> None:
        self.args = args
        self.choices = args.choices
        self.corpus = Retrieval(k=args.choices)
        
        self.context_value = ""
        self.use_context = False
        self.softmax = Softmax(dim=1)
        self.temp = []
        self.replace_list = torch.load('retrieval/replace.pt')
        
        print("Model is loading...")
        self.model = T5ForConditionalGeneration.from_pretrained(args.model).to(args.device)
        self.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
        print("Model loaded!")


    def load_context(self, doc_path):
        print('Loading file:', doc_path.name)
        if doc_path.name[-4:] == '.pdf':
            context = self.read_pdf(doc_path.name)
        else:
            # loader = UnstructuredFileLoader(doc_path.name)
            context = open(doc_path.name, encoding='utf-8').read()
        
        self.retrieval = Retrieval(docs=context)
        self.choices = self.retrieval.k
        self.use_context = True

        return f"Using file from {doc_path.name}"
    
    def API_call(self, prompt):
        response = requests.post(API_URL, headers=headers, json={"inputs": prompt}).json()
        if isinstance(response, list):
            return response[0]['generated_text']
        else:
            time.sleep(3)
            return self.API_call(prompt)

    def asking(self, question):
        timestamp = datetime.now()
        timestamp = timestamp.strftime("[%Y-%m-%d %H:%M:%S]")
        print(timestamp, end=' ')

        s_query = time.time()
        if self.use_context:
            print("Answering with your context:", question)
            contexts = self.retrieval.get_context(question)
        else:
            print("Answering without your context:", question)
            contexts = self.corpus.get_context(question)

        prompts = []
        for context in contexts:
            prompt = f"Trả lời câu hỏi: {question} Trong nội dung: {context['context']}"
            prompts.append(prompt)

        s_token = time.time()
        tokens = self.tokenizer(prompts, max_length=self.args.seq_len, truncation=True, padding='max_length', return_tensors='pt')
        
        s_gen = time.time()
        outputs = self.model.generate(
            input_ids=tokens.input_ids.to(self.args.device),
            attention_mask=tokens.attention_mask.to(self.args.device),
            max_new_tokens=self.args.out_len,
            output_scores=True,
            return_dict_in_generate=True
        )


        s_de = time.time()
        results = []

        scores = self.softmax(outputs.scores[0])
        scores = scores.max(dim=1).values*100
        # print(scores)
        for i in range(self.choices):
            result = contexts[i]
            score = round(scores[i].item())
            result['score'] = score
            
            answer = self.tokenizer.decode(outputs.sequences[i], skip_special_tokens=True)
            result['answer'] = answer
            results.append(result)

        def get_score(record):
            return record['score']**2 * record['score_bm']

        results.sort(key=get_score, reverse=True)

        self.temp = results
        t_mess = "t_query: {:.2f}\t t_token: {:.2f}\t t_gen: {:.2f}\t t_decode: {:.2f}\t".format(
            s_token-s_query, s_gen-s_token, s_de-s_gen, time.time()-s_de
        )
        print(t_mess, len(self.temp))
        if results[0]['score'] > 60:
            return results[0]['answer']
        else:
            return f"Tôi không chắc nhưng câu trả lời có thể là: {results[0]['answer']}\nBạn có thể tham khảo các câu trả lời bên cạnh!"


    
    def get_context(self, context):
        self.context_value = context

        self.retrieval = Retrieval(k=self.choices, docs=context)
        self.choices = self.retrieval.k
        self.use_context = True
        return context
    
    def load_context_file(self, file):
        print('Loading file:', file.name)
        text = ''
        for line in open(file.name, 'r', encoding='utf8'):
            text += line

        self.context_value = text
        return text
    
    def clear_context(self):
        self.context_value = ""
        self.use_context = False
        self.choices = self.args.choices
        return ""

    def replace(self, text):
        for key, value in self.replace_list:
            text = text.replace(key, value)
        return text

    def read_pdf(self, file_path):
        loader = PyPDFLoader(file_path)
        pages = loader.load_and_split()
        text = ''
        for page in pages:
            page_content = page.page_content
            text += self.replace(page_content)

        return text