Fix unit test
Browse files- src/whisper/abstractWhisperContainer.py +10 -2
- tests/vad_test.py +10 -4
src/whisper/abstractWhisperContainer.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import abc
|
2 |
-
from typing import List
|
3 |
|
4 |
from src.config import ModelConfig, VadInitialPromptMode
|
5 |
|
@@ -9,7 +9,7 @@ from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
|
9 |
|
10 |
class AbstractWhisperCallback:
|
11 |
def __init__(self):
|
12 |
-
|
13 |
|
14 |
@abc.abstractmethod
|
15 |
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
@@ -29,6 +29,14 @@ class AbstractWhisperCallback:
|
|
29 |
"""
|
30 |
raise NotImplementedError()
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
class AbstractWhisperContainer:
|
33 |
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
|
34 |
download_root: str = None,
|
|
|
1 |
import abc
|
2 |
+
from typing import Any, Callable, List
|
3 |
|
4 |
from src.config import ModelConfig, VadInitialPromptMode
|
5 |
|
|
|
9 |
|
10 |
class AbstractWhisperCallback:
|
11 |
def __init__(self):
|
12 |
+
pass
|
13 |
|
14 |
@abc.abstractmethod
|
15 |
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
|
|
29 |
"""
|
30 |
raise NotImplementedError()
|
31 |
|
32 |
+
class LambdaWhisperCallback(AbstractWhisperCallback):
|
33 |
+
def __init__(self, callback_lambda: Callable[[Any, int, str, str, ProgressListener], None]):
|
34 |
+
super().__init__()
|
35 |
+
self.callback_lambda = callback_lambda
|
36 |
+
|
37 |
+
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
38 |
+
return self.callback_lambda(audio, segment_index, prompt, detected_language, progress_listener)
|
39 |
+
|
40 |
class AbstractWhisperContainer:
|
41 |
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
|
42 |
download_root: str = None,
|
tests/vad_test.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
-
import pprint
|
2 |
import unittest
|
3 |
import numpy as np
|
4 |
import sys
|
5 |
|
6 |
sys.path.append('../whisper-webui')
|
|
|
7 |
|
|
|
8 |
from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
|
9 |
|
10 |
class TestVad(unittest.TestCase):
|
@@ -13,10 +14,11 @@ class TestVad(unittest.TestCase):
|
|
13 |
self.transcribe_calls = []
|
14 |
|
15 |
def test_transcript(self):
|
16 |
-
mock = MockVadTranscription()
|
|
|
17 |
|
18 |
self.transcribe_calls.clear()
|
19 |
-
result = mock.transcribe("mock", lambda segment : self.transcribe_segments(segment))
|
20 |
|
21 |
self.assertListEqual(self.transcribe_calls, [
|
22 |
[30, 30],
|
@@ -45,8 +47,9 @@ class TestVad(unittest.TestCase):
|
|
45 |
}
|
46 |
|
47 |
class MockVadTranscription(AbstractTranscription):
|
48 |
-
def __init__(self):
|
49 |
super().__init__()
|
|
|
50 |
|
51 |
def get_audio_segment(self, str, start_time: str = None, duration: str = None):
|
52 |
start_time_seconds = float(start_time.removesuffix("s"))
|
@@ -61,6 +64,9 @@ class MockVadTranscription(AbstractTranscription):
|
|
61 |
result.append( { 'start': 30, 'end': 60 } )
|
62 |
result.append( { 'start': 100, 'end': 200 } )
|
63 |
return result
|
|
|
|
|
|
|
64 |
|
65 |
if __name__ == '__main__':
|
66 |
unittest.main()
|
|
|
|
|
1 |
import unittest
|
2 |
import numpy as np
|
3 |
import sys
|
4 |
|
5 |
sys.path.append('../whisper-webui')
|
6 |
+
#print("Sys path: " + str(sys.path))
|
7 |
|
8 |
+
from src.whisper.abstractWhisperContainer import LambdaWhisperCallback
|
9 |
from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
|
10 |
|
11 |
class TestVad(unittest.TestCase):
|
|
|
14 |
self.transcribe_calls = []
|
15 |
|
16 |
def test_transcript(self):
|
17 |
+
mock = MockVadTranscription(mock_audio_length=120)
|
18 |
+
config = TranscriptionConfig()
|
19 |
|
20 |
self.transcribe_calls.clear()
|
21 |
+
result = mock.transcribe("mock", LambdaWhisperCallback(lambda segment, _1, _2, _3, _4: self.transcribe_segments(segment)), config)
|
22 |
|
23 |
self.assertListEqual(self.transcribe_calls, [
|
24 |
[30, 30],
|
|
|
47 |
}
|
48 |
|
49 |
class MockVadTranscription(AbstractTranscription):
|
50 |
+
def __init__(self, mock_audio_length: float = 1000):
|
51 |
super().__init__()
|
52 |
+
self.mock_audio_length = mock_audio_length
|
53 |
|
54 |
def get_audio_segment(self, str, start_time: str = None, duration: str = None):
|
55 |
start_time_seconds = float(start_time.removesuffix("s"))
|
|
|
64 |
result.append( { 'start': 30, 'end': 60 } )
|
65 |
result.append( { 'start': 100, 'end': 200 } )
|
66 |
return result
|
67 |
+
|
68 |
+
def get_audio_duration(self, audio: str, config: TranscriptionConfig):
|
69 |
+
return self.mock_audio_length
|
70 |
|
71 |
if __name__ == '__main__':
|
72 |
unittest.main()
|