File size: 2,473 Bytes
d0bfdd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import contextlib
from typing import Any, Iterable, Iterator, Optional

try:
    from tqdm import tqdm
except ImportError:
    tqdm = None

try:
    from ray.experimental.tqdm_ray import tqdm as ray_tqdm
except:
    ray_tqdm = None

# Global state
_current_progress_type = "none"
_is_progress_bar_active = False


class DummyProgressBar:
    """A no-op progress bar that mimics tqdm interface"""

    def __init__(self, iterable=None, **kwargs):
        self.iterable = iterable

    def __iter__(self):
        return iter(self.iterable)

    def update(self, n=1):
        pass

    def close(self):
        pass

    def set_description(self, desc):
        pass


def get_new_progress_bar(iterable: Optional[Iterable] = None, **kwargs) -> Any:
    if not _is_progress_bar_active:
        return DummyProgressBar(iterable=iterable, **kwargs)

    if _current_progress_type == "tqdm":
        if tqdm is None:
            raise ImportError("tqdm is required but not installed. Please install tqdm to use the tqdm progress bar.")
        return tqdm(iterable=iterable, **kwargs)
    elif _current_progress_type == "ray_tqdm":
        if ray_tqdm is None:
            raise ImportError("ray is required but not installed. Please install ray to use the ray_tqdm progress bar.")
        return ray_tqdm(iterable=iterable, **kwargs)
    return DummyProgressBar(iterable=iterable, **kwargs)


@contextlib.contextmanager
def progress_bar(type: str = "none", enabled=True):
    """
    Context manager for setting progress bar type and options.

    Args:
        type: Type of progress bar ("none" or "tqdm")
        **options: Options to pass to the progress bar (e.g., total, desc)

    Raises:
        ValueError: If progress bar type is invalid
        RuntimeError: If progress bars are nested

    Example:
        with progress_bar(type="tqdm", total=100):
            for i in get_new_progress_bar(range(100)):
                process(i)
    """
    if type not in ("none", "tqdm", "ray_tqdm"):
        raise ValueError("Progress bar type must be 'none' or 'tqdm' or 'ray_tqdm'")
    if not enabled:
        type = "none"
    global _current_progress_type, _is_progress_bar_active

    if _is_progress_bar_active:
        raise RuntimeError("Nested progress bars are not supported")

    _is_progress_bar_active = True
    _current_progress_type = type

    try:
        yield
    finally:
        _is_progress_bar_active = False
        _current_progress_type = "none"