Spaces:
Runtime error
Runtime error
File size: 4,956 Bytes
e011405 d26e120 e011405 d26e120 e011405 d26e120 e011405 d26e120 e011405 d26e120 e011405 d26e120 e011405 d26e120 e011405 d26e120 e011405 d26e120 e011405 d26e120 e011405 d26e120 e011405 d26e120 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.schema import AIMessage, HumanMessage, SystemMessage, Document
from langchain.document_loaders import PyPDFLoader
from transformers import AutoTokenizer, T5ForConditionalGeneration
from retrieval.retrieval import Retrieval, BM25
import os, time, torch
from torch.nn import Softmax
class Agent:
def __init__(self, args=None) -> None:
self.args = args
self.choices = args.choices
self.corpus = Retrieval(k=args.choices)
self.context_value = ""
self.use_context = False
self.softmax = Softmax(dim=1)
self.temp = []
self.replace_list = torch.load('retrieval/replace.pt')
print("Model is loading...")
self.model = T5ForConditionalGeneration.from_pretrained(args.model).to(args.device)
self.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
print("Model loaded!")
def load_context(self, doc_path):
print('Loading file:', doc_path.name)
if doc_path.name[-4:] == '.pdf':
context = self.read_pdf(doc_path.name)
else:
# loader = UnstructuredFileLoader(doc_path.name)
context = open(doc_path.name, encoding='utf-8').read()
self.retrieval = Retrieval(docs=context)
self.choices = self.retrieval.k
self.use_context = True
return f"Using file from {doc_path.name}"
def asking(self, question):
s_query = time.time()
if self.use_context:
print("Answering with your context:", question)
contexts = self.retrieval.get_context(question)
else:
print("Answering without your context:", question)
contexts = self.corpus.get_context(question)
prompts = []
for context in contexts:
prompt = f"Trả lời câu hỏi: {question} Trong nội dung: {context['context']}"
prompts.append(prompt)
s_token = time.time()
tokens = self.tokenizer(prompts, max_length=self.args.seq_len, truncation=True, padding='max_length', return_tensors='pt')
s_gen = time.time()
outputs = self.model.generate(
input_ids=tokens.input_ids.to(self.args.device),
attention_mask=tokens.attention_mask.to(self.args.device),
max_new_tokens=self.args.out_len,
output_scores=True,
return_dict_in_generate=True
)
s_de = time.time()
results = []
scores = self.softmax(outputs.scores[0])
scores = scores.max(dim=1).values*100
# print(scores)
for i in range(self.choices):
result = contexts[i]
score = round(scores[i].item())
result['score'] = score
answer = self.tokenizer.decode(outputs.sequences[i], skip_special_tokens=True)
result['answer'] = answer
results.append(result)
def get_score(record):
return record['score']**2 * record['score_bm']
results.sort(key=get_score, reverse=True)
self.temp = results
t_mess = "t_query: {:.2f}\t t_token: {:.2f}\t t_gen: {:.2f}\t t_decode: {:.2f}\t".format(
s_token-s_query, s_gen-s_token, s_de-s_gen, time.time()-s_de
)
print(t_mess, len(self.temp))
if results[0]['score'] > 50:
return results[0]['answer']
else:
return f"Tôi không chắc nhưng câu trả lời có thể là: {results[0]['answer']}\nBạn có thể tham khảo các câu trả lời bên cạnh!"
def get_context(self, context):
self.context_value = context
self.retrieval = Retrieval(k=self.choices, docs=context)
self.choices = self.retrieval.k
self.use_context = True
return context
def load_context_file(self, file):
print('Loading file:', file.name)
text = ''
for line in open(file.name, 'r', encoding='utf8'):
text += line
self.context_value = text
return text
def clear_context(self):
self.context_value = ""
self.use_context = False
self.choices = self.args.choices
return ""
def replace(self, text):
for key, value in self.replace_list:
text = text.replace(key, value)
return text
def read_pdf(self, file_path):
loader = PyPDFLoader(file_path)
pages = loader.load_and_split()
text = ''
for page in pages:
page_content = page.page_content
text += self.replace(page_content)
return text
|