from langchain.document_loaders.unstructured import UnstructuredFileLoader from langchain.text_splitter import CharacterTextSplitter from langchain.embeddings import OpenAIEmbeddings from langchain.vectorstores import Chroma from langchain.chains import RetrievalQA from langchain.chat_models import ChatOpenAI from langchain.schema import AIMessage, HumanMessage, SystemMessage, Document from langchain.document_loaders import PyPDFLoader from transformers import AutoTokenizer, T5ForConditionalGeneration from retrieval.retrieval import Retrieval, BM25 import os, time, torch from torch.nn import Softmax class Agent: def __init__(self, args=None) -> None: self.args = args self.choices = args.choices self.corpus = Retrieval(k=args.choices) self.context_value = "" self.use_context = False self.softmax = Softmax(dim=1) self.temp = [] self.replace_list = torch.load('retrieval/replace.pt') print("Model is loading...") self.model = T5ForConditionalGeneration.from_pretrained(args.model).to(args.device) self.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) print("Model loaded!") def load_context(self, doc_path): print('Loading file:', doc_path.name) if doc_path.name[-4:] == '.pdf': context = self.read_pdf(doc_path.name) else: # loader = UnstructuredFileLoader(doc_path.name) context = open(doc_path.name, encoding='utf-8').read() self.retrieval = Retrieval(docs=context) self.choices = self.retrieval.k self.use_context = True return f"Using file from {doc_path.name}" def asking(self, question): s_query = time.time() if self.use_context: print("Answering with your context:", question) contexts = self.retrieval.get_context(question) else: print("Answering without your context:", question) contexts = self.corpus.get_context(question) prompts = [] for context in contexts: prompt = f"Trả lời câu hỏi: {question} Trong nội dung: {context['context']}" prompts.append(prompt) s_token = time.time() tokens = self.tokenizer(prompts, max_length=self.args.seq_len, truncation=True, padding='max_length', return_tensors='pt') s_gen = time.time() outputs = self.model.generate( input_ids=tokens.input_ids.to(self.args.device), attention_mask=tokens.attention_mask.to(self.args.device), max_new_tokens=self.args.out_len, output_scores=True, return_dict_in_generate=True ) s_de = time.time() results = [] scores = self.softmax(outputs.scores[0]) scores = scores.max(dim=1).values*100 # print(scores) for i in range(self.choices): result = contexts[i] score = round(scores[i].item()) result['score'] = score answer = self.tokenizer.decode(outputs.sequences[i], skip_special_tokens=True) result['answer'] = answer results.append(result) def get_score(record): return record['score']**2 * record['score_bm'] results.sort(key=get_score, reverse=True) self.temp = results t_mess = "t_query: {:.2f}\t t_token: {:.2f}\t t_gen: {:.2f}\t t_decode: {:.2f}\t".format( s_token-s_query, s_gen-s_token, s_de-s_gen, time.time()-s_de ) print(t_mess, len(self.temp)) if results[0]['score'] > 60: return results[0]['answer'] else: return f"Tôi không chắc nhưng câu trả lời có thể là: {results[0]['answer']}\nBạn có thể tham khảo các câu trả lời bên cạnh!" def get_context(self, context): self.context_value = context self.retrieval = Retrieval(k=self.choices, docs=context) self.choices = self.retrieval.k self.use_context = True return context def load_context_file(self, file): print('Loading file:', file.name) text = '' for line in open(file.name, 'r', encoding='utf8'): text += line self.context_value = text return text def clear_context(self): self.context_value = "" self.use_context = False self.choices = self.args.choices return "" def replace(self, text): for key, value in self.replace_list: text = text.replace(key, value) return text def read_pdf(self, file_path): loader = PyPDFLoader(file_path) pages = loader.load_and_split() text = '' for page in pages: page_content = page.page_content text += self.replace(page_content) return text