Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Add an extra interface for performing diarization
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -1,7 +1,7 @@ 
     | 
|
| 1 | 
         
             
            from datetime import datetime
         
     | 
| 2 | 
         
             
            import json
         
     | 
| 3 | 
         
             
            import math
         
     | 
| 4 | 
         
            -
            from typing import Iterator, Union
         
     | 
| 5 | 
         
             
            import argparse
         
     | 
| 6 | 
         | 
| 7 | 
         
             
            from io import StringIO
         
     | 
| 
         @@ -16,14 +16,14 @@ import torch 
     | 
|
| 16 | 
         
             
            from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
         
     | 
| 17 | 
         
             
            from src.diarization.diarization import Diarization
         
     | 
| 18 | 
         
             
            from src.diarization.diarizationContainer import DiarizationContainer
         
     | 
| 
         | 
|
| 19 | 
         
             
            from src.hooks.progressListener import ProgressListener
         
     | 
| 20 | 
         
             
            from src.hooks.subTaskProgressListener import SubTaskProgressListener
         
     | 
| 21 | 
         
            -
            from src.hooks.whisperProgressHook import create_progress_listener_handle
         
     | 
| 22 | 
         
             
            from src.languages import get_language_names
         
     | 
| 23 | 
         
             
            from src.modelCache import ModelCache
         
     | 
| 24 | 
         
             
            from src.prompts.jsonPromptStrategy import JsonPromptStrategy
         
     | 
| 25 | 
         
             
            from src.prompts.prependPromptStrategy import PrependPromptStrategy
         
     | 
| 26 | 
         
            -
            from src.source import get_audio_source_collection
         
     | 
| 27 | 
         
             
            from src.vadParallel import ParallelContext, ParallelTranscription
         
     | 
| 28 | 
         | 
| 29 | 
         
             
            # External programs
         
     | 
| 
         @@ -101,7 +101,8 @@ class WhisperTranscriber: 
     | 
|
| 101 | 
         
             
                    self.diarization_kwargs = kwargs
         
     | 
| 102 | 
         | 
| 103 | 
         
             
                def unset_diarization(self):
         
     | 
| 104 | 
         
            -
                    self.diarization 
     | 
| 
         | 
|
| 105 | 
         
             
                    self.diarization_kwargs = None
         
     | 
| 106 | 
         | 
| 107 | 
         
             
                # Entry function for the simple tab
         
     | 
| 
         @@ -185,19 +186,59 @@ class WhisperTranscriber: 
     | 
|
| 185 | 
         
             
                                                 word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
         
     | 
| 186 | 
         
             
                                                 progress=progress)
         
     | 
| 187 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 188 | 
         
             
                def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, 
         
     | 
| 189 | 
         
             
                                     vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False, 
         
     | 
| 
         | 
|
| 190 | 
         
             
                                     **decodeOptions: dict):
         
     | 
| 191 | 
         
             
                    try:
         
     | 
| 192 | 
         
             
                        sources = self.__get_source(urlData, multipleFiles, microphoneData)
         
     | 
| 193 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 194 | 
         
             
                        try:
         
     | 
| 195 | 
         
             
                            selectedLanguage = languageName.lower() if len(languageName) > 0 else None
         
     | 
| 196 | 
         
             
                            selectedModel = modelName if modelName is not None else "base"
         
     | 
| 197 | 
         | 
| 198 | 
         
            -
                             
     | 
| 199 | 
         
            -
             
     | 
| 200 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 201 | 
         | 
| 202 | 
         
             
                            # Result
         
     | 
| 203 | 
         
             
                            download = []
         
     | 
| 
         @@ -234,8 +275,12 @@ class WhisperTranscriber: 
     | 
|
| 234 | 
         
             
                                                               sub_task_start=current_progress,
         
     | 
| 235 | 
         
             
                                                               sub_task_total=source_audio_duration)
         
     | 
| 236 | 
         | 
| 237 | 
         
            -
                                # Transcribe
         
     | 
| 238 | 
         
            -
                                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 239 | 
         
             
                                filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
         
     | 
