File size: 4,386 Bytes
40a8f4e
82f1bf5
 
9bd8d8b
 
 
82f1bf5
40a8f4e
 
62b53be
8b0ae10
40a8f4e
fdddf65
40a8f4e
fdddf65
82f1bf5
a5d7977
82f1bf5
40a8f4e
 
 
 
82f1bf5
 
fa0947b
40a8f4e
 
 
 
 
79d936d
82f1bf5
a5d7977
 
 
 
 
fd15ecb
 
 
a5d7977
 
 
 
 
 
 
9bd8d8b
 
 
9b019ea
 
 
9bd8d8b
 
 
 
 
40a8f4e
 
82f1bf5
40a8f4e
 
 
 
 
82f1bf5
40a8f4e
 
fd15ecb
40a8f4e
 
fd15ecb
40a8f4e
 
fd15ecb
40a8f4e
 
03b5741
79d936d
 
 
40a8f4e
 
82f1bf5
40a8f4e
 
a5d7977
40a8f4e
 
 
 
 
 
 
 
 
 
82f1bf5
62b53be
82f1bf5
40a8f4e
 
a5d7977
fdddf65
82f1bf5
 
40a8f4e
 
82f1bf5
 
9bd8d8b
 
 
 
 
 
 
 
 
 
 
 
 
82f1bf5
 
d019027
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from typing import Union

import gradio as gr
import fire
import os
import yaml

from llama_lora.config import Config, process_config
from llama_lora.globals import initialize_global
from llama_lora.utils.data import init_data_dir
from llama_lora.models import prepare_base_model
from llama_lora.ui.main_page import (
    main_page, get_page_title
)
from llama_lora.ui.css_styles import get_css_styles


def main(
    base_model: Union[str, None] = None,
    data_dir: Union[str, None] = None,
    base_model_choices: Union[str, None] = None,
    trust_remote_code: Union[bool, None] = None,
    server_name: str = "127.0.0.1",
    share: bool = False,
    skip_loading_base_model: bool = False,
    load_8bit: Union[bool, None] = None,
    ui_show_sys_info: Union[bool, None] = None,
    ui_dev_mode: Union[bool, None] = None,
    wandb_api_key: Union[str, None] = None,
    wandb_project: Union[str, None] = None,
    timezone: Union[str, None] = None,
):
    '''
    Start the LLaMA-LoRA Tuner UI.

    :param base_model: (required) The name of the default base model to use.
    :param data_dir: (required) The path to the directory to store data.

    :param base_model_choices: Base model selections to display on the UI, seperated by ",". For example: 'decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j'.

    :param server_name: Allows to listen on all interfaces by providing '0.0.0.0'.
    :param share: Create a public Gradio URL.

    :param wandb_api_key: The API key for Weights & Biases. Setting either this or `wandb_project` will enable Weights & Biases.
    :param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
    '''

    config_from_file = read_yaml_config()
    if config_from_file:
        for key, value in config_from_file.items():
            if key == "server_name":
                server_name = value
                continue
            if not hasattr(Config, key):
                available_keys = [k for k in vars(Config) if not k.startswith('__')]
                raise ValueError(f"Invalid config key '{key}' in config.yaml. Available keys: {', '.join(available_keys)}")
            setattr(Config, key, value)

    if base_model is not None:
        Config.default_base_model_name = base_model

    if base_model_choices is not None:
        Config.base_model_choices = base_model_choices

    if trust_remote_code is not None:
        Config.trust_remote_code = trust_remote_code

    if data_dir is not None:
        Config.data_dir = data_dir

    if load_8bit is not None:
        Config.load_8bit = load_8bit

    if wandb_api_key is not None:
        Config.wandb_api_key = wandb_api_key

    if wandb_project is not None:
        Config.default_wandb_project = wandb_project

    if timezone is not None:
        Config.timezone = timezone

    if ui_dev_mode is not None:
        Config.ui_dev_mode = ui_dev_mode

    if ui_show_sys_info is not None:
        Config.ui_show_sys_info = ui_show_sys_info

    process_config()
    initialize_global()

    assert (
        Config.default_base_model_name
    ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"

    assert (
        Config.data_dir
    ), "Please specify a --data_dir, e.g. --data_dir='./data'"

    init_data_dir()

    if (not skip_loading_base_model) and (not Config.ui_dev_mode):
        prepare_base_model(Config.default_base_model_name)

    with gr.Blocks(title=get_page_title(), css=get_css_styles()) as demo:
        main_page()

    demo.queue(concurrency_count=1).launch(
        server_name=server_name, share=share)


def read_yaml_config():
    app_dir = os.path.dirname(os.path.abspath(__file__))
    config_path = os.path.join(app_dir, 'config.yaml')

    if not os.path.exists(config_path):
        return None

    print(f"Loading config from {config_path}...")
    with open(config_path, 'r') as yaml_file:
        config = yaml.safe_load(yaml_file)
    return config


if __name__ == "__main__":
    fire.Fire(main)
elif __name__ == "app":  # running in gradio reload mode (`gradio`)
    try:
        main()
    except AssertionError as e:
        message = str(e)
        message += "\nNote that command line args are not supported while running in gradio reload mode, config.yaml must be used."
        raise AssertionError(message) from e