add generate_answer for long form question answering https://github.com/atilatech/atila-core-service/pull/7
Browse filessee: https://github.com/atilatech/atila-core-service/pull/7
![haystack-lfqa-1.png](https://s3.amazonaws.com/moonup/production/uploads/1674311706276-63a4969d658851481f7729dd.png)
- handler.py +55 -4
handler.py
CHANGED
@@ -3,6 +3,7 @@ from typing import Dict
|
|
3 |
from sentence_transformers import SentenceTransformer
|
4 |
from tqdm import tqdm
|
5 |
import whisper
|
|
|
6 |
import torch
|
7 |
import pytube
|
8 |
import time
|
@@ -12,11 +13,13 @@ class EndpointHandler():
|
|
12 |
# load the model
|
13 |
WHISPER_MODEL_NAME = "tiny.en"
|
14 |
SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
|
|
|
|
|
15 |
|
16 |
def __init__(self, path=""):
|
17 |
|
18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
-
print(f'whisper will use: {device}')
|
20 |
|
21 |
t0 = time.time()
|
22 |
self.whisper_model = whisper.load_model(self.WHISPER_MODEL_NAME).to(device)
|
@@ -31,6 +34,13 @@ class EndpointHandler():
|
|
31 |
|
32 |
total = t1 - t0
|
33 |
print(f'Finished loading sentence_transformer_model in {total} seconds')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
def __call__(self, data: Dict[str, str]) -> Dict:
|
36 |
"""
|
@@ -48,6 +58,7 @@ class EndpointHandler():
|
|
48 |
f" See: https://huggingface.co/docs/inference-endpoints/guides/custom_handler#2-create-endpointhandler-cp")
|
49 |
video_url = data.pop("video_url", None)
|
50 |
query = data.pop("query", None)
|
|
|
51 |
encoded_segments = {}
|
52 |
if video_url:
|
53 |
video_with_transcript = self.transcribe_video(video_url)
|
@@ -63,11 +74,27 @@ class EndpointHandler():
|
|
63 |
**encoded_segments
|
64 |
}
|
65 |
elif query:
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
return {
|
70 |
-
"
|
71 |
}
|
72 |
|
73 |
def transcribe_video(self, video_url):
|
@@ -140,6 +167,30 @@ class EndpointHandler():
|
|
140 |
|
141 |
return all_batches
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
@staticmethod
|
144 |
def combine_transcripts(video, window=6, stride=3):
|
145 |
"""
|
|
|
3 |
from sentence_transformers import SentenceTransformer
|
4 |
from tqdm import tqdm
|
5 |
import whisper
|
6 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
7 |
import torch
|
8 |
import pytube
|
9 |
import time
|
|
|
13 |
# load the model
|
14 |
WHISPER_MODEL_NAME = "tiny.en"
|
15 |
SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
|
16 |
+
QUESTION_ANSWER_MODEL_NAME = "vblagoje/bart_lfqa"
|
17 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
18 |
|
19 |
def __init__(self, path=""):
|
20 |
|
21 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
+
print(f'whisper and question_answer_model will use: {device}')
|
23 |
|
24 |
t0 = time.time()
|
25 |
self.whisper_model = whisper.load_model(self.WHISPER_MODEL_NAME).to(device)
|
|
|
34 |
|
35 |
total = t1 - t0
|
36 |
print(f'Finished loading sentence_transformer_model in {total} seconds')
|
37 |
+
|
38 |
+
self.question_answer_tokenizer = AutoTokenizer.from_pretrained(self.QUESTION_ANSWER_MODEL_NAME)
|
39 |
+
t0 = time.time()
|
40 |
+
self.question_answer_model = AutoModelForSeq2SeqLM.from_pretrained(self.QUESTION_ANSWER_MODEL_NAME).to(device)
|
41 |
+
t1 = time.time()
|
42 |
+
total = t1 - t0
|
43 |
+
print(f'Finished loading question_answer_model in {total} seconds')
|
44 |
|
45 |
def __call__(self, data: Dict[str, str]) -> Dict:
|
46 |
"""
|
|
|
58 |
f" See: https://huggingface.co/docs/inference-endpoints/guides/custom_handler#2-create-endpointhandler-cp")
|
59 |
video_url = data.pop("video_url", None)
|
60 |
query = data.pop("query", None)
|
61 |
+
long_form_answer = data.pop("long_form_answer", None)
|
62 |
encoded_segments = {}
|
63 |
if video_url:
|
64 |
video_with_transcript = self.transcribe_video(video_url)
|
|
|
74 |
**encoded_segments
|
75 |
}
|
76 |
elif query:
|
77 |
+
if long_form_answer:
|
78 |
+
context = data.pop("context", None)
|
79 |
+
answer = self.generate_answer(query, context)
|
80 |
+
response = {
|
81 |
+
"answer": answer
|
82 |
+
}
|
83 |
+
|
84 |
+
return response
|
85 |
+
else:
|
86 |
+
query = [{"text": query, "id": ""}] if isinstance(query, str) else query
|
87 |
+
encoded_segments = self.encode_sentences(query)
|
88 |
|
89 |
+
response = {
|
90 |
+
"encoded_segments": encoded_segments
|
91 |
+
}
|
92 |
+
|
93 |
+
return response
|
94 |
+
|
95 |
+
else:
|
96 |
return {
|
97 |
+
"error": "'video_url' or 'query' must be provided"
|
98 |
}
|
99 |
|
100 |
def transcribe_video(self, video_url):
|
|
|
167 |
|
168 |
return all_batches
|
169 |
|
170 |
+
def generate_answer(self, query, documents):
|
171 |
+
|
172 |
+
# concatenate question and support documents into BART input
|
173 |
+
conditioned_doc = "<P> " + " <P> ".join([d for d in documents])
|
174 |
+
query_and_docs = "question: {} context: {}".format(query, conditioned_doc)
|
175 |
+
|
176 |
+
model_input = self.question_answer_tokenizer(query_and_docs, truncation=False, padding=True, return_tensors="pt")
|
177 |
+
|
178 |
+
generated_answers_encoded = self.question_answer_model.generate(input_ids=model_input["input_ids"].to(self.device),
|
179 |
+
attention_mask=model_input["attention_mask"].to(self.device),
|
180 |
+
min_length=64,
|
181 |
+
max_length=256,
|
182 |
+
do_sample=False,
|
183 |
+
early_stopping=True,
|
184 |
+
num_beams=8,
|
185 |
+
temperature=1.0,
|
186 |
+
top_k=None,
|
187 |
+
top_p=None,
|
188 |
+
eos_token_id=self.question_answer_tokenizer.eos_token_id,
|
189 |
+
no_repeat_ngram_size=3,
|
190 |
+
num_return_sequences=1)
|
191 |
+
answer = self.question_answer_tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,clean_up_tokenization_spaces=True)
|
192 |
+
return answer
|
193 |
+
|
194 |
@staticmethod
|
195 |
def combine_transcripts(video, window=6, stride=3):
|
196 |
"""
|