File size: 1,739 Bytes
bd89ed8
 
 
b6a509f
9ad031d
bd89ed8
9471796
bd89ed8
 
9471796
 
9ad031d
 
 
 
9471796
 
 
 
bd89ed8
 
 
9471796
 
 
 
 
 
 
 
 
 
 
bd89ed8
 
 
 
6508f4c
bd89ed8
9471796
 
 
 
 
 
 
bd89ed8
9471796
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""
Holds the gradio app itself
"""

import os
import gradio as gr

from src.train_workflow import run, DEFAULT_TRAINING_ARGS
from src.calibration_datasets import CalibrationDataset


# TODO: install FA2 in a better way, e.g docker img
os.system("pip install flash-attn --no-build-isolation")


DESCRIPTION = """
The steps to create [medusa](https://sites.google.com/view/medusa-llm) heads are the following:

1. Input a public model id from the Hub
2. Select a dataset to train the medusa heads on. The dataset should be representative of the downstream use case.
3. Click "Submit"
4. That's it! You'll get feedback if it works or not, and if it worked, you'll get the name of the new repo πŸ”₯
"""

title="Create LLM medusa heads in a new repo 🐍"

with gr.Blocks(title=title) as demo:
    description = gr.Markdown(f"""# {title}""")
    description = gr.Markdown(DESCRIPTION)

    with gr.Row() as r:
        with gr.Column() as c:
            model_id = gr.Text(max_lines=1, label="model_id")
            dataset_names = [
                cls.dataset for cls in CalibrationDataset.__subclasses__()
            ]
            dataset = gr.Dropdown(dataset_names, label="dataset")
            with gr.Accordion("Training arguments (advanced)", open=False):
                training_args = gr.Textbox(DEFAULT_TRAINING_ARGS, interactive=True, lines=20, label="training_args")
            with gr.Row() as c:
                clean = gr.ClearButton()
                submit = gr.Button("Submit", variant="primary")

        with gr.Column() as d:
            status_box = gr.Markdown()

    submit.click(run, inputs=[model_id, training_args, dataset], outputs=status_box, concurrency_limit=1)

demo.queue(max_size=10).launch(show_api=True)