| 240 | 
         | 
| 241 | 
         
             
                                # Update progress
         
     | 
| 
         @@ -363,6 +408,10 @@ class WhisperTranscriber: 
     | 
|
| 363 | 
         
             
                            result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
         
     | 
| 364 | 
         | 
| 365 | 
         
             
                    # Diarization
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 366 | 
         
             
                    if self.diarization and self.diarization_kwargs:
         
     | 
| 367 | 
         
             
                        print("Diarizing ", audio_path)
         
     | 
| 368 | 
         
             
                        diarization_result = list(self.diarization.run(audio_path, **self.diarization_kwargs))
         
     | 
| 
         @@ -373,9 +422,9 @@ class WhisperTranscriber: 
     | 
|
| 373 | 
         
             
                            print(f"  start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
         
     | 
| 374 | 
         | 
| 375 | 
         
             
                        # Add speakers to result
         
     | 
| 376 | 
         
            -
                         
     | 
| 377 | 
         | 
| 378 | 
         
            -
                    return  
     | 
| 379 | 
         | 
| 380 | 
         
             
                def _create_progress_listener(self, progress: gr.Progress):
         
     | 
| 381 | 
         
             
                    if (progress is None):
         
     | 
| 
         @@ -449,7 +498,7 @@ class WhisperTranscriber: 
     | 
|
| 449 | 
         
             
                        os.makedirs(output_dir)
         
     | 
| 450 | 
         | 
| 451 | 
         
             
                    text = result["text"]
         
     | 
| 452 | 
         
            -
                    language = result["language"]
         
     | 
| 453 | 
         
             
                    languageMaxLineWidth = self.__get_max_line_width(language)
         
     | 
| 454 | 
         | 
| 455 | 
         
             
                    print("Max line width " + str(languageMaxLineWidth))
         
     | 
| 
         @@ -635,7 +684,25 @@ def create_ui(app_config: ApplicationConfig): 
     | 
|
| 635 | 
         
             
                    gr.Text(label="Segments")
         
     | 
| 636 | 
         
             
                ])
         
     | 
| 637 | 
         | 
| 638 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 639 | 
         | 
| 640 | 
         
             
                # Queue up the demo
         
     | 
| 641 | 
         
             
                if is_queue_mode:
         
     | 
| 
         | 
|
| 1 | 
         
             
            from datetime import datetime
         
     | 
| 2 | 
         
             
            import json
         
     | 
| 3 | 
         
             
            import math
         
     | 
| 4 | 
         
            +
            from typing import Callable, Iterator, Union
         
     | 
| 5 | 
         
             
            import argparse
         
     | 
| 6 | 
         | 
| 7 | 
         
             
            from io import StringIO
         
     | 
| 
         | 
|
| 16 | 
         
             
            from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
         
     | 
| 17 | 
         
             
            from src.diarization.diarization import Diarization
         
     | 
| 18 | 
         
             
            from src.diarization.diarizationContainer import DiarizationContainer
         
     | 
| 19 | 
         
            +
            from src.diarization.transcriptLoader import load_transcript
         
     | 
| 20 | 
         
             
            from src.hooks.progressListener import ProgressListener
         
     | 
| 21 | 
         
             
            from src.hooks.subTaskProgressListener import SubTaskProgressListener
         
     | 
| 
         | 
|
| 22 | 
         
             
            from src.languages import get_language_names
         
     | 
| 23 | 
         
             
            from src.modelCache import ModelCache
         
     | 
| 24 | 
         
             
            from src.prompts.jsonPromptStrategy import JsonPromptStrategy
         
     | 
| 25 | 
         
             
            from src.prompts.prependPromptStrategy import PrependPromptStrategy
         
     | 
| 26 | 
         
            +
            from src.source import AudioSource, get_audio_source_collection
         
     | 
| 27 | 
         
             
            from src.vadParallel import ParallelContext, ParallelTranscription
         
     | 
| 28 | 
         | 
| 29 | 
         
             
            # External programs
         
     | 
| 
         | 
|
| 101 | 
         
             
                    self.diarization_kwargs = kwargs
         
     | 
