from transformers import AutoModelForSequenceClassification, PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer import torch import torch.nn as nn from text_utils import TextCleaner textclenaer = TextCleaner() def length_to_mask(lengths): mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) mask = torch.gt(mask+1, lengths.unsqueeze(1)) return mask device = 'cuda' if torch.cuda.is_available() else 'cpu' # tokenizer_koto_prompt = AutoTokenizer.from_pretrained("google/mt5-small", trust_remote_code=True) tokenizer_koto_prompt = AutoTokenizer.from_pretrained("ku-nlp/deberta-v3-base-japanese", trust_remote_code=True) tokenizer_koto_text = AutoTokenizer.from_pretrained("line-corporation/line-distilbert-base-japanese", trust_remote_code=True) class KotoDama_Prompt(PreTrainedModel): def __init__(self, config): super().__init__(config) self.backbone = AutoModel.from_config(config) self.output = nn.Sequential(nn.Linear(config.hidden_size, 512), nn.LeakyReLU(0.2), nn.Linear(512, config.num_labels)) def forward( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, labels=None, ): outputs = self.backbone( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, ) sequence_output = outputs.last_hidden_state[:, 0, :] outputs = self.output(sequence_output) # if labels, then we are training loss = None if labels is not None: loss_fn = nn.MSELoss() # labels = labels.unsqueeze(1) loss = loss_fn(outputs, labels) return { "loss": loss, "logits": outputs } class KotoDama_Text(PreTrainedModel): def __init__(self, config): super().__init__(config) self.backbone = AutoModel.from_config(config) self.output = nn.Sequential(nn.Linear(config.hidden_size, 512), nn.LeakyReLU(0.2), nn.Linear(512, config.num_labels)) def forward( self, input_ids, attention_mask=None, # token_type_ids=None, # position_ids=None, labels=None, ): outputs = self.backbone( input_ids, attention_mask=attention_mask, # token_type_ids=token_type_ids, # position_ids=position_ids, ) sequence_output = outputs.last_hidden_state[:, 0, :] outputs = self.output(sequence_output) # if labels, then we are training loss = None if labels is not None: loss_fn = nn.MSELoss() # labels = labels.unsqueeze(1) loss = loss_fn(outputs, labels) return { "loss": loss, "logits": outputs } def inference(model, diffusion_sampler, text=None, ref_s=None, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, rate_of_speech=1.): tokens = textclenaer(text) tokens.insert(0, 0) tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) with torch.no_grad(): input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) text_mask = length_to_mask(input_lengths).to(device) t_en = model.text_encoder(tokens, input_lengths, text_mask) bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) d_en = model.bert_encoder(bert_dur).transpose(-1, -2) s_pred = diffusion_sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), embedding=bert_dur, embedding_scale=embedding_scale, features=ref_s, # reference from the same speaker as the embedding num_steps=diffusion_steps).squeeze(1) s = s_pred[:, 128:] ref = s_pred[:, :128] ref = alpha * ref + (1 - alpha) * ref_s[:, :128] s = beta * s + (1 - beta) * ref_s[:, 128:] d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask) x = model.predictor.lstm(d) x_mod = model.predictor.prepare_projection(x) duration = model.predictor.duration_proj(x_mod) duration = torch.sigmoid(duration).sum(axis=-1) / rate_of_speech pred_dur = torch.round(duration.squeeze()).clamp(min=1) pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) c_frame = 0 for i in range(pred_aln_trg.size(0)): pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 c_frame += int(pred_dur[i].data) # encode prosody en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) F0_pred, N_pred = model.predictor.F0Ntrain(en, s) asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device)) out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0)) return out.squeeze().cpu().numpy()[..., :-50] def Longform(model, diffusion_sampler, text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1, rate_of_speech=1.0): tokens = textclenaer(text) tokens.insert(0, 0) tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) with torch.no_grad(): input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) text_mask = length_to_mask(input_lengths).to(device) t_en = model.text_encoder(tokens, input_lengths, text_mask) bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) d_en = model.bert_encoder(bert_dur).transpose(-1, -2) s_pred = diffusion_sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), embedding=bert_dur, embedding_scale=embedding_scale, features=ref_s, num_steps=diffusion_steps).squeeze(1) if s_prev is not None: # convex combination of previous and current style s_pred = t * s_prev + (1 - t) * s_pred s = s_pred[:, 128:] ref = s_pred[:, :128] ref = alpha * ref + (1 - alpha) * ref_s[:, :128] s = beta * s + (1 - beta) * ref_s[:, 128:] s_pred = torch.cat([ref, s], dim=-1) d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask) x = model.predictor.lstm(d) x_mod = model.predictor.prepare_projection(x) # 640 -> 512 duration = model.predictor.duration_proj(x_mod) duration = torch.sigmoid(duration).sum(axis=-1) / rate_of_speech pred_dur = torch.round(duration.squeeze()).clamp(min=1) pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) c_frame = 0 for i in range(pred_aln_trg.size(0)): pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 c_frame += int(pred_dur[i].data) # encode prosody en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) F0_pred, N_pred = model.predictor.F0Ntrain(en, s) asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device)) out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0)) return out.squeeze().cpu().numpy()[..., :-100], s_pred def merge_short_elements(lst): i = 0 while i < len(lst): if i > 0 and len(lst[i]) < 10: lst[i-1] += ' ' + lst[i] lst.pop(i) else: i += 1 return lst def merge_three(text_list, maxim=2): merged_list = [] for i in range(0, len(text_list), maxim): merged_text = ' '.join(text_list[i:i+maxim]) merged_list.append(merged_text) return merged_list def merging_sentences(lst): return merge_three(merge_short_elements(lst))