Spaces:
Runtime error
Runtime error
Duplicate from nateraw/voice-cloning
Browse filesCo-authored-by: Nate Raw <nateraw@users.noreply.huggingface.co>
- .gitattributes +34 -0
- Makefile +11 -0
- README.md +14 -0
- app.py +229 -0
- packages.txt +3 -0
- pyproject.toml +17 -0
- requirements.txt +6 -0
- training_so_vits_svc_fork.ipynb +540 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
Makefile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.PHONY: quality style
|
2 |
+
|
3 |
+
# Check that source code meets quality standards
|
4 |
+
quality:
|
5 |
+
black --check --diff .
|
6 |
+
ruff .
|
7 |
+
|
8 |
+
# Format source code automatically
|
9 |
+
style:
|
10 |
+
black .
|
11 |
+
ruff . --fix
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Voice Cloning
|
3 |
+
emoji: 😻
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.27.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
duplicated_from: nateraw/voice-cloning
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import subprocess
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import librosa
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from demucs.apply import apply_model
|
11 |
+
from demucs.pretrained import DEFAULT_MODEL, get_model
|
12 |
+
from huggingface_hub import hf_hub_download, list_repo_files
|
13 |
+
|
14 |
+
from so_vits_svc_fork.hparams import HParams
|
15 |
+
from so_vits_svc_fork.inference.core import Svc
|
16 |
+
|
17 |
+
|
18 |
+
###################################################################
|
19 |
+
# REPLACE THESE VALUES TO CHANGE THE MODEL REPO/CKPT NAME/SETTINGS
|
20 |
+
###################################################################
|
21 |
+
# The Hugging Face Hub repo ID
|
22 |
+
repo_id = "dog/kanye"
|
23 |
+
|
24 |
+
# If None, Uses latest ckpt in the repo
|
25 |
+
ckpt_name = None
|
26 |
+
|
27 |
+
# If None, Uses "kmeans.pt" if it exists in the repo
|
28 |
+
cluster_model_name = None
|
29 |
+
|
30 |
+
# Set the default f0 type to use - use the one it was trained on.
|
31 |
+
# The default for so-vits-svc-fork is "dio".
|
32 |
+
# Options: "crepe", "crepe-tiny", "parselmouth", "dio", "harvest"
|
33 |
+
default_f0_method = "crepe"
|
34 |
+
|
35 |
+
# The default ratio of cluster inference to SVC inference.
|
36 |
+
# If cluster_model_name is not found in the repo, this is set to 0.
|
37 |
+
default_cluster_infer_ratio = 0.5
|
38 |
+
|
39 |
+
# Limit on duration of audio at inference time. increase if you can
|
40 |
+
# In this parent app, we set the limit with an env var to 30 seconds
|
41 |
+
# If you didnt set env var + you go OOM try changing 9e9 to <=300ish
|
42 |
+
duration_limit = int(os.environ.get("MAX_DURATION_SECONDS", 9e9))
|
43 |
+
###################################################################
|
44 |
+
|
45 |
+
# Figure out the latest generator by taking highest value one.
|
46 |
+
# Ex. if the repo has: G_0.pth, G_100.pth, G_200.pth, we'd use G_200.pth
|
47 |
+
if ckpt_name is None:
|
48 |
+
latest_id = sorted(
|
49 |
+
[
|
50 |
+
int(Path(x).stem.split("_")[1])
|
51 |
+
for x in list_repo_files(repo_id)
|
52 |
+
if x.startswith("G_") and x.endswith(".pth")
|
53 |
+
]
|
54 |
+
)[-1]
|
55 |
+
ckpt_name = f"G_{latest_id}.pth"
|
56 |
+
|
57 |
+
cluster_model_name = cluster_model_name or "kmeans.pt"
|
58 |
+
if cluster_model_name in list_repo_files(repo_id):
|
59 |
+
print(f"Found Cluster model - Downloading {cluster_model_name} from {repo_id}")
|
60 |
+
cluster_model_path = hf_hub_download(repo_id, cluster_model_name)
|
61 |
+
else:
|
62 |
+
print(f"Could not find {cluster_model_name} in {repo_id}. Using None")
|
63 |
+
cluster_model_path = None
|
64 |
+
default_cluster_infer_ratio = default_cluster_infer_ratio if cluster_model_path else 0
|
65 |
+
|
66 |
+
generator_path = hf_hub_download(repo_id, ckpt_name)
|
67 |
+
config_path = hf_hub_download(repo_id, "config.json")
|
68 |
+
hparams = HParams(**json.loads(Path(config_path).read_text()))
|
69 |
+
speakers = list(hparams.spk.keys())
|
70 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
71 |
+
model = Svc(net_g_path=generator_path, config_path=config_path, device=device, cluster_model_path=cluster_model_path)
|
72 |
+
demucs_model = get_model(DEFAULT_MODEL)
|
73 |
+
|
74 |
+
|
75 |
+
def extract_vocal_demucs(model, filename, sr=44100, device=None, shifts=1, split=True, overlap=0.25, jobs=0):
|
76 |
+
wav, sr = librosa.load(filename, mono=False, sr=sr)
|
77 |
+
wav = torch.tensor(wav)
|
78 |
+
ref = wav.mean(0)
|
79 |
+
wav = (wav - ref.mean()) / ref.std()
|
80 |
+
sources = apply_model(
|
81 |
+
model, wav[None], device=device, shifts=shifts, split=split, overlap=overlap, progress=True, num_workers=jobs
|
82 |
+
)[0]
|
83 |
+
sources = sources * ref.std() + ref.mean()
|
84 |
+
# We take just the vocals stem. I know the vocals for this model are at index -1
|
85 |
+
# If using different model, check model.sources.index('vocals')
|
86 |
+
vocal_wav = sources[-1]
|
87 |
+
# I did this because its the same normalization the so-vits model required
|
88 |
+
vocal_wav = vocal_wav / max(1.01 * vocal_wav.abs().max(), 1)
|
89 |
+
vocal_wav = vocal_wav.numpy()
|
90 |
+
vocal_wav = librosa.to_mono(vocal_wav)
|
91 |
+
vocal_wav = vocal_wav.T
|
92 |
+
instrumental_wav = sources[:-1].sum(0).numpy().T
|
93 |
+
return vocal_wav, instrumental_wav
|
94 |
+
|
95 |
+
|
96 |
+
def download_youtube_clip(
|
97 |
+
video_identifier,
|
98 |
+
start_time,
|
99 |
+
end_time,
|
100 |
+
output_filename,
|
101 |
+
num_attempts=5,
|
102 |
+
url_base="https://www.youtube.com/watch?v=",
|
103 |
+
quiet=False,
|
104 |
+
force=False,
|
105 |
+
):
|
106 |
+
output_path = Path(output_filename)
|
107 |
+
if output_path.exists():
|
108 |
+
if not force:
|
109 |
+
return output_path
|
110 |
+
else:
|
111 |
+
output_path.unlink()
|
112 |
+
|
113 |
+
quiet = "--quiet --no-warnings" if quiet else ""
|
114 |
+
command = f"""
|
115 |
+
yt-dlp {quiet} -x --audio-format wav -f bestaudio -o "{output_filename}" --download-sections "*{start_time}-{end_time}" "{url_base}{video_identifier}" # noqa: E501
|
116 |
+
""".strip()
|
117 |
+
|
118 |
+
attempts = 0
|
119 |
+
while True:
|
120 |
+
try:
|
121 |
+
_ = subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT)
|
122 |
+
except subprocess.CalledProcessError:
|
123 |
+
attempts += 1
|
124 |
+
if attempts == num_attempts:
|
125 |
+
return None
|
126 |
+
else:
|
127 |
+
break
|
128 |
+
|
129 |
+
if output_path.exists():
|
130 |
+
return output_path
|
131 |
+
else:
|
132 |
+
return None
|
133 |
+
|
134 |
+
|
135 |
+
def predict(
|
136 |
+
speaker,
|
137 |
+
audio,
|
138 |
+
transpose: int = 0,
|
139 |
+
auto_predict_f0: bool = False,
|
140 |
+
cluster_infer_ratio: float = 0,
|
141 |
+
noise_scale: float = 0.4,
|
142 |
+
f0_method: str = "crepe",
|
143 |
+
db_thresh: int = -40,
|
144 |
+
pad_seconds: float = 0.5,
|
145 |
+
chunk_seconds: float = 0.5,
|
146 |
+
absolute_thresh: bool = False,
|
147 |
+
):
|
148 |
+
audio, _ = librosa.load(audio, sr=model.target_sample, duration=duration_limit)
|
149 |
+
audio = model.infer_silence(
|
150 |
+
audio.astype(np.float32),
|
151 |
+
speaker=speaker,
|
152 |
+
transpose=transpose,
|
153 |
+
auto_predict_f0=auto_predict_f0,
|
154 |
+
cluster_infer_ratio=cluster_infer_ratio,
|
155 |
+
noise_scale=noise_scale,
|
156 |
+
f0_method=f0_method,
|
157 |
+
db_thresh=db_thresh,
|
158 |
+
pad_seconds=pad_seconds,
|
159 |
+
chunk_seconds=chunk_seconds,
|
160 |
+
absolute_thresh=absolute_thresh,
|
161 |
+
)
|
162 |
+
return model.target_sample, audio
|
163 |
+
|
164 |
+
SPACE_ID = "nateraw/voice-cloning"
|
165 |
+
description = f"""
|
166 |
+
# Attention - This Space may be slow in the shared UI if there is a long queue. To speed it up, you can duplicate and use it with a paid private T4 GPU.
|
167 |
+
|
168 |
+
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
|
169 |
+
|
170 |
+
#### This app uses models trained with [so-vits-svc-fork](https://github.com/voicepaw/so-vits-svc-fork) to clone a voice. Model currently being used is https://hf.co/{repo_id}. To change the model being served, duplicate the space and update the `repo_id`/other settings in `app.py`.
|
171 |
+
|
172 |
+
#### Train Your Own: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nateraw/voice-cloning/blob/main/training_so_vits_svc_fork.ipynb)
|
173 |
+
""".strip()
|
174 |
+
|
175 |
+
article = """
|
176 |
+
<p style='text-align: center'>
|
177 |
+
<a href='https://github.com/voicepaw/so-vits-svc-fork' target='_blank'>Github Repo</a>
|
178 |
+
</p>
|
179 |
+
""".strip()
|
180 |
+
|
181 |
+
|
182 |
+
interface_mic = gr.Interface(
|
183 |
+
predict,
|
184 |
+
inputs=[
|
185 |
+
gr.Dropdown(speakers, value=speakers[0], label="Target Speaker"),
|
186 |
+
gr.Audio(type="filepath", source="microphone", label="Source Audio"),
|
187 |
+
gr.Slider(-12, 12, value=0, step=1, label="Transpose (Semitones)"),
|
188 |
+
gr.Checkbox(False, label="Auto Predict F0"),
|
189 |
+
gr.Slider(0.0, 1.0, value=default_cluster_infer_ratio, step=0.1, label="cluster infer ratio"),
|
190 |
+
gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="noise scale"),
|
191 |
+
gr.Dropdown(
|
192 |
+
choices=["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"],
|
193 |
+
value=default_f0_method,
|
194 |
+
label="f0 method",
|
195 |
+
),
|
196 |
+
],
|
197 |
+
outputs="audio",
|
198 |
+
title="Voice Cloning",
|
199 |
+
description=description,
|
200 |
+
article=article,
|
201 |
+
)
|
202 |
+
interface_file = gr.Interface(
|
203 |
+
predict,
|
204 |
+
inputs=[
|
205 |
+
gr.Dropdown(speakers, value=speakers[0], label="Target Speaker"),
|
206 |
+
gr.Audio(type="filepath", source="upload", label="Source Audio"),
|
207 |
+
gr.Slider(-12, 12, value=0, step=1, label="Transpose (Semitones)"),
|
208 |
+
gr.Checkbox(False, label="Auto Predict F0"),
|
209 |
+
gr.Slider(0.0, 1.0, value=default_cluster_infer_ratio, step=0.1, label="cluster infer ratio"),
|
210 |
+
gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="noise scale"),
|
211 |
+
gr.Dropdown(
|
212 |
+
choices=["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"],
|
213 |
+
value=default_f0_method,
|
214 |
+
label="f0 method",
|
215 |
+
),
|
216 |
+
],
|
217 |
+
outputs="audio",
|
218 |
+
title="Voice Cloning",
|
219 |
+
description=description,
|
220 |
+
article=article,
|
221 |
+
)
|
222 |
+
interface = gr.TabbedInterface(
|
223 |
+
[interface_mic, interface_file],
|
224 |
+
["Clone From Mic", "Clone From File"],
|
225 |
+
)
|
226 |
+
|
227 |
+
|
228 |
+
if __name__ == "__main__":
|
229 |
+
interface.launch()
|
packages.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
ffmpeg
|
2 |
+
x264
|
3 |
+
libx264-dev
|
pyproject.toml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.black]
|
2 |
+
line-length = 119
|
3 |
+
target_version = ['py37']
|
4 |
+
|
5 |
+
[tool.ruff]
|
6 |
+
# Never enforce `E501` (line length violations).
|
7 |
+
ignore = ["C901", "E501", "E741", "W605"]
|
8 |
+
select = ["C", "E", "F", "I", "W"]
|
9 |
+
line-length = 119
|
10 |
+
|
11 |
+
# Ignore import violations in all `__init__.py` files.
|
12 |
+
[tool.ruff.per-file-ignores]
|
13 |
+
"__init__.py" = ["E402", "F401", "F403", "F811"]
|
14 |
+
|
15 |
+
[tool.ruff.isort]
|
16 |
+
known-first-party = ["so_vits_svc_fork"]
|
17 |
+
lines-after-imports = 2
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
so-vits-svc-fork
|
2 |
+
gradio
|
3 |
+
huggingface_hub
|
4 |
+
yt-dlp
|
5 |
+
demucs
|
6 |
+
gradio
|
training_so_vits_svc_fork.ipynb
ADDED
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "view-in-github",
|
7 |
+
"colab_type": "text"
|
8 |
+
},
|
9 |
+
"source": [
|
10 |
+
"<a href=\"https://colab.research.google.com/github/nateraw/voice-cloning/blob/main/training_so_vits_svc_fork.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": null,
|
16 |
+
"metadata": {
|
17 |
+
"id": "jIcNJ5QfDsV_"
|
18 |
+
},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"# %%capture\n",
|
22 |
+
"! pip install git+https://github.com/nateraw/so-vits-svc-fork@main\n",
|
23 |
+
"! pip install openai-whisper yt-dlp huggingface_hub demucs"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "markdown",
|
28 |
+
"metadata": {
|
29 |
+
"id": "6uZAhUPOhFv9"
|
30 |
+
},
|
31 |
+
"source": [
|
32 |
+
"---\n",
|
33 |
+
"\n",
|
34 |
+
"# Restart runtime\n",
|
35 |
+
"\n",
|
36 |
+
"After running the cell above, you'll need to restart the Colab runtime because we installed a different version of numpy.\n",
|
37 |
+
"\n",
|
38 |
+
"`Runtime -> Restart runtime`\n",
|
39 |
+
"\n",
|
40 |
+
"---"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": null,
|
46 |
+
"metadata": {
|
47 |
+
"id": "DROusQatF-wF"
|
48 |
+
},
|
49 |
+
"outputs": [],
|
50 |
+
"source": [
|
51 |
+
"from huggingface_hub import login\n",
|
52 |
+
"\n",
|
53 |
+
"login()"
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "markdown",
|
58 |
+
"source": [
|
59 |
+
"## Settings"
|
60 |
+
],
|
61 |
+
"metadata": {
|
62 |
+
"id": "yOM9WWmmRqTA"
|
63 |
+
}
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "code",
|
67 |
+
"execution_count": null,
|
68 |
+
"metadata": {
|
69 |
+
"id": "5oTDjDEKFz3W"
|
70 |
+
},
|
71 |
+
"outputs": [],
|
72 |
+
"source": [
|
73 |
+
"CHARACTER = \"kanye\"\n",
|
74 |
+
"DO_EXTRACT_VOCALS = False\n",
|
75 |
+
"MODEL_REPO_ID = \"dog/kanye\""
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "markdown",
|
80 |
+
"metadata": {
|
81 |
+
"id": "BFd_ly1P_5Ht"
|
82 |
+
},
|
83 |
+
"source": [
|
84 |
+
"## Data Preparation\n",
|
85 |
+
"\n",
|
86 |
+
"Prepare a data.csv file here with `ytid,start,end` as the first line (they're the expected column names). Then, prepare a training set given YouTube IDs and their start and end segment times in seconds. Try to pick segments that have dry vocal only, as that'll provide the best results.\n",
|
87 |
+
"\n",
|
88 |
+
"An example is given below for Kanye West."
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "code",
|
93 |
+
"execution_count": null,
|
94 |
+
"metadata": {
|
95 |
+
"id": "rBrtgDtWmhRb"
|
96 |
+
},
|
97 |
+
"outputs": [],
|
98 |
+
"source": [
|
99 |
+
"%%writefile data.csv\n",
|
100 |
+
"ytid,start,end\n",
|
101 |
+
"lkK4de9nbzQ,0,137\n",
|
102 |
+
"gXU9Am2Seo0,30,69\n",
|
103 |
+
"gXU9Am2Seo0,94,135\n",
|
104 |
+
"iVgrhWvQpqU,0,55\n",
|
105 |
+
"iVgrhWvQpqU,58,110\n",
|
106 |
+
"UIV-q-gneKA,85,99\n",
|
107 |
+
"UIV-q-gneKA,110,125\n",
|
108 |
+
"UIV-q-gneKA,127,141\n",
|
109 |
+
"UIV-q-gneKA,173,183\n",
|
110 |
+
"GmlyYCGE9ak,0,102\n",
|
111 |
+
"x-7aWcPmJ60,25,43\n",
|
112 |
+
"x-7aWcPmJ60,47,72\n",
|
113 |
+
"x-7aWcPmJ60,98,113\n",
|
114 |
+
"DK2LCIzIBrU,0,56\n",
|
115 |
+
"DK2LCIzIBrU,80,166\n",
|
116 |
+
"_W56nZk0fCI,184,224"
|
117 |
+
]
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"cell_type": "code",
|
121 |
+
"execution_count": null,
|
122 |
+
"metadata": {
|
123 |
+
"id": "cxxp4uYoC0aG"
|
124 |
+
},
|
125 |
+
"outputs": [],
|
126 |
+
"source": [
|
127 |
+
"import subprocess\n",
|
128 |
+
"from pathlib import Path\n",
|
129 |
+
"import librosa\n",
|
130 |
+
"from scipy.io import wavfile\n",
|
131 |
+
"import numpy as np\n",
|
132 |
+
"from demucs.pretrained import get_model, DEFAULT_MODEL\n",
|
133 |
+
"from demucs.apply import apply_model\n",
|
134 |
+
"import torch\n",
|
135 |
+
"import csv\n",
|
136 |
+
"import whisper\n",
|
137 |
+
"\n",
|
138 |
+
"\n",
|
139 |
+
"def download_youtube_clip(video_identifier, start_time, end_time, output_filename, num_attempts=5, url_base=\"https://www.youtube.com/watch?v=\"):\n",
|
140 |
+
" status = False\n",
|
141 |
+
"\n",
|
142 |
+
" output_path = Path(output_filename)\n",
|
143 |
+
" if output_path.exists():\n",
|
144 |
+
" return True, \"Already Downloaded\"\n",
|
145 |
+
"\n",
|
146 |
+
" command = f\"\"\"\n",
|
147 |
+
" yt-dlp --quiet --no-warnings -x --audio-format wav -f bestaudio -o \"{output_filename}\" --download-sections \"*{start_time}-{end_time}\" \"{url_base}{video_identifier}\"\n",
|
148 |
+
" \"\"\".strip()\n",
|
149 |
+
"\n",
|
150 |
+
" attempts = 0\n",
|
151 |
+
" while True:\n",
|
152 |
+
" try:\n",
|
153 |
+
" output = subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT)\n",
|
154 |
+
" except subprocess.CalledProcessError as err:\n",
|
155 |
+
" attempts += 1\n",
|
156 |
+
" if attempts == num_attempts:\n",
|
157 |
+
" return status, err.output\n",
|
158 |
+
" else:\n",
|
159 |
+
" break\n",
|
160 |
+
"\n",
|
161 |
+
" status = output_path.exists()\n",
|
162 |
+
" return status, \"Downloaded\"\n",
|
163 |
+
"\n",
|
164 |
+
"\n",
|
165 |
+
"def split_long_audio(model, filepaths, character_name, save_dir=\"data_dir\", out_sr=44100):\n",
|
166 |
+
" if isinstance(filepaths, str):\n",
|
167 |
+
" filepaths = [filepaths]\n",
|
168 |
+
"\n",
|
169 |
+
" for file_idx, filepath in enumerate(filepaths):\n",
|
170 |
+
"\n",
|
171 |
+
" save_path = Path(save_dir) / character_name\n",
|
172 |
+
" save_path.mkdir(exist_ok=True, parents=True)\n",
|
173 |
+
"\n",
|
174 |
+
" print(f\"Transcribing file {file_idx}: '{filepath}' to segments...\")\n",
|
175 |
+
" result = model.transcribe(filepath, word_timestamps=True, task=\"transcribe\", beam_size=5, best_of=5)\n",
|
176 |
+
" segments = result['segments']\n",
|
177 |
+
" \n",
|
178 |
+
" wav, sr = librosa.load(filepath, sr=None, offset=0, duration=None, mono=True)\n",
|
179 |
+
" wav, _ = librosa.effects.trim(wav, top_db=20)\n",
|
180 |
+
" peak = np.abs(wav).max()\n",
|
181 |
+
" if peak > 1.0:\n",
|
182 |
+
" wav = 0.98 * wav / peak\n",
|
183 |
+
" wav2 = librosa.resample(wav, orig_sr=sr, target_sr=out_sr)\n",
|
184 |
+
" wav2 /= max(wav2.max(), -wav2.min())\n",
|
185 |
+
"\n",
|
186 |
+
" for i, seg in enumerate(segments):\n",
|
187 |
+
" start_time = seg['start']\n",
|
188 |
+
" end_time = seg['end']\n",
|
189 |
+
" wav_seg = wav2[int(start_time * out_sr):int(end_time * out_sr)]\n",
|
190 |
+
" wav_seg_name = f\"{character_name}_{file_idx}_{i}.wav\"\n",
|
191 |
+
" out_fpath = save_path / wav_seg_name\n",
|
192 |
+
" wavfile.write(out_fpath, rate=out_sr, data=(wav_seg * np.iinfo(np.int16).max).astype(np.int16))\n",
|
193 |
+
"\n",
|
194 |
+
"\n",
|
195 |
+
"def extract_vocal_demucs(model, filename, out_filename, sr=44100, device=None, shifts=1, split=True, overlap=0.25, jobs=0):\n",
|
196 |
+
" wav, sr = librosa.load(filename, mono=False, sr=sr)\n",
|
197 |
+
" wav = torch.tensor(wav)\n",
|
198 |
+
" ref = wav.mean(0)\n",
|
199 |
+
" wav = (wav - ref.mean()) / ref.std()\n",
|
200 |
+
" sources = apply_model(\n",
|
201 |
+
" model,\n",
|
202 |
+
" wav[None],\n",
|
203 |
+
" device=device,\n",
|
204 |
+
" shifts=shifts,\n",
|
205 |
+
" split=split,\n",
|
206 |
+
" overlap=overlap,\n",
|
207 |
+
" progress=True,\n",
|
208 |
+
" num_workers=jobs\n",
|
209 |
+
" )[0]\n",
|
210 |
+
" sources = sources * ref.std() + ref.mean()\n",
|
211 |
+
"\n",
|
212 |
+
" wav = sources[-1]\n",
|
213 |
+
" wav = wav / max(1.01 * wav.abs().max(), 1)\n",
|
214 |
+
" wavfile.write(out_filename, rate=sr, data=wav.numpy().T)\n",
|
215 |
+
" return out_filename\n",
|
216 |
+
"\n",
|
217 |
+
"\n",
|
218 |
+
"def create_dataset(\n",
|
219 |
+
" clips_csv_filepath = \"data.csv\",\n",
|
220 |
+
" character = \"somebody\",\n",
|
221 |
+
" do_extract_vocals = False,\n",
|
222 |
+
" whisper_size = \"medium\",\n",
|
223 |
+
" # Where raw yt clips will be downloaded to\n",
|
224 |
+
" dl_dir = \"downloads\",\n",
|
225 |
+
" # Where actual data will be organized\n",
|
226 |
+
" data_dir = \"dataset_raw\",\n",
|
227 |
+
" **kwargs\n",
|
228 |
+
"):\n",
|
229 |
+
" dl_path = Path(dl_dir) / character\n",
|
230 |
+
" dl_path.mkdir(exist_ok=True, parents=True)\n",
|
231 |
+
" if do_extract_vocals:\n",
|
232 |
+
" demucs_model = get_model(DEFAULT_MODEL)\n",
|
233 |
+
"\n",
|
234 |
+
" with Path(clips_csv_filepath).open() as f:\n",
|
235 |
+
" reader = csv.DictReader(f)\n",
|
236 |
+
" for i, row in enumerate(reader):\n",
|
237 |
+
" outfile_path = dl_path / f\"{character}_{i:04d}.wav\"\n",
|
238 |
+
" download_youtube_clip(row['ytid'], row['start'], row['end'], outfile_path)\n",
|
239 |
+
" if do_extract_vocals:\n",
|
240 |
+
" extract_vocal_demucs(demucs_model, outfile_path, outfile_path)\n",
|
241 |
+
"\n",
|
242 |
+
" filenames = sorted([str(x) for x in dl_path.glob(\"*.wav\")])\n",
|
243 |
+
" whisper_model = whisper.load_model(whisper_size)\n",
|
244 |
+
" split_long_audio(whisper_model, filenames, character, data_dir) "
|
245 |
+
]
|
246 |
+
},
|
247 |
+
{
|
248 |
+
"cell_type": "code",
|
249 |
+
"execution_count": null,
|
250 |
+
"metadata": {
|
251 |
+
"id": "D9GrcDUKEGro"
|
252 |
+
},
|
253 |
+
"outputs": [],
|
254 |
+
"source": [
|
255 |
+
"\"\"\"\n",
|
256 |
+
"Here, we override config to have num_workers=0 because\n",
|
257 |
+
"of a limitation in HF Spaces Docker /dev/shm.\n",
|
258 |
+
"\"\"\"\n",
|
259 |
+
"\n",
|
260 |
+
"import json\n",
|
261 |
+
"from pathlib import Path\n",
|
262 |
+
"import multiprocessing\n",
|
263 |
+
"\n",
|
264 |
+
"def update_config(config_file=\"configs/44k/config.json\"):\n",
|
265 |
+
" config_path = Path(config_file)\n",
|
266 |
+
" data = json.loads(config_path.read_text())\n",
|
267 |
+
" data['train']['batch_size'] = 32\n",
|
268 |
+
" data['train']['eval_interval'] = 500\n",
|
269 |
+
" data['train']['num_workers'] = multiprocessing.cpu_count()\n",
|
270 |
+
" data['train']['persistent_workers'] = True\n",
|
271 |
+
" data['train']['push_to_hub'] = True\n",
|
272 |
+
" data['train']['repo_id'] = MODEL_REPO_ID # tuple(data['spk'])[0]\n",
|
273 |
+
" data['train']['private'] = True\n",
|
274 |
+
" config_path.write_text(json.dumps(data, indent=2, sort_keys=False))"
|
275 |
+
]
|
276 |
+
},
|
277 |
+
{
|
278 |
+
"cell_type": "markdown",
|
279 |
+
"source": [
|
280 |
+
"## Run all Preprocessing Steps"
|
281 |
+
],
|
282 |
+
"metadata": {
|
283 |
+
"id": "aF6OZkTZRzhj"
|
284 |
+
}
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"cell_type": "code",
|
288 |
+
"execution_count": null,
|
289 |
+
"metadata": {
|
290 |
+
"id": "OAPnD3xKD_Gw"
|
291 |
+
},
|
292 |
+
"outputs": [],
|
293 |
+
"source": [
|
294 |
+
"create_dataset(character=CHARACTER, do_extract_vocals=DO_EXTRACT_VOCALS)\n",
|
295 |
+
"! svc pre-resample\n",
|
296 |
+
"! svc pre-config\n",
|
297 |
+
"! svc pre-hubert -fm crepe\n",
|
298 |
+
"update_config()"
|
299 |
+
]
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"cell_type": "markdown",
|
303 |
+
"source": [
|
304 |
+
"## Training"
|
305 |
+
],
|
306 |
+
"metadata": {
|
307 |
+
"id": "VpyGazF6R3CE"
|
308 |
+
}
|
309 |
+
},
|
310 |
+
{
|
311 |
+
"cell_type": "code",
|
312 |
+
"execution_count": null,
|
313 |
+
"metadata": {
|
314 |
+
"colab": {
|
315 |
+
"background_save": true
|
316 |
+
},
|
317 |
+
"id": "MByHpf_wEByg"
|
318 |
+
},
|
319 |
+
"outputs": [],
|
320 |
+
"source": [
|
321 |
+
"from __future__ import annotations\n",
|
322 |
+
"\n",
|
323 |
+
"import os\n",
|
324 |
+
"import re\n",
|
325 |
+
"import warnings\n",
|
326 |
+
"from logging import getLogger\n",
|
327 |
+
"from multiprocessing import cpu_count\n",
|
328 |
+
"from pathlib import Path\n",
|
329 |
+
"from typing import Any\n",
|
330 |
+
"\n",
|
331 |
+
"import lightning.pytorch as pl\n",
|
332 |
+
"import torch\n",
|
333 |
+
"from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator\n",
|
334 |
+
"from lightning.pytorch.loggers import TensorBoardLogger\n",
|
335 |
+
"from lightning.pytorch.strategies.ddp import DDPStrategy\n",
|
336 |
+
"from lightning.pytorch.tuner import Tuner\n",
|
337 |
+
"from torch.cuda.amp import autocast\n",
|
338 |
+
"from torch.nn import functional as F\n",
|
339 |
+
"from torch.utils.data import DataLoader\n",
|
340 |
+
"from torch.utils.tensorboard.writer import SummaryWriter\n",
|
341 |
+
"\n",
|
342 |
+
"import so_vits_svc_fork.f0\n",
|
343 |
+
"import so_vits_svc_fork.modules.commons as commons\n",
|
344 |
+
"import so_vits_svc_fork.utils\n",
|
345 |
+
"\n",
|
346 |
+
"from so_vits_svc_fork import utils\n",
|
347 |
+
"from so_vits_svc_fork.dataset import TextAudioCollate, TextAudioDataset\n",
|
348 |
+
"from so_vits_svc_fork.logger import is_notebook\n",
|
349 |
+
"from so_vits_svc_fork.modules.descriminators import MultiPeriodDiscriminator\n",
|
350 |
+
"from so_vits_svc_fork.modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss\n",
|
351 |
+
"from so_vits_svc_fork.modules.mel_processing import mel_spectrogram_torch\n",
|
352 |
+
"from so_vits_svc_fork.modules.synthesizers import SynthesizerTrn\n",
|
353 |
+
"\n",
|
354 |
+
"from so_vits_svc_fork.train import VitsLightning, VCDataModule\n",
|
355 |
+
"\n",
|
356 |
+
"LOG = getLogger(__name__)\n",
|
357 |
+
"torch.set_float32_matmul_precision(\"high\")\n",
|
358 |
+
"\n",
|
359 |
+
"\n",
|
360 |
+
"from pathlib import Path\n",
|
361 |
+
"\n",
|
362 |
+
"from huggingface_hub import create_repo, upload_folder, login, list_repo_files, delete_file\n",
|
363 |
+
"\n",
|
364 |
+
"# if os.environ.get(\"HF_TOKEN\"):\n",
|
365 |
+
"# login(os.environ.get(\"HF_TOKEN\"))\n",
|
366 |
+
"\n",
|
367 |
+
"\n",
|
368 |
+
"class HuggingFacePushCallback(pl.Callback):\n",
|
369 |
+
" def __init__(self, repo_id, private=False, every=100):\n",
|
370 |
+
" self.repo_id = repo_id\n",
|
371 |
+
" self.private = private\n",
|
372 |
+
" self.every = every\n",
|
373 |
+
"\n",
|
374 |
+
" def on_validation_epoch_end(self, trainer, pl_module):\n",
|
375 |
+
" self.repo_url = create_repo(\n",
|
376 |
+
" repo_id=self.repo_id,\n",
|
377 |
+
" exist_ok=True,\n",
|
378 |
+
" private=self.private\n",
|
379 |
+
" )\n",
|
380 |
+
" self.repo_id = self.repo_url.repo_id\n",
|
381 |
+
" if pl_module.global_step == 0:\n",
|
382 |
+
" return\n",
|
383 |
+
" print(f\"\\n🤗 Pushing to Hugging Face Hub: {self.repo_url}...\")\n",
|
384 |
+
" model_dir = pl_module.hparams.model_dir\n",
|
385 |
+
" upload_folder(\n",
|
386 |
+
" repo_id=self.repo_id,\n",
|
387 |
+
" folder_path=model_dir,\n",
|
388 |
+
" path_in_repo=\".\",\n",
|
389 |
+
" commit_message=\"🍻 cheers\",\n",
|
390 |
+
" ignore_patterns=[\"*.git*\", \"*README.md*\", \"*__pycache__*\"],\n",
|
391 |
+
" )\n",
|
392 |
+
" ckpt_pattern = r'^(D_|G_)\\d+\\.pth$'\n",
|
393 |
+
" todelete = []\n",
|
394 |
+
" repo_ckpts = [x for x in list_repo_files(self.repo_id) if re.match(ckpt_pattern, x) and x not in [\"G_0.pth\", \"D_0.pth\"]]\n",
|
395 |
+
" local_ckpts = [x.name for x in Path(model_dir).glob(\"*.pth\") if re.match(ckpt_pattern, x.name)]\n",
|
396 |
+
" to_delete = set(repo_ckpts) - set(local_ckpts)\n",
|
397 |
+
"\n",
|
398 |
+
" for fname in to_delete:\n",
|
399 |
+
" print(f\"🗑 Deleting {fname} from repo\")\n",
|
400 |
+
" delete_file(fname, self.repo_id)\n",
|
401 |
+
"\n",
|
402 |
+
"\n",
|
403 |
+
"def train(\n",
|
404 |
+
" config_path: Path | str, model_path: Path | str, reset_optimizer: bool = False\n",
|
405 |
+
"):\n",
|
406 |
+
" config_path = Path(config_path)\n",
|
407 |
+
" model_path = Path(model_path)\n",
|
408 |
+
"\n",
|
409 |
+
" hparams = utils.get_backup_hparams(config_path, model_path)\n",
|
410 |
+
" utils.ensure_pretrained_model(model_path, hparams.model.get(\"type_\", \"hifi-gan\"))\n",
|
411 |
+
"\n",
|
412 |
+
" datamodule = VCDataModule(hparams)\n",
|
413 |
+
" strategy = (\n",
|
414 |
+
" (\n",
|
415 |
+
" \"ddp_find_unused_parameters_true\"\n",
|
416 |
+
" if os.name != \"nt\"\n",
|
417 |
+
" else DDPStrategy(find_unused_parameters=True, process_group_backend=\"gloo\")\n",
|
418 |
+
" )\n",
|
419 |
+
" if torch.cuda.device_count() > 1\n",
|
420 |
+
" else \"auto\"\n",
|
421 |
+
" )\n",
|
422 |
+
" LOG.info(f\"Using strategy: {strategy}\")\n",
|
423 |
+
" \n",
|
424 |
+
" callbacks = []\n",
|
425 |
+
" if hparams.train.push_to_hub:\n",
|
426 |
+
" callbacks.append(HuggingFacePushCallback(hparams.train.repo_id, hparams.train.private))\n",
|
427 |
+
" if not is_notebook():\n",
|
428 |
+
" callbacks.append(pl.callbacks.RichProgressBar())\n",
|
429 |
+
" if callbacks == []:\n",
|
430 |
+
" callbacks = None\n",
|
431 |
+
"\n",
|
432 |
+
" trainer = pl.Trainer(\n",
|
433 |
+
" logger=TensorBoardLogger(\n",
|
434 |
+
" model_path, \"lightning_logs\", hparams.train.get(\"log_version\", 0)\n",
|
435 |
+
" ),\n",
|
436 |
+
" # profiler=\"simple\",\n",
|
437 |
+
" val_check_interval=hparams.train.eval_interval,\n",
|
438 |
+
" max_epochs=hparams.train.epochs,\n",
|
439 |
+
" check_val_every_n_epoch=None,\n",
|
440 |
+
" precision=\"16-mixed\"\n",
|
441 |
+
" if hparams.train.fp16_run\n",
|
442 |
+
" else \"bf16-mixed\"\n",
|
443 |
+
" if hparams.train.get(\"bf16_run\", False)\n",
|
444 |
+
" else 32,\n",
|
445 |
+
" strategy=strategy,\n",
|
446 |
+
" callbacks=callbacks,\n",
|
447 |
+
" benchmark=True,\n",
|
448 |
+
" enable_checkpointing=False,\n",
|
449 |
+
" )\n",
|
450 |
+
" tuner = Tuner(trainer)\n",
|
451 |
+
" model = VitsLightning(reset_optimizer=reset_optimizer, **hparams)\n",
|
452 |
+
"\n",
|
453 |
+
" # automatic batch size scaling\n",
|
454 |
+
" batch_size = hparams.train.batch_size\n",
|
455 |
+
" batch_split = str(batch_size).split(\"-\")\n",
|
456 |
+
" batch_size = batch_split[0]\n",
|
457 |
+
" init_val = 2 if len(batch_split) <= 1 else int(batch_split[1])\n",
|
458 |
+
" max_trials = 25 if len(batch_split) <= 2 else int(batch_split[2])\n",
|
459 |
+
" if batch_size == \"auto\":\n",
|
460 |
+
" batch_size = \"binsearch\"\n",
|
461 |
+
" if batch_size in [\"power\", \"binsearch\"]:\n",
|
462 |
+
" model.tuning = True\n",
|
463 |
+
" tuner.scale_batch_size(\n",
|
464 |
+
" model,\n",
|
465 |
+
" mode=batch_size,\n",
|
466 |
+
" datamodule=datamodule,\n",
|
467 |
+
" steps_per_trial=1,\n",
|
468 |
+
" init_val=init_val,\n",
|
469 |
+
" max_trials=max_trials,\n",
|
470 |
+
" )\n",
|
471 |
+
" model.tuning = False\n",
|
472 |
+
" else:\n",
|
473 |
+
" batch_size = int(batch_size)\n",
|
474 |
+
" # automatic learning rate scaling is not supported for multiple optimizers\n",
|
475 |
+
" \"\"\"if hparams.train.learning_rate == \"auto\":\n",
|
476 |
+
" lr_finder = tuner.lr_find(model)\n",
|
477 |
+
" LOG.info(lr_finder.results)\n",
|
478 |
+
" fig = lr_finder.plot(suggest=True)\n",
|
479 |
+
" fig.savefig(model_path / \"lr_finder.png\")\"\"\"\n",
|
480 |
+
"\n",
|
481 |
+
" trainer.fit(model, datamodule=datamodule)\n",
|
482 |
+
"\n",
|
483 |
+
"if __name__ == '__main__':\n",
|
484 |
+
" train('configs/44k/config.json', 'logs/44k')"
|
485 |
+
]
|
486 |
+
},
|
487 |
+
{
|
488 |
+
"cell_type": "markdown",
|
489 |
+
"source": [
|
490 |
+
"## Train Cluster Model"
|
491 |
+
],
|
492 |
+
"metadata": {
|
493 |
+
"id": "b2vNCDrSR8Xo"
|
494 |
+
}
|
495 |
+
},
|
496 |
+
{
|
497 |
+
"cell_type": "code",
|
498 |
+
"execution_count": null,
|
499 |
+
"metadata": {
|
500 |
+
"id": "DBBEx-6Y1sOy"
|
501 |
+
},
|
502 |
+
"outputs": [],
|
503 |
+
"source": [
|
504 |
+
"! svc train-cluster"
|
505 |
+
]
|
506 |
+
},
|
507 |
+
{
|
508 |
+
"cell_type": "code",
|
509 |
+
"execution_count": null,
|
510 |
+
"metadata": {
|
511 |
+
"id": "y_qYMuNY1tlm"
|
512 |
+
},
|
513 |
+
"outputs": [],
|
514 |
+
"source": [
|
515 |
+
"from huggingface_hub import upload_file\n",
|
516 |
+
"\n",
|
517 |
+
"upload_file(path_or_fileobj=\"/content/logs/44k/kmeans.pt\", repo_id=MODEL_REPO_ID, path_in_repo=\"kmeans.pt\")"
|
518 |
+
]
|
519 |
+
}
|
520 |
+
],
|
521 |
+
"metadata": {
|
522 |
+
"accelerator": "GPU",
|
523 |
+
"colab": {
|
524 |
+
"machine_shape": "hm",
|
525 |
+
"provenance": [],
|
526 |
+
"authorship_tag": "ABX9TyOQeFSvxop9rlCaglNlNoXI",
|
527 |
+
"include_colab_link": true
|
528 |
+
},
|
529 |
+
"gpuClass": "premium",
|
530 |
+
"kernelspec": {
|
531 |
+
"display_name": "Python 3",
|
532 |
+
"name": "python3"
|
533 |
+
},
|
534 |
+
"language_info": {
|
535 |
+
"name": "python"
|
536 |
+
}
|
537 |
+
},
|
538 |
+
"nbformat": 4,
|
539 |
+
"nbformat_minor": 0
|
540 |
+
}
|