mm / src /genmo /lib /progress.py
nruto's picture
Upload 31 files
d0bfdd6 verified
import contextlib
from typing import Any, Iterable, Iterator, Optional
try:
from tqdm import tqdm
except ImportError:
tqdm = None
try:
from ray.experimental.tqdm_ray import tqdm as ray_tqdm
except:
ray_tqdm = None
# Global state
_current_progress_type = "none"
_is_progress_bar_active = False
class DummyProgressBar:
"""A no-op progress bar that mimics tqdm interface"""
def __init__(self, iterable=None, **kwargs):
self.iterable = iterable
def __iter__(self):
return iter(self.iterable)
def update(self, n=1):
pass
def close(self):
pass
def set_description(self, desc):
pass
def get_new_progress_bar(iterable: Optional[Iterable] = None, **kwargs) -> Any:
if not _is_progress_bar_active:
return DummyProgressBar(iterable=iterable, **kwargs)
if _current_progress_type == "tqdm":
if tqdm is None:
raise ImportError("tqdm is required but not installed. Please install tqdm to use the tqdm progress bar.")
return tqdm(iterable=iterable, **kwargs)
elif _current_progress_type == "ray_tqdm":
if ray_tqdm is None:
raise ImportError("ray is required but not installed. Please install ray to use the ray_tqdm progress bar.")
return ray_tqdm(iterable=iterable, **kwargs)
return DummyProgressBar(iterable=iterable, **kwargs)
@contextlib.contextmanager
def progress_bar(type: str = "none", enabled=True):
"""
Context manager for setting progress bar type and options.
Args:
type: Type of progress bar ("none" or "tqdm")
**options: Options to pass to the progress bar (e.g., total, desc)
Raises:
ValueError: If progress bar type is invalid
RuntimeError: If progress bars are nested
Example:
with progress_bar(type="tqdm", total=100):
for i in get_new_progress_bar(range(100)):
process(i)
"""
if type not in ("none", "tqdm", "ray_tqdm"):
raise ValueError("Progress bar type must be 'none' or 'tqdm' or 'ray_tqdm'")
if not enabled:
type = "none"
global _current_progress_type, _is_progress_bar_active
if _is_progress_bar_active:
raise RuntimeError("Nested progress bars are not supported")
_is_progress_bar_active = True
_current_progress_type = type
try:
yield
finally:
_is_progress_bar_active = False
_current_progress_type = "none"