ChatTTS-Forge / modules /finetune /train_speaker.py
zhzluke96
update
1df74c6
raw
history blame
10.6 kB
import torch
import torch.nn.functional as F
import transformers
from modules.finetune.model.encoder import DVAEEncoder, get_encoder_config
from modules.finetune.utils.output import get_ansi_len, output_iter, ansi
from .utils.logger import MetricLogger
from .utils.dataset import AudioCollator, XzListTar
from .utils.model import quantize
IGNORE_TOKEN_ID = transformers.trainer_pt_utils.LabelSmoother.ignore_index
def train_speaker_embeddings(
chat,
dataset,
gpt,
batch_size=16,
epochs=10,
train_text=True,
speaker_embeds=None,
):
tokenizer = chat.pretrain_models["tokenizer"]
decoder_decoder = chat.pretrain_models["decoder"]
decoder_decoder.eval().requires_grad_(False)
decoder_encoder = DVAEEncoder(**get_encoder_config(decoder_decoder.decoder)).to(
device=dataset.device
)
decoder_encoder.eval().requires_grad_(False)
dvae_decoder = chat.pretrain_models["dvae"]
dvae_decoder.eval().requires_grad_(False)
dvae_encoder = DVAEEncoder(**get_encoder_config(dvae_decoder.decoder)).to(
device=dataset.device
)
dvae_encoder.eval().requires_grad_(False)
if speaker_embeds is None:
speaker_embeds = {
speaker: torch.randn(
768,
device=dataset.device,
requires_grad=True,
)
for speaker in dataset.speakers
}
for speaker_embed in speaker_embeds.values():
std, mean = chat.pretrain_models["spk_stat"].chunk(2)
speaker_embed.data = speaker_embed.data * std + mean
SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]")
AUDIO_EOS_TOKEN_ID = 0
AUDIO_PAD_TOKEN_ID = AUDIO_EOS_TOKEN_ID
optimizer = torch.optim.Adam(
speaker_embeds.values(), lr=1e-2, weight_decay=0, betas=[0.9, 0.95], eps=1e-5
)
loss_fn = torch.nn.CrossEntropyLoss()
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-7)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=AudioCollator(text_pad=tokenizer.pad_token_id),
)
logger = MetricLogger()
logger.create_meters(loss=None, mse_loss=None, audio_loss=None, text_loss=None)
for _epoch in range(epochs):
_epoch += 1
logger.reset()
header = "{blue_light}{0}: {1}{reset}".format(
"Epoch", output_iter(_epoch, epochs), **ansi
)
header = header.ljust(max(len("Epoch"), 30) + get_ansi_len(header))
iterator = logger.log_every(loader, header=header, tqdm_header="Batch")
for batch in iterator:
speakers = batch["speaker"]
text_input_ids = batch["text_input_ids"]
text_attention_mask = batch["text_attention_mask"]
audio_mel_specs = batch["audio_mel_specs"]
audio_attention_mask = batch["audio_attention_mask"]
batch_size, text_len = text_attention_mask.size()
dvae_audio_latents = dvae_encoder(audio_mel_specs, audio_attention_mask)
_, dvae_audio_input_ids = quantize(
dvae_decoder.vq_layer.quantizer, dvae_audio_latents
)
dvae_audio_input_ids[~audio_attention_mask.bool()] = AUDIO_PAD_TOKEN_ID
extended_audio_attention_mask = torch.cat(
[
audio_attention_mask,
torch.zeros(
(batch_size, 1),
dtype=audio_attention_mask.dtype,
device=audio_attention_mask.device,
),
],
dim=1,
)
extended_audio_input_ids = torch.cat(
[
dvae_audio_input_ids,
AUDIO_PAD_TOKEN_ID
* torch.ones(
(batch_size, 1, gpt.num_vq),
dtype=dvae_audio_input_ids.dtype,
device=dvae_audio_input_ids.device,
),
],
dim=1,
)
indices = audio_attention_mask.int().sum(dim=1)
for i in range(batch_size):
extended_audio_attention_mask[i, indices[i]] = 1
extended_audio_input_ids[i, indices[i]] = AUDIO_EOS_TOKEN_ID
input_ids = torch.cat(
[
text_input_ids.unsqueeze(-1).repeat(1, 1, gpt.num_vq),
extended_audio_input_ids,
],
dim=1,
)
attention_mask = torch.cat(
[text_attention_mask, extended_audio_attention_mask], dim=1
)
text_mask = torch.cat(
[
torch.ones_like(text_attention_mask, dtype=bool),
torch.zeros_like(extended_audio_attention_mask, dtype=bool),
],
dim=1,
)
labels = input_ids.clone()
labels[~attention_mask.bool()] = IGNORE_TOKEN_ID
inputs_embeds = gpt.get_emb(input_ids=input_ids, text_mask=text_mask)
indices = torch.all(input_ids == SPEAKER_TOKEN_ID, dim=-1)
for i, speaker in enumerate(speakers):
inputs_embeds[i, indices[i]] = F.normalize(
speaker_embeds[speaker].to(dtype=inputs_embeds.dtype),
p=2.0,
dim=-1,
eps=1e-12,
).unsqueeze(0)
outputs = gpt.gpt.forward(
inputs_embeds=inputs_embeds, attention_mask=attention_mask
)
hidden_states = outputs.last_hidden_state
text_hidden_states = hidden_states[:, : text_len - 1]
audio_hidden_states = hidden_states[:, text_len - 1 : -1]
audio_logits = torch.stack(
[gpt.head_code[i](audio_hidden_states) for i in range(gpt.num_vq)],
dim=2,
)
audio_loss = loss_fn(
audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2)
)
loss = audio_loss
if train_text:
text_logits = gpt.head_text(text_hidden_states)
text_loss = loss_fn(
text_logits.flatten(0, 1), labels[:, 1:text_len, 0].flatten(0, 1)
)
loss += text_loss
logger.meters["text_loss"].update(text_loss.item(), n=batch_size)
gpt_gen_mel_specs = decoder_decoder(
audio_hidden_states[:, :-1].transpose(1, 2)
).transpose(1, 2)
mse_loss = torch.nn.functional.mse_loss(gpt_gen_mel_specs, audio_mel_specs)
loss += 0.01 * mse_loss
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(speaker_embeds.values(), 1.0)
optimizer.step()
logger.meters["loss"].update(loss.item(), n=batch_size)
logger.meters["mse_loss"].update(mse_loss.item(), n=batch_size)
logger.meters["audio_loss"].update(audio_loss.item(), n=batch_size)
lr_scheduler.step()
optimizer.zero_grad()
return speaker_embeds
if __name__ == "__main__":
import argparse
import os
import numpy as np
import pathlib
from modules.models import load_chat_tts
from modules.devices import devices
from modules import config
from modules.speaker import Speaker
config.runtime_env_vars.no_half = True
devices.reset_device()
parser = argparse.ArgumentParser()
parser.add_argument("--save_folder", type=str, default="./")
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--train_text", action="store_true", help="train text loss")
# 初始化 speaker
parser.add_argument("--init_speaker", type=str)
parser.add_argument(
"--data_path",
type=str,
default="datasets/data_speaker_a/speaker_a.list",
help="the data_path to json/list file",
)
parser.add_argument("--tar_path", type=str, help="the tarball path with wavs")
parser.add_argument(
"--tar_in_memory", action="store_true", help="load tarball in memory"
)
args = parser.parse_args()
data_path: str = args.data_path
tar_path: str | None = args.tar_path
tar_in_memory: bool = args.tar_in_memory
train_text: bool = args.train_text
# gpt_lora: bool = args.gpt_lora
# gpt_kbit: int = args.gpt_kbit
save_folder: str = args.save_folder
batch_size: int = args.batch_size
epochs: int = args.epochs
init_speaker: str = args.init_speaker
speaker_embeds_save_path = os.path.join(save_folder, "speaker_embeds.npz")
chat = load_chat_tts()
dataset = XzListTar(
root=data_path,
tokenizer=chat.pretrain_models["tokenizer"],
vocos_model=chat.pretrain_models["vocos"],
tar_path=tar_path,
tar_in_memory=tar_in_memory,
device=devices.device,
# speakers=None, # set(['speaker_A', 'speaker_B'])
)
print("len(dataset)", len(dataset))
speaker_embeds = None
if init_speaker:
spk: Speaker = Speaker.from_file(init_speaker)
speaker_embeds = {
speaker: torch.tensor(
spk.emb,
device=devices.device,
requires_grad=True,
)
for speaker in dataset.speakers
}
speaker_embeds = train_speaker_embeddings(
chat,
dataset,
chat.pretrain_models["gpt"],
batch_size=batch_size,
epochs=epochs,
train_text=train_text,
speaker_embeds=speaker_embeds,
)
speaker_outs = {
speaker: Speaker(speaker_embed.detach().cpu(), f"ep{epochs}_{speaker}")
for speaker, speaker_embed in speaker_embeds.items()
}
time_str = np.datetime_as_string(np.datetime64("now", "s"))
time_str = time_str.replace(":", "_").replace(" ", "_").replace("-", "_")
for speaker, speaker_out in speaker_outs.items():
torch.save(
speaker_out,
pathlib.Path(save_folder) / f"spk_{speaker}_{time_str}_ep{epochs}.pt",
)
# example
"""
python -m modules.finetune.train_speaker \
--data_path datasets/data_speaker_a/speaker_a.list \
--save_folder ./data \
--init_speaker ./data/speakers/Bob.pt \
--epochs 100 \
--batch_size 6 \
--train_text
"""