diff --git a/cosyvoice/__init__.py b/cosyvoice/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosyvoice/__pycache__/__init__.cpython-310.pyc b/cosyvoice/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..804b16392c253ff69c74f6e9a3dc0ccbb6847f1a Binary files /dev/null and b/cosyvoice/__pycache__/__init__.cpython-310.pyc differ diff --git a/cosyvoice/bin/inference.py b/cosyvoice/bin/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..6b777fa1cba925f9786db60b7efa15dcd189adeb --- /dev/null +++ b/cosyvoice/bin/inference.py @@ -0,0 +1,114 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +import os + +import torch +from torch.utils.data import DataLoader +import torchaudio +from hyperpyyaml import load_hyperpyyaml +from tqdm import tqdm +from cosyvoice.cli.model import CosyVoiceModel + +from cosyvoice.dataset.dataset import Dataset + +def get_args(): + parser = argparse.ArgumentParser(description='inference with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--prompt_data', required=True, help='prompt data file') + parser.add_argument('--prompt_utt2data', required=True, help='prompt data file') + parser.add_argument('--tts_text', required=True, help='tts input file') + parser.add_argument('--llm_model', required=True, help='llm model file') + parser.add_argument('--flow_model', required=True, help='flow model file') + parser.add_argument('--hifigan_model', required=True, help='hifigan model file') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--mode', + default='sft', + choices=['sft', 'zero_shot'], + help='inference mode') + parser.add_argument('--result_dir', required=True, help='asr result file') + args = parser.parse_args() + print(args) + return args + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + # Init cosyvoice models from configs + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f) + + model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift']) + model.load(args.llm_model, args.flow_model, args.hifigan_model) + + test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data) + test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) + + del configs + os.makedirs(args.result_dir, exist_ok=True) + fn = os.path.join(args.result_dir, 'wav.scp') + f = open(fn, 'w') + with torch.no_grad(): + for batch_idx, batch in tqdm(enumerate(test_data_loader)): + utts = batch["utts"] + assert len(utts) == 1, "inference mode only support batchsize 1" + text = batch["text"] + text_token = batch["text_token"].to(device) + text_token_len = batch["text_token_len"].to(device) + tts_text = batch["tts_text"] + tts_index = batch["tts_index"] + tts_text_token = batch["tts_text_token"].to(device) + tts_text_token_len = batch["tts_text_token_len"].to(device) + speech_token = batch["speech_token"].to(device) + speech_token_len = batch["speech_token_len"].to(device) + speech_feat = batch["speech_feat"].to(device) + speech_feat_len = batch["speech_feat_len"].to(device) + utt_embedding = batch["utt_embedding"].to(device) + spk_embedding = batch["spk_embedding"].to(device) + if args.mode == 'sft': + model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, + 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding} + else: + model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, + 'prompt_text': text_token, 'prompt_text_len': text_token_len, + 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len, + 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len, + 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len, + 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding} + model_output = model.inference(**model_input) + tts_key = '{}_{}'.format(utts[0], tts_index[0]) + tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key)) + torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050) + f.write('{} {}\n'.format(tts_key, tts_fn)) + f.flush() + f.close() + logging.info('Result wav.scp saved in {}'.format(fn)) + + +if __name__ == '__main__': + main() diff --git a/cosyvoice/bin/train.py b/cosyvoice/bin/train.py new file mode 100644 index 0000000000000000000000000000000000000000..2f4c9fee8415823f26ae7d11a3b81ee24b6f31ea --- /dev/null +++ b/cosyvoice/bin/train.py @@ -0,0 +1,140 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import argparse +import datetime +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +from copy import deepcopy +import torch +import torch.distributed as dist +# import deepspeed +import pdb +from hyperpyyaml import load_hyperpyyaml + +from torch.distributed.elastic.multiprocessing.errors import record + +from cosyvoice.utils.executor import Executor +from cosyvoice.utils.train_utils import ( + init_distributed, + init_dataset_and_dataloader, + init_optimizer_and_scheduler, + init_summarywriter, save_model, + wrap_cuda_model, check_modify_and_save_config) + + +def get_args(): + parser = argparse.ArgumentParser(description='training your network') + parser.add_argument('--train_engine', + default='torch_ddp', + choices=['torch_ddp', 'deepspeed'], + help='Engine for paralleled training') + parser.add_argument('--model', required=True, help='model which will be trained') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--train_data', required=True, help='train data file') + parser.add_argument('--cv_data', required=True, help='cv data file') + parser.add_argument('--checkpoint', help='checkpoint model') + parser.add_argument('--model_dir', required=True, help='save model dir') + parser.add_argument('--tensorboard_dir', + default='tensorboard', + help='tensorboard log dir') + parser.add_argument('--ddp.dist_backend', + dest='dist_backend', + default='nccl', + choices=['nccl', 'gloo'], + help='distributed backend') + parser.add_argument('--num_workers', + default=0, + type=int, + help='num of subprocess workers for reading') + parser.add_argument('--prefetch', + default=100, + type=int, + help='prefetch number') + parser.add_argument('--pin_memory', + action='store_true', + default=False, + help='Use pinned memory buffers used for reading') + parser.add_argument('--deepspeed.save_states', + dest='save_states', + default='model_only', + choices=['model_only', 'model+optimizer'], + help='save model/optimizer states') + parser.add_argument('--timeout', + default=30, + type=int, + help='timeout (in seconds) of cosyvoice_join.') + # parser = deepspeed.add_config_arguments(parser) + args = parser.parse_args() + return args + + +@record +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + + override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model} + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f, overrides=override_dict) + configs['train_conf'].update(vars(args)) + + # Init env for ddp + init_distributed(args) + + # Get dataset & dataloader + train_dataset, cv_dataset, train_data_loader, cv_data_loader = \ + init_dataset_and_dataloader(args, configs) + + # Do some sanity checks and save config to arsg.model_dir + configs = check_modify_and_save_config(args, configs) + + # Tensorboard summary + writer = init_summarywriter(args) + + # load checkpoint + model = configs[args.model] + if args.checkpoint is not None: + model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')) + + # Dispatch model from cpu to gpu + model = wrap_cuda_model(args, model) + + # Get optimizer & scheduler + model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model) + # pdb.set_trace() + # Save init checkpoints + info_dict = deepcopy(configs['train_conf']) + save_model(model, 'init', info_dict) + + # Get executor + executor = Executor() + + # Start training loop + for epoch in range(info_dict['max_epoch']): + executor.epoch = epoch + train_dataset.set_epoch(epoch) + dist.barrier() + # try: + # dist.barrier() + # except RuntimeError as e: + # logging.info('except RuntimeError as e: {}'.format(e)) + group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout)) + executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join) + dist.destroy_process_group(group_join) + +if __name__ == '__main__': + main() diff --git a/cosyvoice/cli/__init__.py b/cosyvoice/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py new file mode 100644 index 0000000000000000000000000000000000000000..ea8c4482891a62df6cbac39faa88972c81f5412f --- /dev/null +++ b/cosyvoice/cli/cosyvoice.py @@ -0,0 +1,83 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import torch +from hyperpyyaml import load_hyperpyyaml +from modelscope import snapshot_download +from cosyvoice.cli.frontend import CosyVoiceFrontEnd +from cosyvoice.cli.model import CosyVoiceModel + +class CosyVoice: + + def __init__(self, model_dir): + instruct = True if '-Instruct' in model_dir else False + self.model_dir = model_dir + if not os.path.exists(model_dir): + model_dir = snapshot_download(model_dir) + with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f: + configs = load_hyperpyyaml(f) + self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'], + configs['feat_extractor'], + '{}/campplus.onnx'.format(model_dir), + '{}/speech_tokenizer_v1.onnx'.format(model_dir), + '{}/spk2info.pt'.format(model_dir), + instruct, + configs['allowed_special']) + self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift']) + self.model.load('{}/llm.pt'.format(model_dir), + '{}/flow.pt'.format(model_dir), + '{}/hift.pt'.format(model_dir)) + del configs + + def list_avaliable_spks(self): + spks = list(self.frontend.spk2info.keys()) + return spks + + def inference_sft(self, tts_text, spk_id): + tts_speeches = [] + for i in self.frontend.text_normalize(tts_text, split=True): + model_input = self.frontend.frontend_sft(i, spk_id) + model_output = self.model.inference(**model_input) + tts_speeches.append(model_output['tts_speech']) + return {'tts_speech': torch.concat(tts_speeches, dim=1)} + + def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k): + prompt_text = self.frontend.text_normalize(prompt_text, split=False) + tts_speeches = [] + for i in self.frontend.text_normalize(tts_text, split=True): + model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k) + model_output = self.model.inference(**model_input) + tts_speeches.append(model_output['tts_speech']) + return {'tts_speech': torch.concat(tts_speeches, dim=1)} + + def inference_cross_lingual(self, tts_text, prompt_speech_16k): + if self.frontend.instruct is True: + raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir)) + tts_speeches = [] + for i in self.frontend.text_normalize(tts_text, split=True): + model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k) + model_output = self.model.inference(**model_input) + tts_speeches.append(model_output['tts_speech']) + return {'tts_speech': torch.concat(tts_speeches, dim=1)} + + def inference_instruct(self, tts_text, spk_id, instruct_text): + if self.frontend.instruct is False: + raise ValueError('{} do not support instruct inference'.format(self.model_dir)) + instruct_text = self.frontend.text_normalize(instruct_text, split=False) + tts_speeches = [] + for i in self.frontend.text_normalize(tts_text, split=True): + model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text) + model_output = self.model.inference(**model_input) + tts_speeches.append(model_output['tts_speech']) + return {'tts_speech': torch.concat(tts_speeches, dim=1)} diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py new file mode 100644 index 0000000000000000000000000000000000000000..3ed85500cd3ab65f8f4f7540c084adb0c648186f --- /dev/null +++ b/cosyvoice/cli/frontend.py @@ -0,0 +1,168 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +import onnxruntime +import torch +import numpy as np +import whisper +from typing import Callable +import torchaudio.compliance.kaldi as kaldi +import torchaudio +import os +import re +import inflect +try: + import ttsfrd + use_ttsfrd = True +except ImportError: + print("failed to import ttsfrd, use WeTextProcessing instead") + from tn.chinese.normalizer import Normalizer as ZhNormalizer + from tn.english.normalizer import Normalizer as EnNormalizer + use_ttsfrd = False +from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph + + +class CosyVoiceFrontEnd: + + def __init__(self, + get_tokenizer: Callable, + feat_extractor: Callable, + campplus_model: str, + speech_tokenizer_model: str, + spk2info: str = '', + instruct: bool = False, + allowed_special: str = 'all'): + self.tokenizer = get_tokenizer() + self.feat_extractor = feat_extractor + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"]) + self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"if torch.cuda.is_available() else "CPUExecutionProvider"]) + if os.path.exists(spk2info): + self.spk2info = torch.load(spk2info, map_location=self.device) + self.instruct = instruct + self.allowed_special = allowed_special + self.inflect_parser = inflect.engine() + self.use_ttsfrd = use_ttsfrd + if self.use_ttsfrd: + self.frd = ttsfrd.TtsFrontendEngine() + ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, 'failed to initialize ttsfrd resource' + self.frd.set_lang_type('pinyin') + self.frd.enable_pinyin_mix(True) + self.frd.set_breakmodel_index(1) + else: + self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False) + self.en_tn_model = EnNormalizer() + + def _extract_text_token(self, text): + text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special) + text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device) + text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device) + return text_token, text_token_len + + def _extract_speech_token(self, speech): + feat = whisper.log_mel_spectrogram(speech, n_mels=128) + speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(), + self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() + speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device) + speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device) + return speech_token, speech_token_len + + def _extract_spk_embedding(self, speech): + feat = kaldi.fbank(speech, + num_mel_bins=80, + dither=0, + sample_frequency=16000) + feat = feat - feat.mean(dim=0, keepdim=True) + embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() + embedding = torch.tensor([embedding]).to(self.device) + return embedding + + def _extract_speech_feat(self, speech): + speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device) + speech_feat = speech_feat.unsqueeze(dim=0) + speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device) + return speech_feat, speech_feat_len + + def text_normalize(self, text, split=True): + text = text.strip() + if contains_chinese(text): + if self.use_ttsfrd: + text = self.frd.get_frd_extra_info(text, 'input') + else: + text = self.zh_tn_model.normalize(text) + text = text.replace("\n", "") + text = replace_blank(text) + text = replace_corner_mark(text) + text = text.replace(".", "、") + text = text.replace(" - ", ",") + text = remove_bracket(text) + text = re.sub(r'[,,]+$', '。', text) + texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80, + token_min_n=60, merge_len=20, + comma_split=False)] + else: + if self.use_ttsfrd: + text = self.frd.get_frd_extra_info(text, 'input') + else: + text = self.en_tn_model.normalize(text) + text = spell_out_number(text, self.inflect_parser) + texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80, + token_min_n=60, merge_len=20, + comma_split=False)] + if split is False: + return text + return texts + + def frontend_sft(self, tts_text, spk_id): + tts_text_token, tts_text_token_len = self._extract_text_token(tts_text) + embedding = self.spk2info[spk_id]['embedding'] + model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding} + return model_input + + def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k): + tts_text_token, tts_text_token_len = self._extract_text_token(tts_text) + prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text) + prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k) + speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050) + speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k) + embedding = self._extract_spk_embedding(prompt_speech_16k) + model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, + 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len, + 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len, + 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len, + 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len, + 'llm_embedding': embedding, 'flow_embedding': embedding} + return model_input + + def frontend_cross_lingual(self, tts_text, prompt_speech_16k): + model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k) + # in cross lingual mode, we remove prompt in llm + del model_input['prompt_text'] + del model_input['prompt_text_len'] + del model_input['llm_prompt_speech_token'] + del model_input['llm_prompt_speech_token_len'] + return model_input + + def frontend_instruct(self, tts_text, spk_id, instruct_text): + model_input = self.frontend_sft(tts_text, spk_id) + # in instruct mode, we remove spk_embedding in llm due to information leakage + del model_input['llm_embedding'] + instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '') + model_input['prompt_text'] = instruct_text_token + model_input['prompt_text_len'] = instruct_text_token_len + return model_input diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py new file mode 100644 index 0000000000000000000000000000000000000000..446a84e079dcef74a018ef7fe6b2038709b97b0f --- /dev/null +++ b/cosyvoice/cli/model.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +class CosyVoiceModel: + + def __init__(self, + llm: torch.nn.Module, + flow: torch.nn.Module, + hift: torch.nn.Module): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.llm = llm + self.flow = flow + self.hift = hift + + def load(self, llm_model, flow_model, hift_model): + self.llm.load_state_dict(torch.load(llm_model, map_location=self.device)) + self.llm.to(self.device).eval() + self.flow.load_state_dict(torch.load(flow_model, map_location=self.device)) + self.flow.to(self.device).eval() + self.hift.load_state_dict(torch.load(hift_model, map_location=self.device)) + self.hift.to(self.device).eval() + + def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192), + prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32), + llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32), + flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32), + prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)): + tts_speech_token = self.llm.inference(text=text.to(self.device), + text_len=text_len.to(self.device), + prompt_text=prompt_text.to(self.device), + prompt_text_len=prompt_text_len.to(self.device), + prompt_speech_token=llm_prompt_speech_token.to(self.device), + prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device), + embedding=llm_embedding.to(self.device), + beam_size=1, + sampling=25, + max_token_text_ratio=30, + min_token_text_ratio=3) + tts_mel = self.flow.inference(token=tts_speech_token, + token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device), + prompt_token=flow_prompt_speech_token.to(self.device), + prompt_token_len=flow_prompt_speech_token_len.to(self.device), + prompt_feat=prompt_speech_feat.to(self.device), + prompt_feat_len=prompt_speech_feat_len.to(self.device), + embedding=flow_embedding.to(self.device)) + tts_speech = self.hift.inference(mel=tts_mel).cpu() + torch.cuda.empty_cache() + return {'tts_speech': tts_speech} + + def text_to_token(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192), + prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32), + llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32), + flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32), + prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)): + tts_speech_token = self.llm.inference(text=text.to(self.device), + text_len=text_len.to(self.device), + prompt_text=prompt_text.to(self.device), + prompt_text_len=prompt_text_len.to(self.device), + prompt_speech_token=llm_prompt_speech_token.to(self.device), + prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device), + embedding=llm_embedding.to(self.device), + beam_size=1, + sampling=25, + max_token_text_ratio=30, + min_token_text_ratio=3) + return tts_speech_token + + def token_to_speech(self, tts_speech_token, flow_embedding, llm_embedding=torch.zeros(0, 192), + prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32), + llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32), + flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32), + prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)): + + tts_mel = self.flow.inference(token=tts_speech_token, + token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device), + prompt_token=flow_prompt_speech_token.to(self.device), + prompt_token_len=flow_prompt_speech_token_len.to(self.device), + prompt_feat=prompt_speech_feat.to(self.device), + prompt_feat_len=prompt_speech_feat_len.to(self.device), + embedding=flow_embedding.to(self.device)) + tts_speech = self.hift.inference(mel=tts_mel).cpu() + torch.cuda.empty_cache() + return {'tts_speech': tts_speech} \ No newline at end of file diff --git a/cosyvoice/dataset/__init__.py b/cosyvoice/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosyvoice/dataset/dataset.py b/cosyvoice/dataset/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6681504383f73ac9ba1609b9d48bdec7aae23f28 --- /dev/null +++ b/cosyvoice/dataset/dataset.py @@ -0,0 +1,160 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import json +import math +from functools import partial + +import torch +import torch.distributed as dist +from torch.utils.data import IterableDataset +from cosyvoice.utils.file_utils import read_lists, read_json_lists + + +class Processor(IterableDataset): + + def __init__(self, source, f, *args, **kw): + assert callable(f) + self.source = source + self.f = f + self.args = args + self.kw = kw + + def set_epoch(self, epoch): + self.source.set_epoch(epoch) + + def __iter__(self): + """ Return an iterator over the source dataset processed by the + given processor. + """ + assert self.source is not None + assert callable(self.f) + return self.f(iter(self.source), *self.args, **self.kw) + + def apply(self, f): + assert callable(f) + return Processor(self, f, *self.args, **self.kw) + + +class DistributedSampler: + + def __init__(self, shuffle=True, partition=True): + self.epoch = -1 + self.update() + self.shuffle = shuffle + self.partition = partition + + def update(self): + assert dist.is_available() + if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + else: + self.rank = 0 + self.world_size = 1 + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + self.worker_id = 0 + self.num_workers = 1 + else: + self.worker_id = worker_info.id + self.num_workers = worker_info.num_workers + return dict(rank=self.rank, + world_size=self.world_size, + worker_id=self.worker_id, + num_workers=self.num_workers) + + def set_epoch(self, epoch): + self.epoch = epoch + + def sample(self, data): + """ Sample data according to rank/world_size/num_workers + + Args: + data(List): input data list + + Returns: + List: data list after sample + """ + data = list(range(len(data))) + # force datalist even + if self.partition: + if self.shuffle: + random.Random(self.epoch).shuffle(data) + if len(data) < self.world_size: + data = data * math.ceil(self.world_size / len(data)) + data = data[:self.world_size] + data = data[self.rank::self.world_size] + if len(data) < self.num_workers: + data = data * math.ceil(self.num_workers / len(data)) + data = data[:self.num_workers] + data = data[self.worker_id::self.num_workers] + return data + + +class DataList(IterableDataset): + + def __init__(self, lists, shuffle=True, partition=True): + self.lists = lists + self.sampler = DistributedSampler(shuffle, partition) + + def set_epoch(self, epoch): + self.sampler.set_epoch(epoch) + + def __iter__(self): + sampler_info = self.sampler.update() + indexes = self.sampler.sample(self.lists) + for index in indexes: + data = dict(src=self.lists[index]) + data.update(sampler_info) + yield data + + +def Dataset(data_list_file, + data_pipeline, + mode='train', + shuffle=True, + partition=True, + tts_file='', + prompt_utt2data=''): + """ Construct dataset from arguments + + We have two shuffle stage in the Dataset. The first is global + shuffle at shards tar/raw file level. The second is global shuffle + at training samples level. + + Args: + data_type(str): raw/shard + tokenizer (BaseTokenizer): tokenizer to tokenize + partition(bool): whether to do data partition in terms of rank + """ + assert mode in ['train', 'inference'] + lists = read_lists(data_list_file) + # import pdb + # pdb.set_trace() + if mode == 'inference': + with open(tts_file) as f: + tts_data = json.load(f) + utt2lists = read_json_lists(prompt_utt2data) + # filter unnecessary file in inference mode + lists = list(set([utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists])) + dataset = DataList(lists,shuffle=shuffle,partition=partition) + if mode == 'inference': + # map partial arg tts_data in inference mode + data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data) + for func in data_pipeline: + dataset = Processor(dataset, func, mode=mode) + return dataset diff --git a/cosyvoice/dataset/processor.py b/cosyvoice/dataset/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..8c8c743fdbe03139a9703ba635e31ab74c459c67 --- /dev/null +++ b/cosyvoice/dataset/processor.py @@ -0,0 +1,965 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random +import json +import tarfile +import json +import io +import pyarrow.parquet as pq +from io import BytesIO +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence +import torch.nn.functional as F +import tarfile +import json +import io +import wave +import numpy as np +import torchaudio +import os +import sys +import json +import random +import pickle +import argparse +import itertools +import mmap +import struct +import collections + + + +import shutil +import multiprocessing as mp +from pathlib import Path + +from tqdm import tqdm +from collections import defaultdict +from copy import deepcopy +from datetime import datetime +import pickle + +from wids import wids +import math + +torchaudio.set_audio_backend('soundfile') + +AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) + +try: + MAIN_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/mean_embedding.pt") + GPT_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/spk_mean_embeddings.pt") +except: + MAIN_SPK_EMBEDDING=torch.zeros(1,192) + GPT_SPK_EMBEDDING=torch.zeros(1,192) + +def parquet_opener(data, mode='train', tts_data={}): + """ Give url or local file, return file descriptor + Inplace operation. + + Args: + data(Iterable[str]): url or local file list + + Returns: + Iterable[{src, stream}] + """ + for sample in data: + assert 'src' in sample + url = sample['src'] + try: + df = pq.read_table(url).to_pandas() + for i in range(len(df)): + if mode == 'inference' and df.loc[i, 'utt'] not in tts_data: + continue + sample.update(dict(df.loc[i])) + if mode == 'train': + # NOTE do not return sample directly, must initialize a new dict + yield {**sample} + else: + for index, text in enumerate(tts_data[df.loc[i, 'utt']]): + yield {**sample, 'tts_index': index, 'tts_text': text} + except Exception as ex: + logging.warning('Failed to open {}, ex info {}'.format(url, ex)) + + + + +def parse_tar_header(header_bytes): + header = struct.unpack("!100s8s8s8s12s12s8s1s100s6s2s32s32s8s8s155s", header_bytes) + return TarHeader(*header) + +TarHeader = collections.namedtuple( + "TarHeader", + [ + "name", + "mode", + "uid", + "gid", + "size", + "mtime", + "chksum", + "typeflag", + "linkname", + "magic", + "version", + "uname", + "gname", + "devmajor", + "devminor", + "prefix", + ], +) + +class MMTar: + def __init__(self, file_path: Path | str): + self.stream = open(file_path, "rb") + self.mmap = mmap.mmap(self.stream.fileno(), 0, access=mmap.ACCESS_READ) + + def __del__(self): + try: + self.mmap.close() + self.stream.close() + except: # noqa + pass + + def get_at_offset(self, offset) -> tuple[str, bytes]: + header = parse_tar_header(self.mmap[offset : offset + 500]) + name = header.name.decode("utf-8").strip("\x00") + start = offset + 512 + end = start + int(header.size.decode("utf-8")[:-1], 8) + return name, self.mmap[start:end] + + +class Tar: + def __init__(self, path: Path): + self.tar = MMTar(path) + indices_path = path.with_suffix(".index") + self.index = pickle.loads(indices_path.read_bytes()) + self.name_mapping = {} + for name, offset, _ in self.index: + self.name_mapping[name] = offset + + def read(self, name: str) -> bytes: + return self.tar.get_at_offset(self.name_mapping[name])[1] + +def cosy_jsonl_opener(data, mode='train', tts_data={}): + """ Give url or local file, return file descriptor + Inplace operation. + + Args: + data(Iterable[str]): url or local file list + + Returns: + Iterable[{src, stream}] + """ + for sample in data: + assert 'src' in sample + cosy_jsonl_path = sample['src'] + tar_file_path=cosy_jsonl_path.replace(".vq0907.jsonl",".tar") + try: + tar_data=Tar(Path(tar_file_path)) + with open(cosy_jsonl_path, 'r') as f: + for line in f: + item=json.loads(line) + cosy_token = item['cosy_token'] + sample['speech_token']=torch.tensor(cosy_token) + sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) + # print(item['filename']) + yield {**sample} + + except Exception as ex: + logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) + + +def cosy_jsonl_opener_vq0918_nopool(data, mode='train', tts_data={}): + """ Give url or local file, return file descriptor + Inplace operation. + + Args: + data(Iterable[str]): url or local file list + + Returns: + Iterable[{src, stream}] + """ + for sample in data: + assert 'src' in sample + cosy_jsonl_path = sample['src'] + tar_file_path=cosy_jsonl_path.replace(".vq0918-nopool.jsonl",".tar") + + + try: + tar_data=Tar(Path(tar_file_path)) + with open(cosy_jsonl_path, 'r') as f: + # cosy_data = [json.loads(line) for line in f] + for line in f: + item=json.loads(line) + cosy_token = item['cosy_token'] + sample['speech_token']=torch.tensor(cosy_token) + sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) + # print(item['filename']) + yield {**sample} + + except Exception as ex: + logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) + + + +def cosy_jsonl_opener_vq0918_pool2(data, mode='train', tts_data={}): + """ Give url or local file, return file descriptor + Inplace operation. + + Args: + data(Iterable[str]): url or local file list + + Returns: + Iterable[{src, stream}] + """ + for sample in data: + assert 'src' in sample + cosy_jsonl_path = sample['src'] + tar_file_path=cosy_jsonl_path.replace(".vq0918-pool2.jsonl",".tar") + + try: + tar_data=Tar(Path(tar_file_path)) + with open(cosy_jsonl_path, 'r') as f: + for line in f: + item=json.loads(line) + cosy_token = item['cosy_token'] + sample['speech_token']=torch.tensor(cosy_token) + sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) + + yield {**sample} + + except Exception as ex: + logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) + + +def cosy_jsonl_opener_vq0918_pool4(data, mode='train', tts_data={}): + """ Give url or local file, return file descriptor + Inplace operation. + + Args: + data(Iterable[str]): url or local file list + + Returns: + Iterable[{src, stream}] + """ + for sample in data: + assert 'src' in sample + cosy_jsonl_path = sample['src'] + tar_file_path=cosy_jsonl_path.replace(".vq0918-pool4.jsonl",".tar") + try: + tar_data=Tar(Path(tar_file_path)) + with open(cosy_jsonl_path, 'r') as f: + # cosy_data = [json.loads(line) for line in f] + for line in f: + item=json.loads(line) + cosy_token = item['cosy_token'] + sample['speech_token']=torch.tensor(cosy_token) + sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) + # print(item['filename']) + yield {**sample} + + except Exception as ex: + logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) + + +def cosy_jsonl_opener_vq0918_pool8(data, mode='train', tts_data={}): + """ Give url or local file, return file descriptor + Inplace operation. + + Args: + data(Iterable[str]): url or local file list + + Returns: + Iterable[{src, stream}] + """ + for sample in data: + assert 'src' in sample + cosy_jsonl_path = sample['src'] + tar_file_path=cosy_jsonl_path.replace(".vq0918-pool8.jsonl",".tar") + + try: + tar_data=Tar(Path(tar_file_path)) + with open(cosy_jsonl_path, 'r') as f: + # cosy_data = [json.loads(line) for line in f] + for line in f: + item=json.loads(line) + cosy_token = item['cosy_token'] + sample['speech_token']=torch.tensor(cosy_token) + sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) + # print(item['filename']) + yield {**sample} + + except Exception as ex: + logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) + + + +def process_sft_vq0918_pool4(data, mode='train', tts_data={}): + for sample in data: + assert 'src' in sample + + token_npy_path = sample['src'] + wav_path=token_npy_path.replace(".vq0918-pool4.npy","") + + # wav_path,token_npy_path=sample['src'].split(' ') + try: + sample['speech_token']=torch.tensor(np.load(token_npy_path)) + sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) + if sample['speech'].shape[0] > 1: + sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) + sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) + yield {**sample} + except Exception as ex: + logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) + logging.warning('Failed to open {}'.format(wav_path)) + + +def process_sft_vq0918_pool4_split(data, mode='train',split_token=25, tts_data={}): + for sample in data: + assert 'src' in sample + + token_npy_path = sample['src'] + wav_path=token_npy_path.replace(".vq0918-pool4.npy","") + + # wav_path,token_npy_path=sample['src'].split(' ') + try: + # sample['speech_token']=torch.tensor(np.load(token_npy_path)) + # sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) + # if sample['speech'].shape[0] > 1: + # sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) + + + # sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) + + + speech_token=torch.tensor(np.load(token_npy_path)) + speech,sample_rate= torchaudio.load(wav_path) + # split_speech=int(split_token / 12.5 * sample_rate) + if speech.shape[0] > 1: + speech = speech.mean(dim=0, keepdim=True) + + sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) + sample['sample_rate']=sample_rate + + num_splits = (speech_token.size(0) + split_token - 1) // split_token + + for split_id in range(num_splits): + end_token_idx = min((split_id + 1) * split_token, speech_token.size(0)) + end_speech_idx=int(np.ceil(end_token_idx / 12.5 * sample_rate)) + sample['speech_token']=speech_token[:end_token_idx] + sample['speech']=speech[:,:end_speech_idx] + print(sample['speech_token'].size(),sample['speech'].size()) + yield {**sample} + except Exception as ex: + logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) + logging.warning('Failed to open {}'.format(wav_path)) + + +def process_sft_vq0918_pool2(data, mode='train', tts_data={}): + for sample in data: + assert 'src' in sample + + token_npy_path = sample['src'].replace(".vq0918-pool4.npy",".vq0918-pool2.npy") + wav_path=token_npy_path.replace(".vq0918-pool2.npy","") + + # wav_path,token_npy_path=sample['src'].split(' ') + try: + sample['speech_token']=torch.tensor(np.load(token_npy_path)) + sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) + if sample['speech'].shape[0] > 1: + sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) + + sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) + yield {**sample} + except Exception as ex: + logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) + logging.warning('Failed to open {}'.format(wav_path)) + + +def process_sft_vq0918_pool2_split(data, mode='train',split_token=50, tts_data={}): + for sample in data: + assert 'src' in sample + + token_npy_path = sample['src'] + wav_path=token_npy_path.replace(".vq0918-pool2.npy","") + + # wav_path,token_npy_path=sample['src'].split(' ') + try: + # sample['speech_token']=torch.tensor(np.load(token_npy_path)) + # sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) + # if sample['speech'].shape[0] > 1: + # sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) + + + # sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) + + + speech_token=torch.tensor(np.load(token_npy_path)) + speech,sample_rate= torchaudio.load(wav_path) + # split_speech=int(split_token / 12.5 * sample_rate) + if speech.shape[0] > 1: + speech = speech.mean(dim=0, keepdim=True) + + sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) + sample['sample_rate']=sample_rate + + num_splits = (speech_token.size(0) + split_token - 1) // split_token + + for split_id in range(num_splits): + end_token_idx = min((split_id + 1) * split_token, speech_token.size(0)) + end_speech_idx=int(np.ceil(end_token_idx / 25 * sample_rate)) + sample['speech_token']=speech_token[:end_token_idx] + sample['speech']=speech[:,:end_speech_idx] + print(sample['speech_token'].size(),sample['speech'].size()) + yield {**sample} + except Exception as ex: + logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) + logging.warning('Failed to open {}'.format(wav_path)) + +def process_sft_vq0918_pool4_gpt(data, mode='train', tts_data={}): + for sample in data: + assert 'src' in sample + try: + entry=json.loads(sample['src']) + sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) + + for conv in entry["conversations"]: + if "response_wav" in conv: + wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}" + token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy") + sample['speech_token']=torch.tensor(np.load(token_npy_path)) + sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) + if sample['speech'].shape[0] > 1: + sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) + sample['spk_embedding']=spk_embedding + yield {**sample} + except Exception as ex: + # logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) + logging.warning('Failed to open {}'.format(wav_path)) + + +def process_sft_vq0918_pool4_gpt_1010(data, mode='train', tts_data={}): + for sample in data: + assert 'src' in sample + try: + entry=json.loads(sample['src']) + sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) + + for conv in entry["conversations"]: + if "response_wav" in conv: + wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}" + token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy") + sample['speech_token']=torch.tensor(np.load(token_npy_path)) + sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) + if sample['speech'].shape[0] > 1: + sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) + sample['spk_embedding']=spk_embedding + yield {**sample} + if "prompt_wav" in conv: + wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}" + token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy") + sample['speech_token']=torch.tensor(np.load(token_npy_path)) + sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) + if sample['speech'].shape[0] > 1: + sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) + sample['spk_embedding']=spk_embedding + yield {**sample} + except Exception as ex: + # logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) + logging.warning('Failed to open {}'.format(wav_path)) + + +def filter(data, + max_length=10240, + min_length=10, + token_max_length=200, + token_min_length=1, + min_output_input_ratio=0.0005, + max_output_input_ratio=1, + mode='train'): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + data: Iterable[{key, wav, label, sample_rate}] + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + min_output_input_ratio: minimal ration of + token_length / feats_length(10ms) + max_output_input_ratio: maximum ration of + token_length / feats_length(10ms) + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + # sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data'])) + # del sample['audio_data'] + # sample['wav'] is torch.Tensor, we have 100 frames every second + num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100 + if num_frames < min_length: + continue + if num_frames > max_length: + continue + if len(sample['text_token']) < token_min_length: + continue + if len(sample['text_token']) > token_max_length: + continue + if len(sample['speech_token']) == 0: + continue + if num_frames != 0: + if len(sample['text_token']) / num_frames < min_output_input_ratio: + continue + if len(sample['text_token']) / num_frames > max_output_input_ratio: + continue + yield sample + + +def filter_speech_token(data, + max_length=10240, + min_length=10, + token_max_length=5000, + token_min_length=1, + min_output_input_ratio=0.0005, + max_output_input_ratio=30, + mode='train'): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + data: Iterable[{key, wav, label, sample_rate}] + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + min_output_input_ratio: minimal ration of + token_length / feats_length(10ms) + max_output_input_ratio: maximum ration of + token_length / feats_length(10ms) + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + # sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data'])) + # del sample['audio_data'] + # sample['wav'] is torch.Tensor, we have 100 frames every second + num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100 + if num_frames < min_length: + continue + if num_frames > max_length: + continue + if len(sample['speech_token']) < token_min_length: + continue + if len(sample['speech_token']) > token_max_length: + continue + if len(sample['speech_token']) == 0: + continue + if num_frames != 0: + if len(sample['speech_token']) / num_frames < min_output_input_ratio: + continue + if len(sample['speech_token']) / num_frames > max_output_input_ratio: + continue + yield sample + + +def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'): + """ Resample data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + resample_rate: target resample rate + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'speech' in sample + sample_rate = sample['sample_rate'] + waveform = sample['speech'] + if sample_rate != resample_rate: + if sample_rate < min_sample_rate: + continue + sample['sample_rate'] = resample_rate + sample['speech'] = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=resample_rate)(waveform) + max_val = sample['speech'].abs().max() + if max_val > 1: + sample['speech'] /= max_val + yield sample + + +def compute_fbank(data, + feat_extractor, + mode='train'): + """ Extract fbank + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'speech' in sample + # assert 'utt' in sample + # assert 'text_token' in sample + waveform = sample['speech'] + mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) + sample['speech_feat'] = mat + del sample['speech'] + yield sample + + +def parse_embedding(data, normalize, mode='train'): + """ Parse utt_embedding/spk_embedding + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32) + sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32) + if normalize: + sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0) + sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0) + yield sample + + +def tokenize(data, get_tokenizer, allowed_special, mode='train'): + """ Decode text to chars or BPE + Inplace operation + + Args: + data: Iterable[{key, wav, txt, sample_rate}] + + Returns: + Iterable[{key, wav, txt, tokens, label, sample_rate}] + """ + tokenizer = get_tokenizer() + for sample in data: + assert 'text' in sample + sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special) + if mode == 'inference': + sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special) + yield sample + + +def shuffle(data, shuffle_size=10000, mode='train'): + """ Local shuffle the data + + Args: + data: Iterable[{key, feat, label}] + shuffle_size: buffer size for shuffle + + Returns: + Iterable[{key, feat, label}] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= shuffle_size: + random.shuffle(buf) + for x in buf: + yield x + buf = [] + # The sample left over + random.shuffle(buf) + for x in buf: + yield x + + +def sort(data, sort_size=500, mode='train'): + """ Sort the data by feature length. + Sort is used after shuffle and before batch, so we can group + utts with similar lengths into a batch, and `sort_size` should + be less than `shuffle_size` + + Args: + data: Iterable[{key, feat, label}] + sort_size: buffer size for sort + + Returns: + Iterable[{key, feat, label}] + """ + + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= sort_size: + buf.sort(key=lambda x: x['speech_feat'].size(0)) + for x in buf: + yield x + buf = [] + # The sample left over + buf.sort(key=lambda x: x['speech_feat'].size(0)) + for x in buf: + yield x + + +def static_batch(data, batch_size=16): + """ Static batch the data by `batch_size` + + Args: + data: Iterable[{key, feat, label}] + batch_size: batch size + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= batch_size: + yield buf + buf = [] + if len(buf) > 0: + yield buf + + +def dynamic_batch(data, max_frames_in_batch=12000, mode='train'): + """ Dynamic batch the data until the total frames in batch + reach `max_frames_in_batch` + + Args: + data: Iterable[{key, feat, label}] + max_frames_in_batch: max_frames in one batch + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + longest_frames = 0 + for sample in data: + assert 'speech_feat' in sample + assert isinstance(sample['speech_feat'], torch.Tensor) + new_sample_frames = sample['speech_feat'].size(0) + longest_frames = max(longest_frames, new_sample_frames) + frames_after_padding = longest_frames * (len(buf) + 1) + if frames_after_padding > max_frames_in_batch: + yield buf + buf = [sample] + longest_frames = new_sample_frames + else: + buf.append(sample) + if len(buf) > 0: + yield buf + + +def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'): + """ Wrapper for static/dynamic batch + """ + if mode == 'inference': + return static_batch(data, 1) + else: + if batch_type == 'static': + return static_batch(data, batch_size) + elif batch_type == 'dynamic': + return dynamic_batch(data, max_frames_in_batch) + else: + logging.fatal('Unsupported batch type {}'.format(batch_type)) + + +def padding(data, use_spk_embedding, mode='train'): + """ Padding the data into training data + + Args: + data: Iterable[List[{key, feat, label}]] + + Returns: + Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] + """ + for sample in data: + assert isinstance(sample, list) + speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample], + dtype=torch.int32) + order = torch.argsort(speech_feat_len, descending=True) + + utts = [sample[i]['utt'] for i in order] + speech_token = [torch.tensor(sample[i]['speech_token']) for i in order] + speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32) + speech_token = pad_sequence(speech_token, + batch_first=True, + padding_value=0) + speech_feat = [sample[i]['speech_feat'] for i in order] + speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32) + speech_feat = pad_sequence(speech_feat, + batch_first=True, + padding_value=0) + text = [sample[i]['text'] for i in order] + text_token = [torch.tensor(sample[i]['text_token']) for i in order] + text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32) + text_token = pad_sequence(text_token, batch_first=True, padding_value=0) + utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0) + spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0) + batch = { + "utts": utts, + "speech_token": speech_token, + "speech_token_len": speech_token_len, + "speech_feat": speech_feat, + "speech_feat_len": speech_feat_len, + "text": text, + "text_token": text_token, + "text_token_len": text_token_len, + "utt_embedding": utt_embedding, + "spk_embedding": spk_embedding, + } + if mode == 'inference': + tts_text = [sample[i]['tts_text'] for i in order] + tts_index = [sample[i]['tts_index'] for i in order] + tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order] + tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32) + tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1) + batch.update({'tts_text': tts_text, + 'tts_index': tts_index, + 'tts_text_token': tts_text_token, + 'tts_text_token_len': tts_text_token_len}) + if use_spk_embedding is True: + batch["embedding"] = batch["spk_embedding"] + else: + batch["embedding"] = batch["utt_embedding"] + yield batch + + + +def padding_speech_token(data, use_spk_embedding, mode='train'): + """ Padding the data into training data + + Args: + data: Iterable[List[{key, feat, label}]] + + Returns: + Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] + """ + for sample in data: + assert isinstance(sample, list) + speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample], + dtype=torch.int32) + order = torch.argsort(speech_feat_len, descending=True) + + # utts = [sample[i]['utt'] for i in order] + # speech_token = [torch.tensor(sample[i]['speech_token']) for i in order] + try: + speech_token = [sample[i]['speech_token'].clone().detach() for i in order] + speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32) + speech_token = pad_sequence(speech_token, + batch_first=True, + padding_value=0) + speech_feat = [sample[i]['speech_feat'] for i in order] + speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32) + speech_feat = pad_sequence(speech_feat, + batch_first=True, + padding_value=0) + batch = { + "speech_token": speech_token, + "speech_token_len": speech_token_len, + "speech_feat": speech_feat, + "speech_feat_len": speech_feat_len, + } + if mode == 'inference': + tts_text = [sample[i]['tts_text'] for i in order] + tts_index = [sample[i]['tts_index'] for i in order] + tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order] + tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32) + tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1) + batch.update({'tts_text': tts_text, + 'tts_index': tts_index, + 'tts_text_token': tts_text_token, + 'tts_text_token_len': tts_text_token_len}) + # if use_spk_embedding is True: + # batch["embedding"] = batch["spk_embedding"] + # else: + # batch["embedding"] = batch["utt_embedding"] + batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device) + yield batch + except Exception as ex: + logging.warning(' ex info {}'.format(ex)) + # assert False + + + +def padding_speech_token_spk(data, use_spk_embedding, mode='train'): + """ Padding the data into training data + + Args: + data: Iterable[List[{key, feat, label}]] + + Returns: + Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] + """ + for sample in data: + assert isinstance(sample, list) + speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample], + dtype=torch.int32) + order = torch.argsort(speech_feat_len, descending=True) + + # utts = [sample[i]['utt'] for i in order] + # speech_token = [torch.tensor(sample[i]['speech_token']) for i in order] + try: + speech_token = [sample[i]['speech_token'].clone().detach() for i in order] + speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32) + speech_token = pad_sequence(speech_token, + batch_first=True, + padding_value=0) + speech_feat = [sample[i]['speech_feat'] for i in order] + speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32) + speech_feat = pad_sequence(speech_feat, + batch_first=True, + padding_value=0) + spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0) + batch = { + "speech_token": speech_token, + "speech_token_len": speech_token_len, + "speech_feat": speech_feat, + "speech_feat_len": speech_feat_len, + "spk_embedding": spk_embedding, + } + if mode == 'inference': + tts_text = [sample[i]['tts_text'] for i in order] + tts_index = [sample[i]['tts_index'] for i in order] + tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order] + tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32) + tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1) + batch.update({'tts_text': tts_text, + 'tts_index': tts_index, + 'tts_text_token': tts_text_token, + 'tts_text_token_len': tts_text_token_len}) + # if use_spk_embedding is True: + # batch["embedding"] = batch["spk_embedding"] + # else: + # batch["embedding"] = batch["utt_embedding"] + # batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device) + batch["embedding"] = batch["spk_embedding"] + yield batch + except Exception as ex: + logging.warning(' ex info {}'.format(ex)) + # assert False \ No newline at end of file diff --git a/cosyvoice/flow/__pycache__/decoder.cpython-310.pyc b/cosyvoice/flow/__pycache__/decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c284f92847ddd57aee980cfc0da03b1cc68f542a Binary files /dev/null and b/cosyvoice/flow/__pycache__/decoder.cpython-310.pyc differ diff --git a/cosyvoice/flow/__pycache__/flow.cpython-310.pyc b/cosyvoice/flow/__pycache__/flow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0d2bce3b25e6c5bc3a069368a8f025c42288bb4 Binary files /dev/null and b/cosyvoice/flow/__pycache__/flow.cpython-310.pyc differ diff --git a/cosyvoice/flow/__pycache__/flow_matching.cpython-310.pyc b/cosyvoice/flow/__pycache__/flow_matching.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cee8d4e4d05a3e5bd6874b39b7f13ceb95d6d19a Binary files /dev/null and b/cosyvoice/flow/__pycache__/flow_matching.cpython-310.pyc differ diff --git a/cosyvoice/flow/__pycache__/length_regulator.cpython-310.pyc b/cosyvoice/flow/__pycache__/length_regulator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f1b3d28b835822febd184631100d9f12dbd880f Binary files /dev/null and b/cosyvoice/flow/__pycache__/length_regulator.cpython-310.pyc differ diff --git a/cosyvoice/flow/decoder.py b/cosyvoice/flow/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..43492799390b44a2843bc53604603842754799f9 --- /dev/null +++ b/cosyvoice/flow/decoder.py @@ -0,0 +1,222 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +from einops import pack, rearrange, repeat +from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D +from matcha.models.components.transformer import BasicTransformerBlock + + +class ConditionalDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + ): + """ + This decoder requires an input with the same shape of the target. So, if your text content + is shorter or longer than the outputs, please re-sampling it before feeding to the decoder. + """ + super().__init__() + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + self.down_blocks = nn.ModuleList([]) + self.mid_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + output_channel = in_channels + for i in range(len(channels)): # pylint: disable=consider-using-enumerate + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + + for i in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] * 2 + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + resnet = ResnetBlock1D( + dim=input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) + self.final_block = Block1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) + self.initialize_weights() + + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, mask, mu, t, spks=None, cond=None): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + if cond is not None: + x = pack([x, cond], "b * t")[0] + + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + hiddens.append(x) # Save hidden states for skip connections + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + skip = hiddens.pop() + x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] + x = resnet(x, mask_up, t) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + x = upsample(x * mask_up) + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + return output * mask diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py new file mode 100644 index 0000000000000000000000000000000000000000..415b2a98872c29f82a9a49b89fba7996c10c042d --- /dev/null +++ b/cosyvoice/flow/flow.py @@ -0,0 +1,144 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random +from typing import Dict, Optional +import torch +import torch.nn as nn +from torch.nn import functional as F +from omegaconf import DictConfig +from cosyvoice.utils.mask import make_pad_mask + + +class MaskedDiffWithXvec(torch.nn.Module): + def __init__(self, + input_size: int = 512, + output_size: int = 80, + spk_embed_dim: int = 192, + output_type: str = "mel", + vocab_size: int = 4096, + input_frame_rate: int = 50, + only_mask_loss: bool = True, + encoder: torch.nn.Module = None, + length_regulator: torch.nn.Module = None, + decoder: torch.nn.Module = None, + decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}, + mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.decoder_conf = decoder_conf + self.mel_feat_conf = mel_feat_conf + self.vocab_size = vocab_size + self.output_type = output_type + self.input_frame_rate = input_frame_rate + logging.info(f"input frame rate={self.input_frame_rate}") + self.input_embedding = nn.Embedding(vocab_size, input_size) + self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size) + self.encoder = encoder + self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size) + self.decoder = decoder + self.length_regulator = length_regulator + self.only_mask_loss = only_mask_loss + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + token = batch['speech_token'].to(device) + token_len = batch['speech_token_len'].to(device) + feat = batch['speech_feat'].to(device) + feat_len = batch['speech_feat_len'].to(device) + embedding = batch['embedding'].to(device) + + # xvec projection + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + # embedding=None + + # concat text and prompt_text + mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) + # print(token.max(),self.input_embedding) + token = self.input_embedding(torch.clamp(token, min=0)) * mask + + + # text encode + h, h_lengths = self.encoder(token, token_len) + h = self.encoder_proj(h) + h, h_lengths = self.length_regulator(h, feat_len) + + # get conditions + conds = torch.zeros(feat.shape, device=token.device) + for i, j in enumerate(feat_len): + if random.random() < 0.5: + continue + index = random.randint(0, int(0.8 * j)) + conds[i, :index] = feat[i, :index] + conds = conds.transpose(1, 2) + + mask = (~make_pad_mask(feat_len)).to(h) + feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1) + loss, _ = self.decoder.compute_loss( + feat.transpose(1, 2).contiguous(), + mask.unsqueeze(1), + h.transpose(1, 2).contiguous(), + embedding, + cond=conds + ) + return {'loss': loss} + + @torch.inference_mode() + def inference(self, + token, + token_len, + prompt_token, + prompt_token_len, + prompt_feat, + prompt_feat_len, + embedding): + assert token.shape[0] == 1 + # xvec projection + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + + # concat text and prompt_text + token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len + mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding) + token = self.input_embedding(torch.clamp(token, min=0)) * mask + + # text encode + h, h_lengths = self.encoder(token, token_len) + h = self.encoder_proj(h) + feat_len = (token_len / self.input_frame_rate * 22050 / 256).int() + h, h_lengths = self.length_regulator(h, feat_len) + + # get conditions + conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device) + if prompt_feat.shape[1] != 0: + for i, j in enumerate(prompt_feat_len): + conds[i, :j] = prompt_feat[i] + conds = conds.transpose(1, 2) + + mask = (~make_pad_mask(feat_len)).to(h) + feat = self.decoder( + mu=h.transpose(1, 2).contiguous(), + mask=mask.unsqueeze(1), + spks=embedding, + cond=conds, + n_timesteps=10 + ) + if prompt_feat.shape[1] != 0: + feat = feat[:, :, prompt_feat.shape[1]:] + return feat diff --git a/cosyvoice/flow/flow_gradtts.py b/cosyvoice/flow/flow_gradtts.py new file mode 100644 index 0000000000000000000000000000000000000000..8e558c0ff65a0c6befd1c5aa49c20464307e82b1 --- /dev/null +++ b/cosyvoice/flow/flow_gradtts.py @@ -0,0 +1,142 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random +from typing import Dict, Optional +import torch +import torch.nn as nn +from torch.nn import functional as F +from omegaconf import DictConfig +from cosyvoice.utils.mask import make_pad_mask + + +class MaskedDiffWithXvec(torch.nn.Module): + def __init__(self, + input_size: int = 512, + output_size: int = 80, + spk_embed_dim: int = 192, + output_type: str = "mel", + vocab_size: int = 4096, + input_frame_rate: int = 50, + only_mask_loss: bool = True, + encoder: torch.nn.Module = None, + length_regulator: torch.nn.Module = None, + decoder: torch.nn.Module = None, + decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}, + mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.decoder_conf = decoder_conf + self.mel_feat_conf = mel_feat_conf + self.vocab_size = vocab_size + self.output_type = output_type + self.input_frame_rate = input_frame_rate + logging.info(f"input frame rate={self.input_frame_rate}") + self.input_embedding = nn.Embedding(vocab_size, input_size) + self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size) + self.encoder = encoder + self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size) + self.decoder = decoder + self.length_regulator = length_regulator + self.only_mask_loss = only_mask_loss + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + token = batch['speech_token'].to(device) + token_len = batch['speech_token_len'].to(device) + feat = batch['speech_feat'].to(device) + feat_len = batch['speech_feat_len'].to(device) + embedding = batch['embedding'].to(device) + + # xvec projection + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + # embedding=None + + # concat text and prompt_text + mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) + token = self.input_embedding(torch.clamp(token, min=0)) * mask + + # text encode + h, h_lengths = self.encoder(token, token_len) + h = self.encoder_proj(h) + h, h_lengths = self.length_regulator(h, feat_len) + + # get conditions + conds = torch.zeros(feat.shape, device=token.device) + # for i, j in enumerate(feat_len): + # if random.random() < 0.5: + # continue + # index = random.randint(0, int(0.3 * j)) + # conds[i, :index] = feat[i, :index] + conds = conds.transpose(1, 2) + + mask = (~make_pad_mask(feat_len)).to(h) + feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1) + loss, _ = self.decoder.compute_loss( + feat.transpose(1, 2).contiguous(), + mask.unsqueeze(1), + h.transpose(1, 2).contiguous(), + embedding, + cond=conds + ) + return {'loss': loss} + + @torch.inference_mode() + def inference(self, + token, + token_len, + prompt_token, + prompt_token_len, + prompt_feat, + prompt_feat_len, + embedding): + assert token.shape[0] == 1 + # xvec projection + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + + # concat text and prompt_text + token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len + mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding) + token = self.input_embedding(torch.clamp(token, min=0)) * mask + + # text encode + h, h_lengths = self.encoder(token, token_len) + h = self.encoder_proj(h) + feat_len = (token_len / self.input_frame_rate * 22050 / 256).int() + h, h_lengths = self.length_regulator(h, feat_len) + + # get conditions + conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device) + if prompt_feat.shape[1] != 0: + for i, j in enumerate(prompt_feat_len): + conds[i, :j] = prompt_feat[i] + conds = conds.transpose(1, 2) + + mask = (~make_pad_mask(feat_len)).to(h) + feat = self.decoder( + mu=h.transpose(1, 2).contiguous(), + mask=mask.unsqueeze(1), + spks=embedding, + cond=conds, + n_timesteps=10 + ) + if prompt_feat.shape[1] != 0: + feat = feat[:, :, prompt_feat.shape[1]:] + return feat diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..ec487d7d8effbf9c7284624b839184e43df40b9c --- /dev/null +++ b/cosyvoice/flow/flow_matching.py @@ -0,0 +1,142 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn.functional as F +from matcha.models.components.flow_matching import BASECFM + +class ConditionalCFM(BASECFM): + def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None): + super().__init__( + n_feats=in_channels, + cfm_params=cfm_params, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + ) + self.t_scheduler = cfm_params.t_scheduler + self.training_cfg_rate = cfm_params.training_cfg_rate + self.inference_cfg_rate = cfm_params.inference_cfg_rate + in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0) + # Just change the architecture of the estimator here + self.estimator = estimator + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + torch.manual_seed(42) + + z = torch.randn_like(mu) * temperature + + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) + if self.t_scheduler == 'cosine': + t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) + + def solve_euler(self, x, t_span, mu, mask, spks, cond): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + + for step in range(1, len(t_span)): + dphi_dt = self.estimator(x, mask, mu, t, spks, cond) + # Classifier-Free Guidance inference introduced in VoiceBox + if self.inference_cfg_rate > 0: + cfg_dphi_dt = self.estimator( + x, mask, + torch.zeros_like(mu), t, + torch.zeros_like(spks) if spks is not None else None, + torch.zeros_like(cond) + ) + dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - + self.inference_cfg_rate * cfg_dphi_dt) + x = x + dt * dphi_dt + t = t + dt + + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def compute_loss(self, x1, mask, mu, spks=None, cond=None): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = mu.shape + + # random timestep + t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) + if self.t_scheduler == 'cosine': + t = 1 - torch.cos(t * 0.5 * torch.pi) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + + # during training, we randomly drop condition to trade off mode coverage and sample fidelity + if self.training_cfg_rate > 0: + cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate + mu = mu * cfg_mask.view(-1, 1, 1) + spks = spks * cfg_mask.view(-1, 1) + cond = cond * cfg_mask.view(-1, 1, 1) + + pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) + loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) + return loss, y diff --git a/cosyvoice/flow/flow_matching_dit.py b/cosyvoice/flow/flow_matching_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..abadcc218cc3b30bd5c8da829079974b60f09b47 --- /dev/null +++ b/cosyvoice/flow/flow_matching_dit.py @@ -0,0 +1,180 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pdb + +import torch +import torch.nn.functional as F +from matcha.models.components.flow_matching import BASECFM + + +class ConditionalCFM(BASECFM): + def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None): + super().__init__( + n_feats=in_channels, + cfm_params=cfm_params, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + ) + self.t_scheduler = cfm_params.t_scheduler + self.training_cfg_rate = cfm_params.training_cfg_rate + self.inference_cfg_rate = cfm_params.inference_cfg_rate + in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0) + # Just change the architecture of the estimator here + + io_channels = 80 + input_concat_dim = 80 + embed_dim = 768 + depth = 24 + num_heads = 24 + project_cond_tokens = False + transformer_type = "continuous_transformer" + self.estimator = estimator + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + z = torch.randn_like(mu) * temperature + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) + if self.t_scheduler == 'cosine': + t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) + + def solve_euler(self, x, t_span, mu, mask, spks, cond): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise torch.Size([1, 80, 621]) + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + + cfg_dropout_prob = 0.1 + cfg_scale = 1.0 + + # cfg_dropout_prob = 0.0 + # cfg_scale = 3.0 + + for step in range(1, len(t_span)): + # dphi_dt = self.estimator(x, mask, mu, t, spks, cond) + # pdb.set_trace() + dphi_dt = self.estimator(x, # [bs, 80, 229] + t[None], # (bs,) + global_embed=spks, + input_concat_cond=mu, + mask=mask[0], # [bs, 229] + cfg_dropout_prob=cfg_dropout_prob, cfg_scale=cfg_scale) + + # Classifier-Free Guidance inference introduced in VoiceBox + if self.inference_cfg_rate > 0: + # cfg_dphi_dt = self.estimator( + # x, mask, + # torch.zeros_like(mu), t, + # torch.zeros_like(spks) if spks is not None else None, + # torch.zeros_like(cond) + # ) + cfg_dphi_dt = self.estimator(x, # [bs, 80, 229] + t[None], # (bs,) + global_embed=torch.zeros_like(spks) if spks is not None else None, + input_concat_cond=torch.zeros_like(mu), + mask=mask[0], # [bs, 229] + cfg_dropout_prob=cfg_dropout_prob, cfg_scale=cfg_scale) + + dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - + self.inference_cfg_rate * cfg_dphi_dt) + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def compute_loss(self, x1, mask, mu, spks=None, cond=None): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = mu.shape + + # random timestep + t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) + if self.t_scheduler == 'cosine': + t = 1 - torch.cos(t * 0.5 * torch.pi) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + + # during training, we randomly drop condition to trade off mode coverage and sample fidelity + if self.training_cfg_rate > 0: + cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate + mu = mu * cfg_mask.view(-1, 1, 1) + spks = spks * cfg_mask.view(-1, 1) + cond = cond * cfg_mask.view(-1, 1, 1) + + # pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) + pred = self.estimator(y, # [bs, 80, 229] + t.squeeze(1, 2), # (bs,) + global_embed=spks, + input_concat_cond=mu, + mask=mask.squeeze(1), # [bs, 229] + cfg_dropout_prob=0.1) + + loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) + return loss, y + + # def estimator_trans(self): + # pass diff --git a/cosyvoice/flow/length_regulator.py b/cosyvoice/flow/length_regulator.py new file mode 100644 index 0000000000000000000000000000000000000000..622f29aaccc44d8e8cce23ecab7b086ebb853fde --- /dev/null +++ b/cosyvoice/flow/length_regulator.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple +import torch.nn as nn +from torch.nn import functional as F +from cosyvoice.utils.mask import make_pad_mask + + +class InterpolateRegulator(nn.Module): + def __init__( + self, + channels: int, + sampling_ratios: Tuple, + out_channels: int = None, + groups: int = 1, + ): + super().__init__() + self.sampling_ratios = sampling_ratios + out_channels = out_channels or channels + model = nn.ModuleList([]) + if len(sampling_ratios) > 0: + for _ in sampling_ratios: + module = nn.Conv1d(channels, channels, 3, 1, 1) + norm = nn.GroupNorm(groups, channels) + act = nn.Mish() + model.extend([module, norm, act]) + model.append( + nn.Conv1d(channels, out_channels, 1, 1) + ) + self.model = nn.Sequential(*model) + + def forward(self, x, ylens=None): + # x in (B, T, D) + mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1) + x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest') + out = self.model(x).transpose(1, 2).contiguous() + olens = ylens + return out * mask, olens diff --git a/cosyvoice/flow/stable/adp.py b/cosyvoice/flow/stable/adp.py new file mode 100644 index 0000000000000000000000000000000000000000..a7ff72026df1c0bed73563d025d314dd2ccd4d19 --- /dev/null +++ b/cosyvoice/flow/stable/adp.py @@ -0,0 +1,1591 @@ +# Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License +# License can be found in LICENSES/LICENSE_ADP.txt + +import math +from inspect import isfunction +from math import ceil, floor, log, pi, log2 +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union +from packaging import version + +import torch +import torch.nn as nn +from einops import rearrange, reduce, repeat +from einops.layers.torch import Rearrange +from einops_exts import rearrange_many +from torch import Tensor, einsum +from torch.backends.cuda import sdp_kernel +from torch.nn import functional as F +from dac.nn.layers import Snake1d +import pdb +""" +Utils +""" + + +class ConditionedSequential(nn.Module): + def __init__(self, *modules): + super().__init__() + self.module_list = nn.ModuleList(*modules) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None): + for module in self.module_list: + x = module(x, mapping) + return x + +T = TypeVar("T") + +def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: + if exists(val): + return val + return d() if isfunction(d) else d + +def exists(val: Optional[T]) -> T: + return val is not None + +def closest_power_2(x: float) -> int: + exponent = log2(x) + distance_fn = lambda z: abs(x - 2 ** z) # noqa + exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) + return 2 ** int(exponent_closest) + +def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: + return_dicts: Tuple[Dict, Dict] = ({}, {}) + for key in d.keys(): + no_prefix = int(not key.startswith(prefix)) + return_dicts[no_prefix][key] = d[key] + return return_dicts + +def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: + kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) + if keep_prefix: + return kwargs_with_prefix, kwargs + kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} + return kwargs_no_prefix, kwargs + +""" +Convolutional Blocks +""" +import typing as tp + +# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License +# License available in LICENSES/LICENSE_META.txt + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, + padding_total: int = 0) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left: end] + + +class Conv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor, causal=False) -> Tensor: + kernel_size = self.kernel_size[0] + stride = self.stride[0] + dilation = self.dilation[0] + kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + padding_total = kernel_size - stride + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + if causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d(x, (padding_left, padding_right + extra_padding)) + return super().forward(x) + +class ConvTranspose1d(nn.ConvTranspose1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor, causal=False) -> Tensor: + kernel_size = self.kernel_size[0] + stride = self.stride[0] + padding_total = kernel_size - stride + + y = super().forward(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if causal: + padding_right = ceil(padding_total) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y + + +def Downsample1d( + in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 +) -> nn.Module: + assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" + + return Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * kernel_multiplier + 1, + stride=factor + ) + + +def Upsample1d( + in_channels: int, out_channels: int, factor: int, use_nearest: bool = False +) -> nn.Module: + + if factor == 1: + return Conv1d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3 + ) + + if use_nearest: + return nn.Sequential( + nn.Upsample(scale_factor=factor, mode="nearest"), + Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3 + ), + ) + else: + return ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * 2, + stride=factor + ) + + +class ConvBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + num_groups: int = 8, + use_norm: bool = True, + use_snake: bool = False + ) -> None: + super().__init__() + + self.groupnorm = ( + nn.GroupNorm(num_groups=num_groups, num_channels=in_channels) + if use_norm + else nn.Identity() + ) + + if use_snake: + self.activation = Snake1d(in_channels) + else: + self.activation = nn.SiLU() + + self.project = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + ) + + def forward( + self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False + ) -> Tensor: + x = self.groupnorm(x) + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + x = self.activation(x) + return self.project(x, causal=causal) + + +class MappingToScaleShift(nn.Module): + def __init__( + self, + features: int, + channels: int, + ): + super().__init__() + + self.to_scale_shift = nn.Sequential( + nn.SiLU(), + nn.Linear(in_features=features, out_features=channels * 2), + ) + + def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]: + scale_shift = self.to_scale_shift(mapping) + scale_shift = rearrange(scale_shift, "b c -> b c 1") + scale, shift = scale_shift.chunk(2, dim=1) + return scale, shift + + +class ResnetBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + use_norm: bool = True, + use_snake: bool = False, + num_groups: int = 8, + context_mapping_features: Optional[int] = None, + ) -> None: + super().__init__() + + self.use_mapping = exists(context_mapping_features) + + self.block1 = ConvBlock1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + use_norm=use_norm, + num_groups=num_groups, + use_snake=use_snake + ) + + if self.use_mapping: + assert exists(context_mapping_features) + self.to_scale_shift = MappingToScaleShift( + features=context_mapping_features, channels=out_channels + ) + + self.block2 = ConvBlock1d( + in_channels=out_channels, + out_channels=out_channels, + use_norm=use_norm, + num_groups=num_groups, + use_snake=use_snake + ) + + self.to_out = ( + Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: + assert_message = "context mapping required if context_mapping_features > 0" + assert not (self.use_mapping ^ exists(mapping)), assert_message + + h = self.block1(x, causal=causal) + + scale_shift = None + if self.use_mapping: + scale_shift = self.to_scale_shift(mapping) + + h = self.block2(h, scale_shift=scale_shift, causal=causal) + + return h + self.to_out(x) + + +class Patcher(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + patch_size: int, + context_mapping_features: Optional[int] = None, + use_snake: bool = False, + ): + super().__init__() + assert_message = f"out_channels must be divisible by patch_size ({patch_size})" + assert out_channels % patch_size == 0, assert_message + self.patch_size = patch_size + + self.block = ResnetBlock1d( + in_channels=in_channels, + out_channels=out_channels // patch_size, + num_groups=1, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: + x = self.block(x, mapping, causal=causal) + x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size) + return x + + +class Unpatcher(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + patch_size: int, + context_mapping_features: Optional[int] = None, + use_snake: bool = False + ): + super().__init__() + assert_message = f"in_channels must be divisible by patch_size ({patch_size})" + assert in_channels % patch_size == 0, assert_message + self.patch_size = patch_size + + self.block = ResnetBlock1d( + in_channels=in_channels // patch_size, + out_channels=out_channels, + num_groups=1, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: + x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size) + x = self.block(x, mapping, causal=causal) + return x + + +""" +Attention Components +""" +def FeedForward(features: int, multiplier: int) -> nn.Module: + mid_features = features * multiplier + return nn.Sequential( + nn.Linear(in_features=features, out_features=mid_features), + nn.GELU(), + nn.Linear(in_features=mid_features, out_features=features), + ) + +def add_mask(sim: Tensor, mask: Tensor) -> Tensor: + b, ndim = sim.shape[0], mask.ndim + if ndim == 3: + mask = rearrange(mask, "b n m -> b 1 n m") + if ndim == 2: + mask = repeat(mask, "n m -> b 1 n m", b=b) + max_neg_value = -torch.finfo(sim.dtype).max + sim = sim.masked_fill(~mask, max_neg_value) + return sim + +def causal_mask(q: Tensor, k: Tensor) -> Tensor: + b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device + mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1) + mask = repeat(mask, "n m -> b n m", b=b) + return mask + +class AttentionBase(nn.Module): + def __init__( + self, + features: int, + *, + head_features: int, + num_heads: int, + out_features: Optional[int] = None, + ): + super().__init__() + self.scale = head_features**-0.5 + self.num_heads = num_heads + mid_features = head_features * num_heads + out_features = default(out_features, features) + + self.to_out = nn.Linear( + in_features=mid_features, out_features=out_features + ) + + self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + + if not self.use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + # Use flash attention for A100 GPUs + self.sdp_kernel_config = (True, False, False) + else: + # Don't use flash attention for other GPUs + self.sdp_kernel_config = (False, True, True) + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False + ) -> Tensor: + # Split heads + q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads) + + if not self.use_flash: + if is_causal and not mask: + # Mask out future tokens for causal attention + mask = causal_mask(q, k) + + # Compute similarity matrix and add eventual mask + sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale + sim = add_mask(sim, mask) if exists(mask) else sim + + # Get attention matrix with softmax + attn = sim.softmax(dim=-1, dtype=torch.float32) + + # Compute values + out = einsum("... n m, ... m d -> ... n d", attn, v) + else: + with sdp_kernel(*self.sdp_kernel_config): + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal) + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + +class Attention(nn.Module): + def __init__( + self, + features: int, + *, + head_features: int, + num_heads: int, + out_features: Optional[int] = None, + context_features: Optional[int] = None, + causal: bool = False, + ): + super().__init__() + self.context_features = context_features + self.causal = causal + mid_features = head_features * num_heads + context_features = default(context_features, features) + + self.norm = nn.LayerNorm(features) + self.norm_context = nn.LayerNorm(context_features) + self.to_q = nn.Linear( + in_features=features, out_features=mid_features, bias=False + ) + self.to_kv = nn.Linear( + in_features=context_features, out_features=mid_features * 2, bias=False + ) + self.attention = AttentionBase( + features, + num_heads=num_heads, + head_features=head_features, + out_features=out_features, + ) + + def forward( + self, + x: Tensor, # [b, n, c] + context: Optional[Tensor] = None, # [b, m, d] + context_mask: Optional[Tensor] = None, # [b, m], false is masked, + causal: Optional[bool] = False, + ) -> Tensor: + assert_message = "You must provide a context when using context_features" + assert not self.context_features or exists(context), assert_message + # Use context if provided + context = default(context, x) + # Normalize then compute q from input and k,v from context + x, context = self.norm(x), self.norm_context(context) + + q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) + + if exists(context_mask): + # Mask out cross-attention for padding tokens + mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1]) + k, v = k * mask, v * mask + + # Compute and return attention + return self.attention(q, k, v, is_causal=self.causal or causal) + + +def FeedForward(features: int, multiplier: int) -> nn.Module: + mid_features = features * multiplier + return nn.Sequential( + nn.Linear(in_features=features, out_features=mid_features), + nn.GELU(), + nn.Linear(in_features=mid_features, out_features=features), + ) + +""" +Transformer Blocks +""" + + +class TransformerBlock(nn.Module): + def __init__( + self, + features: int, + num_heads: int, + head_features: int, + multiplier: int, + context_features: Optional[int] = None, + ): + super().__init__() + + self.use_cross_attention = exists(context_features) and context_features > 0 + + self.attention = Attention( + features=features, + num_heads=num_heads, + head_features=head_features + ) + + if self.use_cross_attention: + self.cross_attention = Attention( + features=features, + num_heads=num_heads, + head_features=head_features, + context_features=context_features + ) + + self.feed_forward = FeedForward(features=features, multiplier=multiplier) + + def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor: + x = self.attention(x, causal=causal) + x + if self.use_cross_attention: + x = self.cross_attention(x, context=context, context_mask=context_mask) + x + x = self.feed_forward(x) + x + return x + + +""" +Transformers +""" + + +class Transformer1d(nn.Module): + def __init__( + self, + num_layers: int, + channels: int, + num_heads: int, + head_features: int, + multiplier: int, + context_features: Optional[int] = None, + ): + super().__init__() + + self.to_in = nn.Sequential( + nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True), + Conv1d( + in_channels=channels, + out_channels=channels, + kernel_size=1, + ), + Rearrange("b c t -> b t c"), + ) + + self.blocks = nn.ModuleList( + [ + TransformerBlock( + features=channels, + head_features=head_features, + num_heads=num_heads, + multiplier=multiplier, + context_features=context_features, + ) + for i in range(num_layers) + ] + ) + + self.to_out = nn.Sequential( + Rearrange("b t c -> b c t"), + Conv1d( + in_channels=channels, + out_channels=channels, + kernel_size=1, + ), + ) + + def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor: + x = self.to_in(x) + for block in self.blocks: + x = block(x, context=context, context_mask=context_mask, causal=causal) + x = self.to_out(x) + return x + + +""" +Time Embeddings +""" + + +class SinusoidalEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + device, half_dim = x.device, self.dim // 2 + emb = torch.tensor(log(10000) / (half_dim - 1), device=device) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j") + return torch.cat((emb.sin(), emb.cos()), dim=-1) + + +class LearnedPositionalEmbedding(nn.Module): + """Used for continuous time""" + + def __init__(self, dim: int): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x: Tensor) -> Tensor: + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return fouriered + + +def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: + return nn.Sequential( + LearnedPositionalEmbedding(dim), + nn.Linear(in_features=dim + 1, out_features=out_features), + ) + + +""" +Encoder/Decoder Components +""" + + +class DownsampleBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + factor: int, + num_groups: int, + num_layers: int, + kernel_multiplier: int = 2, + use_pre_downsample: bool = True, + use_skip: bool = False, + use_snake: bool = False, + extract_channels: int = 0, + context_channels: int = 0, + num_transformer_blocks: int = 0, + attention_heads: Optional[int] = None, + attention_features: Optional[int] = None, + attention_multiplier: Optional[int] = None, + context_mapping_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + ): + super().__init__() + self.use_pre_downsample = use_pre_downsample + self.use_skip = use_skip + self.use_transformer = num_transformer_blocks > 0 + self.use_extract = extract_channels > 0 + self.use_context = context_channels > 0 + + channels = out_channels if use_pre_downsample else in_channels + + self.downsample = Downsample1d( + in_channels=in_channels, + out_channels=out_channels, + factor=factor, + kernel_multiplier=kernel_multiplier, + ) + + self.blocks = nn.ModuleList( + [ + ResnetBlock1d( + in_channels=channels + context_channels if i == 0 else channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + for i in range(num_layers) + ] + ) + + if self.use_transformer: + assert ( + (exists(attention_heads) or exists(attention_features)) + and exists(attention_multiplier) + ) + + if attention_features is None and attention_heads is not None: + attention_features = channels // attention_heads + + if attention_heads is None and attention_features is not None: + attention_heads = channels // attention_features + + self.transformer = Transformer1d( + num_layers=num_transformer_blocks, + channels=channels, + num_heads=attention_heads, + head_features=attention_features, + multiplier=attention_multiplier, + context_features=context_embedding_features + ) + + if self.use_extract: + num_extract_groups = min(num_groups, extract_channels) + self.to_extracted = ResnetBlock1d( + in_channels=out_channels, + out_channels=extract_channels, + num_groups=num_extract_groups, + use_snake=use_snake + ) + + def forward( + self, + x: Tensor, + *, + mapping: Optional[Tensor] = None, + channels: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False + ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]: + + if self.use_pre_downsample: + x = self.downsample(x) + + if self.use_context and exists(channels): + x = torch.cat([x, channels], dim=1) + + skips = [] + for block in self.blocks: + x = block(x, mapping=mapping, causal=causal) + skips += [x] if self.use_skip else [] + + if self.use_transformer: + x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) + skips += [x] if self.use_skip else [] + + if not self.use_pre_downsample: + x = self.downsample(x) + + if self.use_extract: + extracted = self.to_extracted(x) + return x, extracted + + return (x, skips) if self.use_skip else x + + +class UpsampleBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + factor: int, + num_layers: int, + num_groups: int, + use_nearest: bool = False, + use_pre_upsample: bool = False, + use_skip: bool = False, + use_snake: bool = False, + skip_channels: int = 0, + use_skip_scale: bool = False, + extract_channels: int = 0, + num_transformer_blocks: int = 0, + attention_heads: Optional[int] = None, + attention_features: Optional[int] = None, + attention_multiplier: Optional[int] = None, + context_mapping_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + ): + super().__init__() + + self.use_extract = extract_channels > 0 + self.use_pre_upsample = use_pre_upsample + self.use_transformer = num_transformer_blocks > 0 + self.use_skip = use_skip + self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0 + + channels = out_channels if use_pre_upsample else in_channels + + self.blocks = nn.ModuleList( + [ + ResnetBlock1d( + in_channels=channels + skip_channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + for _ in range(num_layers) + ] + ) + + if self.use_transformer: + assert ( + (exists(attention_heads) or exists(attention_features)) + and exists(attention_multiplier) + ) + + if attention_features is None and attention_heads is not None: + attention_features = channels // attention_heads + + if attention_heads is None and attention_features is not None: + attention_heads = channels // attention_features + + self.transformer = Transformer1d( + num_layers=num_transformer_blocks, + channels=channels, + num_heads=attention_heads, + head_features=attention_features, + multiplier=attention_multiplier, + context_features=context_embedding_features, + ) + + self.upsample = Upsample1d( + in_channels=in_channels, + out_channels=out_channels, + factor=factor, + use_nearest=use_nearest, + ) + + if self.use_extract: + num_extract_groups = min(num_groups, extract_channels) + self.to_extracted = ResnetBlock1d( + in_channels=out_channels, + out_channels=extract_channels, + num_groups=num_extract_groups, + use_snake=use_snake + ) + + def add_skip(self, x: Tensor, skip: Tensor) -> Tensor: + return torch.cat([x, skip * self.skip_scale], dim=1) + + def forward( + self, + x: Tensor, + *, + skips: Optional[List[Tensor]] = None, + mapping: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False + ) -> Union[Tuple[Tensor, Tensor], Tensor]: + + if self.use_pre_upsample: + x = self.upsample(x) + + for block in self.blocks: + x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x + x = block(x, mapping=mapping, causal=causal) + + if self.use_transformer: + x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) + + if not self.use_pre_upsample: + x = self.upsample(x) + + if self.use_extract: + extracted = self.to_extracted(x) + return x, extracted + + return x + + +class BottleneckBlock1d(nn.Module): + def __init__( + self, + channels: int, + *, + num_groups: int, + num_transformer_blocks: int = 0, + attention_heads: Optional[int] = None, + attention_features: Optional[int] = None, + attention_multiplier: Optional[int] = None, + context_mapping_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + use_snake: bool = False, + ): + super().__init__() + self.use_transformer = num_transformer_blocks > 0 + + self.pre_block = ResnetBlock1d( + in_channels=channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + if self.use_transformer: + assert ( + (exists(attention_heads) or exists(attention_features)) + and exists(attention_multiplier) + ) + + if attention_features is None and attention_heads is not None: + attention_features = channels // attention_heads + + if attention_heads is None and attention_features is not None: + attention_heads = channels // attention_features + + self.transformer = Transformer1d( + num_layers=num_transformer_blocks, + channels=channels, + num_heads=attention_heads, + head_features=attention_features, + multiplier=attention_multiplier, + context_features=context_embedding_features, + ) + + self.post_block = ResnetBlock1d( + in_channels=channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def forward( + self, + x: Tensor, + *, + mapping: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False + ) -> Tensor: + x = self.pre_block(x, mapping=mapping, causal=causal) + if self.use_transformer: + x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) + x = self.post_block(x, mapping=mapping, causal=causal) + return x + + +""" +UNet +""" + + +class UNet1d(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + multipliers: Sequence[int], + factors: Sequence[int], + num_blocks: Sequence[int], + attentions: Sequence[int], + patch_size: int = 1, + resnet_groups: int = 8, + use_context_time: bool = True, + kernel_multiplier_downsample: int = 2, + use_nearest_upsample: bool = False, + use_skip_scale: bool = True, + use_snake: bool = False, + use_stft: bool = False, + use_stft_context: bool = False, + out_channels: Optional[int] = None, + context_features: Optional[int] = None, + context_features_multiplier: int = 4, + context_channels: Optional[Sequence[int]] = None, + context_embedding_features: Optional[int] = None, + **kwargs, + ): + super().__init__() + out_channels = default(out_channels, in_channels) + context_channels = list(default(context_channels, [])) + num_layers = len(multipliers) - 1 + use_context_features = exists(context_features) + use_context_channels = len(context_channels) > 0 + context_mapping_features = None + + attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True) + + self.num_layers = num_layers + self.use_context_time = use_context_time + self.use_context_features = use_context_features + self.use_context_channels = use_context_channels + self.use_stft = use_stft + self.use_stft_context = use_stft_context + + self.context_features = context_features + context_channels_pad_length = num_layers + 1 - len(context_channels) + context_channels = context_channels + [0] * context_channels_pad_length + self.context_channels = context_channels + self.context_embedding_features = context_embedding_features + + if use_context_channels: + has_context = [c > 0 for c in context_channels] + self.has_context = has_context + self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))] + + assert ( + len(factors) == num_layers + and len(attentions) >= num_layers + and len(num_blocks) == num_layers + ) + + if use_context_time or use_context_features: + context_mapping_features = channels * context_features_multiplier + + self.to_mapping = nn.Sequential( + nn.Linear(context_mapping_features, context_mapping_features), + nn.GELU(), + nn.Linear(context_mapping_features, context_mapping_features), + nn.GELU(), + ) + + if use_context_time: + assert exists(context_mapping_features) + self.to_time = nn.Sequential( + TimePositionalEmbedding( + dim=channels, out_features=context_mapping_features + ), + nn.GELU(), + ) + + if use_context_features: + assert exists(context_features) and exists(context_mapping_features) + self.to_features = nn.Sequential( + nn.Linear( + in_features=context_features, out_features=context_mapping_features + ), + nn.GELU(), + ) + + if use_stft: + stft_kwargs, kwargs = groupby("stft_", kwargs) + assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True" + stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2 + in_channels *= stft_channels + out_channels *= stft_channels + context_channels[0] *= stft_channels if use_stft_context else 1 + assert exists(in_channels) and exists(out_channels) + self.stft = STFT(**stft_kwargs) + + assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}" + + self.to_in = Patcher( + in_channels=in_channels + context_channels[0], + out_channels=channels * multipliers[0], + patch_size=patch_size, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + self.downsamples = nn.ModuleList( + [ + DownsampleBlock1d( + in_channels=channels * multipliers[i], + out_channels=channels * multipliers[i + 1], + context_mapping_features=context_mapping_features, + context_channels=context_channels[i + 1], + context_embedding_features=context_embedding_features, + num_layers=num_blocks[i], + factor=factors[i], + kernel_multiplier=kernel_multiplier_downsample, + num_groups=resnet_groups, + use_pre_downsample=True, + use_skip=True, + use_snake=use_snake, + num_transformer_blocks=attentions[i], + **attention_kwargs, + ) + for i in range(num_layers) + ] + ) + + self.bottleneck = BottleneckBlock1d( + channels=channels * multipliers[-1], + context_mapping_features=context_mapping_features, + context_embedding_features=context_embedding_features, + num_groups=resnet_groups, + num_transformer_blocks=attentions[-1], + use_snake=use_snake, + **attention_kwargs, + ) + + self.upsamples = nn.ModuleList( + [ + UpsampleBlock1d( + in_channels=channels * multipliers[i + 1], + out_channels=channels * multipliers[i], + context_mapping_features=context_mapping_features, + context_embedding_features=context_embedding_features, + num_layers=num_blocks[i] + (1 if attentions[i] else 0), + factor=factors[i], + use_nearest=use_nearest_upsample, + num_groups=resnet_groups, + use_skip_scale=use_skip_scale, + use_pre_upsample=False, + use_skip=True, + use_snake=use_snake, + skip_channels=channels * multipliers[i + 1], + num_transformer_blocks=attentions[i], + **attention_kwargs, + ) + for i in reversed(range(num_layers)) + ] + ) + + self.to_out = Unpatcher( + in_channels=channels * multipliers[0], + out_channels=out_channels, + patch_size=patch_size, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def get_channels( + self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0 + ) -> Optional[Tensor]: + """Gets context channels at `layer` and checks that shape is correct""" + use_context_channels = self.use_context_channels and self.has_context[layer] + if not use_context_channels: + return None + assert exists(channels_list), "Missing context" + # Get channels index (skipping zero channel contexts) + channels_id = self.channels_ids[layer] + # Get channels + channels = channels_list[channels_id] + message = f"Missing context for layer {layer} at index {channels_id}" + assert exists(channels), message + # Check channels + num_channels = self.context_channels[layer] + message = f"Expected context with {num_channels} channels at idx {channels_id}" + assert channels.shape[1] == num_channels, message + # STFT channels if requested + channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa + return channels + + def get_mapping( + self, time: Optional[Tensor] = None, features: Optional[Tensor] = None + ) -> Optional[Tensor]: + """Combines context time features and features into mapping""" + items, mapping = [], None + # Compute time features + if self.use_context_time: + assert_message = "use_context_time=True but no time features provided" + assert exists(time), assert_message + items += [self.to_time(time)] + # Compute features + if self.use_context_features: + assert_message = "context_features exists but no features provided" + assert exists(features), assert_message + items += [self.to_features(features)] + # Compute joint mapping + if self.use_context_time or self.use_context_features: + mapping = reduce(torch.stack(items), "n b m -> b m", "sum") + mapping = self.to_mapping(mapping) + return mapping + + def forward( + self, + x: Tensor, + time: Optional[Tensor] = None, + *, + features: Optional[Tensor] = None, + channels_list: Optional[Sequence[Tensor]] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False, + ) -> Tensor: + channels = self.get_channels(channels_list, layer=0) + # Apply stft if required + print(x.shape) + x = self.stft.encode1d(x) if self.use_stft else x # type: ignore + print(x.shape) + # Concat context channels at layer 0 if provided + x = torch.cat([x, channels], dim=1) if exists(channels) else x + print(x.shape) + # Compute mapping from time and features + mapping = self.get_mapping(time, features) + x = self.to_in(x, mapping, causal=causal) + print(x.shape) + skips_list = [x] + + for i, downsample in enumerate(self.downsamples): + channels = self.get_channels(channels_list, layer=i + 1) + x, skips = downsample( + x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal + ) + skips_list += [skips] + + x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) + for i, upsample in enumerate(self.upsamples): + skips = skips_list.pop() + x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) + + x += skips_list.pop() + x = self.to_out(x, mapping, causal=causal) + x = self.stft.decode1d(x) if self.use_stft else x + + return x + + +""" Conditioning Modules """ + + +class FixedEmbedding(nn.Module): + def __init__(self, max_length: int, features: int): + super().__init__() + self.max_length = max_length + self.embedding = nn.Embedding(max_length, features) + + def forward(self, x: Tensor) -> Tensor: + batch_size, length, device = *x.shape[0:2], x.device + assert_message = "Input sequence length must be <= max_length" + assert length <= self.max_length, assert_message + position = torch.arange(length, device=device) + fixed_embedding = self.embedding(position) + fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size) + return fixed_embedding + + +def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: + if proba == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif proba == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) + + +class UNetCFG1d(UNet1d): + + """UNet1d with Classifier-Free Guidance""" + + def __init__( + self, + context_embedding_max_length: int, + context_embedding_features: int, + use_xattn_time: bool = False, + **kwargs, + ): + super().__init__( + context_embedding_features=context_embedding_features, **kwargs + ) + + self.use_xattn_time = use_xattn_time + + if use_xattn_time: + assert exists(context_embedding_features) + self.to_time_embedding = nn.Sequential( + TimePositionalEmbedding( + dim=kwargs["channels"], out_features=context_embedding_features + ), + nn.GELU(), + ) + + context_embedding_max_length += 1 # Add one for time embedding + + self.fixed_embedding = FixedEmbedding( + max_length=context_embedding_max_length, features=context_embedding_features + ) + + def forward( # type: ignore + self, + x: Tensor, + time: Tensor, + *, + embedding: Tensor, + embedding_mask: Optional[Tensor] = None, + embedding_scale: float = 1.0, + embedding_mask_proba: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + scale_phi: float = 0.4, + negative_embedding: Optional[Tensor] = None, + negative_embedding_mask: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + b, device = embedding.shape[0], embedding.device + + if self.use_xattn_time: + embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1) + + if embedding_mask is not None: + embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1) + + fixed_embedding = self.fixed_embedding(embedding) + + if embedding_mask_proba > 0.0: + # Randomly mask embedding + batch_mask = rand_bool( + shape=(b, 1, 1), proba=embedding_mask_proba, device=device + ) + embedding = torch.where(batch_mask, fixed_embedding, embedding) + + if embedding_scale != 1.0: + if batch_cfg: + batch_x = torch.cat([x, x], dim=0) + batch_time = torch.cat([time, time], dim=0) + + if negative_embedding is not None: + if negative_embedding_mask is not None: + negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2) + + negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding) + + batch_embed = torch.cat([embedding, negative_embedding], dim=0) + + else: + batch_embed = torch.cat([embedding, fixed_embedding], dim=0) + + batch_mask = None + if embedding_mask is not None: + batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0) + + batch_features = None + features = kwargs.pop("features", None) + if self.use_context_features: + batch_features = torch.cat([features, features], dim=0) + + batch_channels = None + channels_list = kwargs.pop("channels_list", None) + if self.use_context_channels: + batch_channels = [] + for channels in channels_list: + batch_channels += [torch.cat([channels, channels], dim=0)] + + # Compute both normal and fixed embedding outputs + batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs) + out, out_masked = batch_out.chunk(2, dim=0) + + else: + # Compute both normal and fixed embedding outputs + out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) + out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs) + + out_cfg = out_masked + (out - out_masked) * embedding_scale + + if rescale_cfg: + + out_std = out.std(dim=1, keepdim=True) + out_cfg_std = out_cfg.std(dim=1, keepdim=True) + + return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg + + else: + + return out_cfg + + else: + return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) + + +class UNetNCCA1d(UNet1d): + + """UNet1d with Noise Channel Conditioning Augmentation""" + + def __init__(self, context_features: int, **kwargs): + super().__init__(context_features=context_features, **kwargs) + self.embedder = NumberEmbedder(features=context_features) + + def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor: + x = x if torch.is_tensor(x) else torch.tensor(x) + return x.expand(shape) + + def forward( # type: ignore + self, + x: Tensor, + time: Tensor, + *, + channels_list: Sequence[Tensor], + channels_augmentation: Union[ + bool, Sequence[bool], Sequence[Sequence[bool]], Tensor + ] = False, + channels_scale: Union[ + float, Sequence[float], Sequence[Sequence[float]], Tensor + ] = 0, + **kwargs, + ) -> Tensor: + b, n = x.shape[0], len(channels_list) + channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x) + channels_scale = self.expand(channels_scale, shape=(b, n)).to(x) + + # Augmentation (for each channel list item) + for i in range(n): + scale = channels_scale[:, i] * channels_augmentation[:, i] + scale = rearrange(scale, "b -> b 1 1") + item = channels_list[i] + channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa + + # Scale embedding (sum reduction if more than one channel list item) + channels_scale_emb = self.embedder(channels_scale) + channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum") + + return super().forward( + x=x, + time=time, + channels_list=channels_list, + features=channels_scale_emb, + **kwargs, + ) + + +class UNetAll1d(UNetCFG1d, UNetNCCA1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, *args, **kwargs): # type: ignore + return UNetCFG1d.forward(self, *args, **kwargs) + + +def XUNet1d(type: str = "base", **kwargs) -> UNet1d: + if type == "base": + return UNet1d(**kwargs) + elif type == "all": + return UNetAll1d(**kwargs) + elif type == "cfg": + return UNetCFG1d(**kwargs) + elif type == "ncca": + return UNetNCCA1d(**kwargs) + else: + raise ValueError(f"Unknown XUNet1d type: {type}") + +class NumberEmbedder(nn.Module): + def __init__( + self, + features: int, + dim: int = 256, + ): + super().__init__() + self.features = features + self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) + + def forward(self, x: Union[List[float], Tensor]) -> Tensor: + if not torch.is_tensor(x): + device = next(self.embedding.parameters()).device + x = torch.tensor(x, device=device) + assert isinstance(x, Tensor) + shape = x.shape + x = rearrange(x, "... -> (...)") + embedding = self.embedding(x) + x = embedding.view(*shape, self.features) + return x # type: ignore + + +""" +Audio Transforms +""" + + +class STFT(nn.Module): + """Helper for torch stft and istft""" + + def __init__( + self, + num_fft: int = 1023, + hop_length: int = 256, + window_length: Optional[int] = None, + length: Optional[int] = None, + use_complex: bool = False, + ): + super().__init__() + self.num_fft = num_fft + self.hop_length = default(hop_length, floor(num_fft // 4)) + self.window_length = default(window_length, num_fft) + self.length = length + self.register_buffer("window", torch.hann_window(self.window_length)) + self.use_complex = use_complex + + def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]: + b = wave.shape[0] + wave = rearrange(wave, "b c t -> (b c) t") + + stft = torch.stft( + wave, + n_fft=self.num_fft, + hop_length=self.hop_length, + win_length=self.window_length, + window=self.window, # type: ignore + return_complex=True, + normalized=True, + ) + + if self.use_complex: + # Returns real and imaginary + stft_a, stft_b = stft.real, stft.imag + else: + # Returns magnitude and phase matrices + magnitude, phase = torch.abs(stft), torch.angle(stft) + stft_a, stft_b = magnitude, phase + + return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b) + + def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor: + b, l = stft_a.shape[0], stft_a.shape[-1] # noqa + length = closest_power_2(l * self.hop_length) + + stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l") + + if self.use_complex: + real, imag = stft_a, stft_b + else: + magnitude, phase = stft_a, stft_b + real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase) + + stft = torch.stack([real, imag], dim=-1) + + wave = torch.istft( + stft, + n_fft=self.num_fft, + hop_length=self.hop_length, + win_length=self.window_length, + window=self.window, # type: ignore + length=default(self.length, length), + normalized=True, + ) + + return rearrange(wave, "(b c) t -> b c t", b=b) + + def encode1d( + self, wave: Tensor, stacked: bool = True + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + stft_a, stft_b = self.encode(wave) + stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l") + return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b) + + def decode1d(self, stft_pair: Tensor) -> Tensor: + f = self.num_fft // 2 + 1 + stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1) + stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f) + return self.decode(stft_a, stft_b) diff --git a/cosyvoice/flow/stable/blocks.py b/cosyvoice/flow/stable/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..3c827fd2441e643717d123847236d3d6c003ef4f --- /dev/null +++ b/cosyvoice/flow/stable/blocks.py @@ -0,0 +1,339 @@ +from functools import reduce +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from torch.backends.cuda import sdp_kernel +from packaging import version + +from dac.nn.layers import Snake1d + +class ResidualBlock(nn.Module): + def __init__(self, main, skip=None): + super().__init__() + self.main = nn.Sequential(*main) + self.skip = skip if skip else nn.Identity() + + def forward(self, input): + return self.main(input) + self.skip(input) + +class ResConvBlock(ResidualBlock): + def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False): + skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False) + super().__init__([ + nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias), + nn.GroupNorm(1, c_mid), + Snake1d(c_mid) if use_snake else nn.GELU(), + nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias), + nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), + (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(), + ], skip) + +class SelfAttention1d(nn.Module): + def __init__(self, c_in, n_head=1, dropout_rate=0.): + super().__init__() + assert c_in % n_head == 0 + self.norm = nn.GroupNorm(1, c_in) + self.n_head = n_head + self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1) + self.out_proj = nn.Conv1d(c_in, c_in, 1) + self.dropout = nn.Dropout(dropout_rate, inplace=True) + + self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + + if not self.use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + # Use flash attention for A100 GPUs + self.sdp_kernel_config = (True, False, False) + else: + # Don't use flash attention for other GPUs + self.sdp_kernel_config = (False, True, True) + + def forward(self, input): + n, c, s = input.shape + qkv = self.qkv_proj(self.norm(input)) + qkv = qkv.view( + [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) + q, k, v = qkv.chunk(3, dim=1) + scale = k.shape[3]**-0.25 + + if self.use_flash: + with sdp_kernel(*self.sdp_kernel_config): + y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s]) + else: + att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) + y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) + + + return input + self.dropout(self.out_proj(y)) + +class SkipBlock(nn.Module): + def __init__(self, *main): + super().__init__() + self.main = nn.Sequential(*main) + + def forward(self, input): + return torch.cat([self.main(input), input], dim=1) + +class FourierFeatures(nn.Module): + def __init__(self, in_features, out_features, std=1.): + super().__init__() + assert out_features % 2 == 0 + self.weight = nn.Parameter(torch.randn( + [out_features // 2, in_features]) * std) + + def forward(self, input): + f = 2 * math.pi * input @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) + +def expand_to_planes(input, shape): + return input[..., None].repeat([1, 1, shape[2]]) + +_kernels = { + 'linear': + [1 / 8, 3 / 8, 3 / 8, 1 / 8], + 'cubic': + [-0.01171875, -0.03515625, 0.11328125, 0.43359375, + 0.43359375, 0.11328125, -0.03515625, -0.01171875], + 'lanczos3': + [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, + -0.066637322306633, 0.13550527393817902, 0.44638532400131226, + 0.44638532400131226, 0.13550527393817902, -0.066637322306633, + -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] +} + +class Downsample1d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer('kernel', kernel_1d) + self.channels_last = channels_last + + def forward(self, x): + if self.channels_last: + x = x.permute(0, 2, 1) + x = F.pad(x, (self.pad,) * 2, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + x = F.conv1d(x, weight, stride=2) + if self.channels_last: + x = x.permute(0, 2, 1) + return x + + +class Upsample1d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) * 2 + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer('kernel', kernel_1d) + self.channels_last = channels_last + + def forward(self, x): + if self.channels_last: + x = x.permute(0, 2, 1) + x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1) + if self.channels_last: + x = x.permute(0, 2, 1) + return x + +def Downsample1d_2( + in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 +) -> nn.Module: + assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" + + return nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * kernel_multiplier + 1, + stride=factor, + padding=factor * (kernel_multiplier // 2), + ) + + +def Upsample1d_2( + in_channels: int, out_channels: int, factor: int, use_nearest: bool = False +) -> nn.Module: + + if factor == 1: + return nn.Conv1d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 + ) + + if use_nearest: + return nn.Sequential( + nn.Upsample(scale_factor=factor, mode="nearest"), + nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + ), + ) + else: + return nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * 2, + stride=factor, + padding=factor // 2 + factor % 2, + output_padding=factor % 2, + ) + +def zero_init(layer): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + return layer + +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + +#rms_norm = torch.compile(rms_norm) + +class AdaRMSNorm(nn.Module): + def __init__(self, features, cond_features, eps=1e-6): + super().__init__() + self.eps = eps + self.linear = zero_init(nn.Linear(cond_features, features, bias=False)) + + def extra_repr(self): + return f"eps={self.eps}," + + def forward(self, x, cond): + return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps) + +def normalize(x, eps=1e-4): + dim = list(range(1, x.ndim)) + n = torch.linalg.vector_norm(x, dim=dim, keepdim=True) + alpha = np.sqrt(n.numel() / x.numel()) + return x / torch.add(eps, n, alpha=alpha) + +class ForcedWNConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1): + super().__init__() + self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size])) + + def forward(self, x): + if self.training: + with torch.no_grad(): + self.weight.copy_(normalize(self.weight)) + + fan_in = self.weight[0].numel() + + w = normalize(self.weight) / math.sqrt(fan_in) + + return F.conv1d(x, w, padding='same') + +# Kernels + +use_compile = True + +def compile(function, *args, **kwargs): + if not use_compile: + return function + try: + return torch.compile(function, *args, **kwargs) + except RuntimeError: + return function + + +@compile +def linear_geglu(x, weight, bias=None): + x = x @ weight.mT + if bias is not None: + x = x + bias + x, gate = x.chunk(2, dim=-1) + return x * F.gelu(gate) + + +@compile +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + +# Layers + +class LinearGEGLU(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super().__init__(in_features, out_features * 2, bias=bias) + self.out_features = out_features + + def forward(self, x): + return linear_geglu(x, self.weight, self.bias) + + +class RMSNorm(nn.Module): + def __init__(self, shape, fix_scale = False, eps=1e-6): + super().__init__() + self.eps = eps + + if fix_scale: + self.register_buffer("scale", torch.ones(shape)) + else: + self.scale = nn.Parameter(torch.ones(shape)) + + def extra_repr(self): + return f"shape={tuple(self.scale.shape)}, eps={self.eps}" + + def forward(self, x): + return rms_norm(x, self.scale, self.eps) + +def snake_beta(x, alpha, beta): + return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) + +# try: +# snake_beta = torch.compile(snake_beta) +# except RuntimeError: +# pass + +# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license +# License available in LICENSES/LICENSE_NVIDIA.txt +class SnakeBeta(nn.Module): + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(in_features) * alpha) + self.beta = nn.Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = snake_beta(x, alpha, beta) + + return x \ No newline at end of file diff --git a/cosyvoice/flow/stable/dit.py b/cosyvoice/flow/stable/dit.py new file mode 100644 index 0000000000000000000000000000000000000000..dcd5efa0f9430ca550b9845b2e8c13ae32534c2c --- /dev/null +++ b/cosyvoice/flow/stable/dit.py @@ -0,0 +1,415 @@ +import typing as tp + +import torch + +from einops import rearrange +from torch import nn +from torch.nn import functional as F +from x_transformers import ContinuousTransformerWrapper, Encoder + +from .blocks import FourierFeatures +from .transformer import ContinuousTransformer +from .transformer_use_mask import ContinuousTransformer as ContinuousTransformer_mask + + +class DiffusionTransformer(nn.Module): + def __init__(self, + io_channels=32, + patch_size=1, + embed_dim=768, + cond_token_dim=0, + project_cond_tokens=True, + global_cond_dim=0, + project_global_cond=True, + input_concat_dim=0, + prepend_cond_dim=0, + depth=12, + num_heads=8, + transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers", + global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", + **kwargs): + + super().__init__() + + self.cond_token_dim = cond_token_dim + + # Timestep embeddings + timestep_features_dim = 256 + + self.timestep_features = FourierFeatures(1, timestep_features_dim) + + self.to_timestep_embed = nn.Sequential( + nn.Linear(timestep_features_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True), + ) + + if cond_token_dim > 0: + # Conditioning tokens + + cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim + self.to_cond_embed = nn.Sequential( + nn.Linear(cond_token_dim, cond_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(cond_embed_dim, cond_embed_dim, bias=False) + ) + else: + cond_embed_dim = 0 + self.to_cond_embed = nn.Identity() + + if global_cond_dim > 0: + # Global conditioning + global_embed_dim = global_cond_dim if not project_global_cond else embed_dim + self.to_global_embed = nn.Sequential( + nn.Linear(global_cond_dim, global_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(global_embed_dim, global_embed_dim, bias=False) + ) + + if prepend_cond_dim > 0: + # Prepend conditioning + self.to_prepend_embed = nn.Sequential( + nn.Linear(prepend_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + + self.input_concat_dim = input_concat_dim + + dim_in = io_channels + self.input_concat_dim + + self.patch_size = patch_size + + # Transformer + + self.transformer_type = transformer_type + + self.global_cond_type = global_cond_type + + if self.transformer_type == "x-transformers": + self.transformer = ContinuousTransformerWrapper( + dim_in=dim_in * patch_size, + dim_out=io_channels * patch_size, + max_seq_len=0, # Not relevant without absolute positional embeds + attn_layers=Encoder( + dim=embed_dim, + depth=depth, + heads=num_heads, + attn_flash=True, + cross_attend=cond_token_dim > 0, + dim_context=None if cond_embed_dim == 0 else cond_embed_dim, + zero_init_branch_output=True, + use_abs_pos_emb=False, + rotary_pos_emb=True, + ff_swish=True, + ff_glu=True, + **kwargs + ) + ) + + elif self.transformer_type == "continuous_transformer": + + global_dim = None + + if self.global_cond_type == "adaLN": + # The global conditioning is projected to the embed_dim already at this point + global_dim = embed_dim + + self.transformer = ContinuousTransformer( + dim=embed_dim, + depth=depth, + dim_heads=embed_dim // num_heads, + dim_in=dim_in * patch_size, + dim_out=io_channels * patch_size, + cross_attend=cond_token_dim > 0, + cond_token_dim=cond_embed_dim, + global_cond_dim=global_dim, + **kwargs + ) + elif self.transformer_type == "continuous_transformer_with_mask": + + global_dim = None + + if self.global_cond_type == "adaLN": + # The global conditioning is projected to the embed_dim already at this point + global_dim = embed_dim + + self.transformer = ContinuousTransformer_mask( + dim=embed_dim, + depth=depth, + dim_heads=embed_dim // num_heads, + dim_in=dim_in * patch_size, + dim_out=io_channels * patch_size, + cross_attend=cond_token_dim > 0, + cond_token_dim=cond_embed_dim, + global_cond_dim=global_dim, + **kwargs + ) + + else: + raise ValueError(f"Unknown transformer type: {self.transformer_type}") + + self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False) + nn.init.zeros_(self.preprocess_conv.weight) + self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False) + nn.init.zeros_(self.postprocess_conv.weight) + + def _forward( + self, + x, + t, + mask=None, + cross_attn_cond=None, + cross_attn_cond_mask=None, + input_concat_cond=None, + global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + return_info=False, + **kwargs): + ### 1. 需要重新写过以适应不同长度的con + if cross_attn_cond is not None: + cross_attn_cond = self.to_cond_embed(cross_attn_cond) + + if global_embed is not None: + # Project the global conditioning to the embedding dimension + global_embed = self.to_global_embed(global_embed) + + prepend_inputs = None + prepend_mask = None + prepend_length = 0 + if prepend_cond is not None: + # Project the prepend conditioning to the embedding dimension + prepend_cond = self.to_prepend_embed(prepend_cond) + + prepend_inputs = prepend_cond + if prepend_cond_mask is not None: + prepend_mask = prepend_cond_mask + + if input_concat_cond is not None: + + # Interpolate input_concat_cond to the same length as x + if input_concat_cond.shape[2] != x.shape[2]: + input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2],), mode='nearest') + + x = torch.cat([x, input_concat_cond], dim=1) + + # Get the batch of timestep embeddings + try: + timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim) + except Exception as e: + print("t.shape:", t.shape, "x.shape", x.shape) + print("t:", t) + raise e + + # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists + if global_embed is not None: + global_embed = global_embed + timestep_embed + else: + global_embed = timestep_embed + + # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer + if self.global_cond_type == "prepend": + if prepend_inputs is None: + # Prepend inputs are just the global embed, and the mask is all ones + prepend_inputs = global_embed.unsqueeze(1) + prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool) + else: + # Prepend inputs are the prepend conditioning + the global embed + prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1) + prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], + dim=1) + + prepend_length = prepend_inputs.shape[1] + + x = self.preprocess_conv(x) + x + + x = rearrange(x, "b c t -> b t c") + + extra_args = {} + + if self.global_cond_type == "adaLN": + extra_args["global_cond"] = global_embed + + if self.patch_size > 1: + x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size) + + if self.transformer_type == "x-transformers": + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, + context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, + **extra_args, **kwargs) + elif self.transformer_type in ["continuous_transformer","continuous_transformer_with_mask"] : + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, + context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, + return_info=return_info, **extra_args, **kwargs) + + if return_info: + output, info = output + elif self.transformer_type == "mm_transformer": + output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, + **extra_args, **kwargs) + + output = rearrange(output, "b t c -> b c t")[:, :, prepend_length:] + + if self.patch_size > 1: + output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size) + + output = self.postprocess_conv(output) + output + + if return_info: + return output, info + + return output + + def forward( + self, + x, + t, + cross_attn_cond=None, + cross_attn_cond_mask=None, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + input_concat_cond=None, + global_embed=None, + negative_global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + cfg_scale=1.0, + cfg_dropout_prob=0.0, + causal=False, + scale_phi=0.0, + mask=None, + return_info=False, + **kwargs): + + assert causal == False, "Causal mode is not supported for DiffusionTransformer" + + if cross_attn_cond_mask is not None: + cross_attn_cond_mask = cross_attn_cond_mask.bool() + + cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention + + if prepend_cond_mask is not None: + prepend_cond_mask = prepend_cond_mask.bool() + + # CFG dropout + if cfg_dropout_prob > 0.0: + if cross_attn_cond is not None: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + dropout_mask = torch.bernoulli( + torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to( + torch.bool) + cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) + + if prepend_cond is not None: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + dropout_mask = torch.bernoulli( + torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to( + torch.bool) + prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) + + if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None): + # Classifier-free guidance + # Concatenate conditioned and unconditioned inputs on the batch dimension + batch_inputs = torch.cat([x, x], dim=0) + batch_timestep = torch.cat([t, t], dim=0) + + if global_embed is not None: + batch_global_cond = torch.cat([global_embed, global_embed], dim=0) + else: + batch_global_cond = None + + if input_concat_cond is not None: + batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0) + else: + batch_input_concat_cond = None + + batch_cond = None + batch_cond_masks = None + + # Handle CFG for cross-attention conditioning + if cross_attn_cond is not None: + + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + + # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning + if negative_cross_attn_cond is not None: + + # If there's a negative cross-attention mask, set the masked tokens to the null embed + if negative_cross_attn_mask is not None: + negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2) + + negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, + null_embed) + + batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0) + + else: + batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0) + + if cross_attn_cond_mask is not None: + batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0) + + batch_prepend_cond = None + batch_prepend_cond_mask = None + + if prepend_cond is not None: + + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + + batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) + + if prepend_cond_mask is not None: + batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) + + if mask is not None: + batch_masks = torch.cat([mask, mask], dim=0) + else: + batch_masks = None + + batch_output = self._forward( + batch_inputs, + batch_timestep, + cross_attn_cond=batch_cond, + cross_attn_cond_mask=batch_cond_masks, + mask=batch_masks, + input_concat_cond=batch_input_concat_cond, + global_embed=batch_global_cond, + prepend_cond=batch_prepend_cond, + prepend_cond_mask=batch_prepend_cond_mask, + return_info=return_info, + **kwargs) + + if return_info: + batch_output, info = batch_output + + cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0) + cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale + + # CFG Rescale + if scale_phi != 0.0: + cond_out_std = cond_output.std(dim=1, keepdim=True) + out_cfg_std = cfg_output.std(dim=1, keepdim=True) + output = scale_phi * (cfg_output * (cond_out_std / out_cfg_std)) + (1 - scale_phi) * cfg_output + else: + output = cfg_output + + if return_info: + return output, info + + return output + + else: + return self._forward( + x, + t, + cross_attn_cond=cross_attn_cond, + cross_attn_cond_mask=cross_attn_cond_mask, + input_concat_cond=input_concat_cond, + global_embed=global_embed, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + mask=mask, + return_info=return_info, + **kwargs + ) diff --git a/cosyvoice/flow/stable/dit_v2.py b/cosyvoice/flow/stable/dit_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..b4baad9e5a91561a1c994f2000d75618fa97ba60 --- /dev/null +++ b/cosyvoice/flow/stable/dit_v2.py @@ -0,0 +1,307 @@ +import typing as tp + +import torch + +from einops import rearrange +from torch import nn +from torch.nn import functional as F +from x_transformers import ContinuousTransformerWrapper, Encoder + +from .blocks import FourierFeatures +from .transformer import ContinuousTransformer +from model.stable import transformer_use_mask + + +class DiffusionTransformerV2(nn.Module): + def __init__(self, + io_channels=32, + patch_size=1, + embed_dim=768, + cond_token_dim=0, + project_cond_tokens=True, + global_cond_dim=0, + project_global_cond=True, + input_concat_dim=0, + prepend_cond_dim=0, + depth=12, + num_heads=8, + transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers", + global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", + **kwargs): + + super().__init__() + d_model = embed_dim + n_head = num_heads + n_layers = depth + encoder_layer = torch.nn.TransformerEncoderLayer(batch_first=True, + norm_first=True, + d_model=d_model, + nhead=n_head) + self.transformer = torch.nn.TransformerEncoder(encoder_layer, num_layers=n_layers) + + # ===================================== timestep embedding + timestep_features_dim = 256 + self.timestep_features = FourierFeatures(1, timestep_features_dim) + self.to_timestep_embed = nn.Sequential( + nn.Linear(timestep_features_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True), + ) + + + def _forward( + self, + Xt_btd, + t, #(1d) + mu_btd, + ): + + timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim) + cated_input = torch.cat([t,mu,x_t]) + + ### 1. 需要重新写过以适应不同长度的con + if cross_attn_cond is not None: + cross_attn_cond = self.to_cond_embed(cross_attn_cond) + + if global_embed is not None: + # Project the global conditioning to the embedding dimension + global_embed = self.to_global_embed(global_embed) + + prepend_inputs = None + prepend_mask = None + prepend_length = 0 + if prepend_cond is not None: + # Project the prepend conditioning to the embedding dimension + prepend_cond = self.to_prepend_embed(prepend_cond) + + prepend_inputs = prepend_cond + if prepend_cond_mask is not None: + prepend_mask = prepend_cond_mask + + if input_concat_cond is not None: + + # Interpolate input_concat_cond to the same length as x + if input_concat_cond.shape[2] != x.shape[2]: + input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2],), mode='nearest') + + x = torch.cat([x, input_concat_cond], dim=1) + + # Get the batch of timestep embeddings + try: + timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim) + except Exception as e: + print("t.shape:", t.shape, "x.shape", x.shape) + print("t:", t) + raise e + + # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists + if global_embed is not None: + global_embed = global_embed + timestep_embed + else: + global_embed = timestep_embed + + # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer + if self.global_cond_type == "prepend": + if prepend_inputs is None: + # Prepend inputs are just the global embed, and the mask is all ones + prepend_inputs = global_embed.unsqueeze(1) + prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool) + else: + # Prepend inputs are the prepend conditioning + the global embed + prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1) + prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], + dim=1) + + prepend_length = prepend_inputs.shape[1] + + x = self.preprocess_conv(x) + x + + x = rearrange(x, "b c t -> b t c") + + extra_args = {} + + if self.global_cond_type == "adaLN": + extra_args["global_cond"] = global_embed + + if self.patch_size > 1: + x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size) + + if self.transformer_type == "x-transformers": + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, + context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, + **extra_args, **kwargs) + elif self.transformer_type in ["continuous_transformer", "continuous_transformer_with_mask"]: + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, + context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, + return_info=return_info, **extra_args, **kwargs) + + if return_info: + output, info = output + elif self.transformer_type == "mm_transformer": + output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, + **extra_args, **kwargs) + + output = rearrange(output, "b t c -> b c t")[:, :, prepend_length:] + + if self.patch_size > 1: + output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size) + + output = self.postprocess_conv(output) + output + + if return_info: + return output, info + + return output + + def forward( + self, + x, + t, + cross_attn_cond=None, + cross_attn_cond_mask=None, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + input_concat_cond=None, + global_embed=None, + negative_global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + cfg_scale=1.0, + cfg_dropout_prob=0.0, + causal=False, + scale_phi=0.0, + mask=None, + return_info=False, + **kwargs): + + assert causal == False, "Causal mode is not supported for DiffusionTransformer" + + if cross_attn_cond_mask is not None: + cross_attn_cond_mask = cross_attn_cond_mask.bool() + + cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention + + if prepend_cond_mask is not None: + prepend_cond_mask = prepend_cond_mask.bool() + + # CFG dropout + if cfg_dropout_prob > 0.0: + if cross_attn_cond is not None: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + dropout_mask = torch.bernoulli( + torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to( + torch.bool) + cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) + + if prepend_cond is not None: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + dropout_mask = torch.bernoulli( + torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to( + torch.bool) + prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) + + if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None): + # Classifier-free guidance + # Concatenate conditioned and unconditioned inputs on the batch dimension + batch_inputs = torch.cat([x, x], dim=0) + batch_timestep = torch.cat([t, t], dim=0) + + if global_embed is not None: + batch_global_cond = torch.cat([global_embed, global_embed], dim=0) + else: + batch_global_cond = None + + if input_concat_cond is not None: + batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0) + else: + batch_input_concat_cond = None + + batch_cond = None + batch_cond_masks = None + + # Handle CFG for cross-attention conditioning + if cross_attn_cond is not None: + + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + + # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning + if negative_cross_attn_cond is not None: + + # If there's a negative cross-attention mask, set the masked tokens to the null embed + if negative_cross_attn_mask is not None: + negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2) + + negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, + null_embed) + + batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0) + + else: + batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0) + + if cross_attn_cond_mask is not None: + batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0) + + batch_prepend_cond = None + batch_prepend_cond_mask = None + + if prepend_cond is not None: + + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + + batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) + + if prepend_cond_mask is not None: + batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) + + if mask is not None: + batch_masks = torch.cat([mask, mask], dim=0) + else: + batch_masks = None + + batch_output = self._forward( + batch_inputs, + batch_timestep, + cross_attn_cond=batch_cond, + cross_attn_cond_mask=batch_cond_masks, + mask=batch_masks, + input_concat_cond=batch_input_concat_cond, + global_embed=batch_global_cond, + prepend_cond=batch_prepend_cond, + prepend_cond_mask=batch_prepend_cond_mask, + return_info=return_info, + **kwargs) + + if return_info: + batch_output, info = batch_output + + cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0) + cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale + + # CFG Rescale + if scale_phi != 0.0: + cond_out_std = cond_output.std(dim=1, keepdim=True) + out_cfg_std = cfg_output.std(dim=1, keepdim=True) + output = scale_phi * (cfg_output * (cond_out_std / out_cfg_std)) + (1 - scale_phi) * cfg_output + else: + output = cfg_output + + if return_info: + return output, info + + return output + + else: + return self._forward( + x, + t, + cross_attn_cond=cross_attn_cond, + cross_attn_cond_mask=cross_attn_cond_mask, + input_concat_cond=input_concat_cond, + global_embed=global_embed, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + mask=mask, + return_info=return_info, + **kwargs + ) diff --git a/cosyvoice/flow/stable/sampling.py b/cosyvoice/flow/stable/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..2229e5089e3407a367df2d382ae039ca6364c489 --- /dev/null +++ b/cosyvoice/flow/stable/sampling.py @@ -0,0 +1,232 @@ +import torch +import math +from tqdm import trange, tqdm + +import k_diffusion as K + +# Define the noise schedule and sampling loop +def get_alphas_sigmas(t): + """Returns the scaling factors for the clean image (alpha) and for the + noise (sigma), given a timestep.""" + return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) + +def alpha_sigma_to_t(alpha, sigma): + """Returns a timestep, given the scaling factors for the clean image and for + the noise.""" + return torch.atan2(sigma, alpha) / math.pi * 2 + +def t_to_alpha_sigma(t): + """Returns the scaling factors for the clean image and for the noise, given + a timestep.""" + return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) + + +@torch.no_grad() +def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args): + """Draws samples from a model given starting noise. Euler method""" + + # Make tensor of ones to broadcast the single t values + ts = x.new_ones([x.shape[0]]) + + # Create the noise schedule + t = torch.linspace(sigma_max, 0, steps + 1) + + #alphas, sigmas = 1-t, t + + for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])): + # Broadcast the current timestep to the correct shape + t_curr_tensor = t_curr * torch.ones( + (x.shape[0],), dtype=x.dtype, device=x.device + ) + dt = t_prev - t_curr # we solve backwards in our formulation + x = x + dt * model(x, t_curr_tensor, **extra_args) #.denoise(x, denoiser, t_curr_tensor, cond, uc) + + # If we are on the last timestep, output the denoised image + return x + +@torch.no_grad() +def sample(model, x, steps, eta, **extra_args): + """Draws samples from a model given starting noise. v-diffusion""" + ts = x.new_ones([x.shape[0]]) + + # Create the noise schedule + t = torch.linspace(1, 0, steps + 1)[:-1] + + alphas, sigmas = get_alphas_sigmas(t) + + # The sampling loop + for i in trange(steps): + + # Get the model output (v, the predicted velocity) + with torch.cuda.amp.autocast(): + v = model(x, ts * t[i], **extra_args).float() + + # Predict the noise and the denoised image + pred = x * alphas[i] - v * sigmas[i] + eps = x * sigmas[i] + v * alphas[i] + + # If we are not on the last timestep, compute the noisy image for the + # next timestep. + if i < steps - 1: + # If eta > 0, adjust the scaling factor for the predicted noise + # downward according to the amount of additional noise to add + ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ + (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() + adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() + + # Recombine the predicted noise and predicted denoised image in the + # correct proportions for the next step + x = pred * alphas[i + 1] + eps * adjusted_sigma + + # Add the correct amount of fresh noise + if eta: + x += torch.randn_like(x) * ddim_sigma + + # If we are on the last timestep, output the denoised image + return pred + +# Soft mask inpainting is just shrinking hard (binary) mask inpainting +# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step +def get_bmask(i, steps, mask): + strength = (i+1)/(steps) + # convert to binary mask + bmask = torch.where(mask<=strength,1,0) + return bmask + +def make_cond_model_fn(model, cond_fn): + def cond_model_fn(x, sigma, **kwargs): + with torch.enable_grad(): + x = x.detach().requires_grad_() + denoised = model(x, sigma, **kwargs) + cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() + cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim) + return cond_denoised + return cond_model_fn + +# Uses k-diffusion from https://github.com/crowsonkb/k-diffusion +# init_data is init_audio as latents (if this is latent diffusion) +# For sampling, set both init_data and mask to None +# For variations, set init_data +# For inpainting, set both init_data & mask +def sample_k( + model_fn, + noise, + init_data=None, + mask=None, + steps=100, + sampler_type="dpmpp-2m-sde", + sigma_min=0.5, + sigma_max=50, + rho=1.0, device="cuda", + callback=None, + cond_fn=None, + **extra_args + ): + + denoiser = K.external.VDenoiser(model_fn) + + if cond_fn is not None: + denoiser = make_cond_model_fn(denoiser, cond_fn) + + # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has + sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device) + # Scale the initial noise by sigma + noise = noise * sigmas[0] + + wrapped_callback = callback + + if mask is None and init_data is not None: + # VARIATION (no inpainting) + # set the initial latent to the init_data, and noise it with initial sigma + x = init_data + noise + elif mask is not None and init_data is not None: + # INPAINTING + bmask = get_bmask(0, steps, mask) + # initial noising + input_noised = init_data + noise + # set the initial latent to a mix of init_data and noise, based on step 0's binary mask + x = input_noised * bmask + noise * (1-bmask) + # define the inpainting callback function (Note: side effects, it mutates x) + # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105 + # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + # This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)` + def inpainting_callback(args): + i = args["i"] + x = args["x"] + sigma = args["sigma"] + #denoised = args["denoised"] + # noise the init_data input with this step's appropriate amount of noise + input_noised = init_data + torch.randn_like(init_data) * sigma + # shrinking hard mask + bmask = get_bmask(i, steps, mask) + # mix input_noise with x, using binary mask + new_x = input_noised * bmask + x * (1-bmask) + # mutate x + x[:,:,:] = new_x[:,:,:] + # wrap together the inpainting callback and the user-submitted callback. + if callback is None: + wrapped_callback = inpainting_callback + else: + wrapped_callback = lambda args: (inpainting_callback(args), callback(args)) + else: + # SAMPLING + # set the initial latent to noise + x = noise + + + with torch.cuda.amp.autocast(): + if sampler_type == "k-heun": + return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-lms": + return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpmpp-2s-ancestral": + return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpm-2": + return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpm-fast": + return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpm-adaptive": + return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "dpmpp-2m-sde": + return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "dpmpp-3m-sde": + return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + +# Uses discrete Euler sampling for rectified flow models +# init_data is init_audio as latents (if this is latent diffusion) +# For sampling, set both init_data and mask to None +# For variations, set init_data +# For inpainting, set both init_data & mask +def sample_rf( + model_fn, + noise, + init_data=None, + steps=100, + sigma_max=1, + device="cuda", + callback=None, + cond_fn=None, + **extra_args + ): + + if sigma_max > 1: + sigma_max = 1 + + if cond_fn is not None: + denoiser = make_cond_model_fn(denoiser, cond_fn) + + wrapped_callback = callback + + if init_data is not None: + # VARIATION (no inpainting) + # Interpolate the init data and the noise for init audio + x = init_data * (1 - sigma_max) + noise * sigma_max + else: + # SAMPLING + # set the initial latent to noise + x = noise + + with torch.cuda.amp.autocast(): + # TODO: Add callback support + #return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args) + return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args) \ No newline at end of file diff --git a/cosyvoice/flow/stable/stable_diffusion.py b/cosyvoice/flow/stable/stable_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..732ef506c5e95ada6566176d9ffc3ea99a4cb7ab --- /dev/null +++ b/cosyvoice/flow/stable/stable_diffusion.py @@ -0,0 +1,109 @@ +import torch +from torch.nn import functional as F +from .dit import DiffusionTransformer +from .adp import UNet1d +from .sampling import sample +import math +from model.base import BaseModule +import pdb + +target_length = 1536 + + +def pad_and_create_mask(matrix, target_length): + T = matrix.shape[2] + if T > target_length: + raise ValueError("The third dimension length %s should not exceed %s" % (T, target_length)) + + padding_size = target_length - T + + padded_matrix = F.pad(matrix, (0, padding_size), "constant", 0) + + mask = torch.ones((1, target_length)) + mask[:, T:] = 0 # Set the padding part to 0 + + return padded_matrix.to(matrix.device), mask.to(matrix.device) + + +class Stable_Diffusion(BaseModule): + def __init__(self, io_channels, input_concat_dim=None, embed_dim=768, depth=24, num_heads=24, + project_cond_tokens=False, transformer_type="continuous_transformer"): + super(Stable_Diffusion, self).__init__() + self.diffusion = DiffusionTransformer( + io_channels=io_channels, + input_concat_dim=input_concat_dim, + embed_dim=embed_dim, + # cond_token_dim=target_length, + depth=depth, + num_heads=num_heads, + project_cond_tokens=project_cond_tokens, + transformer_type=transformer_type, + ) + # self.diffusion = UNet1d( + # in_channels=80, + # channels=256, + # resnet_groups=16, + # kernel_multiplier_downsample=2, + # multipliers=[4, 4, 4, 5, 5], + # factors=[1, 2, 2, 4], # 输入长度不一致卷积缩短 + # num_blocks=[2, 2, 2, 2], + # attentions=[1, 3, 3, 3, 3], + # attention_heads=16, + # attention_multiplier=4, + # use_nearest_upsample=False, + # use_skip_scale=True, + # use_context_time=True + # ) + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + @torch.no_grad() + def forward(self, mu, mask, n_timesteps): + # pdb.set_trace() + mask = mask.squeeze(1) + noise = torch.randn_like(mu).to(mu.device) + # mu_pad, mu_pad_mask = pad_and_create_mask(mu, target_length) + # extra_args = {"cross_attn_cond": mu, "cross_attn_cond_mask": mask, "mask": mask} + extra_args = {"input_concat_cond": mu, "mask": mask} + fakes = sample(self.diffusion, noise, n_timesteps, 0, **extra_args) + + return fakes + + def compute_loss(self, x0, mask, mu): + + # pdb.set_trace() + t = self.rng.draw(x0.shape[0])[:, 0].to(x0.device) + alphas, sigmas = torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) + + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(x0) + noised_inputs = x0 * alphas + noise * sigmas + targets = noise * alphas - x0 * sigmas + mask = mask.squeeze(1) + # mu_pad, mu_pad_mask = pad_and_create_mask(mu, target_length) + # output = self.diffusion(noised_inputs, t, cross_attn_cond=mu, + # cross_attn_cond_mask=mask, mask=mask, cfg_dropout_prob=0.1) + # pdb.set_trace() + output = self.diffusion(noised_inputs, # [bs, 80, 229] + t, # (bs,) + input_concat_cond=mu, + mask=mask, # [bs, 229] + cfg_dropout_prob=0.1) + + return self.mse_loss(output, targets, mask), output + + def mse_loss(self, output, targets, mask): + + mse_loss = F.mse_loss(output, targets, reduction='none') + + if mask.ndim == 2 and mse_loss.ndim == 3: + mask = mask.unsqueeze(1) + + if mask.shape[1] != mse_loss.shape[1]: + mask = mask.repeat(1, mse_loss.shape[1], 1) + + mse_loss = mse_loss * mask + + mse_loss = mse_loss.mean() + + return mse_loss diff --git a/cosyvoice/flow/stable/stable_diffusion_test.py b/cosyvoice/flow/stable/stable_diffusion_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b4fb79da70d084ba6571a12fb51f15bb331ea8d7 --- /dev/null +++ b/cosyvoice/flow/stable/stable_diffusion_test.py @@ -0,0 +1,104 @@ +import torch +from torch.nn import functional as F +from .dit import DiffusionTransformer +from .adp import UNet1d +from .sampling import sample +import math +from model.base import BaseModule +import pdb + +target_length = 1536 +def pad_and_create_mask(matrix, target_length): + + T = matrix.shape[2] + if T > target_length: + raise ValueError("The third dimension length %s should not exceed %s"%(T, target_length)) + + padding_size = target_length - T + + padded_matrix = F.pad(matrix, (0, padding_size), "constant", 0) + + mask = torch.ones((1, target_length)) + mask[:, T:] = 0 # Set the padding part to 0 + + return padded_matrix.to(matrix.device), mask.to(matrix.device) + + +class Stable_Diffusion(BaseModule): + def __init__(self): + super(Stable_Diffusion, self).__init__() + self.diffusion = DiffusionTransformer( + io_channels=80, + # input_concat_dim=80, + embed_dim=768, + # cond_token_dim=target_length, + depth=24, + num_heads=24, + project_cond_tokens=False, + transformer_type="continuous_transformer", + ) + # self.diffusion = UNet1d( + # in_channels=80, + # channels=256, + # resnet_groups=16, + # kernel_multiplier_downsample=2, + # multipliers=[4, 4, 4, 5, 5], + # factors=[1, 2, 2, 4], # 输入长度不一致卷积缩短 + # num_blocks=[2, 2, 2, 2], + # attentions=[1, 3, 3, 3, 3], + # attention_heads=16, + # attention_multiplier=4, + # use_nearest_upsample=False, + # use_skip_scale=True, + # use_context_time=True + # ) + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + @torch.no_grad() + def forward(self, mu, mask, n_timesteps): + # pdb.set_trace() + mask = mask.squeeze(1) + # noise = torch.randn_like(mu).to(mu.device) + # mu_pad, mu_pad_mask = pad_and_create_mask(mu, target_length) + # extra_args = {"cross_attn_cond": mu, "cross_attn_cond_mask": mask, "mask": mask} + extra_args = {"mask": mask} + fakes = sample(self.diffusion, mu, n_timesteps, 0, **extra_args) + + return fakes + + + def compute_loss(self, x0, mask, mu): + + # pdb.set_trace() + t = self.rng.draw(x0.shape[0])[:, 0].to(x0.device) + alphas, sigmas = torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) + + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(x0) + noised_inputs = x0 * alphas + noise * sigmas + targets = mu * alphas - x0 * sigmas + mask = mask.squeeze(1) + # mu_pad, mu_pad_mask = pad_and_create_mask(mu, target_length) + # output = self.diffusion(noised_inputs, t, cross_attn_cond=mu, + # cross_attn_cond_mask=mask, mask=mask, cfg_dropout_prob=0.1) + output = self.diffusion(noised_inputs, t, mask=mask, cfg_dropout_prob=0.1) + + return self.mse_loss(output, targets, mask), output + + + def mse_loss(self, output, targets, mask): + + mse_loss = F.mse_loss(output, targets, reduction='none') + + if mask.ndim == 2 and mse_loss.ndim == 3: + mask = mask.unsqueeze(1) + + if mask.shape[1] != mse_loss.shape[1]: + mask = mask.repeat(1, mse_loss.shape[1], 1) + + mse_loss = mse_loss[mask] + + mse_loss = mse_loss.mean() + + return mse_loss \ No newline at end of file diff --git a/cosyvoice/flow/stable/transformer.py b/cosyvoice/flow/stable/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..417beeb78dc70d69fcbb323acbc8d9d7f9e22a09 --- /dev/null +++ b/cosyvoice/flow/stable/transformer.py @@ -0,0 +1,816 @@ +import pdb +from functools import reduce, partial +from packaging import version + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +import torch +import torch.nn.functional as F +from torch import nn, einsum +from torch.cuda.amp import autocast +from typing import Callable, Literal + +try: + from flash_attn import flash_attn_func, flash_attn_kvpacked_func +except ImportError as e: + print(e) + print('flash_attn not installed, disabling Flash Attention') + flash_attn_kvpacked_func = None + flash_attn_func = None + +try: + import natten +except ImportError: + natten = None + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + + +# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License +# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt + +def create_causal_mask(i, j, device): + return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) + +def or_reduce(masks): + head, *body = masks + for rest in body: + head = head | rest + return head + +# positional embeddings + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.scale = dim ** -0.5 + self.max_seq_len = max_seq_len + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x, pos = None, seq_start_pos = None): + seq_len, device = x.shape[1], x.device + assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' + + if pos is None: + pos = torch.arange(seq_len, device = device) + + if seq_start_pos is not None: + pos = (pos - seq_start_pos[..., None]).clamp(min = 0) + + pos_emb = self.emb(pos) + pos_emb = pos_emb * self.scale + return pos_emb + +class ScaledSinusoidalEmbedding(nn.Module): + def __init__(self, dim, theta = 10000): + super().__init__() + assert (dim % 2) == 0, 'dimension must be divisible by 2' + self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) + + half_dim = dim // 2 + freq_seq = torch.arange(half_dim).float() / half_dim + inv_freq = theta ** -freq_seq + self.register_buffer('inv_freq', inv_freq, persistent = False) + + def forward(self, x, pos = None, seq_start_pos = None): + seq_len, device = x.shape[1], x.device + + if pos is None: + pos = torch.arange(seq_len, device = device) + + if seq_start_pos is not None: + pos = pos - seq_start_pos[..., None] + + emb = einsum('i, j -> i j', pos, self.inv_freq) + emb = torch.cat((emb.sin(), emb.cos()), dim = -1) + return emb * self.scale + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim, + use_xpos = False, + scale_base = 512, + interpolation_factor = 1., + base = 10000, + base_rescale_factor = 1. + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + base *= base_rescale_factor ** (dim / (dim - 2)) + + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + assert interpolation_factor >= 1. + self.interpolation_factor = interpolation_factor + + if not use_xpos: + self.register_buffer('scale', None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + + self.scale_base = scale_base + self.register_buffer('scale', scale) + + def forward_from_seq_len(self, seq_len): + device = self.inv_freq.device + + t = torch.arange(seq_len, device = device) + return self.forward(t) + + @autocast(enabled = False) + def forward(self, t): + device = self.inv_freq.device + + t = t.to(torch.float32) + + t = t / self.interpolation_factor + + freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + freqs = torch.cat((freqs, freqs), dim = -1) + + if self.scale is None: + return freqs, 1. + + power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base + scale = self.scale ** rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim = -1) + + return freqs, scale + +def rotate_half(x): + x = rearrange(x, '... (j d) -> ... j d', j = 2) + x1, x2 = x.unbind(dim = -2) + return torch.cat((-x2, x1), dim = -1) + +@autocast(enabled = False) +def apply_rotary_pos_emb(t, freqs, scale = 1): + out_dtype = t.dtype + + # cast to float32 if necessary for numerical stability + dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32)) + rot_dim, seq_len = freqs.shape[-1], t.shape[-2] + freqs, t = freqs.to(dtype), t.to(dtype) + freqs = freqs[-seq_len:, :] + + if t.ndim == 4 and freqs.ndim == 3: + freqs = rearrange(freqs, 'b n d -> b 1 n d') + + # partial rotary embeddings, Wang et al. GPT-J + t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + + t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype) + + return torch.cat((t, t_unrotated), dim = -1) + +# norms +class LayerNorm(nn.Module): + def __init__(self, dim, bias=False, fix_scale=False): + """ + bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less + """ + super().__init__() + + if fix_scale: + self.register_buffer("gamma", torch.ones(dim)) + else: + self.gamma = nn.Parameter(torch.ones(dim)) + + if bias: + self.beta = nn.Parameter(torch.zeros(dim)) + else: + self.register_buffer("beta", torch.zeros(dim)) + + + def forward(self, x): + return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta) + +# feedforward + +class GLU(nn.Module): + def __init__( + self, + dim_in, + dim_out, + activation: Callable, + use_conv = False, + conv_kernel_size = 3, + ): + super().__init__() + self.act = activation + self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2)) + self.use_conv = use_conv + + def forward(self, x): + if self.use_conv: + x = rearrange(x, 'b n d -> b d n') + x = self.proj(x) + x = rearrange(x, 'b d n -> b n d') + else: + x = self.proj(x) + + x, gate = x.chunk(2, dim = -1) + return x * self.act(gate) + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out = None, + mult = 4, + no_bias = False, + glu = True, + use_conv = False, + conv_kernel_size = 3, + zero_init_output = True, + ): + super().__init__() + inner_dim = int(dim * mult) + + # Default to SwiGLU + + activation = nn.SiLU() + + dim_out = dim if dim_out is None else dim_out + + if glu: + linear_in = GLU(dim, inner_dim, activation) + else: + linear_in = nn.Sequential( + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias), + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + activation + ) + + linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias) + + # init last linear layer to 0 + if zero_init_output: + nn.init.zeros_(linear_out.weight) + if not no_bias: + nn.init.zeros_(linear_out.bias) + + + self.ff = nn.Sequential( + linear_in, + Rearrange('b d n -> b n d') if use_conv else nn.Identity(), + linear_out, + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + ) + + def forward(self, x): + return self.ff(x) + +class Attention(nn.Module): + def __init__( + self, + dim, + dim_heads = 64, + dim_context = None, + causal = False, + zero_init_output=True, + qk_norm: Literal['l2', 'ln', 'none'] = 'none', + natten_kernel_size = None + ): + super().__init__() + self.dim = dim + self.dim_heads = dim_heads + self.causal = causal + + dim_kv = dim_context if dim_context is not None else dim + + self.num_heads = dim // dim_heads + self.kv_heads = dim_kv // dim_heads + + if dim_context is not None: + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False) + else: + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) + + self.to_out = nn.Linear(dim, dim, bias=False) + + if zero_init_output: + nn.init.zeros_(self.to_out.weight) + + self.qk_norm = qk_norm + + if self.qk_norm == "ln": + self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) + self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) + + # Using 1d neighborhood attention + self.natten_kernel_size = natten_kernel_size + if natten_kernel_size is not None: + return + + self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + + self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None + # pdb.set_trace() + self.use_fa_flash = False + + self.sdp_kwargs = dict( + enable_flash = True, + enable_math = True, + enable_mem_efficient = True + ) + + def flash_attn( + self, + q, + k, + v, + mask = None, + causal = None + ): + batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device + kv_heads = k.shape[1] + # Recommended for multi-query single-key-value attention by Tri Dao + # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) + + if heads != kv_heads: + # Repeat interleave kv_heads to match q_heads + heads_per_kv_head = heads // kv_heads + k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) + + if k.ndim == 3: + k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) + + if v.ndim == 3: + v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) + + causal = self.causal if causal is None else causal + + if q_len == 1 and causal: + causal = False + + if mask is not None: + assert mask.ndim == 4 + mask = mask.expand(batch, heads, q_len, k_len) + + # handle kv cache - this should be bypassable in updated flash attention 2 + + if k_len > q_len and causal: + causal_mask = self.create_causal_mask(q_len, k_len, device = device) + if mask is None: + mask = ~causal_mask + else: + mask = mask & ~causal_mask + causal = False + + # manually handle causal mask, if another mask was given + + row_is_entirely_masked = None + + if mask is not None and causal: + causal_mask = self.create_causal_mask(q_len, k_len, device = device) + mask = mask & ~causal_mask + + # protect against an entire row being masked out + + row_is_entirely_masked = ~mask.any(dim = -1) + mask[..., 0] = mask[..., 0] | row_is_entirely_masked + + causal = False + + with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs): + out = F.scaled_dot_product_attention( + q, k, v, + attn_mask = mask, + is_causal = causal + ) + + # for a row that is entirely masked out, should zero out the output of that row token + + if row_is_entirely_masked is not None: + out = out.masked_fill(row_is_entirely_masked[..., None], 0.) + + return out + + def forward( + self, + x, + context = None, + mask = None, + context_mask = None, + rotary_pos_emb = None, + causal = None + ): + h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None + + kv_input = context if has_context else x + + if hasattr(self, 'to_q'): + # Use separate linear projections for q and k/v + q = self.to_q(x) + q = rearrange(q, 'b n (h d) -> b h n d', h = h) + + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) + else: + # Use fused linear projection + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + + # Normalize q and k for cosine sim attention + if self.qk_norm == "l2": + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + elif self.qk_norm == "ln": + q = self.q_norm(q) + k = self.k_norm(k) + + if rotary_pos_emb is not None and not has_context: + freqs, _ = rotary_pos_emb + + q_dtype = q.dtype + k_dtype = k.dtype + + q = q.to(torch.float32) + k = k.to(torch.float32) + freqs = freqs.to(torch.float32) + + q = apply_rotary_pos_emb(q, freqs) + k = apply_rotary_pos_emb(k, freqs) + + q = q.to(q_dtype) + k = k.to(k_dtype) + + input_mask = context_mask + + if input_mask is None and not has_context: + input_mask = mask + + # determine masking + masks = [] + final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account + + if input_mask is not None: + input_mask = rearrange(input_mask, 'b j -> b 1 1 j') + masks.append(~input_mask) + + # Other masks will be added here later + + if len(masks) > 0: + final_attn_mask = ~or_reduce(masks) + + n, device = q.shape[-2], q.device + + causal = self.causal if causal is None else causal + + if n == 1 and causal: + causal = False + + if self.natten_kernel_size is not None: + if natten is None: + raise ImportError('natten not installed, please install natten to use neighborhood attention') + + dtype_in = q.dtype + q, k, v = map(lambda t: t.to(torch.float32), (q, k, v)) + + attn = natten.functional.natten1dqk(q, k, kernel_size = self.natten_kernel_size, dilation=1) + + if final_attn_mask is not None: + attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max) + + attn = F.softmax(attn, dim=-1, dtype=torch.float32) + + out = natten.functional.natten1dav(attn, v, kernel_size = self.natten_kernel_size, dilation=1).to(dtype_in) + + # Prioritize Flash Attention 2 + elif self.use_fa_flash: + # pdb.set_trace() + assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2' + # Flash Attention 2 requires FP16 inputs + fa_dtype_in = q.dtype + q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v)) + + out = flash_attn_func(q, k, v, causal = causal) + + out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d') + + # Fall back to PyTorch implementation + elif self.use_pt_flash: + out = self.flash_attn(q, k, v, causal = causal, mask = final_attn_mask) + + else: + # Fall back to custom implementation + + if h != kv_h: + # Repeat interleave kv_heads to match q_heads + heads_per_kv_head = h // kv_h + k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) + + scale = 1. / (q.shape[-1] ** 0.5) + + kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' + + dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale + + i, j, dtype = *dots.shape[-2:], dots.dtype + + mask_value = -torch.finfo(dots.dtype).max + + if final_attn_mask is not None: + dots = dots.masked_fill(~final_attn_mask, mask_value) + + if causal: + causal_mask = self.create_causal_mask(i, j, device = device) + dots = dots.masked_fill(causal_mask, mask_value) + + attn = F.softmax(dots, dim=-1, dtype=torch.float32) + attn = attn.type(dtype) + + out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v) + + # merge heads + out = rearrange(out, ' b h n d -> b n (h d)') + + # Communicate between heads + + # with autocast(enabled = False): + # out_dtype = out.dtype + # out = out.to(torch.float32) + # out = self.to_out(out).to(out_dtype) + out = self.to_out(out) + + if mask is not None: + mask = rearrange(mask, 'b n -> b n 1') + out = out.masked_fill(~mask, 0.) + + return out + +class ConformerModule(nn.Module): + def __init__( + self, + dim, + norm_kwargs = {}, + ): + + super().__init__() + + self.dim = dim + + self.in_norm = LayerNorm(dim, **norm_kwargs) + self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + self.glu = GLU(dim, dim, nn.SiLU()) + self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False) + self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm + self.swish = nn.SiLU() + self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + + def forward(self, x): + x = self.in_norm(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.glu(x) + x = rearrange(x, 'b n d -> b d n') + x = self.depthwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.mid_norm(x) + x = self.swish(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv_2(x) + x = rearrange(x, 'b d n -> b n d') + + return x + +class TransformerBlock(nn.Module): + def __init__( + self, + dim, + dim_heads = 64, + cross_attend = False, + dim_context = None, + global_cond_dim = None, + causal = False, + zero_init_branch_outputs = True, + conformer = False, + layer_ix = -1, + remove_norms = False, + attn_kwargs = {}, + ff_kwargs = {}, + norm_kwargs = {} + ): + + super().__init__() + self.dim = dim + self.dim_heads = dim_heads + self.cross_attend = cross_attend + self.dim_context = dim_context + self.causal = causal + + self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() + + self.self_attn = Attention( + dim, + dim_heads = dim_heads, + causal = causal, + zero_init_output=zero_init_branch_outputs, + **attn_kwargs + ) + ### 2. 主要是这边需要修改 + if cross_attend: + self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() + self.cross_attn = Attention( + dim, + dim_heads = dim_heads, + dim_context=dim_context, + causal = causal, + zero_init_output=zero_init_branch_outputs, + **attn_kwargs + ) + + self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() + self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs) + + self.layer_ix = layer_ix + + self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None + + self.global_cond_dim = global_cond_dim + + if global_cond_dim is not None: + self.to_scale_shift_gate = nn.Sequential( + nn.SiLU(), + nn.Linear(global_cond_dim, dim * 6, bias=False) + ) + + nn.init.zeros_(self.to_scale_shift_gate[1].weight) + #nn.init.zeros_(self.to_scale_shift_gate_self[1].bias) + + def forward( + self, + x, + context = None, + global_cond=None, + mask = None, + context_mask = None, + rotary_pos_emb = None + ): + if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: + + scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1) + + # self-attention with adaLN + residual = x + x = self.pre_norm(x) + x = x * (1 + scale_self) + shift_self + x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb) + x = x * torch.sigmoid(1 - gate_self) + x = x + residual + + if context is not None: + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + + if self.conformer is not None: + x = x + self.conformer(x) + + # feedforward with adaLN + residual = x + x = self.ff_norm(x) + x = x * (1 + scale_ff) + shift_ff + x = self.ff(x) + x = x * torch.sigmoid(1 - gate_ff) + x = x + residual + + else: + x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb) + + if context is not None: + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + + if self.conformer is not None: + x = x + self.conformer(x) + + x = x + self.ff(self.ff_norm(x)) + + return x + +class ContinuousTransformer(nn.Module): + def __init__( + self, + dim, + depth, + *, + dim_in = None, + dim_out = None, + dim_heads = 64, + cross_attend=False, + cond_token_dim=None, + global_cond_dim=None, + causal=False, + rotary_pos_emb=True, + zero_init_branch_outputs=True, + conformer=False, + use_sinusoidal_emb=False, + use_abs_pos_emb=False, + abs_pos_emb_max_length=10000, + **kwargs + ): + + super().__init__() + + self.dim = dim + self.depth = depth + self.causal = causal + self.layers = nn.ModuleList([]) + + self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity() + self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity() + + if rotary_pos_emb: + self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32)) + else: + self.rotary_pos_emb = None + + self.use_sinusoidal_emb = use_sinusoidal_emb + if use_sinusoidal_emb: + self.pos_emb = ScaledSinusoidalEmbedding(dim) + + self.use_abs_pos_emb = use_abs_pos_emb + if use_abs_pos_emb: + self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length) + + for i in range(depth): + self.layers.append( + TransformerBlock( + dim, + dim_heads = dim_heads, + cross_attend = cross_attend, + dim_context = cond_token_dim, + global_cond_dim = global_cond_dim, + causal = causal, + zero_init_branch_outputs = zero_init_branch_outputs, + conformer=conformer, + layer_ix=i, + **kwargs + ) + ) + + def forward( + self, + x, + mask = None, + prepend_embeds = None, + prepend_mask = None, + global_cond = None, + return_info = False, + **kwargs + ): + batch, seq, device = *x.shape[:2], x.device + + info = { + "hidden_states": [], + } + + x = self.project_in(x) + if prepend_embeds is not None: + prepend_length, prepend_dim = prepend_embeds.shape[1:] + + assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension' + + x = torch.cat((prepend_embeds, x), dim = -2) + + if prepend_mask is not None or mask is not None: + mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool) + prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool) + + mask = torch.cat((prepend_mask, mask), dim = -1) + + # Attention layers + + if self.rotary_pos_emb is not None: + rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) + else: + rotary_pos_emb = None + + if self.use_sinusoidal_emb or self.use_abs_pos_emb: + x = x + self.pos_emb(x) + + # Iterate over the transformer layers + for layer in self.layers: + #x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) + # pdb.set_trace() + x = checkpoint(layer, x, mask=mask.bool(),rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) + + if return_info: + info["hidden_states"].append(x) + + x = self.project_out(x) + + if return_info: + return x, info + + return x diff --git a/cosyvoice/flow/stable/transformer_use_mask.py b/cosyvoice/flow/stable/transformer_use_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..d22c5704f71d14b1f4446869a7f191383d5c31d3 --- /dev/null +++ b/cosyvoice/flow/stable/transformer_use_mask.py @@ -0,0 +1,845 @@ +import pdb +from functools import reduce, partial +from packaging import version + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +import torch +import torch.nn.functional as F +from torch import nn, einsum +from torch.cuda.amp import autocast +from typing import Callable, Literal + +try: + from flash_attn import flash_attn_func, flash_attn_kvpacked_func +except ImportError as e: + print(e) + print('flash_attn not installed, disabling Flash Attention') + flash_attn_kvpacked_func = None + flash_attn_func = None + +try: + import natten +except ImportError: + natten = None + + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + + +# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License +# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt + +def create_causal_mask(i, j, device): + return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1) + + +def or_reduce(masks): + head, *body = masks + for rest in body: + head = head | rest + return head + + +# positional embeddings + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.scale = dim ** -0.5 + self.max_seq_len = max_seq_len + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x, pos=None, seq_start_pos=None): + seq_len, device = x.shape[1], x.device + assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' + + if pos is None: + pos = torch.arange(seq_len, device=device) + + if seq_start_pos is not None: + pos = (pos - seq_start_pos[..., None]).clamp(min=0) + + pos_emb = self.emb(pos) + pos_emb = pos_emb * self.scale + return pos_emb + + +class ScaledSinusoidalEmbedding(nn.Module): + def __init__(self, dim, theta=10000): + super().__init__() + assert (dim % 2) == 0, 'dimension must be divisible by 2' + self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) + + half_dim = dim // 2 + freq_seq = torch.arange(half_dim).float() / half_dim + inv_freq = theta ** -freq_seq + self.register_buffer('inv_freq', inv_freq, persistent=False) + + def forward(self, x, pos=None, seq_start_pos=None): + seq_len, device = x.shape[1], x.device + + if pos is None: + pos = torch.arange(seq_len, device=device) + + if seq_start_pos is not None: + pos = pos - seq_start_pos[..., None] + + emb = einsum('i, j -> i j', pos, self.inv_freq) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb * self.scale + + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim, + use_xpos=False, + scale_base=512, + interpolation_factor=1., + base=10000, + base_rescale_factor=1. + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + base *= base_rescale_factor ** (dim / (dim - 2)) + + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + assert interpolation_factor >= 1. + self.interpolation_factor = interpolation_factor + + if not use_xpos: + self.register_buffer('scale', None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + + self.scale_base = scale_base + self.register_buffer('scale', scale) + + def forward_from_seq_len(self, seq_len): + device = self.inv_freq.device + + t = torch.arange(seq_len, device=device) + return self.forward(t) + + @autocast(enabled=False) + def forward(self, t): + device = self.inv_freq.device + + t = t.to(torch.float32) + + t = t / self.interpolation_factor + + freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + freqs = torch.cat((freqs, freqs), dim=-1) + + if self.scale is None: + return freqs, 1. + + power = (torch.arange(seq_len, device=device) - (seq_len // 2)) / self.scale_base + scale = self.scale ** rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim=-1) + + return freqs, scale + + +def rotate_half(x): + x = rearrange(x, '... (j d) -> ... j d', j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +@autocast(enabled=False) +def apply_rotary_pos_emb(t, freqs, scale=1): + out_dtype = t.dtype + + # cast to float32 if necessary for numerical stability + dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32)) + rot_dim, seq_len = freqs.shape[-1], t.shape[-2] + freqs, t = freqs.to(dtype), t.to(dtype) + freqs = freqs[-seq_len:, :] + + if t.ndim == 4 and freqs.ndim == 3: + freqs = rearrange(freqs, 'b n d -> b 1 n d') + + # partial rotary embeddings, Wang et al. GPT-J + t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + + t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype) + + return torch.cat((t, t_unrotated), dim=-1) + + +# norms +class LayerNorm(nn.Module): + def __init__(self, dim, bias=False, fix_scale=False): + """ + bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less + """ + super().__init__() + + if fix_scale: + self.register_buffer("gamma", torch.ones(dim)) + else: + self.gamma = nn.Parameter(torch.ones(dim)) + + if bias: + self.beta = nn.Parameter(torch.zeros(dim)) + else: + self.register_buffer("beta", torch.zeros(dim)) + + def forward(self, x): + return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta) + + +# feedforward + +class GLU(nn.Module): + def __init__( + self, + dim_in, + dim_out, + activation: Callable, + use_conv=False, + conv_kernel_size=3, + ): + super().__init__() + self.act = activation + self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, + padding=(conv_kernel_size // 2)) + self.use_conv = use_conv + + def forward(self, x): + if self.use_conv: + x = rearrange(x, 'b n d -> b d n') + x = self.proj(x) + x = rearrange(x, 'b d n -> b n d') + else: + x = self.proj(x) + + x, gate = x.chunk(2, dim=-1) + return x * self.act(gate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out=None, + mult=4, + no_bias=False, + glu=True, + use_conv=False, + conv_kernel_size=3, + zero_init_output=True, + ): + super().__init__() + inner_dim = int(dim * mult) + + # Default to SwiGLU + + activation = nn.SiLU() + + dim_out = dim if dim_out is None else dim_out + + if glu: + linear_in = GLU(dim, inner_dim, activation) + else: + linear_in = nn.Sequential( + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + nn.Linear(dim, inner_dim, bias=not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, + conv_kernel_size, padding=( + conv_kernel_size // 2), bias=not no_bias), + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + activation + ) + + linear_out = nn.Linear(inner_dim, dim_out, bias=not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, + conv_kernel_size, + padding=( + conv_kernel_size // 2), + bias=not no_bias) + + # init last linear layer to 0 + if zero_init_output: + nn.init.zeros_(linear_out.weight) + if not no_bias: + nn.init.zeros_(linear_out.bias) + + self.ff = nn.Sequential( + linear_in, + Rearrange('b d n -> b n d') if use_conv else nn.Identity(), + linear_out, + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + ) + + def forward(self, x): + return self.ff(x) + + +class Attention(nn.Module): + def __init__( + self, + dim, + dim_heads=64, + dim_context=None, + causal=False, + zero_init_output=True, + qk_norm: Literal['l2', 'ln', 'none'] = 'none', + natten_kernel_size=None + ): + super().__init__() + self.dim = dim + self.dim_heads = dim_heads + self.causal = causal + + dim_kv = dim_context if dim_context is not None else dim + + self.num_heads = dim // dim_heads + self.kv_heads = dim_kv // dim_heads + + if dim_context is not None: + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False) + else: + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) + + self.to_out = nn.Linear(dim, dim, bias=False) + + if zero_init_output: + nn.init.zeros_(self.to_out.weight) + + self.qk_norm = qk_norm + + if self.qk_norm == "ln": + self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) + self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) + + # Using 1d neighborhood attention + self.natten_kernel_size = natten_kernel_size + if natten_kernel_size is not None: + return + + self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + + self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None + # pdb.set_trace() + self.use_fa_flash = False + + self.sdp_kwargs = dict( + enable_flash=True, + enable_math=True, + enable_mem_efficient=True + ) + + def flash_attn( + self, + q, + k, + v, + mask=None, + causal=None + ): + batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device + kv_heads = k.shape[1] + # Recommended for multi-query single-key-value attention by Tri Dao + # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) + + if heads != kv_heads: + # Repeat interleave kv_heads to match q_heads + heads_per_kv_head = heads // kv_heads + k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim=1), (k, v)) + + if k.ndim == 3: + k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) + + if v.ndim == 3: + v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) + + causal = self.causal if causal is None else causal + + if q_len == 1 and causal: + causal = False + + if mask is not None: + assert mask.ndim == 4 + mask = mask.expand(batch, heads, q_len, k_len) + + assert causal + # handle kv cache - this should be bypassable in updated flash attention 2 + if k_len > q_len and causal: + causal_mask = create_causal_mask(q_len, k_len, device=device) + if mask is None: + mask = ~causal_mask + else: + mask = mask & ~causal_mask + causal = False + + # manually handle causal mask, if another mask was given + + row_is_entirely_masked = None + + if mask is not None and causal: + causal_mask = create_causal_mask(q_len, k_len, device=device) + mask = mask & ~causal_mask + + # protect against an entire row being masked out + + row_is_entirely_masked = ~mask.any(dim=-1) + mask[..., 0] = mask[..., 0] | row_is_entirely_masked + + causal = False + + with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs): + out = F.scaled_dot_product_attention( + q, k, v, + attn_mask=mask, + is_causal=causal + ) + + # for a row that is entirely masked out, should zero out the output of that row token + + if row_is_entirely_masked is not None: + out = out.masked_fill(row_is_entirely_masked[..., None], 0.) + + return out + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rotary_pos_emb=None, + causal=None + ): + h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None + + kv_input = context if has_context else x + + if hasattr(self, 'to_q'): + # Use separate linear projections for q and k/v + q = self.to_q(x) + q = rearrange(q, 'b n (h d) -> b h n d', h=h) + + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=kv_h), (k, v)) + else: + # Use fused linear projection + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + # Normalize q and k for cosine sim attention + if self.qk_norm == "l2": + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + elif self.qk_norm == "ln": + q = self.q_norm(q) + k = self.k_norm(k) + + if rotary_pos_emb is not None and not has_context: + freqs, _ = rotary_pos_emb + + q_dtype = q.dtype + k_dtype = k.dtype + + q = q.to(torch.float32) + k = k.to(torch.float32) + freqs = freqs.to(torch.float32) + + q = apply_rotary_pos_emb(q, freqs) + k = apply_rotary_pos_emb(k, freqs) + + q = q.to(q_dtype) + k = k.to(k_dtype) + + input_mask = context_mask + + if input_mask is None and not has_context: + input_mask = mask + + # determine masking + masks = [] + final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account + + if input_mask is not None: + input_mask = rearrange(input_mask, 'b j -> b 1 1 j') + masks.append(~input_mask) + + # Other masks will be added here later + + if len(masks) > 0: + final_attn_mask = ~or_reduce(masks) + + n, device = q.shape[-2], q.device + + causal = self.causal if causal is None else causal + + if n == 1 and causal: + causal = False + + if self.natten_kernel_size is not None: + if natten is None: + raise ImportError('natten not installed, please install natten to use neighborhood attention') + + dtype_in = q.dtype + q, k, v = map(lambda t: t.to(torch.float32), (q, k, v)) + + attn = natten.functional.natten1dqk(q, k, kernel_size=self.natten_kernel_size, dilation=1) + + if final_attn_mask is not None: + attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max) + + attn = F.softmax(attn, dim=-1, dtype=torch.float32) + + out = natten.functional.natten1dav(attn, v, kernel_size=self.natten_kernel_size, dilation=1).to(dtype_in) + + # Prioritize Flash Attention 2 + elif self.use_fa_flash: + assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2' + # Flash Attention 2 requires FP16 inputs + fa_dtype_in = q.dtype + q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v)) + + out = flash_attn_func(q, k, v, causal=causal) + + out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d') + + # Fall back to PyTorch implementation + elif self.use_pt_flash: + # causal=False + # final_attn_mask:[64, 1, 1, 348] + out = self.flash_attn(q, k, v, causal=True, mask=final_attn_mask) + + else: + # Fall back to custom implementation + + if h != kv_h: + # Repeat interleave kv_heads to match q_heads + heads_per_kv_head = h // kv_h + k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim=1), (k, v)) + + scale = 1. / (q.shape[-1] ** 0.5) + + kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' + + dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale + + i, j, dtype = *dots.shape[-2:], dots.dtype + + mask_value = -torch.finfo(dots.dtype).max + + if final_attn_mask is not None: + dots = dots.masked_fill(~final_attn_mask, mask_value) + + if causal: + causal_mask = create_causal_mask(i, j, device=device) + dots = dots.masked_fill(causal_mask, mask_value) + + attn = F.softmax(dots, dim=-1, dtype=torch.float32) + attn = attn.type(dtype) + + out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v) + + # merge heads + out = rearrange(out, ' b h n d -> b n (h d)') + + # Communicate between heads + + # with autocast(enabled = False): + # out_dtype = out.dtype + # out = out.to(torch.float32) + # out = self.to_out(out).to(out_dtype) + out = self.to_out(out) + + if mask is not None: + mask = rearrange(mask, 'b n -> b n 1') + out = out.masked_fill(~mask, 0.) + + return out + + +class ConformerModule(nn.Module): + def __init__( + self, + dim, + norm_kwargs={}, + ): + super().__init__() + + self.dim = dim + + self.in_norm = LayerNorm(dim, **norm_kwargs) + self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + self.glu = GLU(dim, dim, nn.SiLU()) + self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False) + self.mid_norm = LayerNorm(dim, + **norm_kwargs) # This is a batch norm in the original but I don't like batch norm + self.swish = nn.SiLU() + self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + + def forward(self, x): + x = self.in_norm(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.glu(x) + x = rearrange(x, 'b n d -> b d n') + x = self.depthwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.mid_norm(x) + x = self.swish(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv_2(x) + x = rearrange(x, 'b d n -> b n d') + + return x + + +class TransformerBlock(nn.Module): + def __init__( + self, + dim, + dim_heads=64, + cross_attend=False, + dim_context=None, + global_cond_dim=None, + causal=False, + zero_init_branch_outputs=True, + conformer=False, + layer_ix=-1, + remove_norms=False, + attn_kwargs={}, + ff_kwargs={}, + norm_kwargs={} + ): + + super().__init__() + self.dim = dim + self.dim_heads = dim_heads + self.cross_attend = cross_attend + self.dim_context = dim_context + self.causal = causal + + self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() + + self.self_attn = Attention( + dim, + dim_heads=dim_heads, + causal=causal, + zero_init_output=zero_init_branch_outputs, + **attn_kwargs + ) + ### 2. 主要是这边需要修改 + if cross_attend: + self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() + self.cross_attn = Attention( + dim, + dim_heads=dim_heads, + dim_context=dim_context, + causal=causal, + zero_init_output=zero_init_branch_outputs, + **attn_kwargs + ) + + self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() + self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs) + + self.layer_ix = layer_ix + + self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None + + self.global_cond_dim = global_cond_dim + + if global_cond_dim is not None: + self.to_scale_shift_gate = nn.Sequential( + nn.SiLU(), + nn.Linear(global_cond_dim, dim * 6, bias=False) + ) + + nn.init.zeros_(self.to_scale_shift_gate[1].weight) + # nn.init.zeros_(self.to_scale_shift_gate_self[1].bias) + + def forward( + self, + x, + context=None, + global_cond=None, + mask=None, + context_mask=None, + rotary_pos_emb=None + ): + if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: + + scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate( + global_cond).unsqueeze(1).chunk(6, dim=-1) + + # self-attention with adaLN + residual = x + x = self.pre_norm(x) + x = x * (1 + scale_self) + shift_self + x = self.self_attn(x, mask=mask, rotary_pos_emb=rotary_pos_emb) + x = x * torch.sigmoid(1 - gate_self) + x = x + residual + + if context is not None: + x = x + self.cross_attn(self.cross_attend_norm(x), context=context, context_mask=context_mask) + + if self.conformer is not None: + x = x + self.conformer(x) + + # feedforward with adaLN + residual = x + x = self.ff_norm(x) + x = x * (1 + scale_ff) + shift_ff + x = self.ff(x) + x = x * torch.sigmoid(1 - gate_ff) + x = x + residual + + else: + x = x + self.self_attn(self.pre_norm(x), mask=mask, rotary_pos_emb=rotary_pos_emb) + + if context is not None: + x = x + self.cross_attn(self.cross_attend_norm(x), context=context, context_mask=context_mask) + + if self.conformer is not None: + x = x + self.conformer(x) + + x = x + self.ff(self.ff_norm(x)) + + return x + + +class ContinuousTransformer(nn.Module): + def __init__( + self, + dim, + depth, + *, + dim_in=None, + dim_out=None, + dim_heads=64, + cross_attend=False, + cond_token_dim=None, + global_cond_dim=None, + causal=False, + rotary_pos_emb=True, + zero_init_branch_outputs=True, + conformer=False, + use_sinusoidal_emb=False, + use_abs_pos_emb=False, + abs_pos_emb_max_length=10000, + **kwargs + ): + + super().__init__() + + self.dim = dim + self.depth = depth + self.causal = causal + self.layers = nn.ModuleList([]) + + self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity() + self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity() + + if rotary_pos_emb: + self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32)) + else: + self.rotary_pos_emb = None + + self.use_sinusoidal_emb = use_sinusoidal_emb + if use_sinusoidal_emb: + self.pos_emb = ScaledSinusoidalEmbedding(dim) + + self.use_abs_pos_emb = use_abs_pos_emb + if use_abs_pos_emb: + self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length) + + for i in range(depth): + self.layers.append( + TransformerBlock( + dim, + dim_heads=dim_heads, + cross_attend=cross_attend, + dim_context=cond_token_dim, + global_cond_dim=global_cond_dim, + causal=causal, + zero_init_branch_outputs=zero_init_branch_outputs, + conformer=conformer, + layer_ix=i, + **kwargs + ) + ) + + def forward( + self, + x, + mask=None, + prepend_embeds=None, + prepend_mask=None, + global_cond=None, + return_info=False, + **kwargs + ): + batch, seq, device = *x.shape[:2], x.device + + info = { + "hidden_states": [], + } + + x = self.project_in(x) + if prepend_embeds is not None: + prepend_length, prepend_dim = prepend_embeds.shape[1:] + + assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension' + + x = torch.cat((prepend_embeds, x), dim=-2) + + if prepend_mask is not None or mask is not None: + mask = mask if mask is not None else torch.ones((batch, seq), device=device, dtype=torch.bool) + prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), + device=device, dtype=torch.bool) + + mask = torch.cat((prepend_mask, mask), dim=-1) + + # Attention layers + + if self.rotary_pos_emb is not None: + rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) + else: + rotary_pos_emb = None + + if self.use_sinusoidal_emb or self.use_abs_pos_emb: + x = x + self.pos_emb(x) + + # Iterate over the transformer layers + mask = self.refine_mask(mask) + for layer in self.layers: + # x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) + # pdb.set_trace() + x = checkpoint(layer, x, mask=mask.bool(), rotary_pos_emb=rotary_pos_emb, global_cond=global_cond, **kwargs) + + if return_info: + info["hidden_states"].append(x) + + x = self.project_out(x) + + if return_info: + return x, info + + return x + + def refine_mask(self, mask): + return mask + # pdb.set_trace() + # mask = 1 - torch.triu(torch.ones(seq_length, seq_length), diagonal=1) + # return mask diff --git a/cosyvoice/hifigan/__pycache__/f0_predictor.cpython-310.pyc b/cosyvoice/hifigan/__pycache__/f0_predictor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9abe015846b124f91743016e18e02688325d8a2b Binary files /dev/null and b/cosyvoice/hifigan/__pycache__/f0_predictor.cpython-310.pyc differ diff --git a/cosyvoice/hifigan/__pycache__/generator.cpython-310.pyc b/cosyvoice/hifigan/__pycache__/generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d476ff03e7272ea529d72e774f76129e325d5a3c Binary files /dev/null and b/cosyvoice/hifigan/__pycache__/generator.cpython-310.pyc differ diff --git a/cosyvoice/hifigan/f0_predictor.py b/cosyvoice/hifigan/f0_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..36b85f4ed90c3a412cb179f49ccb471132a86550 --- /dev/null +++ b/cosyvoice/hifigan/f0_predictor.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm + + +class ConvRNNF0Predictor(nn.Module): + def __init__(self, + num_class: int = 1, + in_channels: int = 80, + cond_channels: int = 512 + ): + super().__init__() + + self.num_class = num_class + self.condnet = nn.Sequential( + weight_norm( + nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + ) + self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.condnet(x) + x = x.transpose(1, 2) + return torch.abs(self.classifier(x).squeeze(-1)) diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..a43ac05a7d828c86965d3f56b6798d522689089f --- /dev/null +++ b/cosyvoice/hifigan/generator.py @@ -0,0 +1,398 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""HIFI-GAN""" + +import typing as tp +import numpy as np +from scipy.signal import get_window +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv1d +from torch.nn import ConvTranspose1d +from torch.nn.utils import remove_weight_norm +from torch.nn.utils import weight_norm +from torch.distributions.uniform import Uniform + +from cosyvoice.transformer.activation import Snake +from cosyvoice.utils.common import get_padding +from cosyvoice.utils.common import init_weights + + +"""hifigan based generator implementation. + +This code is modified from https://github.com/jik876/hifi-gan + ,https://github.com/kan-bayashi/ParallelWaveGAN and + https://github.com/NVIDIA/BigVGAN + +""" + + +class ResBlock(torch.nn.Module): + """Residual block module in HiFiGAN/BigVGAN.""" + def __init__( + self, + channels: int = 512, + kernel_size: int = 3, + dilations: tp.List[int] = [1, 3, 5], + ): + super(ResBlock, self).__init__() + self.convs1 = nn.ModuleList() + self.convs2 = nn.ModuleList() + + for dilation in dilations: + self.convs1.append( + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation, + padding=get_padding(kernel_size, dilation) + ) + ) + ) + self.convs2.append( + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1) + ) + ) + ) + self.convs1.apply(init_weights) + self.convs2.apply(init_weights) + self.activations1 = nn.ModuleList([ + Snake(channels, alpha_logscale=False) + for _ in range(len(self.convs1)) + ]) + self.activations2 = nn.ModuleList([ + Snake(channels, alpha_logscale=False) + for _ in range(len(self.convs2)) + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for idx in range(len(self.convs1)): + xt = self.activations1[idx](x) + xt = self.convs1[idx](xt) + xt = self.activations2[idx](xt) + xt = self.convs2[idx](xt) + x = xt + x + return x + + def remove_weight_norm(self): + for idx in range(len(self.convs1)): + remove_weight_norm(self.convs1[idx]) + remove_weight_norm(self.convs2[idx]) + + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + @torch.no_grad() + def forward(self, f0): + """ + :param f0: [B, 1, sample_len], Hz + :return: [B, 1, sample_len] + """ + + F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device) + for i in range(self.harmonic_num + 1): + F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate + + theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1) + u_dist = Uniform(low=-np.pi, high=np.pi) + phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device) + phase_vec[:, 0, :] = 0 + + # generate sine waveforms + sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec) + + # generate uv signal + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + with torch.no_grad(): + sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2)) + sine_wavs = sine_wavs.transpose(1, 2) + uv = uv.transpose(1, 2) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + +class HiFTGenerator(nn.Module): + """ + HiFTNet Generator: Neural Source Filter + ISTFTNet + https://arxiv.org/abs/2309.09493 + """ + def __init__( + self, + in_channels: int = 80, + base_channels: int = 512, + nb_harmonics: int = 8, + sampling_rate: int = 22050, + nsf_alpha: float = 0.1, + nsf_sigma: float = 0.003, + nsf_voiced_threshold: float = 10, + upsample_rates: tp.List[int] = [8, 8], + upsample_kernel_sizes: tp.List[int] = [16, 16], + istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4}, + resblock_kernel_sizes: tp.List[int] = [3, 7, 11], + resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + source_resblock_kernel_sizes: tp.List[int] = [7, 11], + source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]], + lrelu_slope: float = 0.1, + audio_limit: float = 0.99, + f0_predictor: torch.nn.Module = None, + ): + super(HiFTGenerator, self).__init__() + + self.out_channels = 1 + self.nb_harmonics = nb_harmonics + self.sampling_rate = sampling_rate + self.istft_params = istft_params + self.lrelu_slope = lrelu_slope + self.audio_limit = audio_limit + + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.m_source = SourceModuleHnNSF( + sampling_rate=sampling_rate, + upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"], + harmonic_num=nb_harmonics, + sine_amp=nsf_alpha, + add_noise_std=nsf_sigma, + voiced_threshod=nsf_voiced_threshold) + self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"]) + + self.conv_pre = weight_norm( + Conv1d(in_channels, base_channels, 7, 1, padding=3) + ) + + # Up + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + base_channels // (2**i), + base_channels // (2**(i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + # Down + self.source_downs = nn.ModuleList() + self.source_resblocks = nn.ModuleList() + downsample_rates = [1] + upsample_rates[::-1][:-1] + downsample_cum_rates = np.cumprod(downsample_rates) + for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)): + if u == 1: + self.source_downs.append( + Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1) + ) + else: + self.source_downs.append( + Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2)) + ) + + self.source_resblocks.append( + ResBlock(base_channels // (2 ** (i + 1)), k, d) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = base_channels // (2**(i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(ResBlock(ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.reflection_pad = nn.ReflectionPad1d((1, 0)) + self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32)) + self.f0_predictor = f0_predictor + + def _f02source(self, f0: torch.Tensor) -> torch.Tensor: + f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + + har_source, _, _ = self.m_source(f0) + return har_source.transpose(1, 2) + + def _stft(self, x): + spec = torch.stft( + x, + self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device), + return_complex=True) + spec = torch.view_as_real(spec) # [B, F, TT, 2] + return spec[..., 0], spec[..., 1] + + def _istft(self, magnitude, phase): + magnitude = torch.clip(magnitude, max=1e2) + real = magnitude * torch.cos(phase) + img = magnitude * torch.sin(phase) + inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], + self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device)) + return inverse_transform + + def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: + f0 = self.f0_predictor(x) + s = self._f02source(f0) + + # use cache_source to avoid glitch + if cache_source.shape[2] != 0: + s[:, :, :cache_source.shape[2]] = cache_source + + s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) + s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) + + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, self.lrelu_slope) + x = self.ups[i](x) + + if i == self.num_upsamples - 1: + x = self.reflection_pad(x) + + # fusion + si = self.source_downs[i](s_stft) + si = self.source_resblocks[i](si) + x = x + si + + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + x = F.leaky_relu(x) + x = self.conv_post(x) + magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :]) + phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy + + x = self._istft(magnitude, phase) + x = torch.clamp(x, -self.audio_limit, self.audio_limit) + return x, s + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + self.source_module.remove_weight_norm() + for l in self.source_downs: + remove_weight_norm(l) + for l in self.source_resblocks: + l.remove_weight_norm() + + @torch.inference_mode() + def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: + return self.forward(x=mel, cache_source=cache_source) \ No newline at end of file diff --git a/cosyvoice/llm/__pycache__/llm.cpython-310.pyc b/cosyvoice/llm/__pycache__/llm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34518dc60e25faa5c7518e19ba925c31dd81cd69 Binary files /dev/null and b/cosyvoice/llm/__pycache__/llm.cpython-310.pyc differ diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py new file mode 100644 index 0000000000000000000000000000000000000000..3b418c5d1017c6f8412418dd8d1c1b7790947241 --- /dev/null +++ b/cosyvoice/llm/llm.py @@ -0,0 +1,206 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Union +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence, unpad_sequence +from cosyvoice.utils.common import IGNORE_ID +from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss +from cosyvoice.utils.common import th_accuracy + + +class TransformerLM(torch.nn.Module): + def __init__( + self, + text_encoder_input_size: int, + llm_input_size: int, + llm_output_size: int, + text_token_size: int, + speech_token_size: int, + text_encoder: torch.nn.Module, + llm: torch.nn.Module, + length_normalized_loss: bool = True, + lsm_weight: float = 0.0, + spk_embed_dim: int = 192, + ): + super().__init__() + self.llm_input_size = llm_input_size + self.speech_token_size = speech_token_size + # 1. build text token inputs related modules + self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size) + self.text_encoder = text_encoder + self.text_encoder_affine_layer = nn.Linear( + self.text_encoder.output_size(), + llm_input_size + ) + + # 2. build speech token language model related modules + self.sos_eos = 0 + self.task_id = 1 + self.llm_embedding = torch.nn.Embedding(2, llm_input_size) + self.llm = llm + self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1) + self.criterion_ce = LabelSmoothingLoss( + size=speech_token_size + 1, + padding_idx=IGNORE_ID, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + # 3. [Optional] build speech token related modules + self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size) + self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size) + + def encode( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + ): + encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + encoder_out = self.text_encoder_affine_layer(encoder_out) + return encoder_out, encoder_out_lens + + def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len): + text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True) + speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True) + lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) for i in range(len(text_token))] + lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32) + lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID) + return lm_input, lm_input_len + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + """ + Args: + text: (B, L, D) + text_lengths: (B,) + audio: (B, T, N) or (B, T) + audio_lengths: (B,) + """ + text_token = batch['text_token'].to(device) + text_token_len = batch['text_token_len'].to(device) + speech_token = batch['speech_token'].to(device) + speech_token_len = batch['speech_token_len'].to(device) + embedding = batch['embedding'].to(device) + + # 1. prepare llm_target + lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))] + lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device) + + # 1. encode text_token + text_token = self.text_embedding(text_token) + text_token, text_token_len = self.encode(text_token, text_token_len) + + # 2. embedding projection + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + embedding = embedding.unsqueeze(1) + + # 3. eos and task_id + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + + # 4. encode speech_token + speech_token = self.speech_embedding(speech_token) + + # 5. unpad and pad + lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len) + + # 6. run lm forward + lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device)) + logits = self.llm_decoder(lm_output) + loss = self.criterion_ce(logits, lm_target) + acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID) + return {'loss': loss, 'acc': acc} + + def sampling_ids( + self, + weighted_scores: torch.Tensor, + sampling: Union[bool, int, float] = True, + beam_size: int = 1, + ignore_eos: bool = True, + ): + while True: + prob, indices = weighted_scores.softmax(dim=-1).topk(sampling) + top_ids = prob.multinomial(beam_size, replacement=True) + top_ids = indices[top_ids] + if (not ignore_eos) or (self.speech_token_size not in top_ids): + break + return top_ids + + @torch.inference_mode() + def inference( + self, + text: torch.Tensor, + text_len: torch.Tensor, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + beam_size: int = 1, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> torch.Tensor: + device = text.device + text = torch.concat([prompt_text, text], dim=1) + text_len += prompt_text_len + text = self.text_embedding(text) + + # 1. encode text + text, text_len = self.encode(text, text_len) + + # 2. encode embedding + if embedding.shape[0] != 0: + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + embedding = embedding.unsqueeze(dim=1) + else: + embedding = torch.zeros(1, 0, self.llm_input_size).to(device) + + # 3. concat llm_input + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size).to(device) + lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1) + + # 4. cal min/max_length + min_len = int((text_len - prompt_text_len) * min_token_text_ratio) + max_len = int((text_len - prompt_text_len) * max_token_text_ratio) + + # 5. step by step decode + out_tokens = [] + offset = 0 + att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device) + for i in range(max_len): + y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache, + att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool)) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item() + if top_ids == self.speech_token_size: + break + out_tokens.append(top_ids) + offset += lm_input.size(1) + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + + return torch.tensor([out_tokens], dtype=torch.int64, device=device) diff --git a/cosyvoice/transformer/__init__.py b/cosyvoice/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosyvoice/transformer/__pycache__/__init__.cpython-310.pyc b/cosyvoice/transformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d71fb5c11a8fa0a581c86879ffa02a0a807318b Binary files /dev/null and b/cosyvoice/transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/cosyvoice/transformer/__pycache__/activation.cpython-310.pyc b/cosyvoice/transformer/__pycache__/activation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a87ad714ccb43b5ca6064f851fcaa97f072cbda Binary files /dev/null and b/cosyvoice/transformer/__pycache__/activation.cpython-310.pyc differ diff --git a/cosyvoice/transformer/__pycache__/attention.cpython-310.pyc b/cosyvoice/transformer/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b4983a5d705ce1f4bd3c4e30683062b3eae5326 Binary files /dev/null and b/cosyvoice/transformer/__pycache__/attention.cpython-310.pyc differ diff --git a/cosyvoice/transformer/__pycache__/convolution.cpython-310.pyc b/cosyvoice/transformer/__pycache__/convolution.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..761ecbece7e0d87fa2a8df6bd780d5faea88b6b6 Binary files /dev/null and b/cosyvoice/transformer/__pycache__/convolution.cpython-310.pyc differ diff --git a/cosyvoice/transformer/__pycache__/embedding.cpython-310.pyc b/cosyvoice/transformer/__pycache__/embedding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a996369750d8f76cd52d97afa5ae4d4ffed908c0 Binary files /dev/null and b/cosyvoice/transformer/__pycache__/embedding.cpython-310.pyc differ diff --git a/cosyvoice/transformer/__pycache__/encoder.cpython-310.pyc b/cosyvoice/transformer/__pycache__/encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82bb76dce5080f5d1de3a09c18a5315134e6564d Binary files /dev/null and b/cosyvoice/transformer/__pycache__/encoder.cpython-310.pyc differ diff --git a/cosyvoice/transformer/__pycache__/encoder_layer.cpython-310.pyc b/cosyvoice/transformer/__pycache__/encoder_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2201dae6addd7e8be75bf4340a5d422ce26706cc Binary files /dev/null and b/cosyvoice/transformer/__pycache__/encoder_layer.cpython-310.pyc differ diff --git a/cosyvoice/transformer/__pycache__/label_smoothing_loss.cpython-310.pyc b/cosyvoice/transformer/__pycache__/label_smoothing_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c48fd666dca70f5dd90672ad376412862031184c Binary files /dev/null and b/cosyvoice/transformer/__pycache__/label_smoothing_loss.cpython-310.pyc differ diff --git a/cosyvoice/transformer/__pycache__/positionwise_feed_forward.cpython-310.pyc b/cosyvoice/transformer/__pycache__/positionwise_feed_forward.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70b8b81ffb6bcc87da4941d41b3e11308d9511be Binary files /dev/null and b/cosyvoice/transformer/__pycache__/positionwise_feed_forward.cpython-310.pyc differ diff --git a/cosyvoice/transformer/__pycache__/subsampling.cpython-310.pyc b/cosyvoice/transformer/__pycache__/subsampling.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..731576e9426c5793afd4a318b3d8702fcf9235e6 Binary files /dev/null and b/cosyvoice/transformer/__pycache__/subsampling.cpython-310.pyc differ diff --git a/cosyvoice/transformer/activation.py b/cosyvoice/transformer/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..8cea54816385d3b6585ccc2417bc71630d578177 --- /dev/null +++ b/cosyvoice/transformer/activation.py @@ -0,0 +1,84 @@ +# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe) +# 2020 Northwestern Polytechnical University (Pengcheng Guo) +# 2020 Mobvoi Inc (Binbin Zhang) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Swish() activation function for Conformer.""" + +import torch +from torch import nn, sin, pow +from torch.nn import Parameter + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return Swish activation function.""" + return x * torch.sigmoid(x) + + +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x diff --git a/cosyvoice/transformer/attention.py b/cosyvoice/transformer/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..b9aaa62d1ec0954e9a168b42bc66702e41591aed --- /dev/null +++ b/cosyvoice/transformer/attention.py @@ -0,0 +1,612 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Multi-Head Attention layer definition.""" + +import math +from typing import Tuple + +import torch +from torch import nn + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, + key_bias: bool = True): + """Construct an MultiHeadedAttention object.""" + super().__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor, size + (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor, size + (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor, size + (#batch, n_head, time2, d_k). + + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention( + self, + value: torch.Tensor, + scores: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) + ) -> torch.Tensor: + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value, size + (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score, size + (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask, size (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + # NOTE(xcsong): When will `if mask.size(2) > 0` be True? + # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the + # 1st chunk to ease the onnx export.] + # 2. pytorch training + if mask.size(2) > 0: # time2 > 0 + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + # For last chunk, time2 might be larger than scores.size(-1) + mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2) + scores = scores.masked_fill(mask, -float('inf')) + attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0) # (batch, head, time1, time2) + # NOTE(xcsong): When will `if mask.size(2) > 0` be False? + # 1. onnx(16/-1, -1/-1, 16/0) + # 2. jit (16/-1, -1/-1, 16/0, 16/4) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, + self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + 1.When applying cross attention between decoder and encoder, + the batch padding mask for input is in (#batch, 1, T) shape. + 2.When applying self attention of encoder, + the mask is in (#batch, T, T) shape. + 3.When applying self attention of decoder, + the mask is in (#batch, L, L) shape. + 4.If the different position in decoder see different block + of the encoder, such as Mocha, the passed in mask could be + in (#batch, L, T) shape. But there is no such case in current + CosyVoice. + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + + """ + q, k, v = self.forward_qkv(query, key, value) + + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.size(0) > 0: + key_cache, value_cache = torch.split(cache, + cache.size(-1) // 2, + dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = torch.cat((k, v), dim=-1) + + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask), new_cache + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, + key_bias: bool = True): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate, key_bias) + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x): + """Compute relative positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + torch.Tensor: Output tensor. + + """ + zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x)[ + :, :, :, : x.size(-1) // 2 + 1 + ] # only keep the positions from 0 to time2 + return x + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, time2, size). + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.size(0) > 0: + key_cache, value_cache = torch.split(cache, + cache.size(-1) // 2, + dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = torch.cat((k, v), dim=-1) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used + if matrix_ac.shape != matrix_bd.shape: + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask), new_cache + + + + +# class BlockRelPositionMultiHeadedAttention(MultiHeadedAttention): +# """Multi-Head Attention layer with relative position encoding. +# Paper: https://arxiv.org/abs/1901.02860 +# Args: +# n_head (int): The number of heads. +# n_feat (int): The number of features. +# dropout_rate (float): Dropout rate. +# """ + +# def __init__(self, +# n_head: int, +# n_feat: int, +# dropout_rate: float, +# key_bias: bool = True, +# block_size=25): +# """Construct an RelPositionMultiHeadedAttention object.""" +# super().__init__(n_head, n_feat, dropout_rate, key_bias) +# # linear transformation for positional encoding +# self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) +# # these two learnable bias are used in matrix c and matrix d +# # as described in https://arxiv.org/abs/1901.02860 Section 3.3 +# self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) +# self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) +# torch.nn.init.xavier_uniform_(self.pos_bias_u) +# torch.nn.init.xavier_uniform_(self.pos_bias_v) +# self.block_size=block_size + +# def rel_shift(self, x): +# """Compute relative positional encoding. + +# Args: +# x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). +# time1 means the length of query vector. + +# Returns: +# torch.Tensor: Output tensor. + +# """ +# zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) +# x_padded = torch.cat([zero_pad, x], dim=-1) + +# x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) +# x = x_padded[:, :, 1:].view_as(x)[ +# :, :, :, : x.size(-1) // 2 + 1 +# ] # only keep the positions from 0 to time2 +# return x + +# def forward( +# self, +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), +# pos_emb: torch.Tensor = torch.empty(0), +# cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) +# ) -> Tuple[torch.Tensor, torch.Tensor]: +# """Compute 'Scaled Dot Product Attention' with rel. positional encoding. +# Args: +# query (torch.Tensor): Query tensor (#batch, time1, size). +# key (torch.Tensor): Key tensor (#batch, time2, size). +# value (torch.Tensor): Value tensor (#batch, time2, size). +# mask (torch.Tensor): Mask tensor (#batch, 1, time2) or +# (#batch, time1, time2), (0, 0, 0) means fake mask. +# pos_emb (torch.Tensor): Positional embedding tensor +# (#batch, time2, size). +# cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), +# where `cache_t == chunk_size * num_decoding_left_chunks` +# and `head * d_k == size` +# Returns: +# torch.Tensor: Output tensor (#batch, time1, d_model). +# torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) +# where `cache_t == chunk_size * num_decoding_left_chunks` +# and `head * d_k == size` +# """ +# q, k, v = self.forward_qkv(query, key, value) +# q = q.transpose(1, 2) # (batch, time1, head, d_k) + +# # NOTE(xcsong): +# # when export onnx model, for 1st chunk, we feed +# # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) +# # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). +# # In all modes, `if cache.size(0) > 0` will alwayse be `True` +# # and we will always do splitting and +# # concatnation(this will simplify onnx export). Note that +# # it's OK to concat & split zero-shaped tensors(see code below). +# # when export jit model, for 1st chunk, we always feed +# # cache(0, 0, 0, 0) since jit supports dynamic if-branch. +# # >>> a = torch.ones((1, 2, 0, 4)) +# # >>> b = torch.ones((1, 2, 3, 4)) +# # >>> c = torch.cat((a, b), dim=2) +# # >>> torch.equal(b, c) # True +# # >>> d = torch.split(a, 2, dim=-1) +# # >>> torch.equal(d[0], d[1]) # True +# if cache.size(0) > 0: +# key_cache, value_cache = torch.split(cache, +# cache.size(-1) // 2, +# dim=-1) +# k = torch.cat([key_cache, k], dim=2) +# v = torch.cat([value_cache, v], dim=2) +# # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's +# # non-trivial to calculate `next_cache_start` here. +# new_cache = torch.cat((k, v), dim=-1) + +# n_batch_pos = pos_emb.size(0) +# p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) +# p = p.transpose(1, 2) # (batch, head, time1, d_k) + +# # (batch, head, time1, d_k) +# q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) +# # (batch, head, time1, d_k) +# q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + +# # compute attention score +# # first compute matrix a and matrix c +# # as described in https://arxiv.org/abs/1901.02860 Section 3.3 +# # (batch, head, time1, time2) + +# # Compute matrix ac and bd +# matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) # (batch, head, time1, time2) +# matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) # (batch, head, time1, time2) + +# batch_size, num_heads, seq_len, _ = matrix_ac.shape + +# # Create block causal mask +# block_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=self.block_size).to(matrix_ac.device).bool() +# # mask = mask.masked_fill(mask == 1, float('-inf')) # mask upper triangular matrix beyond block + +# # Apply relative shift if necessary +# if matrix_ac.shape != matrix_bd.shape: +# matrix_bd = self.rel_shift(matrix_bd) + +# # Combine ac and bd and apply the block causal mask +# scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2) +# scores = scores.masked_fill(block_mask.unsqueeze(0).unsqueeze(0), float('-inf')) # apply the block mask + +# # Forward attention +# return self.forward_attention(v, scores, mask), new_cache + + + +from cosyvoice.utils import block_mask_util +class BlockRelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, + key_bias: bool = True, block_size=25): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate, key_bias) + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + self.block_size = block_size + + def rel_shift(self, x: torch.Tensor) -> torch.Tensor: + """Compute relative positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + torch.Tensor: Output tensor. + + """ + zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1), + device=x.device, + dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(x.size()[0], + x.size()[1], + x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x)[ + :, :, :, : x.size(-1) // 2 + 1 + ] # only keep the positions from 0 to time2 + return x + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, time2, size). + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + # 0代表被mask的位置 + bs, time_len, _ = query.shape + # mask = torch.tril(torch.ones(time_len, time_len).to(mask), diagonal=0).int() + # block_size = self.block_size + # mask[:, 0:block_size] = 1 + block_mask = block_mask_util.create_grid_mask(time_len,self.block_size,fill_triangle=True).to(query).int() + block_mask = block_mask[None].repeat(bs, 1, 1) + mask=mask*block_mask + + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.size(0) > 0: + key_cache, value_cache = torch.split(cache, + cache.size(-1) // 2, + dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = torch.cat((k, v), dim=-1) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used + if matrix_ac.shape != matrix_bd.shape: + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask), new_cache diff --git a/cosyvoice/transformer/convolution.py b/cosyvoice/transformer/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5d96149154776000991a681a666fbe55e562fe --- /dev/null +++ b/cosyvoice/transformer/convolution.py @@ -0,0 +1,145 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""ConvolutionModule definition.""" + +from typing import Tuple + +import torch +from torch import nn + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model.""" + + def __init__(self, + channels: int, + kernel_size: int = 15, + activation: nn.Module = nn.ReLU(), + norm: str = "batch_norm", + causal: bool = False, + bias: bool = True): + """Construct an ConvolutionModule object. + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernel size of conv layers. + causal (int): Whether use causal convolution or not + """ + super().__init__() + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + # self.lorder is used to distinguish if it's a causal convolution, + # if self.lorder > 0: it's a causal convolution, the input will be + # padded with self.lorder frames on the left in forward. + # else: it's a symmetrical convolution + if causal: + padding = 0 + self.lorder = kernel_size - 1 + else: + # kernel_size should be an odd number for none causal convolution + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.lorder = 0 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + + assert norm in ['batch_norm', 'layer_norm'] + if norm == "batch_norm": + self.use_layer_norm = False + self.norm = nn.BatchNorm1d(channels) + else: + self.use_layer_norm = True + self.norm = nn.LayerNorm(channels) + + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = activation + + def forward( + self, + x: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + cache: torch.Tensor = torch.zeros((0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute convolution module. + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), + (0, 0, 0) means fake mask. + cache (torch.Tensor): left context cache, it is only + used in causal convolution (#batch, channels, cache_t), + (0, 0, 0) meas fake cache. + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) # (#batch, channels, time) + + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad, 0.0) + + if self.lorder > 0: + if cache.size(2) == 0: # cache_t == 0 + x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) + else: + assert cache.size(0) == x.size(0) # equal batch + assert cache.size(1) == x.size(1) # equal channel + x = torch.cat((cache, x), dim=2) + assert (x.size(2) > self.lorder) + new_cache = x[:, :, -self.lorder:] + else: + # It's better we just return None if no cache is required, + # However, for JIT export, here we just fake one tensor instead of + # None. + new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.activation(self.norm(x)) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.pointwise_conv2(x) + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad, 0.0) + + return x.transpose(1, 2), new_cache diff --git a/cosyvoice/transformer/decoder.py b/cosyvoice/transformer/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..961c875eab519f7a9e8a6e56720dc878b7852372 --- /dev/null +++ b/cosyvoice/transformer/decoder.py @@ -0,0 +1,396 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Decoder definition.""" +from typing import Tuple, List, Optional + +import torch +import torch.utils.checkpoint as ckpt +import logging + +from cosyvoice.transformer.decoder_layer import DecoderLayer +from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward +from cosyvoice.utils.class_utils import ( + COSYVOICE_EMB_CLASSES, + COSYVOICE_ATTENTION_CLASSES, + COSYVOICE_ACTIVATION_CLASSES, +) +from cosyvoice.utils.mask import (subsequent_mask, make_pad_mask) + + +class TransformerDecoder(torch.nn.Module): + """Base class of Transfomer decoder module. + Args: + vocab_size: output dim + encoder_output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the hidden units number of position-wise feedforward + num_blocks: the number of decoder blocks + dropout_rate: dropout rate + self_attention_dropout_rate: dropout rate for attention + input_layer: input layer type + use_output_layer: whether to use output layer + pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + normalize_before: + True: use layer_norm before each sub-block of a layer. + False: use layer_norm after each sub-block of a layer. + src_attention: if false, encoder-decoder cross attention is not + applied, such as CIF model + key_bias: whether use bias in attention.linear_k, False for whisper models. + gradient_checkpointing: rerunning a forward-pass segment for each + checkpointed segment during backward. + tie_word_embedding: Tie or clone module weights depending of whether we are + using TorchScript or not + """ + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = "embed", + use_output_layer: bool = True, + normalize_before: bool = True, + src_attention: bool = True, + key_bias: bool = True, + activation_type: str = "relu", + gradient_checkpointing: bool = False, + tie_word_embedding: bool = False, + ): + super().__init__() + attention_dim = encoder_output_size + activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]() + + self.embed = torch.nn.Sequential( + torch.nn.Identity() if input_layer == "no_pos" else + torch.nn.Embedding(vocab_size, attention_dim), + COSYVOICE_EMB_CLASSES[input_layer](attention_dim, + positional_dropout_rate), + ) + + self.normalize_before = normalize_before + self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5) + self.use_output_layer = use_output_layer + if use_output_layer: + self.output_layer = torch.nn.Linear(attention_dim, vocab_size) + else: + self.output_layer = torch.nn.Identity() + self.num_blocks = num_blocks + self.decoders = torch.nn.ModuleList([ + DecoderLayer( + attention_dim, + COSYVOICE_ATTENTION_CLASSES["selfattn"]( + attention_heads, attention_dim, + self_attention_dropout_rate, key_bias), + COSYVOICE_ATTENTION_CLASSES["selfattn"]( + attention_heads, attention_dim, src_attention_dropout_rate, + key_bias) if src_attention else None, + PositionwiseFeedForward(attention_dim, linear_units, + dropout_rate, activation), + dropout_rate, + normalize_before, + ) for _ in range(self.num_blocks) + ]) + + self.gradient_checkpointing = gradient_checkpointing + self.tie_word_embedding = tie_word_embedding + + def forward( + self, + memory: torch.Tensor, + memory_mask: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + r_ys_in_pad: torch.Tensor = torch.empty(0), + reverse_weight: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward decoder. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoder memory mask, (batch, 1, maxlen_in) + ys_in_pad: padded input token ids, int64 (batch, maxlen_out) + ys_in_lens: input lengths of this batch (batch) + r_ys_in_pad: not used in transformer decoder, in order to unify api + with bidirectional decoder + reverse_weight: not used in transformer decoder, in order to unify + api with bidirectional decode + Returns: + (tuple): tuple containing: + x: decoded token score before softmax (batch, maxlen_out, + vocab_size) if use_output_layer is True, + torch.tensor(0.0), in order to unify api with bidirectional decoder + olens: (batch, ) + NOTE(xcsong): + We pass the `__call__` method of the modules instead of `forward` to the + checkpointing API because `__call__` attaches all the hooks of the module. + https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + """ + tgt = ys_in_pad + maxlen = tgt.size(1) + # tgt_mask: (B, 1, L) + tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1) + tgt_mask = tgt_mask.to(tgt.device) + # m: (1, L, L) + m = subsequent_mask(tgt_mask.size(-1), + device=tgt_mask.device).unsqueeze(0) + # tgt_mask: (B, L, L) + tgt_mask = tgt_mask & m + x, _ = self.embed(tgt) + if self.gradient_checkpointing and self.training: + x = self.forward_layers_checkpointed(x, tgt_mask, memory, + memory_mask) + else: + x = self.forward_layers(x, tgt_mask, memory, memory_mask) + if self.normalize_before: + x = self.after_norm(x) + if self.use_output_layer: + x = self.output_layer(x) + olens = tgt_mask.sum(1) + return x, torch.tensor(0.0), olens + + def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor, + memory: torch.Tensor, + memory_mask: torch.Tensor) -> torch.Tensor: + for layer in self.decoders: + x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory, + memory_mask) + return x + + @torch.jit.ignore(drop=True) + def forward_layers_checkpointed(self, x: torch.Tensor, + tgt_mask: torch.Tensor, + memory: torch.Tensor, + memory_mask: torch.Tensor) -> torch.Tensor: + for layer in self.decoders: + x, tgt_mask, memory, memory_mask = ckpt.checkpoint( + layer.__call__, x, tgt_mask, memory, memory_mask) + return x + + def forward_one_step( + self, + memory: torch.Tensor, + memory_mask: torch.Tensor, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + cache: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Forward one step. + This is only used for decoding. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoded memory mask, (batch, 1, maxlen_in) + tgt: input token ids, int64 (batch, maxlen_out) + tgt_mask: input token mask, (batch, maxlen_out) + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (include 1.2) + cache: cached output list of (batch, max_time_out-1, size) + Returns: + y, cache: NN output value and cache per `self.decoders`. + y.shape` is (batch, maxlen_out, token) + """ + x, _ = self.embed(tgt) + new_cache = [] + for i, decoder in enumerate(self.decoders): + if cache is None: + c = None + else: + c = cache[i] + x, tgt_mask, memory, memory_mask = decoder(x, + tgt_mask, + memory, + memory_mask, + cache=c) + new_cache.append(x) + if self.normalize_before: + y = self.after_norm(x[:, -1]) + else: + y = x[:, -1] + if self.use_output_layer: + y = torch.log_softmax(self.output_layer(y), dim=-1) + return y, new_cache + + def tie_or_clone_weights(self, jit_mode: bool = True): + """Tie or clone module weights (between word_emb and output_layer) + depending of whether we are using TorchScript or not""" + if not self.use_output_layer: + return + if jit_mode: + logging.info("clone emb.weight to output.weight") + self.output_layer.weight = torch.nn.Parameter( + self.embed[0].weight.clone()) + else: + logging.info("tie emb.weight with output.weight") + self.output_layer.weight = self.embed[0].weight + + if getattr(self.output_layer, "bias", None) is not None: + self.output_layer.bias.data = torch.nn.functional.pad( + self.output_layer.bias.data, + ( + 0, + self.output_layer.weight.shape[0] - + self.output_layer.bias.shape[0], + ), + "constant", + 0, + ) + + +class BiTransformerDecoder(torch.nn.Module): + """Base class of Transfomer decoder module. + Args: + vocab_size: output dim + encoder_output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the hidden units number of position-wise feedforward + num_blocks: the number of decoder blocks + r_num_blocks: the number of right to left decoder blocks + dropout_rate: dropout rate + self_attention_dropout_rate: dropout rate for attention + input_layer: input layer type + use_output_layer: whether to use output layer + pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + normalize_before: + True: use layer_norm before each sub-block of a layer. + False: use layer_norm after each sub-block of a layer. + key_bias: whether use bias in attention.linear_k, False for whisper models. + """ + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + r_num_blocks: int = 0, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = "embed", + use_output_layer: bool = True, + normalize_before: bool = True, + key_bias: bool = True, + gradient_checkpointing: bool = False, + tie_word_embedding: bool = False, + ): + + super().__init__() + self.tie_word_embedding = tie_word_embedding + self.left_decoder = TransformerDecoder( + vocab_size, + encoder_output_size, + attention_heads, + linear_units, + num_blocks, + dropout_rate, + positional_dropout_rate, + self_attention_dropout_rate, + src_attention_dropout_rate, + input_layer, + use_output_layer, + normalize_before, + key_bias=key_bias, + gradient_checkpointing=gradient_checkpointing, + tie_word_embedding=tie_word_embedding) + + self.right_decoder = TransformerDecoder( + vocab_size, + encoder_output_size, + attention_heads, + linear_units, + r_num_blocks, + dropout_rate, + positional_dropout_rate, + self_attention_dropout_rate, + src_attention_dropout_rate, + input_layer, + use_output_layer, + normalize_before, + key_bias=key_bias, + gradient_checkpointing=gradient_checkpointing, + tie_word_embedding=tie_word_embedding) + + def forward( + self, + memory: torch.Tensor, + memory_mask: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + r_ys_in_pad: torch.Tensor, + reverse_weight: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward decoder. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoder memory mask, (batch, 1, maxlen_in) + ys_in_pad: padded input token ids, int64 (batch, maxlen_out) + ys_in_lens: input lengths of this batch (batch) + r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out), + used for right to left decoder + reverse_weight: used for right to left decoder + Returns: + (tuple): tuple containing: + x: decoded token score before softmax (batch, maxlen_out, + vocab_size) if use_output_layer is True, + r_x: x: decoded token score (right to left decoder) + before softmax (batch, maxlen_out, vocab_size) + if use_output_layer is True, + olens: (batch, ) + """ + l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad, + ys_in_lens) + r_x = torch.tensor(0.0) + if reverse_weight > 0.0: + r_x, _, olens = self.right_decoder(memory, memory_mask, + r_ys_in_pad, ys_in_lens) + return l_x, r_x, olens + + def forward_one_step( + self, + memory: torch.Tensor, + memory_mask: torch.Tensor, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + cache: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Forward one step. + This is only used for decoding. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoded memory mask, (batch, 1, maxlen_in) + tgt: input token ids, int64 (batch, maxlen_out) + tgt_mask: input token mask, (batch, maxlen_out) + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (include 1.2) + cache: cached output list of (batch, max_time_out-1, size) + Returns: + y, cache: NN output value and cache per `self.decoders`. + y.shape` is (batch, maxlen_out, token) + """ + return self.left_decoder.forward_one_step(memory, memory_mask, tgt, + tgt_mask, cache) + + def tie_or_clone_weights(self, jit_mode: bool = True): + """Tie or clone module weights (between word_emb and output_layer) + depending of whether we are using TorchScript or not""" + self.left_decoder.tie_or_clone_weights(jit_mode) + self.right_decoder.tie_or_clone_weights(jit_mode) diff --git a/cosyvoice/transformer/decoder_layer.py b/cosyvoice/transformer/decoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..91c7c5d7fb2a8e79cea7705646e5381016f73466 --- /dev/null +++ b/cosyvoice/transformer/decoder_layer.py @@ -0,0 +1,132 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Decoder self-attention layer definition.""" +from typing import Optional, Tuple + +import torch +from torch import nn + + +class DecoderLayer(nn.Module): + """Single decoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + src_attn (torch.nn.Module): Inter-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + If `None` is passed, Inter-attention is not used, such as + CIF, GPT, and other decoder only model. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: to use layer_norm after each sub-block. + """ + + def __init__( + self, + size: int, + self_attn: nn.Module, + src_attn: Optional[nn.Module], + feed_forward: nn.Module, + dropout_rate: float, + normalize_before: bool = True, + ): + """Construct an DecoderLayer object.""" + super().__init__() + self.size = size + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.norm1 = nn.LayerNorm(size, eps=1e-5) + self.norm2 = nn.LayerNorm(size, eps=1e-5) + self.norm3 = nn.LayerNorm(size, eps=1e-5) + self.dropout = nn.Dropout(dropout_rate) + self.normalize_before = normalize_before + + def forward( + self, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + memory: torch.Tensor, + memory_mask: torch.Tensor, + cache: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute decoded features. + + Args: + tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). + tgt_mask (torch.Tensor): Mask for input tensor + (#batch, maxlen_out). + memory (torch.Tensor): Encoded memory + (#batch, maxlen_in, size). + memory_mask (torch.Tensor): Encoded memory mask + (#batch, maxlen_in). + cache (torch.Tensor): cached tensors. + (#batch, maxlen_out - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, maxlen_out, size). + torch.Tensor: Mask for output tensor (#batch, maxlen_out). + torch.Tensor: Encoded memory (#batch, maxlen_in, size). + torch.Tensor: Encoded memory mask (#batch, maxlen_in). + + """ + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + + if cache is None: + tgt_q = tgt + tgt_q_mask = tgt_mask + else: + # compute only the last frame query keeping dim: max_time_out -> 1 + assert cache.shape == ( + tgt.shape[0], + tgt.shape[1] - 1, + self.size, + ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" + tgt_q = tgt[:, -1:, :] + residual = residual[:, -1:, :] + tgt_q_mask = tgt_mask[:, -1:, :] + + x = residual + self.dropout( + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) + if not self.normalize_before: + x = self.norm1(x) + + if self.src_attn is not None: + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + self.dropout( + self.src_attn(x, memory, memory, memory_mask)[0]) + if not self.normalize_before: + x = self.norm2(x) + + residual = x + if self.normalize_before: + x = self.norm3(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm3(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + return x, tgt_mask, memory, memory_mask diff --git a/cosyvoice/transformer/embedding.py b/cosyvoice/transformer/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..46130a503f72f103e09d3392077ed352368ce54f --- /dev/null +++ b/cosyvoice/transformer/embedding.py @@ -0,0 +1,293 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Positonal Encoding Module.""" + +import math +from typing import Tuple, Union + +import torch +import torch.nn.functional as F +import numpy as np + + +class PositionalEncoding(torch.nn.Module): + """Positional encoding. + + :param int d_model: embedding dim + :param float dropout_rate: dropout rate + :param int max_len: maximum input length + + PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) + PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) + """ + + def __init__(self, + d_model: int, + dropout_rate: float, + max_len: int = 5000, + reverse: bool = False): + """Construct an PositionalEncoding object.""" + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.max_len = max_len + + self.pe = torch.zeros(self.max_len, self.d_model) + position = torch.arange(0, self.max_len, + dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) * + -(math.log(10000.0) / self.d_model)) + self.pe[:, 0::2] = torch.sin(position * div_term) + self.pe[:, 1::2] = torch.cos(position * div_term) + self.pe = self.pe.unsqueeze(0) + + def forward(self, + x: torch.Tensor, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + offset (int, torch.tensor): position offset + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + torch.Tensor: for compatibility to RelPositionalEncoding + """ + + self.pe = self.pe.to(x.device) + pos_emb = self.position_encoding(offset, x.size(1), False) + x = x * self.xscale + pos_emb + return self.dropout(x), self.dropout(pos_emb) + + def position_encoding(self, + offset: Union[int, torch.Tensor], + size: int, + apply_dropout: bool = True) -> torch.Tensor: + """ For getting encoding in a streaming fashion + + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. + + Args: + offset (int or torch.tensor): start offset + size (int): required size of position encoding + + Returns: + torch.Tensor: Corresponding encoding + """ + # How to subscript a Union type: + # https://github.com/pytorch/pytorch/issues/69434 + if isinstance(offset, int): + assert offset + size <= self.max_len + pos_emb = self.pe[:, offset:offset + size] + elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar + assert offset + size <= self.max_len + pos_emb = self.pe[:, offset:offset + size] + else: # for batched streaming decoding on GPU + assert torch.max(offset) + size <= self.max_len + index = offset.unsqueeze(1) + \ + torch.arange(0, size).to(offset.device) # B X T + flag = index > 0 + # remove negative offset + index = index * flag + pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model + + if apply_dropout: + pos_emb = self.dropout(pos_emb) + return pos_emb + + +class RelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module. + See : Appendix B in https://arxiv.org/abs/1901.02860 + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): + """Initialize class.""" + super().__init__(d_model, dropout_rate, max_len, reverse=True) + + def forward(self, + x: torch.Tensor, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """Compute positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Positional embedding tensor (1, time, `*`). + """ + self.pe = self.pe.to(x.device) + x = x * self.xscale + pos_emb = self.position_encoding(offset, x.size(1), False) + return self.dropout(x), self.dropout(pos_emb) + + +class WhisperPositionalEncoding(PositionalEncoding): + """ Sinusoids position encoding used in openai-whisper.encoder + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500): + super().__init__(d_model, dropout_rate, max_len) + self.xscale = 1.0 + log_timescale_increment = np.log(10000) / (d_model // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * + torch.arange(d_model // 2)) + scaled_time = torch.arange(max_len)[:, np.newaxis] * \ + inv_timescales[np.newaxis, :] + pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + delattr(self, "pe") + self.register_buffer("pe", pe.unsqueeze(0)) + + +class LearnablePositionalEncoding(PositionalEncoding): + """ Learnable position encoding used in openai-whisper.decoder + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448): + super().__init__(d_model, dropout_rate, max_len) + # NOTE(xcsong): overwrite self.pe & self.xscale + self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model)) + self.xscale = 1.0 + + +class NoPositionalEncoding(torch.nn.Module): + """ No position encoding + """ + + def __init__(self, d_model: int, dropout_rate: float): + super().__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + + def forward(self, + x: torch.Tensor, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """ Just return zero vector for interface compatibility + """ + pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device) + return self.dropout(x), pos_emb + + def position_encoding(self, offset: Union[int, torch.Tensor], + size: int) -> torch.Tensor: + return torch.zeros(1, size, self.d_model) + + +class EspnetRelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module (new implementation). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + See : Appendix B in https://arxiv.org/abs/1901.02860 + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Construct an PositionalEncoding object.""" + super(EspnetRelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i torch.Tensor: + """ For getting encoding in a streaming fashion + + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. + + Args: + offset (int or torch.tensor): start offset + size (int): required size of position encoding + + Returns: + torch.Tensor: Corresponding encoding + """ + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size, + ] + return pos_emb diff --git a/cosyvoice/transformer/encoder.py b/cosyvoice/transformer/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f5e98c683bd2e76b51fda7acdca355348d072d58 --- /dev/null +++ b/cosyvoice/transformer/encoder.py @@ -0,0 +1,567 @@ +# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Encoder definition.""" +from typing import Tuple + +import torch +import torch.utils.checkpoint as ckpt + +from cosyvoice.transformer.convolution import ConvolutionModule +from cosyvoice.transformer.encoder_layer import TransformerEncoderLayer +from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer +from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward +from cosyvoice.utils.class_utils import ( + COSYVOICE_EMB_CLASSES, + COSYVOICE_SUBSAMPLE_CLASSES, + COSYVOICE_ATTENTION_CLASSES, + COSYVOICE_ACTIVATION_CLASSES, +) +from cosyvoice.utils.mask import make_pad_mask +from cosyvoice.utils.mask import add_optional_chunk_mask + + +class BaseEncoder(torch.nn.Module): + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "abs_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + gradient_checkpointing: bool = False, + ): + """ + Args: + input_size (int): input dim + output_size (int): dimension of attention + attention_heads (int): the number of heads of multi head attention + linear_units (int): the hidden units number of position-wise feed + forward + num_blocks (int): the number of decoder blocks + dropout_rate (float): dropout rate + attention_dropout_rate (float): dropout rate in attention + positional_dropout_rate (float): dropout rate after adding + positional encoding + input_layer (str): input layer type. + optional [linear, conv2d, conv2d6, conv2d8] + pos_enc_layer_type (str): Encoder positional encoding layer type. + opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos] + normalize_before (bool): + True: use layer_norm before each sub-block of a layer. + False: use layer_norm after each sub-block of a layer. + static_chunk_size (int): chunk size for static chunk training and + decoding + use_dynamic_chunk (bool): whether use dynamic chunk size for + training or not, You can only use fixed chunk(chunk_size > 0) + or dyanmic chunk size(use_dynamic_chunk = True) + global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module + use_dynamic_left_chunk (bool): whether use dynamic left chunk in + dynamic chunk training + key_bias: whether use bias in attention.linear_k, False for whisper models. + gradient_checkpointing: rerunning a forward-pass segment for each + checkpointed segment during backward. + """ + super().__init__() + self._output_size = output_size + + self.global_cmvn = global_cmvn + self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer]( + input_size, + output_size, + dropout_rate, + COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size, + positional_dropout_rate), + ) + + self.normalize_before = normalize_before + self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5) + self.static_chunk_size = static_chunk_size + self.use_dynamic_chunk = use_dynamic_chunk + self.use_dynamic_left_chunk = use_dynamic_left_chunk + self.gradient_checkpointing = gradient_checkpointing + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Embed positions in tensor. + + Args: + xs: padded input tensor (B, T, D) + xs_lens: input length (B) + decoding_chunk_size: decoding chunk size for dynamic chunk + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + Returns: + encoder output tensor xs, and subsampled masks + xs: padded output tensor (B, T' ~= T/subsample_rate, D) + masks: torch.Tensor batch padding mask after subsample + (B, 1, T' ~= T/subsample_rate) + NOTE(xcsong): + We pass the `__call__` method of the modules instead of `forward` to the + checkpointing API because `__call__` attaches all the hooks of the module. + https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + """ + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + xs, pos_emb, masks = self.embed(xs, masks) + mask_pad = masks # (B, 1, T/subsample_rate) + chunk_masks = add_optional_chunk_mask(xs, masks, + self.use_dynamic_chunk, + self.use_dynamic_left_chunk, + decoding_chunk_size, + self.static_chunk_size, + num_decoding_left_chunks) + if self.gradient_checkpointing and self.training: + xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb, + mask_pad) + else: + xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) + if self.normalize_before: + xs = self.after_norm(xs) + # Here we assume the mask is not changed in encoder layers, so just + # return the masks before encoder layers, and the masks will be used + # for cross attention with decoder later + return xs, masks + + def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor) -> torch.Tensor: + for layer in self.encoders: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + return xs + + @torch.jit.ignore(drop=True) + def forward_layers_checkpointed(self, xs: torch.Tensor, + chunk_masks: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor) -> torch.Tensor: + for layer in self.encoders: + xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs, + chunk_masks, pos_emb, + mask_pad) + return xs + + def forward_chunk( + self, + xs: torch.Tensor, + offset: int, + required_cache_size: int, + att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ Forward just one chunk + + Args: + xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + offset (int): current offset in encoder output time stamp + required_cache_size (int): cache size required for next chunk + compuation + >=0: actual cache size + <0: means all history cache is required + att_cache (torch.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, + (elayers, b=1, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + + Returns: + torch.Tensor: output of current input xs, + with shape (b=1, chunk_size, hidden-dim). + torch.Tensor: new attention cache required for next chunk, with + dynamic shape (elayers, head, ?, d_k * 2) + depending on required_cache_size. + torch.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. + + """ + assert xs.size(0) == 1 + # tmp_masks is just for interface compatibility + tmp_masks = torch.ones(1, + xs.size(1), + device=xs.device, + dtype=torch.bool) + tmp_masks = tmp_masks.unsqueeze(1) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) + xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) + # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim) + elayers, cache_t1 = att_cache.size(0), att_cache.size(2) + chunk_size = xs.size(1) + attention_key_size = cache_t1 + chunk_size + pos_emb = self.embed.position_encoding(offset=offset - cache_t1, + size=attention_key_size) + if required_cache_size < 0: + next_cache_start = 0 + elif required_cache_size == 0: + next_cache_start = attention_key_size + else: + next_cache_start = max(attention_key_size - required_cache_size, 0) + r_att_cache = [] + r_cnn_cache = [] + for i, layer in enumerate(self.encoders): + # NOTE(xcsong): Before layer.forward + # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), + # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) + xs, _, new_att_cache, new_cnn_cache = layer( + xs, + att_mask, + pos_emb, + att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, + cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache) + # NOTE(xcsong): After layer.forward + # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), + # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) + r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) + r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) + if self.normalize_before: + xs = self.after_norm(xs) + + # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2), + # ? may be larger than cache_t1, it depends on required_cache_size + r_att_cache = torch.cat(r_att_cache, dim=0) + # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2) + r_cnn_cache = torch.cat(r_cnn_cache, dim=0) + + return (xs, r_att_cache, r_cnn_cache) + + def forward_chunk_by_chunk( + self, + xs: torch.Tensor, + decoding_chunk_size: int, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ Forward input chunk by chunk with chunk_size like a streaming + fashion + + Here we should pay special attention to computation cache in the + streaming style forward chunk by chunk. Three things should be taken + into account for computation in the current network: + 1. transformer/conformer encoder layers output cache + 2. convolution in conformer + 3. convolution in subsampling + + However, we don't implement subsampling cache for: + 1. We can control subsampling module to output the right result by + overlapping input instead of cache left context, even though it + wastes some computation, but subsampling only takes a very + small fraction of computation in the whole model. + 2. Typically, there are several covolution layers with subsampling + in subsampling module, it is tricky and complicated to do cache + with different convolution layers with different subsampling + rate. + 3. Currently, nn.Sequential is used to stack all the convolution + layers in subsampling, we need to rewrite it to make it work + with cache, which is not prefered. + Args: + xs (torch.Tensor): (1, max_len, dim) + chunk_size (int): decoding chunk size + """ + assert decoding_chunk_size > 0 + # The model is trained by static or dynamic chunk + assert self.static_chunk_size > 0 or self.use_dynamic_chunk + subsampling = self.embed.subsampling_rate + context = self.embed.right_context + 1 # Add current frame + stride = subsampling * decoding_chunk_size + decoding_window = (decoding_chunk_size - 1) * subsampling + context + num_frames = xs.size(1) + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) + outputs = [] + offset = 0 + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + + # Feed forward overlap input step by step + for cur in range(0, num_frames - context + 1, stride): + end = min(cur + decoding_window, num_frames) + chunk_xs = xs[:, cur:end, :] + (y, att_cache, + cnn_cache) = self.forward_chunk(chunk_xs, offset, + required_cache_size, att_cache, + cnn_cache) + outputs.append(y) + offset += y.size(1) + ys = torch.cat(outputs, 1) + masks = torch.ones((1, 1, ys.size(1)), + device=ys.device, + dtype=torch.bool) + return ys, masks + + +class TransformerEncoder(BaseEncoder): + """Transformer encoder module.""" + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "abs_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + key_bias: bool = True, + selfattention_layer_type: str = "selfattn", + activation_type: str = "relu", + gradient_checkpointing: bool = False, + ): + """ Construct TransformerEncoder + + See Encoder for the meaning of each parameter. + """ + super().__init__(input_size, output_size, attention_heads, + linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, gradient_checkpointing) + activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]() + self.encoders = torch.nn.ModuleList([ + TransformerEncoderLayer( + output_size, + COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](attention_heads, + output_size, + attention_dropout_rate, + key_bias), + PositionwiseFeedForward(output_size, linear_units, + dropout_rate, activation), + dropout_rate, normalize_before) for _ in range(num_blocks) + ]) + + +class ConformerEncoder(BaseEncoder): + """Conformer encoder module.""" + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "rel_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + positionwise_conv_kernel_size: int = 1, + macaron_style: bool = True, + selfattention_layer_type: str = "rel_selfattn", + activation_type: str = "swish", + use_cnn_module: bool = True, + cnn_module_kernel: int = 15, + causal: bool = False, + cnn_module_norm: str = "batch_norm", + key_bias: bool = True, + gradient_checkpointing: bool = False, + ): + """Construct ConformerEncoder + + Args: + input_size to use_dynamic_chunk, see in BaseEncoder + positionwise_conv_kernel_size (int): Kernel size of positionwise + conv1d layer. + macaron_style (bool): Whether to use macaron style for + positionwise layer. + selfattention_layer_type (str): Encoder attention layer type, + the parameter has no effect now, it's just for configure + compatibility. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernel size of convolution module. + causal (bool): whether to use causal convolution or not. + key_bias: whether use bias in attention.linear_k, False for whisper models. + """ + super().__init__(input_size, output_size, attention_heads, + linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, gradient_checkpointing) + activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]() + + # self-attention module definition + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + key_bias, + ) + # feed-forward module definition + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + ) + # convolution module definition + convolution_layer_args = (output_size, cnn_module_kernel, activation, + cnn_module_norm, causal) + + self.encoders = torch.nn.ModuleList([ + ConformerEncoderLayer( + output_size, + COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type]( + *encoder_selfattn_layer_args), + PositionwiseFeedForward(*positionwise_layer_args), + PositionwiseFeedForward( + *positionwise_layer_args) if macaron_style else None, + ConvolutionModule( + *convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + ) for _ in range(num_blocks) + ]) + + + + +class BlockConformerEncoder(BaseEncoder): + """Conformer encoder module.""" + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "rel_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + positionwise_conv_kernel_size: int = 1, + macaron_style: bool = True, + selfattention_layer_type: str = "rel_selfattn", + activation_type: str = "swish", + use_cnn_module: bool = True, + cnn_module_kernel: int = 15, + causal: bool = False, + cnn_module_norm: str = "batch_norm", + key_bias: bool = True, + gradient_checkpointing: bool = False, + block_size=25, + ): + """Construct ConformerEncoder + + Args: + input_size to use_dynamic_chunk, see in BaseEncoder + positionwise_conv_kernel_size (int): Kernel size of positionwise + conv1d layer. + macaron_style (bool): Whether to use macaron style for + positionwise layer. + selfattention_layer_type (str): Encoder attention layer type, + the parameter has no effect now, it's just for configure + compatibility. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernel size of convolution module. + causal (bool): whether to use causal convolution or not. + key_bias: whether use bias in attention.linear_k, False for whisper models. + """ + super().__init__(input_size, output_size, attention_heads, + linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, gradient_checkpointing) + activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]() + + # self-attention module definition + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + key_bias, + block_size, + ) + # feed-forward module definition + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + ) + # convolution module definition + convolution_layer_args = (output_size, cnn_module_kernel, activation, + cnn_module_norm, causal) + + self.encoders = torch.nn.ModuleList([ + ConformerEncoderLayer( + output_size, + COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type]( + *encoder_selfattn_layer_args), + PositionwiseFeedForward(*positionwise_layer_args), + PositionwiseFeedForward( + *positionwise_layer_args) if macaron_style else None, + ConvolutionModule( + *convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + ) for _ in range(num_blocks) + ]) + self.block_size=block_size diff --git a/cosyvoice/transformer/encoder_layer.py b/cosyvoice/transformer/encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd758bc1cc7780aa4f6a322a264c879b74a6cfe --- /dev/null +++ b/cosyvoice/transformer/encoder_layer.py @@ -0,0 +1,236 @@ +# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Encoder self-attention layer definition.""" + +from typing import Optional, Tuple + +import torch +from torch import nn + + +class TransformerEncoderLayer(nn.Module): + """Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: to use layer_norm after each sub-block. + """ + + def __init__( + self, + size: int, + self_attn: torch.nn.Module, + feed_forward: torch.nn.Module, + dropout_rate: float, + normalize_before: bool = True, + ): + """Construct an EncoderLayer object.""" + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = nn.LayerNorm(size, eps=1e-5) + self.norm2 = nn.LayerNorm(size, eps=1e-5) + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time,time), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): just for interface compatibility + to ConformerEncoderLayer + mask_pad (torch.Tensor): does not used in transformer layer, + just for unified api with conformer. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2), not used here, it's for interface + compatibility to ConformerEncoderLayer. + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2). + + """ + residual = x + if self.normalize_before: + x = self.norm1(x) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache) + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + return x, mask, new_att_cache, fake_cnn_cache + + +class ConformerEncoderLayer(nn.Module): + """Encoder layer module. + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + feed_forward_macaron (torch.nn.Module): Additional feed-forward module + instance. + `PositionwiseFeedForward` instance can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: use layer_norm after each sub-block. + """ + + def __init__( + self, + size: int, + self_attn: torch.nn.Module, + feed_forward: Optional[nn.Module] = None, + feed_forward_macaron: Optional[nn.Module] = None, + conv_module: Optional[nn.Module] = None, + dropout_rate: float = 0.1, + normalize_before: bool = True, + ): + """Construct an EncoderLayer object.""" + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module + self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module + if feed_forward_macaron is not None: + self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module + self.norm_final = nn.LayerNorm( + size, eps=1e-5) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time,time), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): positional encoding, must not be None + for ConformerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2) + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, + att_cache) + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + + # convolution module + # Fake new cnn cache here, and then change it in conv_module + new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) + x = residual + self.dropout(x) + + if not self.normalize_before: + x = self.norm_conv(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + return x, mask, new_att_cache, new_cnn_cache diff --git a/cosyvoice/transformer/label_smoothing_loss.py b/cosyvoice/transformer/label_smoothing_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..feacabf09609ee6eb047c89ce18d372256c72c71 --- /dev/null +++ b/cosyvoice/transformer/label_smoothing_loss.py @@ -0,0 +1,96 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Label smoothing module.""" + +import torch +from torch import nn + + +class LabelSmoothingLoss(nn.Module): + """Label-smoothing loss. + + In a standard CE loss, the label's data distribution is: + [0,1,2] -> + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + + In the smoothing version CE Loss,some probabilities + are taken from the true label prob (1.0) and are divided + among other labels. + + e.g. + smoothing=0.1 + [0,1,2] -> + [ + [0.9, 0.05, 0.05], + [0.05, 0.9, 0.05], + [0.05, 0.05, 0.9], + ] + + Args: + size (int): the number of class + padding_idx (int): padding class id which will be ignored for loss + smoothing (float): smoothing rate (0.0 means the conventional CE) + normalize_length (bool): + normalize loss by sequence length if True + normalize loss by batch size if False + """ + + def __init__(self, + size: int, + padding_idx: int, + smoothing: float, + normalize_length: bool = False): + """Construct an LabelSmoothingLoss object.""" + super(LabelSmoothingLoss, self).__init__() + self.criterion = nn.KLDivLoss(reduction="none") + self.padding_idx = padding_idx + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.size = size + self.normalize_length = normalize_length + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Compute loss between x and target. + + The model outputs and data labels tensors are flatten to + (batch*seqlen, class) shape and a mask is applied to the + padding part which should not be calculated for loss. + + Args: + x (torch.Tensor): prediction (batch, seqlen, class) + target (torch.Tensor): + target signal masked with self.padding_id (batch, seqlen) + Returns: + loss (torch.Tensor) : The KL loss, scalar float value + """ + assert x.size(2) == self.size + batch_size = x.size(0) + x = x.view(-1, self.size) + target = target.view(-1) + # use zeros_like instead of torch.no_grad() for true_dist, + # since no_grad() can not be exported by JIT + true_dist = torch.zeros_like(x) + true_dist.fill_(self.smoothing / (self.size - 1)) + ignore = target == self.padding_idx # (B,) + total = len(target) - ignore.sum().item() + target = target.masked_fill(ignore, 0) # avoid -1 index + true_dist.scatter_(1, target.unsqueeze(1), self.confidence) + kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) + denom = total if self.normalize_length else batch_size + return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom diff --git a/cosyvoice/transformer/positionwise_feed_forward.py b/cosyvoice/transformer/positionwise_feed_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..b7a2cf6e7315e3a5ed2794423daff0a59cc5b208 --- /dev/null +++ b/cosyvoice/transformer/positionwise_feed_forward.py @@ -0,0 +1,115 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Positionwise feed forward layer definition.""" + +import torch + + +class PositionwiseFeedForward(torch.nn.Module): + """Positionwise feed forward layer. + + FeedForward are appied on each position of the sequence. + The output dim is same with the input dim. + + Args: + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + activation (torch.nn.Module): Activation function + """ + + def __init__( + self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU(), + ): + """Construct a PositionwiseFeedForward object.""" + super(PositionwiseFeedForward, self).__init__() + self.w_1 = torch.nn.Linear(idim, hidden_units) + self.activation = activation + self.dropout = torch.nn.Dropout(dropout_rate) + self.w_2 = torch.nn.Linear(hidden_units, idim) + + def forward(self, xs: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + xs: input tensor (B, L, D) + Returns: + output tensor, (B, L, D) + """ + return self.w_2(self.dropout(self.activation(self.w_1(xs)))) + + +class MoEFFNLayer(torch.nn.Module): + """ + Mixture of expert with Positionwise feed forward layer + See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf + The output dim is same with the input dim. + + Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 + https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 + Args: + n_expert: number of expert. + n_expert_per_token: The actual number of experts used for each frame + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + activation (torch.nn.Module): Activation function + """ + + def __init__( + self, + n_expert: int, + n_expert_per_token: int, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU(), + ): + super(MoEFFNLayer, self).__init__() + self.gate = torch.nn.Linear(idim, n_expert, bias=False) + self.experts = torch.nn.ModuleList( + PositionwiseFeedForward(idim, hidden_units, dropout_rate, + activation) for _ in range(n_expert)) + self.n_expert_per_token = n_expert_per_token + + def forward(self, xs: torch.Tensor) -> torch.Tensor: + """Foward function. + Args: + xs: input tensor (B, L, D) + Returns: + output tensor, (B, L, D) + + """ + B, L, D = xs.size( + ) # batch size, sequence length, embedding dimension (idim) + xs = xs.view(-1, D) # (B*L, D) + router = self.gate(xs) # (B*L, n_expert) + logits, indices = torch.topk( + router, self.n_expert_per_token + ) # probs:(B*L, n_expert), indices: (B*L, n_expert) + weights = torch.nn.functional.softmax( + logits, dim=1, + dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) + output = torch.zeros_like(xs) # (B*L, D) + for i, expert in enumerate(self.experts): + mask = indices == i + batch_idx, ith_expert = torch.where(mask) + output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( + xs[batch_idx]) + return output.view(B, L, D) diff --git a/cosyvoice/transformer/subsampling.py b/cosyvoice/transformer/subsampling.py new file mode 100644 index 0000000000000000000000000000000000000000..e17c2e324e3afb24e1b619effe29cef07c9c5b3a --- /dev/null +++ b/cosyvoice/transformer/subsampling.py @@ -0,0 +1,383 @@ +# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Subsampling layer definition.""" + +from typing import Tuple, Union + +import torch + + +class BaseSubsampling(torch.nn.Module): + + def __init__(self): + super().__init__() + self.right_context = 0 + self.subsampling_rate = 1 + + def position_encoding(self, offset: Union[int, torch.Tensor], + size: int) -> torch.Tensor: + return self.pos_enc.position_encoding(offset, size) + + +class EmbedinigNoSubsampling(BaseSubsampling): + """Embedding input without subsampling + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + super().__init__() + self.embed = torch.nn.Embedding(idim, odim) + self.pos_enc = pos_enc_class + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Input x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + torch.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + + """ + x = self.embed(x) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask + + +class LinearNoSubsampling(BaseSubsampling): + """Linear transform the input without subsampling + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an linear object.""" + super().__init__() + self.out = torch.nn.Sequential( + torch.nn.Linear(idim, odim), + torch.nn.LayerNorm(odim, eps=1e-5), + torch.nn.Dropout(dropout_rate), + ) + self.pos_enc = pos_enc_class + self.right_context = 0 + self.subsampling_rate = 1 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Input x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + torch.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + + """ + x = self.out(x) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask + + +class Conv1dSubsampling2(BaseSubsampling): + """Convolutional 1D subsampling (to 1/2 length). + It is designed for Whisper, ref: + https://github.com/openai/whisper/blob/main/whisper/model.py + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv1dSubsampling2 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1), + torch.nn.GELU(), + torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1), + torch.nn.GELU(), + ) + self.pos_enc = pos_enc_class + # The right context for every conv layer is computed by: + # (kernel_size - 1) * frame_rate_of_this_layer + self.subsampling_rate = 2 + # 4 = (3 - 1) * 1 + (3 - 1) * 1 + self.right_context = 4 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 2. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 2. + torch.Tensor: positional encoding + + """ + time = x.size(1) + x = x.transpose(1, 2) # (b, f, t) + x = self.conv(x) + x = x.transpose(1, 2) # (b, t, f) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, (time + 1) % 2::2] + + +class Conv2dSubsampling4(BaseSubsampling): + """Convolutional 2D subsampling (to 1/4 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling4 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) + self.pos_enc = pos_enc_class + # The right context for every conv layer is computed by: + # (kernel_size - 1) * frame_rate_of_this_layer + self.subsampling_rate = 4 + # 6 = (3 - 1) * 1 + (3 - 1) * 2 + self.right_context = 6 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 4. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 4. + torch.Tensor: positional encoding + + """ + x = x.unsqueeze(1) # (b, c=1, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2] + + +class Conv2dSubsampling6(BaseSubsampling): + """Convolutional 2D subsampling (to 1/6 length). + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc (torch.nn.Module): Custom position encoding layer. + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling6 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 5, 3), + torch.nn.ReLU(), + ) + self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), + odim) + self.pos_enc = pos_enc_class + # 10 = (3 - 1) * 1 + (5 - 1) * 2 + self.subsampling_rate = 6 + self.right_context = 10 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 6. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 6. + torch.Tensor: positional encoding + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3] + + +class Conv2dSubsampling8(BaseSubsampling): + """Convolutional 2D subsampling (to 1/8 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling8 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.linear = torch.nn.Linear( + odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim) + self.pos_enc = pos_enc_class + self.subsampling_rate = 8 + # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4 + self.right_context = 14 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 8. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 8. + torch.Tensor: positional encoding + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2] + + +class LegacyLinearNoSubsampling(BaseSubsampling): + """Linear transform the input without subsampling + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an linear object.""" + super().__init__() + self.out = torch.nn.Sequential( + torch.nn.Linear(idim, odim), + torch.nn.LayerNorm(odim, eps=1e-5), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + ) + self.pos_enc = pos_enc_class + self.right_context = 0 + self.subsampling_rate = 1 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Input x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + torch.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + + """ + x = self.out(x) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask diff --git a/cosyvoice/utils/__init__.py b/cosyvoice/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosyvoice/utils/__pycache__/__init__.cpython-310.pyc b/cosyvoice/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7507965aacaa240a736f712d57b2e88c8597a2ca Binary files /dev/null and b/cosyvoice/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/cosyvoice/utils/__pycache__/block_mask_util.cpython-310.pyc b/cosyvoice/utils/__pycache__/block_mask_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..170e852125a9d13d61a0b7f856c2108996fab5cc Binary files /dev/null and b/cosyvoice/utils/__pycache__/block_mask_util.cpython-310.pyc differ diff --git a/cosyvoice/utils/__pycache__/class_utils.cpython-310.pyc b/cosyvoice/utils/__pycache__/class_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4046793e01797034d0f97aaaf14f029e6ceb870f Binary files /dev/null and b/cosyvoice/utils/__pycache__/class_utils.cpython-310.pyc differ diff --git a/cosyvoice/utils/__pycache__/common.cpython-310.pyc b/cosyvoice/utils/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7c03d916f917c2a35d4fde60c722b358470808a Binary files /dev/null and b/cosyvoice/utils/__pycache__/common.cpython-310.pyc differ diff --git a/cosyvoice/utils/__pycache__/mask.cpython-310.pyc b/cosyvoice/utils/__pycache__/mask.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63b8ef9aae859ac1e6498a6c308caa7e6a640f11 Binary files /dev/null and b/cosyvoice/utils/__pycache__/mask.cpython-310.pyc differ diff --git a/cosyvoice/utils/block_mask_util.py b/cosyvoice/utils/block_mask_util.py new file mode 100644 index 0000000000000000000000000000000000000000..58e22b051e2ed91fbe51f3287ae09ac31fee65da --- /dev/null +++ b/cosyvoice/utils/block_mask_util.py @@ -0,0 +1,34 @@ +import torch + + +def create_grid_mask(seq_length, trunck_length, fill_triangle): + assert seq_length > 0 + + # 先不考虑seen_length创建一个grid mask: + if fill_triangle: + mask = 1 - torch.triu(torch.ones(seq_length, seq_length), diagonal=1) + # 下三角与主对角线都为1 + else: + mask = torch.zeros(seq_length, seq_length) + + for i in range(seq_length): + trunck_idx = i // trunck_length + trunck_start = trunck_idx * trunck_length + trunck_end = trunck_length + trunck_start + mask[i][trunck_start:trunck_end] = 1 + + return mask + + +if __name__ == "__main__": + mask = create_grid_mask(seq_length=8, trunck_length=3, fill_triangle=True).int() + print(mask) +# tensor([[1, 1, 1, 0, 0, 0, 0, 0], +# [1, 1, 1, 0, 0, 0, 0, 0], +# [1, 1, 1, 0, 0, 0, 0, 0], +# [1, 1, 1, 1, 1, 1, 0, 0], +# [1, 1, 1, 1, 1, 1, 0, 0], +# [1, 1, 1, 1, 1, 1, 0, 0], +# [1, 1, 1, 1, 1, 1, 1, 1], +# [1, 1, 1, 1, 1, 1, 1, 1]] + diff --git a/cosyvoice/utils/class_utils.py b/cosyvoice/utils/class_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f2250c792cbf5f2cdf2705500fcdd22c5615a133 --- /dev/null +++ b/cosyvoice/utils/class_utils.py @@ -0,0 +1,72 @@ +# Copyright [2023-11-28] +# 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from cosyvoice.transformer.activation import Swish +from cosyvoice.transformer.subsampling import ( + LinearNoSubsampling, + EmbedinigNoSubsampling, + Conv1dSubsampling2, + Conv2dSubsampling4, + Conv2dSubsampling6, + Conv2dSubsampling8, +) +from cosyvoice.transformer.embedding import (PositionalEncoding, + RelPositionalEncoding, + WhisperPositionalEncoding, + LearnablePositionalEncoding, + NoPositionalEncoding) +from cosyvoice.transformer.attention import (MultiHeadedAttention, + RelPositionMultiHeadedAttention, + BlockRelPositionMultiHeadedAttention) +from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding +from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling + + +COSYVOICE_ACTIVATION_CLASSES = { + "hardtanh": torch.nn.Hardtanh, + "tanh": torch.nn.Tanh, + "relu": torch.nn.ReLU, + "selu": torch.nn.SELU, + "swish": getattr(torch.nn, "SiLU", Swish), + "gelu": torch.nn.GELU, +} + +COSYVOICE_SUBSAMPLE_CLASSES = { + "linear": LinearNoSubsampling, + "linear_legacy": LegacyLinearNoSubsampling, + "embed": EmbedinigNoSubsampling, + "conv1d2": Conv1dSubsampling2, + "conv2d": Conv2dSubsampling4, + "conv2d6": Conv2dSubsampling6, + "conv2d8": Conv2dSubsampling8, + 'paraformer_dummy': torch.nn.Identity +} + +COSYVOICE_EMB_CLASSES = { + "embed": PositionalEncoding, + "abs_pos": PositionalEncoding, + "rel_pos": RelPositionalEncoding, + "rel_pos_espnet": EspnetRelPositionalEncoding, + "no_pos": NoPositionalEncoding, + "abs_pos_whisper": WhisperPositionalEncoding, + "embed_learnable_pe": LearnablePositionalEncoding, +} + +COSYVOICE_ATTENTION_CLASSES = { + "selfattn": MultiHeadedAttention, + "rel_selfattn": RelPositionMultiHeadedAttention, + "block_rel_selfattn": BlockRelPositionMultiHeadedAttention, +} diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..6ec5e178359031e42c64090eede8aabfdf067afa --- /dev/null +++ b/cosyvoice/utils/common.py @@ -0,0 +1,103 @@ +# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +# 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Unility functions for Transformer.""" + +from typing import List + +import torch + +IGNORE_ID = -1 + + +def pad_list(xs: List[torch.Tensor], pad_value: int): + """Perform padding for the list of tensors. + + Args: + xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. + pad_value (float): Value for padding. + + Returns: + Tensor: Padded tensor (B, Tmax, `*`). + + Examples: + >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] + >>> x + [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] + >>> pad_list(x, 0) + tensor([[1., 1., 1., 1.], + [1., 1., 0., 0.], + [1., 0., 0., 0.]]) + + """ + max_len = max([len(item) for item in xs]) + batchs = len(xs) + ndim = xs[0].ndim + if ndim == 1: + pad_res = torch.zeros(batchs, + max_len, + dtype=xs[0].dtype, + device=xs[0].device) + elif ndim == 2: + pad_res = torch.zeros(batchs, + max_len, + xs[0].shape[1], + dtype=xs[0].dtype, + device=xs[0].device) + elif ndim == 3: + pad_res = torch.zeros(batchs, + max_len, + xs[0].shape[1], + xs[0].shape[2], + dtype=xs[0].dtype, + device=xs[0].device) + else: + raise ValueError(f"Unsupported ndim: {ndim}") + pad_res.fill_(pad_value) + for i in range(batchs): + pad_res[i, :len(xs[i])] = xs[i] + return pad_res + + +def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor, + ignore_label: int) -> torch.Tensor: + """Calculate accuracy. + + Args: + pad_outputs (Tensor): Prediction tensors (B * Lmax, D). + pad_targets (LongTensor): Target label tensors (B, Lmax). + ignore_label (int): Ignore label id. + + Returns: + torch.Tensor: Accuracy value (0.0 - 1.0). + + """ + pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1), + pad_outputs.size(1)).argmax(2) + mask = pad_targets != ignore_label + numerator = torch.sum( + pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) + denominator = torch.sum(mask) + return (numerator / denominator).detach() + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) diff --git a/cosyvoice/utils/executor.py b/cosyvoice/utils/executor.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9159411eab228fb36b46d5c728f3989ce5e920 --- /dev/null +++ b/cosyvoice/utils/executor.py @@ -0,0 +1,132 @@ +# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +# 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from contextlib import nullcontext +import os + +import torch +import torch.distributed as dist +import tqdm + +from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join + + +class Executor: + + def __init__(self): + self.step = 0 + self.epoch = 0 + self.rank = int(os.environ.get('RANK', 0)) + self.device = torch.device('cuda:{}'.format(self.rank)) + + def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join): + ''' Train one epoch + ''' + + lr = optimizer.param_groups[0]['lr'] + logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank)) + logging.info('using accumulate grad, new batch size is {} times' + ' larger than before'.format(info_dict['accum_grad'])) + # A context manager to be used in conjunction with an instance of + # torch.nn.parallel.DistributedDataParallel to be able to train + # with uneven inputs across participating processes. + model.train() + model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext + with model_context(): + for batch_idx, batch_dict in tqdm.tqdm(enumerate(train_data_loader)): + # print("======== forword ========") + info_dict["tag"] = "TRAIN" + info_dict["step"] = self.step + info_dict["epoch"] = self.epoch + info_dict["batch_idx"] = batch_idx + if cosyvoice_join(group_join, info_dict): + break + # import pdb + # pdb.set_trace() + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0: + context = model.no_sync + # Used for single gpu training and DDP gradient synchronization + # processes. + else: + context = nullcontext + + new_batch_dict={ + # "utts":batch_dict["utts"], + "speech_token":batch_dict["speech_token"], + "speech_token_len":batch_dict["speech_token_len"], + "speech_feat":batch_dict["speech_feat"], + "speech_feat_len":batch_dict["speech_feat_len"], + "embedding":batch_dict["embedding"], + # "embedding":torch.zeros((batch_dict["speech_feat"].size(0),192),device=batch_dict["speech_feat"].device) + } + + with context(): + info_dict = batch_forward(model, new_batch_dict, info_dict) + info_dict = batch_backward(model, info_dict) + + info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict) + log_per_step(writer, info_dict) + # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save + if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and (batch_idx + 1) % info_dict["accum_grad"] == 0: + dist.barrier() + # try: + # dist.barrier() + # except RuntimeError as e: + # logging.info('except RuntimeError as e: {}'.format(e)) + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False) + model.train() + if (batch_idx + 1) % info_dict["accum_grad"] == 0: + self.step += 1 + dist.barrier() + # try: + # dist.barrier() + # except RuntimeError as e: + # logging.info('except RuntimeError as e: {}'.format(e)) + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True) + + @torch.inference_mode() + def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True): + ''' Cross validation on + ''' + logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank)) + model.eval() + total_num_utts, total_loss_dict = 0, {} # avoid division by 0 + for batch_idx, batch_dict in enumerate(cv_data_loader): + info_dict["tag"] = "CV" + info_dict["step"] = self.step + info_dict["epoch"] = self.epoch + info_dict["batch_idx"] = batch_idx + + # num_utts = len(batch_dict["utts"]) + num_utts=batch_dict["speech_token"].size(0) + total_num_utts += num_utts + + info_dict = batch_forward(model, batch_dict, info_dict) + + for k, v in info_dict['loss_dict'].items(): + if k not in total_loss_dict: + total_loss_dict[k] = [] + total_loss_dict[k].append(v.item() * num_utts) + log_per_step(None, info_dict) + for k, v in total_loss_dict.items(): + total_loss_dict[k] = sum(v) / total_num_utts + info_dict['loss_dict'] = total_loss_dict + log_per_save(writer, info_dict) + model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1) + save_model(model, model_name, info_dict) diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d4179e109da4073ca9be75767c3f59d2ee68a5cf --- /dev/null +++ b/cosyvoice/utils/file_utils.py @@ -0,0 +1,53 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import torchaudio + + +def read_lists(list_file): + lists = [] + with open(list_file, 'r', encoding='utf8') as fin: + for line in fin: + lists.append(line.strip()) + return lists + +def read_json_lists(list_file): + lists = read_lists(list_file) + results = {} + for fn in lists: + with open(fn, 'r', encoding='utf8') as fin: + results.update(json.load(fin)) + return results + +def load_wav(wav, target_sr): + speech, sample_rate = torchaudio.load(wav) + speech = speech.mean(dim=0, keepdim=True) + if sample_rate != target_sr: + assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) + speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech) + return speech + +def speed_change(waveform, sample_rate, speed_factor: str): + effects = [ + ["tempo", speed_factor], # speed_factor + ["rate", f"{sample_rate}"] + ] + augmented_waveform, new_sample_rate = torchaudio.sox_effects.apply_effects_tensor( + waveform, + sample_rate, + effects + ) + return augmented_waveform, new_sample_rate diff --git a/cosyvoice/utils/frontend_utils.py b/cosyvoice/utils/frontend_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..59489a7a6fdb442b1134baac3e5eef0211130954 --- /dev/null +++ b/cosyvoice/utils/frontend_utils.py @@ -0,0 +1,125 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+') + +# whether contain chinese character +def contains_chinese(text): + return bool(chinese_char_pattern.search(text)) + + +# replace special symbol +def replace_corner_mark(text): + text = text.replace('²', '平方') + text = text.replace('³', '立方') + return text + + +# remove meaningless symbol +def remove_bracket(text): + text = text.replace('(', '').replace(')', '') + text = text.replace('【', '').replace('】', '') + text = text.replace('`', '').replace('`', '') + text = text.replace("——", " ") + return text + + +# spell Arabic numerals +def spell_out_number(text: str, inflect_parser): + new_text = [] + st = None + for i, c in enumerate(text): + if not c.isdigit(): + if st is not None: + num_str = inflect_parser.number_to_words(text[st: i]) + new_text.append(num_str) + st = None + new_text.append(c) + else: + if st is None: + st = i + if st is not None and st < len(text): + num_str = inflect_parser.number_to_words(text[st:]) + new_text.append(num_str) + return ''.join(new_text) + + +# split paragrah logic: +# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len +# 2. cal sentence len according to lang +# 3. split sentence according to puncatation +def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False): + def calc_utt_length(_text: str): + if lang == "zh": + return len(_text) + else: + return len(tokenize(_text)) + + def should_merge(_text: str): + if lang == "zh": + return len(_text) < merge_len + else: + return len(tokenize(_text)) < merge_len + + if lang == "zh": + pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';'] + else: + pounc = ['.', '?', '!', ';', ':'] + if comma_split: + pounc.extend([',', ',']) + st = 0 + utts = [] + for i, c in enumerate(text): + if c in pounc: + if len(text[st: i]) > 0: + utts.append(text[st: i] + c) + if i + 1 < len(text) and text[i + 1] in ['"', '”']: + tmp = utts.pop(-1) + utts.append(tmp + text[i + 1]) + st = i + 2 + else: + st = i + 1 + if len(utts) == 0: + if lang == "zh": + utts.append(text + '。') + else: + utts.append(text + '.') + final_utts = [] + cur_utt = "" + for utt in utts: + if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n: + final_utts.append(cur_utt) + cur_utt = "" + cur_utt = cur_utt + utt + if len(cur_utt) > 0: + if should_merge(cur_utt) and len(final_utts) != 0: + final_utts[-1] = final_utts[-1] + cur_utt + else: + final_utts.append(cur_utt) + + return final_utts + + +# remove blank between chinese character +def replace_blank(text: str): + out_str = [] + for i, c in enumerate(text): + if c == " ": + if ((text[i + 1].isascii() and text[i + 1] != " ") and + (text[i - 1].isascii() and text[i - 1] != " ")): + out_str.append(c) + else: + out_str.append(c) + return "".join(out_str) diff --git a/cosyvoice/utils/mask.py b/cosyvoice/utils/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..2b460bbd5adb4bd61d643ace71400a14fe314236 --- /dev/null +++ b/cosyvoice/utils/mask.py @@ -0,0 +1,227 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +''' +def subsequent_mask( + size: int, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size). + + This mask is used only in decoder which works in an auto-regressive mode. + This means the current step could only do attention with its left steps. + + In encoder, fully attention is used when streaming is not necessary and + the sequence is not long. In this case, no attention mask is needed. + + When streaming is need, chunk-based attention is used in encoder. See + subsequent_chunk_mask for the chunk-based attention mask. + + Args: + size (int): size of mask + str device (str): "cpu" or "cuda" or torch.Tensor.device + dtype (torch.device): result dtype + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + """ + ret = torch.ones(size, size, device=device, dtype=torch.bool) + return torch.tril(ret) +''' + + +def subsequent_mask( + size: int, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size). + + This mask is used only in decoder which works in an auto-regressive mode. + This means the current step could only do attention with its left steps. + + In encoder, fully attention is used when streaming is not necessary and + the sequence is not long. In this case, no attention mask is needed. + + When streaming is need, chunk-based attention is used in encoder. See + subsequent_chunk_mask for the chunk-based attention mask. + + Args: + size (int): size of mask + str device (str): "cpu" or "cuda" or torch.Tensor.device + dtype (torch.device): result dtype + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + """ + arange = torch.arange(size, device=device) + mask = arange.expand(size, size) + arange = arange.unsqueeze(-1) + mask = mask <= arange + return mask + + +def subsequent_chunk_mask( + size: int, + chunk_size: int, + num_left_chunks: int = -1, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size) with chunk size, + this is for streaming encoder + + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + device (torch.device): "cpu" or "cuda" or torch.Tensor.device + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + ret = torch.zeros(size, size, device=device, dtype=torch.bool) + for i in range(size): + if num_left_chunks < 0: + start = 0 + else: + start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) + ending = min((i // chunk_size + 1) * chunk_size, size) + ret[i, start:ending] = True + return ret + + +def add_optional_chunk_mask(xs: torch.Tensor, + masks: torch.Tensor, + use_dynamic_chunk: bool, + use_dynamic_left_chunk: bool, + decoding_chunk_size: int, + static_chunk_size: int, + num_decoding_left_chunks: int, + enable_full_context: bool = True): + """ Apply optional mask for encoder. + + Args: + xs (torch.Tensor): padded input, (B, L, D), L for max length + mask (torch.Tensor): mask for xs, (B, 1, L) + use_dynamic_chunk (bool): whether to use dynamic chunk or not + use_dynamic_left_chunk (bool): whether to use dynamic left chunk for + training. + decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + static_chunk_size (int): chunk size for static chunk training/decoding + if it's greater than 0, if use_dynamic_chunk is true, + this parameter will be ignored + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + enable_full_context (bool): + True: chunk size is either [1, 25] or full context(max_len) + False: chunk size ~ U[1, 25] + + Returns: + torch.Tensor: chunk mask of the input xs. + """ + # Whether to use chunk mask or not + if use_dynamic_chunk: + max_len = xs.size(1) + if decoding_chunk_size < 0: + chunk_size = max_len + num_left_chunks = -1 + elif decoding_chunk_size > 0: + chunk_size = decoding_chunk_size + num_left_chunks = num_decoding_left_chunks + else: + # chunk size is either [1, 25] or full context(max_len). + # Since we use 4 times subsampling and allow up to 1s(100 frames) + # delay, the maximum frame is 100 / 4 = 25. + chunk_size = torch.randint(1, max_len, (1, )).item() + num_left_chunks = -1 + if chunk_size > max_len // 2 and enable_full_context: + chunk_size = max_len + else: + chunk_size = chunk_size % 25 + 1 + if use_dynamic_left_chunk: + max_left_chunks = (max_len - 1) // chunk_size + num_left_chunks = torch.randint(0, max_left_chunks, + (1, )).item() + chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, + num_left_chunks, + xs.device) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + elif static_chunk_size > 0: + num_left_chunks = num_decoding_left_chunks + chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, + num_left_chunks, + xs.device) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + else: + chunk_masks = masks + return chunk_masks + + +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (torch.Tensor): Batch of lengths (B,). + Returns: + torch.Tensor: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = lengths.size(0) + max_len = max_len if max_len > 0 else lengths.max().item() + seq_range = torch.arange(0, + max_len, + dtype=torch.int64, + device=lengths.device) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask diff --git a/cosyvoice/utils/scheduler.py b/cosyvoice/utils/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..fbf4803f81bd7a3cee4af7bd8b6af2d3b46304d7 --- /dev/null +++ b/cosyvoice/utils/scheduler.py @@ -0,0 +1,739 @@ +# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +# 2022 Ximalaya Inc (Yuguang Yang) +# 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +# NeMo(https://github.com/NVIDIA/NeMo) + +from typing import Union + +import math +import warnings +import torch +from torch.optim.lr_scheduler import _LRScheduler + + +class WarmupLR(_LRScheduler): + """The WarmupLR scheduler + + This scheduler is almost same as NoamLR Scheduler except for following + difference: + + NoamLR: + lr = optimizer.lr * model_size ** -0.5 + * min(step ** -0.5, step * warmup_step ** -1.5) + WarmupLR: + lr = optimizer.lr * warmup_step ** 0.5 + * min(step ** -0.5, step * warmup_step ** -1.5) + + Note that the maximum lr equals to optimizer.lr in this scheduler. + + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + warmup_steps: Union[int, float] = 25000, + last_epoch: int = -1, + ): + self.warmup_steps = warmup_steps + + # __init__() must be invoked before setting field + # because step() is also invoked in __init__() + super().__init__(optimizer, last_epoch) + + def __repr__(self): + return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})" + + def get_lr(self): + step_num = self.last_epoch + 1 + if self.warmup_steps == 0: + return [lr * step_num**-0.5 for lr in self.base_lrs] + else: + return [ + lr * self.warmup_steps**0.5 * + min(step_num**-0.5, step_num * self.warmup_steps**-1.5) + for lr in self.base_lrs + ] + + def set_step(self, step: int): + self.last_epoch = step + + +class WarmupPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + """ + + def __init__(self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1): + assert not (warmup_steps is not None and warmup_ratio is not None),\ + "Either use particular number of step or ratio" + assert warmup_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning, + stacklevel=2) + + step = self.last_epoch + + if step <= self.warmup_steps and self.warmup_steps > 0: + return self._get_warmup_lr(step) + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + def _get_warmup_lr(self, step): + lr_val = (step + 1) / (self.warmup_steps + 1) + return [initial_lr * lr_val for initial_lr in self.base_lrs] + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs + + +class SquareRootConstantPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + """ + + def __init__(self, + optimizer, + *, + constant_steps=None, + constant_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1): + assert not (constant_steps is not None + and constant_ratio is not None), \ + "Either use particular number of step or ratio" + assert constant_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if constant_steps is not None: + self.constant_steps = constant_steps + elif constant_ratio is not None: + self.constant_steps = int(constant_ratio * max_steps) + else: + self.constant_steps = 0 + + self.constant_lr = 1 / (constant_steps**0.5) + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning, + stacklevel=2) + + step = self.last_epoch + + if step <= self.constant_steps: + return [self.constant_lr for _ in self.base_lrs] + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs + + +class WarmupHoldPolicy(WarmupPolicy): + """Variant of WarmupPolicy which maintains high + learning rate for a defined number of steps. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + hold_steps: Number of training steps to + hold the learning rate after warm up + hold_ratio: Ratio of hold steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + """ + + def __init__( + self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + hold_steps=None, + hold_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1, + ): + assert not (hold_steps is not None and hold_ratio is not None), \ + "Either use particular number of step or ratio" + assert hold_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + self.min_lr = min_lr + self._last_warmup_lr = 0.0 + + # Necessary to duplicate as class attributes are hidden in inner class + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + if hold_steps is not None: + self.hold_steps = hold_steps + self.warmup_steps + elif hold_ratio is not None: + self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps + else: + self.hold_steps = 0 + + super().__init__( + optimizer, + warmup_steps=warmup_steps, + warmup_ratio=warmup_ratio, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + ) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler," + " " + "please use `get_last_lr()`.", + UserWarning, + stacklevel=2) + + step = self.last_epoch + + # Warmup phase + if step <= self.warmup_steps and self.warmup_steps > 0: + return self._get_warmup_lr(step) + + # Hold phase + if (step >= self.warmup_steps) and (step < self.hold_steps): + return self.base_lrs + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + +class WarmupAnnealHoldPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + min_lr: Minimum lr to hold the learning rate after decay at. + constant_steps: Number of steps to keep lr constant at. + constant_ratio: Ratio of steps to keep lr constant. + """ + + def __init__( + self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + constant_steps=None, + constant_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1, + ): + assert not (warmup_steps is not None + and warmup_ratio is not None), \ + "Either use particular number of step or ratio" + assert not (constant_steps is not None + and constant_ratio is not None), \ + "Either use constant_steps or constant_ratio" + assert warmup_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + if constant_steps is not None: + self.constant_steps = constant_steps + elif constant_ratio is not None: + self.constant_steps = int(constant_ratio * max_steps) + else: + self.constant_steps = 0 + + self.decay_steps = max_steps - (self.constant_steps + + self.warmup_steps) + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning, + stacklevel=2) + + step = self.last_epoch + + # Warmup steps + if self.warmup_steps > 0 and step <= self.warmup_steps: + return self._get_warmup_lr(step) + + # Constant steps after warmup and decay + if self.constant_steps > 0 and ( + self.warmup_steps + self.decay_steps) < step <= self.max_steps: + return self._get_constant_lr(step) + + # Min lr after max steps of updates + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + def _get_warmup_lr(self, step): + lr_val = (step + 1) / (self.warmup_steps + 1) + return [initial_lr * lr_val for initial_lr in self.base_lrs] + + def _get_constant_lr(self, step): + return [self.min_lr for _ in self.base_lrs] + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs + + +def _squareroot_annealing(initial_lr, step, max_steps, min_lr): + mult = ((max_steps - step) / max_steps)**0.5 + out_lr = initial_lr * mult + out_lr = max(out_lr, min_lr) + return out_lr + + +def _square_annealing(initial_lr, step, max_steps, min_lr): + mult = ((max_steps - step) / max_steps)**2 + out_lr = initial_lr * mult + out_lr = max(out_lr, min_lr) + return out_lr + + +def _cosine_annealing(initial_lr, step, max_steps, min_lr): + mult = 0.5 * (1 + math.cos(math.pi * step / max_steps)) + out_lr = (initial_lr - min_lr) * mult + min_lr + return out_lr + + +def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step, + decay_steps, min_lr): + assert max_lr > min_lr + # Use linear warmup for the initial part. + if warmup_steps > 0 and step <= warmup_steps: + return max_lr * float(step) / float(warmup_steps) + + # For any steps larger than `decay_steps`, use `min_lr`. + if step > warmup_steps + decay_steps: + return min_lr + + # If we are done with the warmup period, use the decay style. + num_steps_ = step - warmup_steps + decay_steps_ = decay_steps + decay_ratio = float(num_steps_) / float(decay_steps_) + assert decay_ratio >= 0.0 + assert decay_ratio <= 1.0 + delta_lr = max_lr - min_lr + + coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) + + return min_lr + coeff * delta_lr + + +def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle): + if cycle: + multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps) + decay_steps *= multiplier + else: + step = min(step, decay_steps) + p = step / decay_steps + lr = (initial_lr - min_lr) * math.pow(1.0 - p, power) + lr += min_lr + return lr + + +def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps, + decay_rate, min_lr): + # hold_steps = total number of steps + # to hold the LR, not the warmup + hold steps. + T_warmup_decay = max(1, warmup_steps**decay_rate) + T_hold_decay = max(1, (step - hold_steps)**decay_rate) + lr = (initial_lr * T_warmup_decay) / T_hold_decay + lr = max(lr, min_lr) + return lr + + +class SquareAnnealing(WarmupPolicy): + + def __init__(self, + optimizer, + *, + max_steps, + min_lr=1e-5, + last_epoch=-1, + **kwargs): + super().__init__(optimizer=optimizer, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + **kwargs) + + def _get_lr(self, step): + new_lrs = [ + _square_annealing( + initial_lr=initial_lr, + step=step - self.warmup_steps, + max_steps=self.max_steps - self.warmup_steps, + min_lr=self.min_lr, + ) for initial_lr in self.base_lrs + ] + return new_lrs + + +class SquareRootAnnealing(WarmupPolicy): + + def __init__(self, + optimizer, + *, + max_steps, + min_lr=0, + last_epoch=-1, + **kwargs): + super().__init__(optimizer=optimizer, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + **kwargs) + + def _get_lr(self, step): + new_lrs = [ + _squareroot_annealing(initial_lr=initial_lr, + step=step, + max_steps=self.max_steps, + min_lr=self.min_lr) + for initial_lr in self.base_lrs + ] + return new_lrs + + +class CosineAnnealing(WarmupAnnealHoldPolicy): + + def __init__(self, + optimizer, + *, + max_steps, + min_lr=0, + last_epoch=-1, + **kwargs): + super().__init__(optimizer=optimizer, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + **kwargs) + + def _get_lr(self, step): + for initial_lr in self.base_lrs: + if initial_lr < self.min_lr: + raise ValueError( + f"{self} received an initial learning rate " + f"that was lower than the minimum learning rate.") + + if self.constant_steps is None or self.constant_steps == 0: + new_lrs = [ + _cosine_annealing( + initial_lr=initial_lr, + step=step - self.warmup_steps, + max_steps=self.max_steps - self.warmup_steps, + min_lr=self.min_lr, + ) for initial_lr in self.base_lrs + ] + else: + new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step) + return new_lrs + + def _get_warmup_lr(self, step): + if self.constant_steps is None or self.constant_steps == 0: + return super()._get_warmup_lr(step) + else: + # Use linear warmup for the initial part. + return self._get_linear_warmup_with_cosine_annealing_lr(step) + + def _get_constant_lr(self, step): + # Only called when `constant_steps` > 0. + return self._get_linear_warmup_with_cosine_annealing_lr(step) + + def _get_linear_warmup_with_cosine_annealing_lr(self, step): + # Cosine Schedule for Megatron LM, + # slightly different warmup schedule + constant LR at the end. + new_lrs = [ + _linear_warmup_with_cosine_annealing( + max_lr=self.base_lrs[0], + warmup_steps=self.warmup_steps, + step=step, + decay_steps=self.decay_steps, + min_lr=self.min_lr, + ) for _ in self.base_lrs + ] + return new_lrs + + +class NoamAnnealing(_LRScheduler): + + def __init__(self, + optimizer, + *, + d_model, + warmup_steps=None, + warmup_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1): + self._normalize = d_model**(-0.5) + assert not (warmup_steps is not None + and warmup_ratio is not None), \ + "Either use particular number of step or ratio" + assert warmup_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning, + stacklevel=2) + + step = max(1, self.last_epoch) + + for initial_lr in self.base_lrs: + if initial_lr < self.min_lr: + raise ValueError( + f"{self} received an initial learning rate " + f"that was lower than the minimum learning rate.") + + new_lrs = [ + self._noam_annealing(initial_lr=initial_lr, step=step) + for initial_lr in self.base_lrs + ] + return new_lrs + + def _noam_annealing(self, initial_lr, step): + if self.warmup_steps > 0: + mult = self._normalize * min(step**(-0.5), + step * (self.warmup_steps**(-1.5))) + else: + mult = self._normalize * step**(-0.5) + + out_lr = initial_lr * mult + if step > self.warmup_steps: + out_lr = max(out_lr, self.min_lr) + return out_lr + + +class NoamHoldAnnealing(WarmupHoldPolicy): + + def __init__(self, + optimizer, + *, + max_steps, + decay_rate=0.5, + min_lr=0.0, + last_epoch=-1, + **kwargs): + """ + From Nemo: + Implementation of the Noam Hold Annealing policy + from the SqueezeFormer paper. + + Unlike NoamAnnealing, the peak learning rate + can be explicitly set for this scheduler. + The schedule first performs linear warmup, + then holds the peak LR, then decays with some schedule for + the remainder of the steps. + Therefore the min-lr is still dependent + on the hyper parameters selected. + + It's schedule is determined by three factors- + + Warmup Steps: Initial stage, where linear warmup + occurs uptil the peak LR is reached. Unlike NoamAnnealing, + the peak LR is explicitly stated here instead of a scaling factor. + + Hold Steps: Intermediate stage, where the peak LR + is maintained for some number of steps. In this region, + the high peak LR allows the model to converge faster + if training is stable. However the high LR + may also cause instability during training. + Should usually be a significant fraction of training + steps (around 30-40% of the entire training steps). + + Decay Steps: Final stage, where the LR rapidly decays + with some scaling rate (set by decay rate). + To attain Noam decay, use 0.5, + for Squeezeformer recommended decay, use 1.0. + The fast decay after prolonged high LR during + hold phase allows for rapid convergence. + + References: + - [Squeezeformer: + An Efficient Transformer for Automatic Speech Recognition] + (https://arxiv.org/abs/2206.00888) + + Args: + optimizer: Pytorch compatible Optimizer object. + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + hold_steps: Number of training steps to + hold the learning rate after warm up + hold_ratio: Ratio of hold steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + decay_rate: Float value describing the polynomial decay + after the hold period. Default value + of 0.5 corresponds to Noam decay. + min_lr: Minimum learning rate. + """ + self.decay_rate = decay_rate + super().__init__(optimizer=optimizer, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + **kwargs) + + def _get_lr(self, step): + if self.warmup_steps is None or self.warmup_steps == 0: + raise ValueError( + "Noam scheduler cannot be used without warmup steps") + + if self.hold_steps > 0: + hold_steps = self.hold_steps - self.warmup_steps + else: + hold_steps = 0 + + new_lrs = [ + _noam_hold_annealing( + initial_lr, + step=step, + warmup_steps=self.warmup_steps, + hold_steps=hold_steps, + decay_rate=self.decay_rate, + min_lr=self.min_lr, + ) for initial_lr in self.base_lrs + ] + return new_lrs + + def set_step(self, step: int): + self.last_epoch = step + + +class ConstantLR(_LRScheduler): + """The ConstantLR scheduler + + This scheduler keeps a constant lr + + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + ): + # __init__() must be invoked before setting field + # because step() is also invoked in __init__() + super().__init__(optimizer) + + def get_lr(self): + return self.base_lrs + + def set_step(self, step: int): + self.last_epoch = step diff --git a/cosyvoice/utils/train_utils.py b/cosyvoice/utils/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..020005d001e1c386c84398981dfcc039b41fa89b --- /dev/null +++ b/cosyvoice/utils/train_utils.py @@ -0,0 +1,289 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# 2023 Horizon Inc. (authors: Xingchen Song) +# 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext +import logging +import os +import torch +import json +import re +import datetime +import yaml + +# import deepspeed +import torch.optim as optim +import torch.distributed as dist + +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DataLoader +from torch.nn.utils import clip_grad_norm_ + +# from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live + +from cosyvoice.dataset.dataset import Dataset +from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR + + +def init_distributed(args): + world_size = int(os.environ.get('WORLD_SIZE', 1)) + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + rank = int(os.environ.get('RANK', 0)) + logging.info('training on multiple gpus, this gpu {}'.format(local_rank) + + ', rank {}, world_size {}'.format(rank, world_size)) + if args.train_engine == 'torch_ddp': + torch.cuda.set_device(local_rank) + dist.init_process_group(args.dist_backend) + else: + deepspeed.init_distributed(dist_backend=args.dist_backend) + return world_size, local_rank, rank + + +def init_dataset_and_dataloader(args, configs): + train_dataset = Dataset(args.train_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=True, partition=True) + cv_dataset = Dataset(args.cv_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=False, partition=False) + + # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts + train_data_loader = DataLoader(train_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) + cv_data_loader = DataLoader(cv_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) + return train_dataset, cv_dataset, train_data_loader, cv_data_loader + + + +def check_modify_and_save_config(args, configs): + if args.train_engine == "torch_ddp": + configs['train_conf']["dtype"] = 'fp32' + else: + with open(args.deepspeed_config, 'r') as fin: + ds_configs = json.load(fin) + if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]: + configs['train_conf']["dtype"] = "fp16" + elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]: + configs['train_conf']["dtype"] = "bf16" + else: + configs['train_conf']["dtype"] = "fp32" + assert ds_configs["train_micro_batch_size_per_gpu"] == 1 + # if use deepspeed, override ddp config + configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] * configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"]) + configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"] + configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"] + configs['train_conf']['log_interval'] = ds_configs["steps_per_print"] + return configs + + +def wrap_cuda_model(args, model): + local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1)) + world_size = int(os.environ.get('WORLD_SIZE', 1)) + if args.train_engine == "torch_ddp": # native pytorch ddp + assert (torch.cuda.is_available()) + model.cuda() + model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) + else: + if int(os.environ.get('RANK', 0)) == 0: + logging.info("Estimating model states memory needs (zero2)...") + estimate_zero2_model_states_mem_needs_all_live( + model, + num_gpus_per_node=local_world_size, + num_nodes=world_size // local_world_size) + return model + + +def init_optimizer_and_scheduler(args, configs, model): + if configs['train_conf']['optim'] == 'adam': + optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf']) + elif configs['train_conf']['optim'] == 'adamw': + optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf']) + else: + raise ValueError("unknown optimizer: " + configs['train_conf']) + + if configs['train_conf']['scheduler'] == 'warmuplr': + scheduler_type = WarmupLR + scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing': + scheduler_type = NoamHoldAnnealing + scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'constantlr': + scheduler_type = ConstantLR + scheduler = ConstantLR(optimizer) + else: + raise ValueError("unknown scheduler: " + configs['train_conf']) + + # use deepspeed optimizer for speedup + if args.train_engine == "deepspeed": + def scheduler(opt): + return scheduler_type(opt, **configs['train_conf']['scheduler_conf']) + model, optimizer, _, scheduler = deepspeed.initialize( + args=args, + model=model, + optimizer=None, + lr_scheduler=scheduler, + model_parameters=model.parameters()) + + return model, optimizer, scheduler + + +def init_summarywriter(args): + writer = None + if int(os.environ.get('RANK', 0)) == 0: + os.makedirs(args.model_dir, exist_ok=True) + writer = SummaryWriter(args.tensorboard_dir) + return writer + + +def save_model(model, model_name, info_dict): + rank = int(os.environ.get('RANK', 0)) + model_dir = info_dict["model_dir"] + save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name)) + + if info_dict["train_engine"] == "torch_ddp": + if rank == 0: + torch.save(model.module.state_dict(), save_model_path) + else: + with torch.no_grad(): + model.save_checkpoint(save_dir=model_dir, + tag=model_name, + client_state=info_dict) + if rank == 0: + info_path = re.sub('.pt$', '.yaml', save_model_path) + info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') + with open(info_path, 'w') as fout: + data = yaml.dump(info_dict) + fout.write(data) + logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path)) + + +def cosyvoice_join(group_join, info_dict): + world_size = int(os.environ.get('WORLD_SIZE', 1)) + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + rank = int(os.environ.get('RANK', 0)) + + if info_dict["batch_idx"] != 0: + # we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr + try: + dist.monitored_barrier(group=group_join, + timeout=group_join.options._timeout) + return False + except RuntimeError as e: + logging.info("Detected uneven workload distribution: {}\n".format(e) + + "Break current worker to manually join all workers, " + + "world_size {}, current rank {}, current local_rank {}\n". + format(world_size, rank, local_rank)) + return True + else: + return False + + +def batch_forward(model, batch, info_dict): + device = int(os.environ.get('LOCAL_RANK', 0)) + + dtype = info_dict["dtype"] + if dtype == "fp16": + dtype = torch.float16 + elif dtype == "bf16": + dtype = torch.bfloat16 + else: # fp32 + dtype = torch.float32 + + if info_dict['train_engine'] == 'torch_ddp': + autocast = nullcontext() + else: + autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False) + + with autocast: + info_dict['loss_dict'] = model(batch, device) + return info_dict + + +def batch_backward(model, info_dict): + if info_dict["train_engine"] == "deepspeed": + scaled_loss = model.backward(info_dict['loss_dict']['loss']) + else: + scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad'] + scaled_loss.backward() + + info_dict['loss_dict']['loss'] = scaled_loss + return info_dict + + +def update_parameter_and_lr(model, optimizer, scheduler, info_dict): + grad_norm = 0.0 + if info_dict['train_engine'] == "deepspeed": + info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary() + model.step() + grad_norm = model.get_global_grad_norm() + elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0: + grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) + if torch.isfinite(grad_norm): + optimizer.step() + optimizer.zero_grad() + scheduler.step() + info_dict["lr"] = optimizer.param_groups[0]['lr'] + info_dict["grad_norm"] = grad_norm + return info_dict + + +def log_per_step(writer, info_dict): + tag = info_dict["tag"] + epoch = info_dict.get('epoch', 0) + step = info_dict["step"] + batch_idx = info_dict["batch_idx"] + loss_dict = info_dict['loss_dict'] + rank = int(os.environ.get('RANK', 0)) + + # only rank 0 write to tensorboard to avoid multi-process write + if writer is not None: + if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \ + (info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0): + for k in ['epoch', 'lr', 'grad_norm']: + writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1) + for k, v in loss_dict.items(): + writer.add_scalar('{}/{}'.format(tag, k), v, step + 1) + + # TRAIN & CV, Shell log (stdout) + if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0: + log_str = '{} Batch {}/{} '.format(tag, epoch, batch_idx + 1) + for name, value in loss_dict.items(): + log_str += '{} {:.6f} '.format(name, value) + if tag == "TRAIN": + log_str += 'lr {:.8f} grad_norm {:.6f}'.format( + info_dict["lr"], info_dict['grad_norm']) + log_str += ' rank {}'.format(rank) + logging.debug(log_str) + + +def log_per_save(writer, info_dict): + tag = info_dict["tag"] + epoch = info_dict["epoch"] + step = info_dict["step"] + loss_dict = info_dict["loss_dict"] + lr = info_dict['lr'] + rank = int(os.environ.get('RANK', 0)) + logging.info( + 'Epoch {} Step {} CV info lr {} {} rank {}'.format( + epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()]))) + + if writer is not None: + for k in ['epoch', 'lr']: + writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1) + for k, v in loss_dict.items(): + writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)