File size: 4,216 Bytes
1df74c6
 
 
d2b7e94
 
1df74c6
 
d2b7e94
1df74c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr

from modules.Enhancer.ResembleEnhance import unload_enhancer
from modules.models import unload_chat_tts
from modules.speaker import speaker_mgr
from modules.webui import webui_config
from modules.webui.webui_utils import get_speaker_names

from .ft_ui_utils import get_datasets_listfile, run_speaker_ft
from .ProcessMonitor import ProcessMonitor


class SpeakerFt:
    def __init__(self):
        self.process_monitor = ProcessMonitor()
        self.status_str = "idle"

    def unload_main_thread_models(self):
        unload_chat_tts()
        unload_enhancer()

    def run(
        self,
        batch_size: int,
        epochs: int,
        lr: str,
        train_text: bool,
        data_path: str,
        select_speaker: str = "",
    ):
        if self.process_monitor.process:
            return
        self.unload_main_thread_models()
        spk_path = None
        if select_speaker != "" and select_speaker != "none":
            select_speaker = select_speaker.split(" : ")[1].strip()
            spk = speaker_mgr.get_speaker(select_speaker)
            if spk is None:
                return ["Speaker not found"]
            spk_filename = speaker_mgr.get_speaker_filename(spk.id)
            spk_path = f"./data/speakers/{spk_filename}"

        command = ["python3", "-m", "modules.finetune.train_speaker"]
        command += [
            f"--batch_size={batch_size}",
            f"--epochs={epochs}",
            f"--data_path={data_path}",
        ]
        if train_text:
            command.append("--train_text")
        if spk_path:
            command.append(f"--init_speaker={spk_path}")

        self.status("Training process starting")

        self.process_monitor.start_process(command)

        self.status("Training started")

    def status(self, text: str):
        self.status_str = text

    def flush(self):
        stdout, stderr = self.process_monitor.get_output()
        return f"{self.status_str}\n{stdout}\n{stderr}"

    def clear(self):
        self.process_monitor.stdout = ""
        self.process_monitor.stderr = ""
        self.status("Logs cleared")

    def stop(self):
        self.process_monitor.stop_process()
        self.status("Training stopped")


def create_speaker_ft_tab(demo: gr.Blocks):
    spk_ft = SpeakerFt()
    speakers, speaker_names = get_speaker_names()
    speaker_names = ["none"] + speaker_names

    with gr.Row():
        with gr.Column(scale=2):
            with gr.Group():
                gr.Markdown("🎛️hparams")
                dataset_input = gr.Dropdown(
                    label="Dataset", choices=get_datasets_listfile()
                )
                lr_input = gr.Textbox(label="Learning Rate", value="1e-2")
                epochs_input = gr.Slider(
                    label="Epochs", value=10, minimum=1, maximum=100, step=1
                )
                batch_size_input = gr.Slider(
                    label="Batch Size", value=4, minimum=1, maximum=64, step=1
                )
                train_text_checkbox = gr.Checkbox(label="Train text_loss", value=True)
                init_spk_dropdown = gr.Dropdown(
                    label="Initial Speaker",
                    choices=speaker_names,
                    value="none",
                )

            with gr.Group():
                start_train_btn = gr.Button("Start Training")
                stop_train_btn = gr.Button("Stop Training")
                clear_train_btn = gr.Button("Clear logs")
        with gr.Column(scale=5):
            with gr.Group():
                # log
                gr.Markdown("📜logs")
                log_output = gr.Textbox(
                    show_label=False, label="Log", value="", lines=20, interactive=True
                )

    start_train_btn.click(
        spk_ft.run,
        inputs=[
            batch_size_input,
            epochs_input,
            lr_input,
            train_text_checkbox,
            dataset_input,
            init_spk_dropdown,
        ],
        outputs=[],
    )
    stop_train_btn.click(spk_ft.stop)
    clear_train_btn.click(spk_ft.clear)

    if webui_config.experimental:
        demo.load(spk_ft.flush, every=1, outputs=[log_output])