add support for query to be a string or an array of sentences that need encoding
Browse files- handler.py +9 -7
handler.py
CHANGED
@@ -9,23 +9,24 @@ import time
|
|
9 |
|
10 |
|
11 |
class EndpointHandler():
|
|
|
|
|
|
|
|
|
12 |
def __init__(self, path=""):
|
13 |
-
# load the model
|
14 |
-
WHISPER_MODEL_NAME = "tiny.en"
|
15 |
-
SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
|
16 |
|
17 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
print(f'whisper will use: {device}')
|
19 |
|
20 |
t0 = time.time()
|
21 |
-
self.whisper_model = whisper.load_model(WHISPER_MODEL_NAME).to(device)
|
22 |
t1 = time.time()
|
23 |
|
24 |
total = t1 - t0
|
25 |
print(f'Finished loading whisper_model in {total} seconds')
|
26 |
|
27 |
t0 = time.time()
|
28 |
-
self.sentence_transformer_model = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_NAME)
|
29 |
t1 = time.time()
|
30 |
|
31 |
total = t1 - t0
|
@@ -50,6 +51,7 @@ class EndpointHandler():
|
|
50 |
encoded_segments = {}
|
51 |
if video_url:
|
52 |
video_with_transcript = self.transcribe_video(video_url)
|
|
|
53 |
encode_transcript = data.pop("encode_transcript", True)
|
54 |
if encode_transcript:
|
55 |
encoded_segments = self.combine_transcripts(video_with_transcript)
|
@@ -61,7 +63,7 @@ class EndpointHandler():
|
|
61 |
**encoded_segments
|
62 |
}
|
63 |
elif query:
|
64 |
-
query = [{"text": query, "id": ""}]
|
65 |
encoded_segments = self.encode_sentences(query)
|
66 |
|
67 |
return {
|
@@ -131,7 +133,7 @@ class EndpointHandler():
|
|
131 |
batch_details = [
|
132 |
{
|
133 |
**batch_meta[x],
|
134 |
-
'vectors':batch_vectors[x]
|
135 |
} for x in range(0, len(batch_meta))
|
136 |
]
|
137 |
all_batches.extend(batch_details)
|
|
|
9 |
|
10 |
|
11 |
class EndpointHandler():
|
12 |
+
# load the model
|
13 |
+
WHISPER_MODEL_NAME = "tiny.en"
|
14 |
+
SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
|
15 |
+
|
16 |
def __init__(self, path=""):
|
|
|
|
|
|
|
17 |
|
18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
print(f'whisper will use: {device}')
|
20 |
|
21 |
t0 = time.time()
|
22 |
+
self.whisper_model = whisper.load_model(self.WHISPER_MODEL_NAME).to(device)
|
23 |
t1 = time.time()
|
24 |
|
25 |
total = t1 - t0
|
26 |
print(f'Finished loading whisper_model in {total} seconds')
|
27 |
|
28 |
t0 = time.time()
|
29 |
+
self.sentence_transformer_model = SentenceTransformer(self.SENTENCE_TRANSFORMER_MODEL_NAME)
|
30 |
t1 = time.time()
|
31 |
|
32 |
total = t1 - t0
|
|
|
51 |
encoded_segments = {}
|
52 |
if video_url:
|
53 |
video_with_transcript = self.transcribe_video(video_url)
|
54 |
+
video_with_transcript['transcript']['transcription_source'] = f"whisper_{self.WHISPER_MODEL_NAME}"
|
55 |
encode_transcript = data.pop("encode_transcript", True)
|
56 |
if encode_transcript:
|
57 |
encoded_segments = self.combine_transcripts(video_with_transcript)
|
|
|
63 |
**encoded_segments
|
64 |
}
|
65 |
elif query:
|
66 |
+
query = [{"text": query, "id": ""}] if isinstance(query, str) else query
|
67 |
encoded_segments = self.encode_sentences(query)
|
68 |
|
69 |
return {
|
|
|
133 |
batch_details = [
|
134 |
{
|
135 |
**batch_meta[x],
|
136 |
+
'vectors': batch_vectors[x]
|
137 |
} for x in range(0, len(batch_meta))
|
138 |
]
|
139 |
all_batches.extend(batch_details)
|