Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Fix unit test
Browse files- src/whisper/abstractWhisperContainer.py +10 -2
 - tests/vad_test.py +10 -4
 
    	
        src/whisper/abstractWhisperContainer.py
    CHANGED
    
    | 
         @@ -1,5 +1,5 @@ 
     | 
|
| 1 | 
         
             
            import abc
         
     | 
| 2 | 
         
            -
            from typing import List
         
     | 
| 3 | 
         | 
| 4 | 
         
             
            from src.config import ModelConfig, VadInitialPromptMode
         
     | 
| 5 | 
         | 
| 
         @@ -9,7 +9,7 @@ from src.prompts.abstractPromptStrategy import AbstractPromptStrategy 
     | 
|
| 9 | 
         | 
| 10 | 
         
             
            class AbstractWhisperCallback:
         
     | 
| 11 | 
         
             
                def __init__(self):
         
     | 
| 12 | 
         
            -
                     
     | 
| 13 | 
         | 
| 14 | 
         
             
                @abc.abstractmethod
         
     | 
| 15 | 
         
             
                def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
         
     | 
| 
         @@ -29,6 +29,14 @@ class AbstractWhisperCallback: 
     | 
|
| 29 | 
         
             
                    """
         
     | 
| 30 | 
         
             
                    raise NotImplementedError()
         
     | 
| 31 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 32 | 
         
             
            class AbstractWhisperContainer:
         
     | 
| 33 | 
         
             
                def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
         
     | 
| 34 | 
         
             
                             download_root: str = None,
         
     | 
| 
         | 
|
| 1 | 
         
             
            import abc
         
     | 
| 2 | 
         
            +
            from typing import Any, Callable, List
         
     | 
| 3 | 
         | 
| 4 | 
         
             
            from src.config import ModelConfig, VadInitialPromptMode
         
     | 
| 5 | 
         | 
| 
         | 
|
| 9 | 
         | 
| 10 | 
         
             
            class AbstractWhisperCallback:
         
     | 
| 11 | 
         
             
                def __init__(self):
         
     | 
| 12 | 
         
            +
                    pass
         
     | 
| 13 | 
         | 
| 14 | 
         
             
                @abc.abstractmethod
         
     | 
| 15 | 
         
             
                def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
         
     | 
| 
         | 
|
| 29 | 
         
             
                    """
         
     | 
| 30 | 
         
             
                    raise NotImplementedError()
         
     | 
| 31 | 
         | 
| 32 | 
         
            +
            class LambdaWhisperCallback(AbstractWhisperCallback):
         
     | 
| 33 | 
         
            +
                def __init__(self, callback_lambda: Callable[[Any, int, str, str, ProgressListener], None]):
         
     | 
| 34 | 
         
            +
                    super().__init__()
         
     | 
| 35 | 
         
            +
                    self.callback_lambda = callback_lambda
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
         
     | 
| 38 | 
         
            +
                    return self.callback_lambda(audio, segment_index, prompt, detected_language, progress_listener)
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
             
            class AbstractWhisperContainer:
         
     | 
| 41 | 
         
             
                def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
         
     | 
| 42 | 
         
             
                             download_root: str = None,
         
     | 
    	
        tests/vad_test.py
    CHANGED
    
    | 
         @@ -1,10 +1,11 @@ 
     | 
|
| 1 | 
         
            -
            import pprint
         
     | 
| 2 | 
         
             
            import unittest
         
     | 
| 3 | 
         
             
            import numpy as np
         
     | 
| 4 | 
         
             
            import sys
         
     | 
| 5 | 
         | 
| 6 | 
         
             
            sys.path.append('../whisper-webui')
         
     | 
| 
         | 
|
| 7 | 
         | 
| 
         | 
|
| 8 | 
         
             
            from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
         
     | 
| 9 | 
         | 
| 10 | 
         
             
            class TestVad(unittest.TestCase):
         
     | 
| 
         @@ -13,10 +14,11 @@ class TestVad(unittest.TestCase): 
     | 
|
| 13 | 
         
             
                    self.transcribe_calls = []
         
     | 
| 14 | 
         | 
| 15 | 
         
             
                def test_transcript(self):
         
     | 
| 16 | 
         
            -
                    mock = MockVadTranscription()
         
     | 
| 
         | 
|
| 17 | 
         | 
| 18 | 
         
             
                    self.transcribe_calls.clear()
         
     | 
| 19 | 
         
            -
                    result = mock.transcribe("mock", lambda segment : self.transcribe_segments(segment))
         
     | 
