Spaces:
Sleeping
Sleeping
import gradio as gr | |
from modules.Enhancer.ResembleEnhance import unload_enhancer | |
from modules.models import unload_chat_tts | |
from modules.speaker import speaker_mgr | |
from modules.webui import webui_config | |
from modules.webui.webui_utils import get_speaker_names | |
from .ft_ui_utils import get_datasets_listfile, run_speaker_ft | |
from .ProcessMonitor import ProcessMonitor | |
class SpeakerFt: | |
def __init__(self): | |
self.process_monitor = ProcessMonitor() | |
self.status_str = "idle" | |
def unload_main_thread_models(self): | |
unload_chat_tts() | |
unload_enhancer() | |
def run( | |
self, | |
batch_size: int, | |
epochs: int, | |
lr: str, | |
train_text: bool, | |
data_path: str, | |
select_speaker: str = "", | |
): | |
if self.process_monitor.process: | |
return | |
self.unload_main_thread_models() | |
spk_path = None | |
if select_speaker != "" and select_speaker != "none": | |
select_speaker = select_speaker.split(" : ")[1].strip() | |
spk = speaker_mgr.get_speaker(select_speaker) | |
if spk is None: | |
return ["Speaker not found"] | |
spk_filename = speaker_mgr.get_speaker_filename(spk.id) | |
spk_path = f"./data/speakers/{spk_filename}" | |
command = ["python3", "-m", "modules.finetune.train_speaker"] | |
command += [ | |
f"--batch_size={batch_size}", | |
f"--epochs={epochs}", | |
f"--data_path={data_path}", | |
] | |
if train_text: | |
command.append("--train_text") | |
if spk_path: | |
command.append(f"--init_speaker={spk_path}") | |
self.status("Training process starting") | |
self.process_monitor.start_process(command) | |
self.status("Training started") | |
def status(self, text: str): | |
self.status_str = text | |
def flush(self): | |
stdout, stderr = self.process_monitor.get_output() | |
return f"{self.status_str}\n{stdout}\n{stderr}" | |
def clear(self): | |
self.process_monitor.stdout = "" | |
self.process_monitor.stderr = "" | |
self.status("Logs cleared") | |
def stop(self): | |
self.process_monitor.stop_process() | |
self.status("Training stopped") | |
def create_speaker_ft_tab(demo: gr.Blocks): | |
spk_ft = SpeakerFt() | |
speakers, speaker_names = get_speaker_names() | |
speaker_names = ["none"] + speaker_names | |
with gr.Row(): | |
with gr.Column(scale=2): | |
with gr.Group(): | |
gr.Markdown("🎛️hparams") | |
dataset_input = gr.Dropdown( | |
label="Dataset", choices=get_datasets_listfile() | |
) | |
lr_input = gr.Textbox(label="Learning Rate", value="1e-2") | |
epochs_input = gr.Slider( | |
label="Epochs", value=10, minimum=1, maximum=100, step=1 | |
) | |
batch_size_input = gr.Slider( | |
label="Batch Size", value=4, minimum=1, maximum=64, step=1 | |
) | |
train_text_checkbox = gr.Checkbox(label="Train text_loss", value=True) | |
init_spk_dropdown = gr.Dropdown( | |
label="Initial Speaker", | |
choices=speaker_names, | |
value="none", | |
) | |
with gr.Group(): | |
start_train_btn = gr.Button("Start Training") | |
stop_train_btn = gr.Button("Stop Training") | |
clear_train_btn = gr.Button("Clear logs") | |
with gr.Column(scale=5): | |
with gr.Group(): | |
# log | |
gr.Markdown("📜logs") | |
log_output = gr.Textbox( | |
show_label=False, label="Log", value="", lines=20, interactive=True | |
) | |
start_train_btn.click( | |
spk_ft.run, | |
inputs=[ | |
batch_size_input, | |
epochs_input, | |
lr_input, | |
train_text_checkbox, | |
dataset_input, | |
init_spk_dropdown, | |
], | |
outputs=[], | |
) | |
stop_train_btn.click(spk_ft.stop) | |
clear_train_btn.click(spk_ft.clear) | |
if webui_config.experimental: | |
demo.load(spk_ft.flush, every=1, outputs=[log_output]) | |