Spaces:
Running
Running
import os | |
import shutil | |
import traceback | |
import faiss | |
import gradio as gr | |
import numpy as np | |
from sklearn.cluster import MiniBatchKMeans | |
from random import shuffle | |
from glob import glob | |
from infer.modules.train.train import train | |
from zero import zero | |
def write_filelist(exp_dir: str) -> None: | |
if_f0_3 = True | |
spk_id5 = 0 | |
gt_wavs_dir = "%s/0_gt_wavs" % (exp_dir) | |
feature_dir = "%s/3_feature768" % (exp_dir) | |
if if_f0_3: | |
f0_dir = "%s/2a_f0" % (exp_dir) | |
f0nsf_dir = "%s/2b-f0nsf" % (exp_dir) | |
names = ( | |
set([name.split(".")[0] for name in os.listdir(gt_wavs_dir)]) | |
& set([name.split(".")[0] for name in os.listdir(feature_dir)]) | |
& set([name.split(".")[0] for name in os.listdir(f0_dir)]) | |
& set([name.split(".")[0] for name in os.listdir(f0nsf_dir)]) | |
) | |
else: | |
names = set([name.split(".")[0] for name in os.listdir(gt_wavs_dir)]) & set( | |
[name.split(".")[0] for name in os.listdir(feature_dir)] | |
) | |
opt = [] | |
for name in names: | |
if if_f0_3: | |
opt.append( | |
"%s/%s.wav|%s/%s.npy|%s/%s.wav.npy|%s/%s.wav.npy|%s" | |
% ( | |
gt_wavs_dir.replace("\\", "\\\\"), | |
name, | |
feature_dir.replace("\\", "\\\\"), | |
name, | |
f0_dir.replace("\\", "\\\\"), | |
name, | |
f0nsf_dir.replace("\\", "\\\\"), | |
name, | |
spk_id5, | |
) | |
) | |
else: | |
opt.append( | |
"%s/%s.wav|%s/%s.npy|%s" | |
% ( | |
gt_wavs_dir.replace("\\", "\\\\"), | |
name, | |
feature_dir.replace("\\", "\\\\"), | |
name, | |
spk_id5, | |
) | |
) | |
fea_dim = 768 | |
now_dir = os.getcwd() | |
sr2 = "40k" | |
if if_f0_3: | |
for _ in range(2): | |
opt.append( | |
"%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature%s/mute.npy|%s/logs/mute/2a_f0/mute.wav.npy|%s/logs/mute/2b-f0nsf/mute.wav.npy|%s" | |
% (now_dir, sr2, now_dir, fea_dim, now_dir, now_dir, spk_id5) | |
) | |
else: | |
for _ in range(2): | |
opt.append( | |
"%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature%s/mute.npy|%s" | |
% (now_dir, sr2, now_dir, fea_dim, spk_id5) | |
) | |
shuffle(opt) | |
with open("%s/filelist.txt" % exp_dir, "w") as f: | |
f.write("\n".join(opt)) | |
def train_model(exp_dir: str) -> str: | |
shutil.copy("config.json", exp_dir) | |
write_filelist(exp_dir) | |
train(exp_dir) | |
models = glob(f"{exp_dir}/G_*.pth") | |
print(models) | |
if not models: | |
raise gr.Error("No model found") | |
latest_model = max(models, key=os.path.getctime) | |
return latest_model | |
def train_index(exp_dir: str) -> str: | |
feature_dir = "%s/3_feature768" % (exp_dir) | |
if not os.path.exists(feature_dir): | |
raise gr.Error("Please extract features first.") | |
listdir_res = list(os.listdir(feature_dir)) | |
if len(listdir_res) == 0: | |
raise gr.Error("Please extract features first.") | |
npys = [] | |
for name in sorted(listdir_res): | |
phone = np.load("%s/%s" % (feature_dir, name)) | |
npys.append(phone) | |
big_npy = np.concatenate(npys, 0) | |
big_npy_idx = np.arange(big_npy.shape[0]) | |
np.random.shuffle(big_npy_idx) | |
big_npy = big_npy[big_npy_idx] | |
if big_npy.shape[0] > 2e5: | |
print("Trying doing kmeans %s shape to 10k centers." % big_npy.shape[0]) | |
try: | |
big_npy = ( | |
MiniBatchKMeans( | |
n_clusters=10000, | |
verbose=True, | |
batch_size=256 * 8, | |
compute_labels=False, | |
init="random", | |
) | |
.fit(big_npy) | |
.cluster_centers_ | |
) | |
except: | |
info = traceback.format_exc() | |
print(info) | |
raise gr.Error(info) | |
np.save("%s/total_fea.npy" % exp_dir, big_npy) | |
n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39) | |
print("%s,%s" % (big_npy.shape, n_ivf)) | |
index = faiss.index_factory(768, "IVF%s,Flat" % n_ivf) | |
# index = faiss.index_factory(256if version19=="v1"else 768, "IVF%s,PQ128x4fs,RFlat"%n_ivf) | |
print("training") | |
index_ivf = faiss.extract_index_ivf(index) # | |
index_ivf.nprobe = 1 | |
index.train(big_npy) | |
faiss.write_index( | |
index, | |
"%s/trained_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe), | |
) | |
print("adding") | |
batch_size_add = 8192 | |
for i in range(0, big_npy.shape[0], batch_size_add): | |
index.add(big_npy[i : i + batch_size_add]) | |
faiss.write_index( | |
index, | |
"%s/added_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe), | |
) | |
print("built added_IVF%s_Flat_nprobe_%s.index" % (n_ivf, index_ivf.nprobe)) | |
return "%s/added_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe) | |
class TrainTab: | |
def __init__(self): | |
pass | |
def ui(self): | |
gr.Markdown("# Training") | |
gr.Markdown( | |
"You can start training the model by clicking the button below. " | |
"Each time you click the button, the model will train for 10 epochs, which takes about 3 minutes on ZeroGPU (A100). " | |
"Tha latest *training checkpoint* will be avaible below." | |
) | |
with gr.Row(): | |
self.train_btn = gr.Button(value="Train", variant="primary") | |
self.latest_checkpoint = gr.File(label="Latest checkpoint") | |
with gr.Row(): | |
self.train_index_btn = gr.Button(value="Train index", variant="primary") | |
self.trained_index = gr.File(label="Trained index") | |
def build(self, exp_dir: gr.Textbox): | |
self.train_btn.click( | |
fn=train_model, | |
inputs=[exp_dir], | |
outputs=[self.latest_checkpoint], | |
) | |
self.train_index_btn.click( | |
fn=train_index, | |
inputs=[exp_dir], | |
outputs=[self.trained_index], | |
) | |