File size: 3,025 Bytes
82f1bf5
 
 
 
 
 
 
a5d7977
d754e91
62b53be
82f1bf5
 
a5d7977
82f1bf5
 
 
fd15ecb
82f1bf5
 
 
fa0947b
a5d7977
82f1bf5
d754e91
a5d7977
 
82f1bf5
a5d7977
 
 
 
 
fd15ecb
 
 
a5d7977
 
 
 
 
 
 
d754e91
 
82f1bf5
 
 
 
 
 
 
 
fd15ecb
 
 
 
 
 
 
 
 
 
d754e91
82f1bf5
 
a5d7977
 
 
 
 
 
 
d754e91
82f1bf5
 
 
62b53be
82f1bf5
a5d7977
 
 
d754e91
82f1bf5
 
27c4f46
82f1bf5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import sys

import fire
import gradio as gr

from llama_lora.globals import Global
from llama_lora.models import prepare_base_model
from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
from llama_lora.utils.data import init_data_dir



def main(
    base_model: str = "",
    data_dir: str = "",
    base_model_choices: str = "",
    # Allows to listen on all interfaces by providing '0.0.0.0'.
    server_name: str = "127.0.0.1",
    share: bool = False,
    skip_loading_base_model: bool = False,
    load_8bit: bool = False,
    ui_show_sys_info: bool = True,
    ui_dev_mode: bool = False,
    wandb_api_key: str = "",
    wandb_project: str = "",
):
    '''
    Start the LLaMA-LoRA Tuner UI.

    :param base_model: (required) The name of the default base model to use.
    :param data_dir: (required) The path to the directory to store data.

    :param base_model_choices: Base model selections to display on the UI, seperated by ",". For example: 'decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j'.

    :param server_name: Allows to listen on all interfaces by providing '0.0.0.0'.
    :param share: Create a public Gradio URL.

    :param wandb_api_key: The API key for Weights & Biases. Setting either this or `wandb_project` will enable Weights & Biases.
    :param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
    '''

    base_model = base_model or os.environ.get("LLAMA_LORA_BASE_MODEL", "")
    data_dir = data_dir or os.environ.get("LLAMA_LORA_DATA_DIR", "")
    assert (
        base_model
    ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"

    assert (
        data_dir
    ), "Please specify a --data_dir, e.g. --data_dir='./data'"

    Global.default_base_model_name = Global.base_model_name = base_model

    if base_model_choices:
        base_model_choices = base_model_choices.split(',')
        base_model_choices = [name.strip() for name in base_model_choices]
        Global.base_model_choices = base_model_choices

    if base_model not in Global.base_model_choices:
        Global.base_model_choices = [base_model] + Global.base_model_choices

    Global.data_dir = os.path.abspath(data_dir)
    Global.load_8bit = load_8bit

    if len(wandb_api_key) > 0:
        Global.enable_wandb = True
        Global.wandb_api_key = wandb_api_key
    if len(wandb_project) > 0:
        Global.enable_wandb = True
        Global.wandb_project = wandb_project

    Global.ui_dev_mode = ui_dev_mode
    Global.ui_show_sys_info = ui_show_sys_info

    os.makedirs(data_dir, exist_ok=True)
    init_data_dir()

    if (not skip_loading_base_model) and (not ui_dev_mode):
        prepare_base_model(base_model)

    with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
        main_page()

    demo.queue(concurrency_count=1).launch(server_name=server_name, share=share)


if __name__ == "__main__":
    fire.Fire(main)