CreatorPhan commited on
Commit
d26e120
1 Parent(s): 16f230f

Update agent_t5.py

Browse files
Files changed (1) hide show
  1. agent_t5.py +60 -16
agent_t5.py CHANGED
@@ -5,21 +5,26 @@ from langchain.vectorstores import Chroma
5
  from langchain.chains import RetrievalQA
6
  from langchain.chat_models import ChatOpenAI
7
  from langchain.schema import AIMessage, HumanMessage, SystemMessage, Document
 
8
 
9
  from transformers import AutoTokenizer, T5ForConditionalGeneration
10
  from retrieval.retrieval import Retrieval, BM25
11
- import os, time
 
12
 
13
 
14
 
15
  class Agent:
16
  def __init__(self, args=None) -> None:
17
  self.args = args
18
- self.corpus = Retrieval()
19
  self.choices = args.choices
 
20
 
21
  self.context_value = ""
22
  self.use_context = False
 
 
 
23
 
24
  print("Model is loading...")
25
  self.model = T5ForConditionalGeneration.from_pretrained(args.model).to(args.device)
@@ -28,9 +33,12 @@ class Agent:
28
 
29
 
30
  def load_context(self, doc_path):
31
- loader = UnstructuredFileLoader(doc_path.name)
32
  print('Loading file:', doc_path.name)
33
- context = loader.load()[0].page_content
 
 
 
 
34
 
35
  self.retrieval = Retrieval(docs=context)
36
  self.choices = self.retrieval.k
@@ -42,10 +50,10 @@ class Agent:
42
  def asking(self, question):
43
  s_query = time.time()
44
  if self.use_context:
45
- print("Answering with your context")
46
  contexts = self.retrieval.get_context(question)
47
  else:
48
- print("Answering without your context")
49
  contexts = self.corpus.get_context(question)
50
 
51
  prompts = []
@@ -60,28 +68,48 @@ class Agent:
60
  outputs = self.model.generate(
61
  input_ids=tokens.input_ids.to(self.args.device),
62
  attention_mask=tokens.attention_mask.to(self.args.device),
63
- max_new_tokens=self.args.out_len
 
 
64
  )
65
 
66
- s_de = time.time()
67
- answers = []
68
- for output in outputs:
69
- sequence = self.tokenizer.decode(output, skip_special_tokens=True)
70
- answers.append(sequence)
71
 
72
- self.temp = [contexts, answers]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  t_mess = "t_query: {:.2f}\t t_token: {:.2f}\t t_gen: {:.2f}\t t_decode: {:.2f}\t".format(
74
  s_token-s_query, s_gen-s_token, s_de-s_gen, time.time()-s_de
75
  )
76
- print(t_mess)
77
- return answers
 
 
 
78
 
79
 
80
 
81
  def get_context(self, context):
82
  self.context_value = context
83
 
84
- self.retrieval = Retrieval(docs=context)
85
  self.choices = self.retrieval.k
86
  self.use_context = True
87
  return context
@@ -100,3 +128,19 @@ class Agent:
100
  self.use_context = False
101
  self.choices = self.args.choices
102
  return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from langchain.chains import RetrievalQA
6
  from langchain.chat_models import ChatOpenAI
7
  from langchain.schema import AIMessage, HumanMessage, SystemMessage, Document
8
+ from langchain.document_loaders import PyPDFLoader
9
 
10
  from transformers import AutoTokenizer, T5ForConditionalGeneration
11
  from retrieval.retrieval import Retrieval, BM25
12
+ import os, time, torch
13
+ from torch.nn import Softmax
14
 
15
 
16
 
17
  class Agent:
18
  def __init__(self, args=None) -> None:
19
  self.args = args
 
20
  self.choices = args.choices
21
+ self.corpus = Retrieval(k=args.choices)
22
 
23
  self.context_value = ""
24
  self.use_context = False
25
+ self.softmax = Softmax(dim=1)
26
+ self.temp = []
27
+ self.replace_list = torch.load('retrieval/replace.pt')
28
 
29
  print("Model is loading...")
30
  self.model = T5ForConditionalGeneration.from_pretrained(args.model).to(args.device)
 
33
 
34
 
35
  def load_context(self, doc_path):
 
36
  print('Loading file:', doc_path.name)
37
+ if doc_path.name[-4:] == '.pdf':
38
+ context = self.read_pdf(doc_path.name)
39
+ else:
40
+ # loader = UnstructuredFileLoader(doc_path.name)
41
+ context = open(doc_path.name, encoding='utf-8').read()
42
 
43
  self.retrieval = Retrieval(docs=context)
44
  self.choices = self.retrieval.k
 
50
  def asking(self, question):
51
  s_query = time.time()
52
  if self.use_context:
53
+ print("Answering with your context:", question)
54
  contexts = self.retrieval.get_context(question)
55
  else:
56
+ print("Answering without your context:", question)
57
  contexts = self.corpus.get_context(question)
58
 
59
  prompts = []
 
68
  outputs = self.model.generate(
69
  input_ids=tokens.input_ids.to(self.args.device),
70
  attention_mask=tokens.attention_mask.to(self.args.device),
71
+ max_new_tokens=self.args.out_len,
72
+ output_scores=True,
73
+ return_dict_in_generate=True
74
  )
75
 
 
 
 
 
 
76
 
77
+ s_de = time.time()
78
+ results = []
79
+
80
+ scores = self.softmax(outputs.scores[0])
81
+ scores = scores.max(dim=1).values*100
82
+ # print(scores)
83
+ for i in range(self.choices):
84
+ result = contexts[i]
85
+ score = round(scores[i].item())
86
+ result['score'] = score
87
+
88
+ answer = self.tokenizer.decode(outputs.sequences[i], skip_special_tokens=True)
89
+ result['answer'] = answer
90
+ results.append(result)
91
+
92
+ def get_score(record):
93
+ return record['score']**2 * record['score_bm']
94
+
95
+ results.sort(key=get_score, reverse=True)
96
+
97
+ self.temp = results
98
  t_mess = "t_query: {:.2f}\t t_token: {:.2f}\t t_gen: {:.2f}\t t_decode: {:.2f}\t".format(
99
  s_token-s_query, s_gen-s_token, s_de-s_gen, time.time()-s_de
100
  )
101
+ print(t_mess, len(self.temp))
102
+ if results[0]['score'] > 50:
103
+ return results[0]['answer']
104
+ else:
105
+ return f"Tôi không chắc nhưng câu trả lời có thể là: {results[0]['answer']}\nBạn có thể tham khảo các câu trả lời bên cạnh!"
106
 
107
 
108
 
109
  def get_context(self, context):
110
  self.context_value = context
111
 
112
+ self.retrieval = Retrieval(k=self.choices, docs=context)
113
  self.choices = self.retrieval.k
114
  self.use_context = True
115
  return context
 
128
  self.use_context = False
129
  self.choices = self.args.choices
130
  return ""
131
+
132
+ def replace(self, text):
133
+ for key, value in self.replace_list:
134
+ text = text.replace(key, value)
135
+ return text
136
+
137
+ def read_pdf(self, file_path):
138
+ loader = PyPDFLoader(file_path)
139
+ pages = loader.load_and_split()
140
+ text = ''
141
+ for page in pages:
142
+ page_content = page.page_content
143
+ text += self.replace(page_content)
144
+
145
+ return text
146
+