import torch import torch.nn.functional as F import transformers from modules.finetune.model.encoder import DVAEEncoder, get_encoder_config from modules.finetune.utils.output import ansi, get_ansi_len, output_iter from .utils.dataset import AudioCollator, XzListTar from .utils.logger import MetricLogger from .utils.model import quantize IGNORE_TOKEN_ID = transformers.trainer_pt_utils.LabelSmoother.ignore_index def train_speaker_embeddings( chat, dataset, gpt, batch_size=16, epochs=10, train_text=True, speaker_embeds=None, ): tokenizer = chat.pretrain_models["tokenizer"] decoder_decoder = chat.pretrain_models["decoder"] decoder_decoder.eval().requires_grad_(False) decoder_encoder = DVAEEncoder(**get_encoder_config(decoder_decoder.decoder)).to( device=dataset.device ) decoder_encoder.eval().requires_grad_(False) dvae_decoder = chat.pretrain_models["dvae"] dvae_decoder.eval().requires_grad_(False) dvae_encoder = DVAEEncoder(**get_encoder_config(dvae_decoder.decoder)).to( device=dataset.device ) dvae_encoder.eval().requires_grad_(False) if speaker_embeds is None: speaker_embeds = { speaker: torch.randn( 768, device=dataset.device, requires_grad=True, ) for speaker in dataset.speakers } for speaker_embed in speaker_embeds.values(): std, mean = chat.pretrain_models["spk_stat"].chunk(2) speaker_embed.data = speaker_embed.data * std + mean SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]") AUDIO_EOS_TOKEN_ID = 0 AUDIO_PAD_TOKEN_ID = AUDIO_EOS_TOKEN_ID optimizer = torch.optim.Adam( speaker_embeds.values(), lr=1e-2, weight_decay=0, betas=[0.9, 0.95], eps=1e-5 ) loss_fn = torch.nn.CrossEntropyLoss() lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-7) loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=True, collate_fn=AudioCollator(text_pad=tokenizer.pad_token_id), ) logger = MetricLogger() logger.create_meters(loss=None, mse_loss=None, audio_loss=None, text_loss=None) for _epoch in range(epochs): _epoch += 1 logger.reset() header = "{blue_light}{0}: {1}{reset}".format( "Epoch", output_iter(_epoch, epochs), **ansi ) header = header.ljust(max(len("Epoch"), 30) + get_ansi_len(header)) iterator = logger.log_every(loader, header=header, tqdm_header="Batch") for batch in iterator: speakers = batch["speaker"] text_input_ids = batch["text_input_ids"] text_attention_mask = batch["text_attention_mask"] audio_mel_specs = batch["audio_mel_specs"] audio_attention_mask = batch["audio_attention_mask"] batch_size, text_len = text_attention_mask.size() dvae_audio_latents = dvae_encoder(audio_mel_specs, audio_attention_mask) _, dvae_audio_input_ids = quantize( dvae_decoder.vq_layer.quantizer, dvae_audio_latents ) dvae_audio_input_ids[~audio_attention_mask.bool()] = AUDIO_PAD_TOKEN_ID extended_audio_attention_mask = torch.cat( [ audio_attention_mask, torch.zeros( (batch_size, 1), dtype=audio_attention_mask.dtype, device=audio_attention_mask.device, ), ], dim=1, ) extended_audio_input_ids = torch.cat( [ dvae_audio_input_ids, AUDIO_PAD_TOKEN_ID * torch.ones( (batch_size, 1, gpt.num_vq), dtype=dvae_audio_input_ids.dtype, device=dvae_audio_input_ids.device, ), ], dim=1, ) indices = audio_attention_mask.int().sum(dim=1) for i in range(batch_size): extended_audio_attention_mask[i, indices[i]] = 1 extended_audio_input_ids[i, indices[i]] = AUDIO_EOS_TOKEN_ID input_ids = torch.cat( [ text_input_ids.unsqueeze(-1).repeat(1, 1, gpt.num_vq), extended_audio_input_ids, ], dim=1, ) attention_mask = torch.cat( [text_attention_mask, extended_audio_attention_mask], dim=1 ) text_mask = torch.cat( [ torch.ones_like(text_attention_mask, dtype=bool), torch.zeros_like(extended_audio_attention_mask, dtype=bool), ], dim=1, ) labels = input_ids.clone() labels[~attention_mask.bool()] = IGNORE_TOKEN_ID inputs_embeds = gpt.get_emb(input_ids=input_ids, text_mask=text_mask) indices = torch.all(input_ids == SPEAKER_TOKEN_ID, dim=-1) for i, speaker in enumerate(speakers): inputs_embeds[i, indices[i]] = F.normalize( speaker_embeds[speaker].to(dtype=inputs_embeds.dtype), p=2.0, dim=-1, eps=1e-12, ).unsqueeze(0) outputs = gpt.gpt.forward( inputs_embeds=inputs_embeds, attention_mask=attention_mask ) hidden_states = outputs.last_hidden_state text_hidden_states = hidden_states[:, : text_len - 1] audio_hidden_states = hidden_states[:, text_len - 1 : -1] audio_logits = torch.stack( [gpt.head_code[i](audio_hidden_states) for i in range(gpt.num_vq)], dim=2, ) audio_loss = loss_fn( audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2) ) loss = audio_loss text_logits = gpt.head_text(text_hidden_states) text_loss = loss_fn( text_logits.flatten(0, 1), labels[:, 1:text_len, 0].flatten(0, 1) ) loss += text_loss logger.meters["text_loss"].update(text_loss.item(), n=batch_size) gpt_gen_mel_specs = decoder_decoder( audio_hidden_states[:, :-1].transpose(1, 2) ).transpose(1, 2) mse_loss = torch.nn.functional.mse_loss(gpt_gen_mel_specs, audio_mel_specs) loss += 0.01 * mse_loss optimizer.zero_grad() if train_text: # just for test text_loss.backward() else: loss.backward() torch.nn.utils.clip_grad_norm_(speaker_embeds.values(), 1.0) optimizer.step() logger.meters["loss"].update(loss.item(), n=batch_size) logger.meters["mse_loss"].update(mse_loss.item(), n=batch_size) logger.meters["audio_loss"].update(audio_loss.item(), n=batch_size) lr_scheduler.step() optimizer.zero_grad() return speaker_embeds if __name__ == "__main__": import argparse import os import pathlib import numpy as np from modules import config from modules.devices import devices from modules.models import load_chat_tts from modules.speaker import Speaker config.runtime_env_vars.no_half = True config.runtime_env_vars.use_cpu = [] devices.reset_device() parser = argparse.ArgumentParser() parser.add_argument("--save_folder", type=str, default="./") parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--train_text", action="store_true", help="train text loss") # 初始化 speaker parser.add_argument("--init_speaker", type=str) parser.add_argument( "--data_path", type=str, default="datasets/data_speaker_a/speaker_a.list", help="the data_path to json/list file", ) parser.add_argument("--tar_path", type=str, help="the tarball path with wavs") parser.add_argument( "--tar_in_memory", action="store_true", help="load tarball in memory" ) args = parser.parse_args() data_path: str = args.data_path tar_path: str | None = args.tar_path tar_in_memory: bool = args.tar_in_memory train_text: bool = args.train_text # gpt_lora: bool = args.gpt_lora # gpt_kbit: int = args.gpt_kbit save_folder: str = args.save_folder batch_size: int = args.batch_size epochs: int = args.epochs init_speaker: str = args.init_speaker speaker_embeds_save_path = os.path.join(save_folder, "speaker_embeds.npz") chat = load_chat_tts() dataset = XzListTar( root=data_path, tokenizer=chat.pretrain_models["tokenizer"], vocos_model=chat.pretrain_models["vocos"], tar_path=tar_path, tar_in_memory=tar_in_memory, device=devices.get_device_for("trainer"), # speakers=None, # set(['speaker_A', 'speaker_B']) ) print("len(dataset)", len(dataset)) speaker_embeds = None if init_speaker: spk: Speaker = Speaker.from_file(init_speaker) speaker_embeds = { speaker: torch.tensor( spk.emb, device=devices.get_device_for("trainer"), requires_grad=True, ) for speaker in dataset.speakers } speaker_embeds = train_speaker_embeddings( chat, dataset, chat.pretrain_models["gpt"], batch_size=batch_size, epochs=epochs, train_text=train_text, speaker_embeds=speaker_embeds, ) speaker_outs = { speaker: Speaker(speaker_embed.detach().cpu(), f"ep{epochs}_{speaker}") for speaker, speaker_embed in speaker_embeds.items() } time_str = np.datetime_as_string(np.datetime64("now", "s")) time_str = time_str.replace(":", "_").replace(" ", "_").replace("-", "_") for speaker, speaker_out in speaker_outs.items(): torch.save( speaker_out, pathlib.Path(save_folder) / f"spk_{speaker}_{time_str}_ep{epochs}.pt", ) # example """ python -m modules.finetune.train_speaker \ --data_path datasets/data_speaker_a/speaker_a.list \ --save_folder ./data \ --init_speaker ./data/speakers/Bob.pt \ --epochs 100 \ --batch_size 6 \ --train_text """