tomiwa1a commited on
Commit
93a37ae
1 Parent(s): 81a0884

add support for query to be a string or an array of sentences that need encoding

Browse files
Files changed (1) hide show
  1. handler.py +9 -7
handler.py CHANGED
@@ -9,23 +9,24 @@ import time
9
 
10
 
11
  class EndpointHandler():
 
 
 
 
12
  def __init__(self, path=""):
13
- # load the model
14
- WHISPER_MODEL_NAME = "tiny.en"
15
- SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
16
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  print(f'whisper will use: {device}')
19
 
20
  t0 = time.time()
21
- self.whisper_model = whisper.load_model(WHISPER_MODEL_NAME).to(device)
22
  t1 = time.time()
23
 
24
  total = t1 - t0
25
  print(f'Finished loading whisper_model in {total} seconds')
26
 
27
  t0 = time.time()
28
- self.sentence_transformer_model = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_NAME)
29
  t1 = time.time()
30
 
31
  total = t1 - t0
@@ -50,6 +51,7 @@ class EndpointHandler():
50
  encoded_segments = {}
51
  if video_url:
52
  video_with_transcript = self.transcribe_video(video_url)
 
53
  encode_transcript = data.pop("encode_transcript", True)
54
  if encode_transcript:
55
  encoded_segments = self.combine_transcripts(video_with_transcript)
@@ -61,7 +63,7 @@ class EndpointHandler():
61
  **encoded_segments
62
  }
63
  elif query:
64
- query = [{"text": query, "id": ""}]
65
  encoded_segments = self.encode_sentences(query)
66
 
67
  return {
@@ -131,7 +133,7 @@ class EndpointHandler():
131
  batch_details = [
132
  {
133
  **batch_meta[x],
134
- 'vectors':batch_vectors[x]
135
  } for x in range(0, len(batch_meta))
136
  ]
137
  all_batches.extend(batch_details)
 
9
 
10
 
11
  class EndpointHandler():
12
+ # load the model
13
+ WHISPER_MODEL_NAME = "tiny.en"
14
+ SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
15
+
16
  def __init__(self, path=""):
 
 
 
17
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  print(f'whisper will use: {device}')
20
 
21
  t0 = time.time()
22
+ self.whisper_model = whisper.load_model(self.WHISPER_MODEL_NAME).to(device)
23
  t1 = time.time()
24
 
25
  total = t1 - t0
26
  print(f'Finished loading whisper_model in {total} seconds')
27
 
28
  t0 = time.time()
29
+ self.sentence_transformer_model = SentenceTransformer(self.SENTENCE_TRANSFORMER_MODEL_NAME)
30
  t1 = time.time()
31
 
32
  total = t1 - t0
 
51
  encoded_segments = {}
52
  if video_url:
53
  video_with_transcript = self.transcribe_video(video_url)
54
+ video_with_transcript['transcript']['transcription_source'] = f"whisper_{self.WHISPER_MODEL_NAME}"
55
  encode_transcript = data.pop("encode_transcript", True)
56
  if encode_transcript:
57
  encoded_segments = self.combine_transcripts(video_with_transcript)
 
63
  **encoded_segments
64
  }
65
  elif query:
66
+ query = [{"text": query, "id": ""}] if isinstance(query, str) else query
67
  encoded_segments = self.encode_sentences(query)
68
 
69
  return {
 
133
  batch_details = [
134
  {
135
  **batch_meta[x],
136
+ 'vectors': batch_vectors[x]
137
  } for x in range(0, len(batch_meta))
138
  ]
139
  all_batches.extend(batch_details)