| 102 | 
         | 
| 103 | 
         
             
                def unset_diarization(self):
         
     | 
| 104 | 
         
            +
                    if self.diarization is not None:
         
     | 
| 105 | 
         
            +
                        self.diarization.cleanup()
         
     | 
| 106 | 
         
             
                    self.diarization_kwargs = None
         
     | 
| 107 | 
         | 
| 108 | 
         
             
                # Entry function for the simple tab
         
     | 
| 
         | 
|
| 186 | 
         
             
                                                 word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
         
     | 
| 187 | 
         
             
                                                 progress=progress)
         
     | 
| 188 | 
         | 
| 189 | 
         
            +
                # Perform diarization given a specific input audio file and whisper file
         
     | 
| 190 | 
         
            +
                def perform_extra(self, languageName, urlData, singleFile, whisper_file: str, 
         
     | 
| 191 | 
         
            +
                                  highlight_words: bool = False,
         
     | 
| 192 | 
         
            +
                                  diarization: bool = False, diarization_speakers: int = 2, diarization_min_speakers = 1, diarization_max_speakers = 5, progress=gr.Progress()):
         
     | 
| 193 | 
         
            +
                
         
     | 
| 194 | 
         
            +
                    if whisper_file is None:
         
     | 
| 195 | 
         
            +
                        raise ValueError("whisper_file is required")
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                     # Set diarization
         
     | 
| 198 | 
         
            +
                    if diarization:
         
     | 
| 199 | 
         
            +
                        self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers, 
         
     | 
| 200 | 
         
            +
                                            min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
         
     | 
| 201 | 
         
            +
                    else:
         
     | 
| 202 | 
         
            +
                        self.unset_diarization()
         
     | 
| 203 | 
         
            +
                    
         
     | 
| 204 | 
         
            +
                    def custom_transcribe_file(source: AudioSource):
         
     | 
| 205 | 
         
            +
                        result = load_transcript(whisper_file.name)
         
     | 
| 206 | 
         
            +
                        
         
     | 
| 207 | 
         
            +
                        # Set language if not set
         
     | 
| 208 | 
         
            +
                        if not "language" in result:
         
     | 
| 209 | 
         
            +
                            result["language"] = languageName
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                        # Mark speakers
         
     | 
| 212 | 
         
            +
                        result = self._handle_diarization(source.source_path, result)
         
     | 
| 213 | 
         
            +
                        return result
         
     | 
| 214 | 
         
            +
                    
         
     | 
| 215 | 
         
            +
                    multipleFiles = [singleFile] if singleFile else None
         
     | 
| 216 | 
         
            +
                    
         
     | 
| 217 | 
         
            +
                    # Will return download, text, vtt
         
     | 
| 218 | 
         
            +
                    return self.transcribe_webui("base", "", urlData, multipleFiles, None, None, None, 
         
     | 
| 219 | 
         
            +
                                                   progress=progress,highlight_words=highlight_words,
         
     | 
| 220 | 
         
            +
                                                   override_transcribe_file=custom_transcribe_file, override_max_sources=1)
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
             
                def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, 
         
     | 
| 223 | 
         
             
                                     vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False, 
         
     | 
| 224 | 
         
            +
                                     override_transcribe_file: Callable[[AudioSource], dict] = None, override_max_sources = None,
         
     | 
| 225 | 
         
             
                                     **decodeOptions: dict):
         
     | 
| 226 | 
         
             
                    try:
         
     | 
| 227 | 
         
             
                        sources = self.__get_source(urlData, multipleFiles, microphoneData)
         
     | 
| 228 | 
         | 
| 229 | 
         
            +
                        if override_max_sources is not None and len(sources) > override_max_sources:
         
     | 
| 230 | 
         
            +
                            raise ValueError("Maximum number of sources is " + str(override_max_sources) + ", but " + str(len(sources)) + " were provided")
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
             
                        try:
         
     | 
| 233 | 
         
             
                            selectedLanguage = languageName.lower() if len(languageName) > 0 else None
         
     | 
