Fix diarization in CLI
Browse files
app.py
CHANGED
@@ -240,19 +240,6 @@ class WhisperTranscriber:
|
|
240 |
# Update progress
|
241 |
current_progress += source_audio_duration
|
242 |
|
243 |
-
# Diarization
|
244 |
-
if self.diarization and self.diarization_kwargs:
|
245 |
-
print("Diarizing ", source.source_path)
|
246 |
-
diarization_result = list(self.diarization.run(source.source_path, **self.diarization_kwargs))
|
247 |
-
|
248 |
-
# Print result
|
249 |
-
print("Diarization result: ")
|
250 |
-
for entry in diarization_result:
|
251 |
-
print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
|
252 |
-
|
253 |
-
# Add speakers to result
|
254 |
-
result = self.diarization.mark_speakers(diarization_result, result)
|
255 |
-
|
256 |
source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory, highlight_words)
|
257 |
|
258 |
if len(sources) > 1:
|
@@ -373,6 +360,19 @@ class WhisperTranscriber:
|
|
373 |
else:
|
374 |
# Default VAD
|
375 |
result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
|
377 |
return result
|
378 |
|
|
|
240 |
# Update progress
|
241 |
current_progress += source_audio_duration
|
242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory, highlight_words)
|
244 |
|
245 |
if len(sources) > 1:
|
|
|
360 |
else:
|
361 |
# Default VAD
|
362 |
result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
|
363 |
+
|
364 |
+
# Diarization
|
365 |
+
if self.diarization and self.diarization_kwargs:
|
366 |
+
print("Diarizing ", audio_path)
|
367 |
+
diarization_result = list(self.diarization.run(audio_path, **self.diarization_kwargs))
|
368 |
+
|
369 |
+
# Print result
|
370 |
+
print("Diarization result: ")
|
371 |
+
for entry in diarization_result:
|
372 |
+
print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
|
373 |
+
|
374 |
+
# Add speakers to result
|
375 |
+
result = self.diarization.mark_speakers(diarization_result, result)
|
376 |
|
377 |
return result
|
378 |
|
cli.py
CHANGED
@@ -111,9 +111,9 @@ def cli():
|
|
111 |
parser.add_argument('--auth_token', type=str, default=None, help='HuggingFace API Token (optional)')
|
112 |
parser.add_argument("--diarization", type=str2bool, default=app_config.diarization, \
|
113 |
help="whether to perform speaker diarization")
|
114 |
-
parser.add_argument("--
|
115 |
-
parser.add_argument("--
|
116 |
-
parser.add_argument("--
|
117 |
|
118 |
args = parser.parse_args().__dict__
|
119 |
model_name: str = args.pop("model")
|
@@ -151,11 +151,11 @@ def cli():
|
|
151 |
compute_type = args.pop("compute_type")
|
152 |
highlight_words = args.pop("highlight_words")
|
153 |
|
154 |
-
diarization = args.pop("diarization")
|
155 |
auth_token = args.pop("auth_token")
|
156 |
-
|
157 |
-
|
158 |
-
|
|
|
159 |
|
160 |
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
|
161 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
|
|
111 |
parser.add_argument('--auth_token', type=str, default=None, help='HuggingFace API Token (optional)')
|
112 |
parser.add_argument("--diarization", type=str2bool, default=app_config.diarization, \
|
113 |
help="whether to perform speaker diarization")
|
114 |
+
parser.add_argument("--diarization_num_speakers", type=int, default=None, help="Number of speakers")
|
115 |
+
parser.add_argument("--diarization_min_speakers", type=int, default=None, help="Minimum number of speakers")
|
116 |
+
parser.add_argument("--diarization_max_speakers", type=int, default=None, help="Maximum number of speakers")
|
117 |
|
118 |
args = parser.parse_args().__dict__
|
119 |
model_name: str = args.pop("model")
|
|
|
151 |
compute_type = args.pop("compute_type")
|
152 |
highlight_words = args.pop("highlight_words")
|
153 |
|
|
|
154 |
auth_token = args.pop("auth_token")
|
155 |
+
diarization = args.pop("diarization")
|
156 |
+
num_speakers = args.pop("diarization_num_speakers")
|
157 |
+
min_speakers = args.pop("diarization_min_speakers")
|
158 |
+
max_speakers = args.pop("diarization_max_speakers")
|
159 |
|
160 |
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
|
161 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|