| 20 | 
         | 
| 21 | 
         
             
                    self.assertListEqual(self.transcribe_calls, [ 
         
     | 
| 22 | 
         
             
                        [30, 30],
         
     | 
| 
         @@ -45,8 +47,9 @@ class TestVad(unittest.TestCase): 
     | 
|
| 45 | 
         
             
                    }
         
     | 
| 46 | 
         | 
| 47 | 
         
             
            class MockVadTranscription(AbstractTranscription):
         
     | 
| 48 | 
         
            -
                def __init__(self):
         
     | 
| 49 | 
         
             
                    super().__init__()
         
     | 
| 
         | 
|
| 50 | 
         | 
| 51 | 
         
             
                def get_audio_segment(self, str, start_time: str = None, duration: str = None):
         
     | 
| 52 | 
         
             
                    start_time_seconds = float(start_time.removesuffix("s"))
         
     | 
| 
         @@ -61,6 +64,9 @@ class MockVadTranscription(AbstractTranscription): 
     | 
|
| 61 | 
         
             
                    result.append( {  'start': 30, 'end': 60 } )
         
     | 
| 62 | 
         
             
                    result.append( {  'start': 100, 'end': 200 } )
         
     | 
| 63 | 
         
             
                    return result
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 64 | 
         | 
| 65 | 
         
             
            if __name__ == '__main__':
         
     | 
| 66 | 
         
             
                unittest.main()
         
     | 
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            import unittest
         
     | 
| 2 | 
         
             
            import numpy as np
         
     | 
| 3 | 
         
             
            import sys
         
     | 
| 4 | 
         | 
| 5 | 
         
             
            sys.path.append('../whisper-webui')
         
     | 
| 6 | 
         
            +
            #print("Sys path: " + str(sys.path))
         
     | 
| 7 | 
         | 
| 8 | 
         
            +
            from src.whisper.abstractWhisperContainer import LambdaWhisperCallback
         
     | 
| 9 | 
         
             
            from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
         
     | 
| 10 | 
         | 
| 11 | 
         
             
            class TestVad(unittest.TestCase):
         
     | 
| 
         | 
|
| 14 | 
         
             
                    self.transcribe_calls = []
         
     | 
| 15 | 
         | 
| 16 | 
         
             
                def test_transcript(self):
         
     | 
| 17 | 
         
            +
                    mock = MockVadTranscription(mock_audio_length=120)
         
     | 
| 18 | 
         
            +
                    config = TranscriptionConfig()
         
     | 
| 19 | 
         | 
| 20 | 
         
             
                    self.transcribe_calls.clear()
         
     | 
| 21 | 
         
            +
                    result = mock.transcribe("mock", LambdaWhisperCallback(lambda segment, _1, _2, _3, _4: self.transcribe_segments(segment)), config)
         
     | 
| 22 | 
         | 
| 23 | 
         
             
                    self.assertListEqual(self.transcribe_calls, [ 
         
     | 
| 24 | 
         
             
                        [30, 30],
         
     | 
| 
         | 
|
| 47 | 
         
             
                    }
         
     | 
| 48 | 
         | 
| 49 | 
         
             
            class MockVadTranscription(AbstractTranscription):
         
     | 
| 50 | 
         
            +
                def __init__(self, mock_audio_length: float = 1000):
         
     | 
| 51 | 
         
             
                    super().__init__()
         
     | 
| 52 | 
         
            +
                    self.mock_audio_length = mock_audio_length
         
     | 
| 53 | 
         | 
| 54 | 
         
             
                def get_audio_segment(self, str, start_time: str = None, duration: str = None):
         
     | 
| 55 | 
         
             
                    start_time_seconds = float(start_time.removesuffix("s"))
         
     | 
| 
         | 
|
| 64 | 
         
             
                    result.append( {  'start': 30, 'end': 60 } )
         
     | 
| 65 | 
         
             
                    result.append( {  'start': 100, 'end': 200 } )
         
     | 
| 66 | 
         
             
                    return result
         
     | 
| 67 | 
         
            +
                    
         
     | 
| 68 | 
         
            +
                def get_audio_duration(self, audio: str, config: TranscriptionConfig):
         
     | 
| 69 | 
         
            +
                    return self.mock_audio_length
         
     | 
| 70 | 
         | 
| 71 | 
         
             
            if __name__ == '__main__':
         
     | 
| 72 | 
         
             
                unittest.main()
         
     |