Spaces:
Sleeping
Sleeping
import functools | |
import torch | |
import transformers | |
import peft | |
from transformers.trainer_pt_utils import LabelSmoother | |
from utils.dataset import AudioCollator | |
from utils.logger import MetricLogger | |
from utils.output import ansi, get_ansi_len, output_iter | |
IGNORE_TOKEN_ID = LabelSmoother.ignore_index | |
def train_gpt_lora( | |
chat, | |
dataset, | |
decoder_encoder, | |
dvae_encoder, | |
batch_size=16, | |
epochs=10, | |
train_text=True, | |
speaker_embeds=None, | |
lora_r=8, | |
lora_alpha=16, | |
): | |
if speaker_embeds is None: | |
speaker_embeds = {} | |
tokenizer = chat.pretrain_models["tokenizer"] | |
decoder_decoder = chat.pretrain_models["decoder"] | |
decoder_decoder.eval().requires_grad_(False) | |
decoder_encoder.to(device=dataset.device).eval().requires_grad_(False) | |
dvae_decoder = chat.pretrain_models["dvae"] | |
dvae_decoder.eval().requires_grad_(False) | |
dvae_encoder.to(device=dataset.device).eval().requires_grad_(False) | |
gpt = chat.pretrain_models["gpt"] | |
gpt.train().requires_grad_() | |
# Add LoRA to GPT model | |
lora_config = peft.LoraConfig(r=lora_r, lora_alpha=lora_alpha) | |
gpt.gpt = peft.get_peft_model(gpt.gpt, lora_config) | |
speaker_embeds = { | |
speaker: torch.randn(768, device=dataset.device, requires_grad=True) | |
for speaker in dataset.speakers | |
} | speaker_embeds | |
for speaker_embed in speaker_embeds.values(): | |
std, mean = chat.pretrain_models["spk_stat"].chunk(2) | |
speaker_embed.data = speaker_embed.data * std + mean | |
SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]") | |
AUDIO_EOS_TOKEN_ID = 0 | |
AUDIO_PAD_TOKEN_ID = AUDIO_EOS_TOKEN_ID | |
train_params = list(gpt.parameters()) + list(speaker_embeds.values()) | |
optimizer = torch.optim.Adam( | |
gpt.parameters(), lr=1e-3, weight_decay=0, betas=[0.9, 0.95], eps=1e-5 | |
) | |
optimizer.add_param_group({"params": speaker_embeds.values(), "lr": 1e-1}) | |
loss_fn = torch.nn.CrossEntropyLoss() | |
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-7) | |
loader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
collate_fn=AudioCollator(text_pad=tokenizer.pad_token_id), | |
) | |
logger = MetricLogger() | |
logger.create_meters(loss=None, mse_loss=None, audio_loss=None, text_loss=None) | |
for _epoch in range(epochs): | |
_epoch += 1 | |
logger.reset() | |
header = "{blue_light}{0}: {1}{reset}".format( | |
"Epoch", output_iter(_epoch, epochs), **ansi | |
) | |
header = header.ljust(max(len("Epoch"), 30) + get_ansi_len(header)) | |
iterator = logger.log_every(loader, header=header, tqdm_header="Batch") | |
for batch in iterator: | |
speakers = batch["speaker"] | |
text_input_ids = batch["text_input_ids"] | |
text_attention_mask = batch["text_attention_mask"] | |
audio_mel_specs = batch["audio_mel_specs"] | |
audio_attention_mask = batch["audio_attention_mask"] | |
batch_size, text_len = text_attention_mask.size() | |
dvae_audio_latents = dvae_encoder(audio_mel_specs, audio_attention_mask) | |
_, dvae_audio_input_ids = quantize( | |
dvae_decoder.vq_layer.quantizer, dvae_audio_latents | |
) | |
dvae_audio_input_ids[~audio_attention_mask.bool()] = AUDIO_PAD_TOKEN_ID | |
extended_audio_attention_mask = torch.cat( | |
[ | |
audio_attention_mask, | |
torch.zeros( | |
(batch_size, 1), | |
dtype=audio_attention_mask.dtype, | |
device=audio_attention_mask.device, | |
), | |
], | |
dim=1, | |
) | |
extended_audio_input_ids = torch.cat( | |
[ | |
dvae_audio_input_ids, | |
AUDIO_PAD_TOKEN_ID | |
* torch.ones( | |
(batch_size, 1, gpt.num_vq), | |
dtype=dvae_audio_input_ids.dtype, | |
device=dvae_audio_input_ids.device, | |
), | |
], | |
dim=1, | |
) | |
indices = audio_attention_mask.int().sum(dim=1) | |
for i in range(batch_size): | |
extended_audio_attention_mask[i, indices[i]] = 1 | |
extended_audio_input_ids[i, indices[i]] = AUDIO_EOS_TOKEN_ID | |
input_ids = torch.cat( | |
[ | |
text_input_ids.unsqueeze(-1).repeat(1, 1, gpt.num_vq), | |
extended_audio_input_ids, | |
], | |
dim=1, | |
) | |
attention_mask = torch.cat( | |
[text_attention_mask, extended_audio_attention_mask], dim=1 | |
) | |
text_mask = torch.cat( | |
[ | |
torch.ones_like(text_attention_mask, dtype=bool), | |
torch.zeros_like(extended_audio_attention_mask, dtype=bool), | |
], | |
dim=1, | |
) | |
labels = input_ids.clone() | |
labels[~attention_mask.bool()] = IGNORE_TOKEN_ID | |
inputs_embeds = gpt.get_emb(input_ids=input_ids, text_mask=text_mask) | |
indices = torch.all(input_ids == SPEAKER_TOKEN_ID, dim=-1) | |
for i, speaker in enumerate(speakers): | |
inputs_embeds[i, indices[i]] = torch.nn.functional.normalize( | |
speaker_embeds[speaker].to(dtype=inputs_embeds.dtype), | |
p=2.0, | |
dim=-1, | |
eps=1e-12, | |
).unsqueeze(0) | |
outputs = gpt.gpt.forward( | |
inputs_embeds=inputs_embeds, attention_mask=attention_mask | |
) | |
hidden_states = outputs.last_hidden_state | |
text_hidden_states = hidden_states[:, : text_len - 1] | |
audio_hidden_states = hidden_states[:, text_len - 1 : -1] | |
audio_logits = torch.stack( | |
[gpt.head_code[i](audio_hidden_states) for i in range(gpt.num_vq)], | |
dim=2, | |
) | |
audio_loss = loss_fn( | |
audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2) | |
) | |
loss = audio_loss | |
if train_text: | |
text_logits = gpt.head_text(text_hidden_states) | |
text_loss = loss_fn( | |
text_logits.flatten(0, 1), labels[:, 1:text_len, 0].flatten(0, 1) | |
) | |
loss += text_loss | |
logger.meters["text_loss"].update(text_loss.item(), n=batch_size) | |
gpt_gen_mel_specs = decoder_decoder( | |
audio_hidden_states[:, :-1].transpose(1, 2) | |
).transpose(1, 2) | |
mse_loss = torch.nn.functional.mse_loss(gpt_gen_mel_specs, audio_mel_specs) | |
loss += 0.01 * mse_loss | |
optimizer.zero_grad() | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(train_params, 1.0) | |
optimizer.step() | |
logger.meters["loss"].update(loss.item(), n=batch_size) | |
logger.meters["mse_loss"].update(mse_loss.item(), n=batch_size) | |
logger.meters["audio_loss"].update(audio_loss.item(), n=batch_size) | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
return speaker_embeds | |
# Example usage | |
def main(): | |
# Load necessary models and data paths | |
chat = ChatTTS.Chat() | |
chat.load_models() | |
dataset = XzListTar( | |
root="data/all.list", | |
tokenizer=chat.pretrain_models["tokenizer"], | |
vocos_model=chat.pretrain_models["vocos"], | |
tar_path="data/Xz.tar", | |
tar_in_memory=True, | |
process_ahead=True, | |
) | |
decoder_encoder = DVAEEncoder( | |
**get_encoder_config(chat.pretrain_models["decoder"].decoder) | |
) | |
dvae_encoder = DVAEEncoder( | |
**get_encoder_config(chat.pretrain_models["dvae"].decoder) | |
) | |
# Train GPT with LoRA | |
speaker_embeds = train_gpt_lora( | |
chat=chat, | |
dataset=dataset, | |
decoder_encoder=decoder_encoder, | |
dvae_encoder=dvae_encoder, | |
batch_size=32, | |
epochs=10, | |
train_text=True, | |
lora_r=8, | |
lora_alpha=16, | |
) | |
# Save LoRA parameters and embeddings | |
lora_save_path = "./saved_models/gpt_lora.pth" | |
peft.save_pretrained(gpt.gpt, lora_save_path) | |
np.savez( | |
"./saved_models/speaker_embeds.npz", | |
**{k: v.cpu().numpy() for k, v in speaker_embeds.items()} | |
) | |
if __name__ == "__main__": | |
main() | |