Update app.py
Browse files
app.py
CHANGED
@@ -99,14 +99,13 @@ class Retriever:
|
|
99 |
|
100 |
return retrieved_texts
|
101 |
|
102 |
-
|
103 |
class RAG:
|
104 |
def __init__(self,
|
105 |
file_path,
|
106 |
device,
|
107 |
context_model_name="facebook/dpr-ctx_encoder-multiset-base",
|
108 |
question_model_name="facebook/dpr-question_encoder-multiset-base",
|
109 |
-
generator_name="
|
110 |
|
111 |
# generator_name = "valhalla/bart-large-finetuned-squadv1"
|
112 |
# generator_name = "'vblagoje/bart_lfqa'"
|
@@ -122,22 +121,24 @@ class RAG:
|
|
122 |
|
123 |
|
124 |
def abstractive_query(self, question):
|
125 |
-
self.generator_tokenizer = BartTokenizer.from_pretrained(
|
126 |
-
self.generator_model = BartForConditionalGeneration.from_pretrained(
|
127 |
context = self.retriever.retrieve_top_k(question, k=5)
|
|
|
128 |
|
129 |
input_text = "answer: " + " ".join(context) + " " + question
|
130 |
|
131 |
-
inputs = self.generator_tokenizer.encode(input_text, return_tensors='pt', max_length=
|
132 |
-
outputs = self.generator_model.generate(inputs, max_length=
|
133 |
|
134 |
answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
135 |
return answer
|
136 |
|
137 |
def extractive_query(self, question):
|
138 |
-
context = self.retriever.retrieve_top_k(question, k=
|
|
|
139 |
|
140 |
-
inputs = self.generator_tokenizer(question, ". ".join(context), return_tensors="pt", truncation=True, max_length=
|
141 |
with torch.no_grad():
|
142 |
model_inputs = inputs.to(device)
|
143 |
outputs = self.generator_model(**model_inputs)
|
|
|
99 |
|
100 |
return retrieved_texts
|
101 |
|
|
|
102 |
class RAG:
|
103 |
def __init__(self,
|
104 |
file_path,
|
105 |
device,
|
106 |
context_model_name="facebook/dpr-ctx_encoder-multiset-base",
|
107 |
question_model_name="facebook/dpr-question_encoder-multiset-base",
|
108 |
+
generator_name="facebook/bart-large"):
|
109 |
|
110 |
# generator_name = "valhalla/bart-large-finetuned-squadv1"
|
111 |
# generator_name = "'vblagoje/bart_lfqa'"
|
|
|
121 |
|
122 |
|
123 |
def abstractive_query(self, question):
|
124 |
+
self.generator_tokenizer = BartTokenizer.from_pretrained(generator_name)
|
125 |
+
self.generator_model = BartForConditionalGeneration.from_pretrained(generator_name).to(device)
|
126 |
context = self.retriever.retrieve_top_k(question, k=5)
|
127 |
+
# input_text = question + " " + " ".join(context)
|
128 |
|
129 |
input_text = "answer: " + " ".join(context) + " " + question
|
130 |
|
131 |
+
inputs = self.generator_tokenizer.encode(input_text, return_tensors='pt', max_length=500, truncation=True).to(device)
|
132 |
+
outputs = self.generator_model.generate(inputs, max_length=150, min_length=2, length_penalty=2.0, num_beams=4, early_stopping=True)
|
133 |
|
134 |
answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
135 |
return answer
|
136 |
|
137 |
def extractive_query(self, question):
|
138 |
+
context = self.retriever.retrieve_top_k(question, k=15)
|
139 |
+
|
140 |
|
141 |
+
inputs = self.generator_tokenizer(question, ". ".join(context), return_tensors="pt", truncation=True, max_length=300, padding="max_length")
|
142 |
with torch.no_grad():
|
143 |
model_inputs = inputs.to(device)
|
144 |
outputs = self.generator_model(**model_inputs)
|