rtabrizi commited on
Commit
2042d5a
1 Parent(s): cbd01e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -99,14 +99,13 @@ class Retriever:
99
 
100
  return retrieved_texts
101
 
102
-
103
  class RAG:
104
  def __init__(self,
105
  file_path,
106
  device,
107
  context_model_name="facebook/dpr-ctx_encoder-multiset-base",
108
  question_model_name="facebook/dpr-question_encoder-multiset-base",
109
- generator_name="valhalla/bart-large-finetuned-squadv1"):
110
 
111
  # generator_name = "valhalla/bart-large-finetuned-squadv1"
112
  # generator_name = "'vblagoje/bart_lfqa'"
@@ -122,22 +121,24 @@ class RAG:
122
 
123
 
124
  def abstractive_query(self, question):
125
- self.generator_tokenizer = BartTokenizer.from_pretrained(self.generator_name)
126
- self.generator_model = BartForConditionalGeneration.from_pretrained(self.generator_name).to(device)
127
  context = self.retriever.retrieve_top_k(question, k=5)
 
128
 
129
  input_text = "answer: " + " ".join(context) + " " + question
130
 
131
- inputs = self.generator_tokenizer.encode(input_text, return_tensors='pt', max_length=300, truncation=True).to(device)
132
- outputs = self.generator_model.generate(inputs, max_length=300, min_length=2, length_penalty=2.0, num_beams=4, early_stopping=True)
133
 
134
  answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
135
  return answer
136
 
137
  def extractive_query(self, question):
138
- context = self.retriever.retrieve_top_k(question, k=7)
 
139
 
140
- inputs = self.generator_tokenizer(question, ". ".join(context), return_tensors="pt", truncation=True, max_length=150, padding="max_length")
141
  with torch.no_grad():
142
  model_inputs = inputs.to(device)
143
  outputs = self.generator_model(**model_inputs)
 
99
 
100
  return retrieved_texts
101
 
 
102
  class RAG:
103
  def __init__(self,
104
  file_path,
105
  device,
106
  context_model_name="facebook/dpr-ctx_encoder-multiset-base",
107
  question_model_name="facebook/dpr-question_encoder-multiset-base",
108
+ generator_name="facebook/bart-large"):
109
 
110
  # generator_name = "valhalla/bart-large-finetuned-squadv1"
111
  # generator_name = "'vblagoje/bart_lfqa'"
 
121
 
122
 
123
  def abstractive_query(self, question):
124
+ self.generator_tokenizer = BartTokenizer.from_pretrained(generator_name)
125
+ self.generator_model = BartForConditionalGeneration.from_pretrained(generator_name).to(device)
126
  context = self.retriever.retrieve_top_k(question, k=5)
127
+ # input_text = question + " " + " ".join(context)
128
 
129
  input_text = "answer: " + " ".join(context) + " " + question
130
 
131
+ inputs = self.generator_tokenizer.encode(input_text, return_tensors='pt', max_length=500, truncation=True).to(device)
132
+ outputs = self.generator_model.generate(inputs, max_length=150, min_length=2, length_penalty=2.0, num_beams=4, early_stopping=True)
133
 
134
  answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
135
  return answer
136
 
137
  def extractive_query(self, question):
138
+ context = self.retriever.retrieve_top_k(question, k=15)
139
+
140
 
141
+ inputs = self.generator_tokenizer(question, ". ".join(context), return_tensors="pt", truncation=True, max_length=300, padding="max_length")
142
  with torch.no_grad():
143
  model_inputs = inputs.to(device)
144
  outputs = self.generator_model(**model_inputs)