| 234 | 
         
             
                            selectedModel = modelName if modelName is not None else "base"
         
     | 
| 235 | 
         | 
| 236 | 
         
            +
                            if override_transcribe_file is None:
         
     | 
| 237 | 
         
            +
                                model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation, 
         
     | 
| 238 | 
         
            +
                                                                model_name=selectedModel, compute_type=self.app_config.compute_type, 
         
     | 
| 239 | 
         
            +
                                                                cache=self.model_cache, models=self.app_config.models)
         
     | 
| 240 | 
         
            +
                            else:
         
     | 
| 241 | 
         
            +
                                model = None
         
     | 
| 242 | 
         | 
| 243 | 
         
             
                            # Result
         
     | 
| 244 | 
         
             
                            download = []
         
     | 
| 
         | 
|
| 275 | 
         
             
                                                               sub_task_start=current_progress,
         
     | 
| 276 | 
         
             
                                                               sub_task_total=source_audio_duration)
         
     | 
| 277 | 
         | 
| 278 | 
         
            +
                                # Transcribe using the override function if specified
         
     | 
| 279 | 
         
            +
                                if override_transcribe_file is None:
         
     | 
| 280 | 
         
            +
                                    result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vadOptions, scaled_progress_listener, **decodeOptions)
         
     | 
| 281 | 
         
            +
                                else:
         
     | 
| 282 | 
         
            +
                                    result = override_transcribe_file(source)
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
             
                                filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
         
     | 
| 285 | 
         | 
| 286 | 
         
             
                                # Update progress
         
     | 
| 
         | 
|
| 408 | 
         
             
                            result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
         
     | 
| 409 | 
         | 
| 410 | 
         
             
                    # Diarization
         
     | 
| 411 | 
         
            +
                    result = self._handle_diarization(audio_path, result)
         
     | 
| 412 | 
         
            +
                    return result
         
     | 
| 413 | 
         
            +
             
     | 
| 414 | 
         
            +
                def _handle_diarization(self, audio_path: str, input: dict):
         
     | 
| 415 | 
         
             
                    if self.diarization and self.diarization_kwargs:
         
     | 
| 416 | 
         
             
                        print("Diarizing ", audio_path)
         
     | 
| 417 | 
         
             
                        diarization_result = list(self.diarization.run(audio_path, **self.diarization_kwargs))
         
     | 
| 
         | 
|
| 422 | 
         
             
                            print(f"  start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
         
     | 
| 423 | 
         | 
| 424 | 
         
             
                        # Add speakers to result
         
     | 
| 425 | 
         
            +
                        input = self.diarization.mark_speakers(diarization_result, input)
         
     | 
| 426 | 
         | 
| 427 | 
         
            +
                    return input
         
     | 
| 428 | 
         | 
| 429 | 
         
             
                def _create_progress_listener(self, progress: gr.Progress):
         
     | 
| 430 | 
         
             
                    if (progress is None):
         
     | 
| 
         | 
|
| 498 | 
         
             
                        os.makedirs(output_dir)
         
     | 
| 499 | 
         | 
| 500 | 
         
             
                    text = result["text"]
         
     | 
| 501 | 
         
            +
                    language = result["language"] if "language" in result else None
         
     | 
| 502 | 
         
             
                    languageMaxLineWidth = self.__get_max_line_width(language)
         
     | 
| 503 | 
         | 
| 504 | 
         
             
                    print("Max line width " + str(languageMaxLineWidth))
         
     | 
| 
         | 
|
| 684 | 
         
             
                    gr.Text(label="Segments")
         
     | 
| 685 | 
         
             
                ])
         
     | 
| 686 | 
         | 
