File size: 5,887 Bytes
06a60a3
 
a791472
06a60a3
a791472
 
06a60a3
a791472
 
b91e31d
06a60a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b91e31d
a791472
 
 
 
 
b91e31d
 
a791472
 
8d6d31e
a791472
b91e31d
 
 
06a60a3
 
 
 
 
 
b91e31d
 
 
 
06a60a3
 
 
 
 
dbf5fb1
7162f12
06a60a3
 
b91e31d
 
 
a791472
06a60a3
 
 
 
 
 
 
7162f12
b91e31d
a791472
06a60a3
 
 
 
 
 
 
 
 
b91e31d
 
06a60a3
b91e31d
06a60a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbf5fb1
b91e31d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06a60a3
b91e31d
 
06a60a3
 
 
 
 
b91e31d
 
 
06a60a3
b91e31d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from contextlib import contextmanager

import gradio as gr
from markup import get_text, highlight
from template import get_templates


templates = get_templates()


def fill_tab(title, explanation):
    """
    Fill the tab with the appropriate title and explanation.
    """
    return gr.Markdown(title), gr.Markdown(explanation)


@contextmanager
def new_section():
    """
    A context manager to create a new section in the interface. Equivalent of:
    ```python
    with gr.Row():
        with gr.Column():
            ...
    ```
    """
    with gr.Row():
        with gr.Column():
            yield


def change(inp, textbox):
    """Based on an `inp`, render and highlight the appropriate code sample.

    Args:
        inp (`str`):
            The input button from the interface.
        textbox (`str`):
            The textbox specifying the tab name from the interface.

    Returns:
        `tuple`: A tuple of the highlighted code diff, and the title for the section.
    """
    if textbox == "base":
        code, explanation, docs = get_text(inp, textbox)
        if inp == "Basic":
            return (
                highlight(code),
                "## Accelerate Code (Base Integration)",
                explanation,
                docs,
            )
        elif inp == "Calculating Metrics":
            return (highlight(code), f"## Accelerate Code ({inp})", explanation, docs)
        else:
            return (highlight(code), f"## Accelerate Code ({inp})", explanation, docs)
    elif textbox == "training_configuration":
        yaml, changes, command, explanation, docs = get_text(inp, textbox)
        return (highlight(yaml), highlight(changes), command, explanation, docs)
    else:
        raise ValueError(f"Invalid tab name: {textbox}")


default_base = change("Basic", "base")
default_training_config = change("Multi GPU", "training_configuration")


def base_features(textbox):
    inp = gr.Radio(
        [
            "Basic",
            "Calculating Metrics",
            "Checkpointing",
            "Experiment Tracking",
            "Gradient Accumulation",
        ],
        label="Select a feature you would like to integrate",
        value="Basic",
    )
    with new_section():
        feature, out = fill_tab("## Accelerate Code", default_base[0])
    with new_section():
        _, explanation = fill_tab("## Explanation", default_base[2])
    with new_section():
        _, docs = fill_tab("## Documentation Links", default_base[3])
    inp.change(
        fn=change, inputs=[inp, textbox], outputs=[out, feature, explanation, docs]
    )


def training_config(textbox):
    inp = gr.Radio(
        [
            "AWS SageMaker",
            "DeepSpeed",
            "Megatron-LM",
            "Multi GPU",
            "Multi Node Multi GPU",
            "PyTorch FSDP",
        ],
        label="Select a distributed YAML configuration you would like to view.",
        value="Multi GPU",
    )
    with new_section():
        _, yaml = fill_tab("## Example YAML Configuration", default_training_config[0])
    with new_section():
        _, changes = fill_tab(
            "## Changes to Training Script", default_training_config[1]
        )
    with new_section():
        _, command = fill_tab("## Command to Run Training", default_training_config[2])
    with new_section():
        _, explanation = fill_tab("## Explanation", default_training_config[3])
    with new_section():
        _, docs = fill_tab("## Documentation Links", default_training_config[4])
    inp.change(
        fn=change,
        inputs=[inp, textbox],
        outputs=[yaml, changes, command, explanation, docs],
    )


# def big_model_inference():
#     inp = gr.Radio(
#         ["Accelerate's Big Model Inference",], # "DeepSpeed ZeRO Stage-3 Offload"
#         label="Select a feature you would like to integrate",
#         value="Basic",
#     )
#     with gr.Row():
#         with gr.Column():
#             feature = gr.Markdown("## Accelerate Code")
#             out = gr.Markdown(default[0])
#     with gr.Row():
#         with gr.Column():
#             gr.Markdown(default[1])
#             explanation = gr.Markdown(default[2])
#     with gr.Row():
#         with gr.Column():
#             gr.Markdown("## Documentation Links")
#             docs = gr.Markdown(default[3])
#     inp.change(fn=change, inputs=[inp, "big_model_inference"], outputs=[out, feature, explanation, docs])


# def notebook_launcher():
#     inp = gr.Radio(
#         ["Colab GPU", "Colab TPU", "Kaggle GPU", "Kaggle Multi GPU", "Kaggle TPU", "Multi GPU VMs"],
#         label="Select a feature you would like to integrate",
#         value="Basic",
#     )
#     with gr.Row():
#         with gr.Column():
#             feature = gr.Markdown("## Accelerate Code")
#             out = gr.Markdown(default[0])
#     with gr.Row():
#         with gr.Column():
#             gr.Markdown(default[1])
#             explanation = gr.Markdown(default[2])
#     with gr.Row():
#         with gr.Column():
#             gr.Markdown("## Documentation Links")
#             docs = gr.Markdown(default[3])
#     inp.change(fn=change, inputs=[inp, "notebook_launcher"], outputs=[out, feature, explanation, docs])


with gr.Blocks() as demo:

    with gr.Tabs():
        with gr.TabItem("Basic Training Integration"):
            textbox = gr.Textbox(label="tab_name", visible=False, value="base")
            base_features(textbox)
        with gr.TabItem("Launch Configuration"):
            textbox = gr.Textbox(
                label="tab_name", visible=False, value="training_configuration"
            )
            training_config(textbox)
        with gr.TabItem("Big Model Inference"):
            # big_model_inference()
            pass
        with gr.TabItem("Launching from Notebooks"):
            # notebook_launcher()
            pass


demo.launch()