XTTSv2-Finetuning-Vi / train_dvae_xtts.py
tuandaodev's picture
Upload folder using huggingface_hub
c1bab10 verified
import torch
# import wandb
from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
from torch.utils.data import DataLoader
from TTS.tts.layers.xtts.trainer.dvae_dataset import DVAEDataset
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
from TTS.tts.datasets import load_tts_samples
from TTS.config.shared_configs import BaseDatasetConfig
from dataclasses import dataclass, field
from typing import Optional
import os
import datetime
from transformers import HfArgumentParser
@dataclass
class DVAETrainerArgs:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
output_path: str = field(
metadata={"help": "Path to pretrained + checkpoint model"}
)
train_csv_path: str = field(
metadata={"help": "Path to train metadata file"},
)
eval_csv_path: Optional[str] = field(
default="",
metadata={"help": "Path to eval metadata file"},
)
language: Optional[str] = field(
default="en",
metadata={"help": "The language you want to train (language in your dataset)"},
)
lr: Optional[float] = field(
default=5e-6,
metadata={"help": "Learning rate"},
)
num_epochs: Optional[int] = field(
default=5,
)
batch_size: Optional[int] = field(
default=512,
)
def train(output_path, train_csv_path, eval_csv_path="", language="en", lr=5e-6, num_epochs=5, batch_size=512):
dvae_pretrained = os.path.join(output_path, 'XTTS_v2.0_original_model_files/dvae.pth')
mel_norm_file = os.path.join(output_path, 'XTTS_v2.0_original_model_files/mel_stats.pth')
now = datetime.datetime.now()
now_without_ms = now.replace(microsecond=0)
# CHECKPOINTS_OUT_PATH = os.path.join(output_path, f"DVAE_checkpoint_{now_without_ms}/")
# os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
config_dataset = BaseDatasetConfig(
formatter="coqui",
dataset_name="large",
path=os.path.dirname(train_csv_path),
meta_file_train=os.path.basename(train_csv_path),
meta_file_val=os.path.basename(eval_csv_path),
language=language,
)
# Add here the configs of the datasets
DATASETS_CONFIG_LIST = [config_dataset]
GRAD_CLIP_NORM = 0.5
LEARNING_RATE = lr
dvae = DiscreteVAE(
channels=80,
normalization=None,
positional_dims=1,
num_tokens=1024,
codebook_dim=512,
hidden_dim=512,
num_resnet_blocks=3,
kernel_size=3,
num_layers=2,
use_transposed_convs=False,
)
dvae.load_state_dict(torch.load(dvae_pretrained), strict=False)
dvae.cuda()
opt = Adam(dvae.parameters(), lr = LEARNING_RATE)
torch_mel_spectrogram_dvae = TorchMelSpectrogram(
mel_norm_file=mel_norm_file, sampling_rate=22050
).cuda()
train_samples, eval_samples = load_tts_samples(
DATASETS_CONFIG_LIST,
eval_split=True,
eval_split_max_size=256,
eval_split_size=0.01,
)
eval_dataset = DVAEDataset(eval_samples, 22050, True, max_wav_len=15*22050)
train_dataset = DVAEDataset(train_samples, 22050, False, max_wav_len=15*22050)
eval_data_loader = DataLoader(
eval_dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False,
collate_fn=eval_dataset.collate_fn,
num_workers=0,
pin_memory=False,
)
train_data_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False,
collate_fn=train_dataset.collate_fn,
num_workers=4,
pin_memory=False,
)
torch.set_grad_enabled(True)
dvae.train()
# wandb.init(project = 'train_dvae')
# wandb.watch(dvae)
def to_cuda(x: torch.Tensor) -> torch.Tensor:
if x is None:
return None
if torch.is_tensor(x):
x = x.contiguous()
if torch.cuda.is_available():
x = x.cuda(non_blocking=True)
return x
@torch.no_grad()
def format_batch(batch):
if isinstance(batch, dict):
for k, v in batch.items():
batch[k] = to_cuda(v)
elif isinstance(batch, list):
batch = [to_cuda(v) for v in batch]
try:
batch['mel'] = torch_mel_spectrogram_dvae(batch['wav'])
# if the mel spectogram is not divisible by 4 then input.shape != output.shape
# for dvae
remainder = batch['mel'].shape[-1] % 4
if remainder:
batch['mel'] = batch['mel'][:, :, :-remainder]
except NotImplementedError:
pass
return batch
best_loss = 1e6
for i in range(num_epochs):
dvae.train()
for cur_step, batch in enumerate(train_data_loader):
opt.zero_grad()
batch = format_batch(batch)
recon_loss, commitment_loss, out = dvae(batch['mel'])
recon_loss = recon_loss.mean()
total_loss = recon_loss + commitment_loss
# print(f"commitment_loss shape: {commitment_loss.shape}")
# print(f"recon_loss shape: {recon_loss.shape}")
# print(f"total_loss shape: {total_loss.shape}")
total_loss.backward()
clip_grad_norm_(dvae.parameters(), GRAD_CLIP_NORM)
opt.step()
log = {'epoch': i,
'cur_step': cur_step,
'loss': total_loss.item(),
'recon_loss': recon_loss.item(),
'commit_loss': commitment_loss.item()}
print(f"epoch: {i}", print(f"step: {cur_step}"), f'loss - {total_loss.item()}', f'recon_loss - {recon_loss.item()}', f'commit_loss - {commitment_loss.item()}')
# wandb.log(log)
torch.cuda.empty_cache()
with torch.no_grad():
dvae.eval()
eval_loss = 0
for cur_step, batch in enumerate(eval_data_loader):
batch = format_batch(batch)
recon_loss, commitment_loss, out = dvae(batch['mel'])
recon_loss = recon_loss.mean()
eval_loss += (recon_loss + commitment_loss).item()
eval_loss = eval_loss/len(eval_data_loader)
if eval_loss < best_loss:
best_loss = eval_loss
torch.save(dvae.state_dict(), dvae_pretrained)
print(f"#######################################\nepoch: {i}\tEVAL loss: {eval_loss}\n#######################################")
print(f'Checkpoint saved at {dvae_pretrained}')
if __name__ == "__main__":
parser = HfArgumentParser(DVAETrainerArgs)
args = parser.parse_args_into_dataclasses()[0]
trainer_out_path = train(
language=args.language,
train_csv_path=args.train_csv_path,
eval_csv_path=args.eval_csv_path,
output_path=args.output_path,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
lr=args.lr
)