whisper-webui / src /hooks /whisperProgressHook.py
aadnk's picture
Add progress listener to none/VAD
33a2c1e
raw
history blame
3.95 kB
import sys
import threading
from typing import List, Union
import tqdm
class ProgressListener:
def on_progress(self, current: Union[int, float], total: Union[int, float]):
self.total = total
def on_finished(self):
pass
class ProgressListenerHandle:
def __init__(self, listener: ProgressListener):
self.listener = listener
def __enter__(self):
register_thread_local_progress_listener(self.listener)
def __exit__(self, exc_type, exc_val, exc_tb):
unregister_thread_local_progress_listener(self.listener)
if exc_type is None:
self.listener.on_finished()
class SubTaskProgressListener(ProgressListener):
"""
A sub task listener that reports the progress of a sub task to a base task listener
Parameters
----------
base_task_listener : ProgressListener
The base progress listener to accumulate overall progress in.
base_task_total : float
The maximum total progress that will be reported to the base progress listener.
sub_task_start : float
The starting progress of a sub task, in respect to the base progress listener.
sub_task_total : float
The total amount of progress a sub task will report to the base progress listener.
"""
def __init__(
self,
base_task_listener: ProgressListener,
base_task_total: float,
sub_task_start: float,
sub_task_total: float,
):
self.base_task_listener = base_task_listener
self.base_task_total = base_task_total
self.sub_task_start = sub_task_start
self.sub_task_total = sub_task_total
def on_progress(self, current: Union[int, float], total: Union[int, float]):
sub_task_progress_frac = current / total
sub_task_progress = self.sub_task_start + self.sub_task_total * sub_task_progress_frac
self.base_task_listener.on_progress(sub_task_progress, self.base_task_total)
def on_finished(self):
self.base_task_listener.on_progress(self.sub_task_start + self.sub_task_total, self.base_task_total)
class _CustomProgressBar(tqdm.tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._current = self.n # Set the initial value
def update(self, n):
super().update(n)
# Because the progress bar might be disabled, we need to manually update the progress
self._current += n
# Inform listeners
listeners = _get_thread_local_listeners()
for listener in listeners:
listener.on_progress(self._current, self.total)
_thread_local = threading.local()
def _get_thread_local_listeners():
if not hasattr(_thread_local, 'listeners'):
_thread_local.listeners = []
return _thread_local.listeners
_hooked = False
def init_progress_hook():
global _hooked
if _hooked:
return
# Inject into tqdm.tqdm of Whisper, so we can see progress
import whisper.transcribe
transcribe_module = sys.modules['whisper.transcribe']
transcribe_module.tqdm.tqdm = _CustomProgressBar
_hooked = True
def register_thread_local_progress_listener(progress_listener: ProgressListener):
# This is a workaround for the fact that the progress bar is not exposed in the API
init_progress_hook()
listeners = _get_thread_local_listeners()
listeners.append(progress_listener)
def unregister_thread_local_progress_listener(progress_listener: ProgressListener):
listeners = _get_thread_local_listeners()
if progress_listener in listeners:
listeners.remove(progress_listener)
def create_progress_listener_handle(progress_listener: ProgressListener):
return ProgressListenerHandle(progress_listener)
if __name__ == '__main__':
with create_progress_listener_handle(ProgressListener()) as listener:
# Call model.transcribe here
pass
print("Done")