| 687 | 
         
            +
                perform_extra_interface = gr.Interface(fn=ui.perform_extra,
         
     | 
| 688 | 
         
            +
                                               description="Perform additional processing on a given JSON or SRT file", article=ui_article, inputs=[
         
     | 
| 689 | 
         
            +
                    gr.Dropdown(choices=sorted(get_language_names()), label="Language", value=app_config.language),
         
     | 
| 690 | 
         
            +
                    gr.Text(label="URL (YouTube, etc.)"),
         
     | 
| 691 | 
         
            +
                    gr.File(label="Upload Audio File", file_count="single"),
         
     | 
| 692 | 
         
            +
                    gr.File(label="Upload JSON/SRT File", file_count="single"),
         
     | 
| 693 | 
         
            +
                    gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
         
     | 
| 694 | 
         
            +
             
     | 
| 695 | 
         
            +
                    *common_diarization_inputs(),
         
     | 
| 696 | 
         
            +
                    gr.Number(label="Diarization - Min Speakers", precision=0, value=app_config.diarization_min_speakers, interactive=has_diarization_libs),
         
     | 
| 697 | 
         
            +
                    gr.Number(label="Diarization - Max Speakers", precision=0, value=app_config.diarization_max_speakers, interactive=has_diarization_libs),
         
     | 
| 698 | 
         
            +
             
     | 
| 699 | 
         
            +
                ], outputs=[
         
     | 
| 700 | 
         
            +
                    gr.File(label="Download"),
         
     | 
| 701 | 
         
            +
                    gr.Text(label="Transcription"), 
         
     | 
| 702 | 
         
            +
                    gr.Text(label="Segments")
         
     | 
| 703 | 
         
            +
                ])
         
     | 
| 704 | 
         
            +
             
     | 
| 705 | 
         
            +
                demo = gr.TabbedInterface([simple_transcribe, full_transcribe, perform_extra_interface], tab_names=["Simple", "Full", "Extra"])
         
     | 
| 706 | 
         | 
| 707 | 
         
             
                # Queue up the demo
         
     | 
| 708 | 
         
             
                if is_queue_mode:
         
     | 
    	
        cli.py
    CHANGED
    
    | 
         @@ -108,12 +108,12 @@ def cli(): 
     | 
|
| 108 | 
         
             
                                    help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
         
     | 
| 109 | 
         | 
| 110 | 
         
             
                # Diarization
         
     | 
| 111 | 
         
            -
                parser.add_argument('--auth_token', type=str, default= 
     | 
| 112 | 
         
             
                parser.add_argument("--diarization", type=str2bool, default=app_config.diarization, \
         
     | 
| 113 | 
         
             
                                    help="whether to perform speaker diarization")
         
     | 
| 114 | 
         
            -
                parser.add_argument("--diarization_num_speakers", type=int, default= 
     | 
| 115 | 
         
            -
                parser.add_argument("--diarization_min_speakers", type=int, default= 
     | 
| 116 | 
         
            -
                parser.add_argument("--diarization_max_speakers", type=int, default= 
     | 
| 117 | 
         | 
| 118 | 
         
             
                args = parser.parse_args().__dict__
         
     | 
| 119 | 
         
             
                model_name: str = args.pop("model")
         
     | 
| 
         | 
|
| 108 | 
         
             
                                    help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
         
     | 
| 109 | 
         | 
| 110 | 
         
             
                # Diarization
         
     | 
| 111 | 
         
            +
                parser.add_argument('--auth_token', type=str, default=app_config.auth_token, help='HuggingFace API Token (optional)')
         
     | 
| 112 | 
         
             
                parser.add_argument("--diarization", type=str2bool, default=app_config.diarization, \
         
     | 
| 113 | 
         
             
                                    help="whether to perform speaker diarization")
         
     | 
| 114 | 
         
            +
                parser.add_argument("--diarization_num_speakers", type=int, default=app_config.diarization_speakers, help="Number of speakers")
         
     | 
| 115 | 
         
            +
                parser.add_argument("--diarization_min_speakers", type=int, default=app_config.diarization_min_speakers, help="Minimum number of speakers")
         
     | 
| 116 | 
         
            +
                parser.add_argument("--diarization_max_speakers", type=int, default=app_config.diarization_max_speakers, help="Maximum number of speakers")
         
     | 
| 117 | 
         | 
| 118 | 
         
             
                args = parser.parse_args().__dict__
         
     | 
| 119 | 
         
             
                model_name: str = args.pop("model")
         
     |