File size: 2,473 Bytes
d0bfdd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import contextlib
from typing import Any, Iterable, Iterator, Optional
try:
from tqdm import tqdm
except ImportError:
tqdm = None
try:
from ray.experimental.tqdm_ray import tqdm as ray_tqdm
except:
ray_tqdm = None
# Global state
_current_progress_type = "none"
_is_progress_bar_active = False
class DummyProgressBar:
"""A no-op progress bar that mimics tqdm interface"""
def __init__(self, iterable=None, **kwargs):
self.iterable = iterable
def __iter__(self):
return iter(self.iterable)
def update(self, n=1):
pass
def close(self):
pass
def set_description(self, desc):
pass
def get_new_progress_bar(iterable: Optional[Iterable] = None, **kwargs) -> Any:
if not _is_progress_bar_active:
return DummyProgressBar(iterable=iterable, **kwargs)
if _current_progress_type == "tqdm":
if tqdm is None:
raise ImportError("tqdm is required but not installed. Please install tqdm to use the tqdm progress bar.")
return tqdm(iterable=iterable, **kwargs)
elif _current_progress_type == "ray_tqdm":
if ray_tqdm is None:
raise ImportError("ray is required but not installed. Please install ray to use the ray_tqdm progress bar.")
return ray_tqdm(iterable=iterable, **kwargs)
return DummyProgressBar(iterable=iterable, **kwargs)
@contextlib.contextmanager
def progress_bar(type: str = "none", enabled=True):
"""
Context manager for setting progress bar type and options.
Args:
type: Type of progress bar ("none" or "tqdm")
**options: Options to pass to the progress bar (e.g., total, desc)
Raises:
ValueError: If progress bar type is invalid
RuntimeError: If progress bars are nested
Example:
with progress_bar(type="tqdm", total=100):
for i in get_new_progress_bar(range(100)):
process(i)
"""
if type not in ("none", "tqdm", "ray_tqdm"):
raise ValueError("Progress bar type must be 'none' or 'tqdm' or 'ray_tqdm'")
if not enabled:
type = "none"
global _current_progress_type, _is_progress_bar_active
if _is_progress_bar_active:
raise RuntimeError("Nested progress bars are not supported")
_is_progress_bar_active = True
_current_progress_type = type
try:
yield
finally:
_is_progress_bar_active = False
_current_progress_type = "none"
|