import gradio as gr from modules.Enhancer.ResembleEnhance import unload_enhancer from modules.models import unload_chat_tts from modules.speaker import speaker_mgr from modules.webui import webui_config from modules.webui.webui_utils import get_speaker_names from .ft_ui_utils import get_datasets_listfile, run_speaker_ft from .ProcessMonitor import ProcessMonitor class SpeakerFt: def __init__(self): self.process_monitor = ProcessMonitor() self.status_str = "idle" def unload_main_thread_models(self): unload_chat_tts() unload_enhancer() def run( self, batch_size: int, epochs: int, lr: str, train_text: bool, data_path: str, select_speaker: str = "", ): if self.process_monitor.process: return self.unload_main_thread_models() spk_path = None if select_speaker != "" and select_speaker != "none": select_speaker = select_speaker.split(" : ")[1].strip() spk = speaker_mgr.get_speaker(select_speaker) if spk is None: return ["Speaker not found"] spk_filename = speaker_mgr.get_speaker_filename(spk.id) spk_path = f"./data/speakers/{spk_filename}" command = ["python3", "-m", "modules.finetune.train_speaker"] command += [ f"--batch_size={batch_size}", f"--epochs={epochs}", f"--data_path={data_path}", ] if train_text: command.append("--train_text") if spk_path: command.append(f"--init_speaker={spk_path}") self.status("Training process starting") self.process_monitor.start_process(command) self.status("Training started") def status(self, text: str): self.status_str = text def flush(self): stdout, stderr = self.process_monitor.get_output() return f"{self.status_str}\n{stdout}\n{stderr}" def clear(self): self.process_monitor.stdout = "" self.process_monitor.stderr = "" self.status("Logs cleared") def stop(self): self.process_monitor.stop_process() self.status("Training stopped") def create_speaker_ft_tab(demo: gr.Blocks): spk_ft = SpeakerFt() speakers, speaker_names = get_speaker_names() speaker_names = ["none"] + speaker_names with gr.Row(): with gr.Column(scale=2): with gr.Group(): gr.Markdown("🎛️hparams") dataset_input = gr.Dropdown( label="Dataset", choices=get_datasets_listfile() ) lr_input = gr.Textbox(label="Learning Rate", value="1e-2") epochs_input = gr.Slider( label="Epochs", value=10, minimum=1, maximum=100, step=1 ) batch_size_input = gr.Slider( label="Batch Size", value=4, minimum=1, maximum=64, step=1 ) train_text_checkbox = gr.Checkbox(label="Train text_loss", value=True) init_spk_dropdown = gr.Dropdown( label="Initial Speaker", choices=speaker_names, value="none", ) with gr.Group(): start_train_btn = gr.Button("Start Training") stop_train_btn = gr.Button("Stop Training") clear_train_btn = gr.Button("Clear logs") with gr.Column(scale=5): with gr.Group(): # log gr.Markdown("📜logs") log_output = gr.Textbox( show_label=False, label="Log", value="", lines=20, interactive=True ) start_train_btn.click( spk_ft.run, inputs=[ batch_size_input, epochs_input, lr_input, train_text_checkbox, dataset_input, init_spk_dropdown, ], outputs=[], ) stop_train_btn.click(spk_ft.stop) clear_train_btn.click(spk_ft.clear) if webui_config.experimental: demo.load(spk_ft.flush, every=1, outputs=[log_output])