cydxg commited on
Commit
0a948c1
1 Parent(s): f82f9a8

Upload 73 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. cosyvoice/__init__.py +0 -0
  2. cosyvoice/__pycache__/__init__.cpython-310.pyc +0 -0
  3. cosyvoice/bin/inference.py +114 -0
  4. cosyvoice/bin/train.py +140 -0
  5. cosyvoice/cli/__init__.py +0 -0
  6. cosyvoice/cli/cosyvoice.py +83 -0
  7. cosyvoice/cli/frontend.py +168 -0
  8. cosyvoice/cli/model.py +95 -0
  9. cosyvoice/dataset/__init__.py +0 -0
  10. cosyvoice/dataset/dataset.py +160 -0
  11. cosyvoice/dataset/processor.py +965 -0
  12. cosyvoice/flow/__pycache__/decoder.cpython-310.pyc +0 -0
  13. cosyvoice/flow/__pycache__/flow.cpython-310.pyc +0 -0
  14. cosyvoice/flow/__pycache__/flow_matching.cpython-310.pyc +0 -0
  15. cosyvoice/flow/__pycache__/length_regulator.cpython-310.pyc +0 -0
  16. cosyvoice/flow/decoder.py +222 -0
  17. cosyvoice/flow/flow.py +144 -0
  18. cosyvoice/flow/flow_gradtts.py +142 -0
  19. cosyvoice/flow/flow_matching.py +142 -0
  20. cosyvoice/flow/flow_matching_dit.py +180 -0
  21. cosyvoice/flow/length_regulator.py +49 -0
  22. cosyvoice/flow/stable/adp.py +1591 -0
  23. cosyvoice/flow/stable/blocks.py +339 -0
  24. cosyvoice/flow/stable/dit.py +415 -0
  25. cosyvoice/flow/stable/dit_v2.py +307 -0
  26. cosyvoice/flow/stable/sampling.py +232 -0
  27. cosyvoice/flow/stable/stable_diffusion.py +109 -0
  28. cosyvoice/flow/stable/stable_diffusion_test.py +104 -0
  29. cosyvoice/flow/stable/transformer.py +816 -0
  30. cosyvoice/flow/stable/transformer_use_mask.py +845 -0
  31. cosyvoice/hifigan/__pycache__/f0_predictor.cpython-310.pyc +0 -0
  32. cosyvoice/hifigan/__pycache__/generator.cpython-310.pyc +0 -0
  33. cosyvoice/hifigan/f0_predictor.py +55 -0
  34. cosyvoice/hifigan/generator.py +398 -0
  35. cosyvoice/llm/__pycache__/llm.cpython-310.pyc +0 -0
  36. cosyvoice/llm/llm.py +206 -0
  37. cosyvoice/transformer/__init__.py +0 -0
  38. cosyvoice/transformer/__pycache__/__init__.cpython-310.pyc +0 -0
  39. cosyvoice/transformer/__pycache__/activation.cpython-310.pyc +0 -0
  40. cosyvoice/transformer/__pycache__/attention.cpython-310.pyc +0 -0
  41. cosyvoice/transformer/__pycache__/convolution.cpython-310.pyc +0 -0
  42. cosyvoice/transformer/__pycache__/embedding.cpython-310.pyc +0 -0
  43. cosyvoice/transformer/__pycache__/encoder.cpython-310.pyc +0 -0
  44. cosyvoice/transformer/__pycache__/encoder_layer.cpython-310.pyc +0 -0
  45. cosyvoice/transformer/__pycache__/label_smoothing_loss.cpython-310.pyc +0 -0
  46. cosyvoice/transformer/__pycache__/positionwise_feed_forward.cpython-310.pyc +0 -0
  47. cosyvoice/transformer/__pycache__/subsampling.cpython-310.pyc +0 -0
  48. cosyvoice/transformer/activation.py +84 -0
  49. cosyvoice/transformer/attention.py +612 -0
  50. cosyvoice/transformer/convolution.py +145 -0
cosyvoice/__init__.py ADDED
File without changes
cosyvoice/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (132 Bytes). View file
 
cosyvoice/bin/inference.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+
22
+ import torch
23
+ from torch.utils.data import DataLoader
24
+ import torchaudio
25
+ from hyperpyyaml import load_hyperpyyaml
26
+ from tqdm import tqdm
27
+ from cosyvoice.cli.model import CosyVoiceModel
28
+
29
+ from cosyvoice.dataset.dataset import Dataset
30
+
31
+ def get_args():
32
+ parser = argparse.ArgumentParser(description='inference with your model')
33
+ parser.add_argument('--config', required=True, help='config file')
34
+ parser.add_argument('--prompt_data', required=True, help='prompt data file')
35
+ parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
36
+ parser.add_argument('--tts_text', required=True, help='tts input file')
37
+ parser.add_argument('--llm_model', required=True, help='llm model file')
38
+ parser.add_argument('--flow_model', required=True, help='flow model file')
39
+ parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
40
+ parser.add_argument('--gpu',
41
+ type=int,
42
+ default=-1,
43
+ help='gpu id for this rank, -1 for cpu')
44
+ parser.add_argument('--mode',
45
+ default='sft',
46
+ choices=['sft', 'zero_shot'],
47
+ help='inference mode')
48
+ parser.add_argument('--result_dir', required=True, help='asr result file')
49
+ args = parser.parse_args()
50
+ print(args)
51
+ return args
52
+
53
+
54
+ def main():
55
+ args = get_args()
56
+ logging.basicConfig(level=logging.DEBUG,
57
+ format='%(asctime)s %(levelname)s %(message)s')
58
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
59
+
60
+ # Init cosyvoice models from configs
61
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
62
+ device = torch.device('cuda' if use_cuda else 'cpu')
63
+ with open(args.config, 'r') as f:
64
+ configs = load_hyperpyyaml(f)
65
+
66
+ model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
67
+ model.load(args.llm_model, args.flow_model, args.hifigan_model)
68
+
69
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
70
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
71
+
72
+ del configs
73
+ os.makedirs(args.result_dir, exist_ok=True)
74
+ fn = os.path.join(args.result_dir, 'wav.scp')
75
+ f = open(fn, 'w')
76
+ with torch.no_grad():
77
+ for batch_idx, batch in tqdm(enumerate(test_data_loader)):
78
+ utts = batch["utts"]
79
+ assert len(utts) == 1, "inference mode only support batchsize 1"
80
+ text = batch["text"]
81
+ text_token = batch["text_token"].to(device)
82
+ text_token_len = batch["text_token_len"].to(device)
83
+ tts_text = batch["tts_text"]
84
+ tts_index = batch["tts_index"]
85
+ tts_text_token = batch["tts_text_token"].to(device)
86
+ tts_text_token_len = batch["tts_text_token_len"].to(device)
87
+ speech_token = batch["speech_token"].to(device)
88
+ speech_token_len = batch["speech_token_len"].to(device)
89
+ speech_feat = batch["speech_feat"].to(device)
90
+ speech_feat_len = batch["speech_feat_len"].to(device)
91
+ utt_embedding = batch["utt_embedding"].to(device)
92
+ spk_embedding = batch["spk_embedding"].to(device)
93
+ if args.mode == 'sft':
94
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
95
+ 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
96
+ else:
97
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
98
+ 'prompt_text': text_token, 'prompt_text_len': text_token_len,
99
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
100
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
101
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
102
+ 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
103
+ model_output = model.inference(**model_input)
104
+ tts_key = '{}_{}'.format(utts[0], tts_index[0])
105
+ tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
106
+ torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050)
107
+ f.write('{} {}\n'.format(tts_key, tts_fn))
108
+ f.flush()
109
+ f.close()
110
+ logging.info('Result wav.scp saved in {}'.format(fn))
111
+
112
+
113
+ if __name__ == '__main__':
114
+ main()
cosyvoice/bin/train.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+ import argparse
17
+ import datetime
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ from copy import deepcopy
21
+ import torch
22
+ import torch.distributed as dist
23
+ # import deepspeed
24
+ import pdb
25
+ from hyperpyyaml import load_hyperpyyaml
26
+
27
+ from torch.distributed.elastic.multiprocessing.errors import record
28
+
29
+ from cosyvoice.utils.executor import Executor
30
+ from cosyvoice.utils.train_utils import (
31
+ init_distributed,
32
+ init_dataset_and_dataloader,
33
+ init_optimizer_and_scheduler,
34
+ init_summarywriter, save_model,
35
+ wrap_cuda_model, check_modify_and_save_config)
36
+
37
+
38
+ def get_args():
39
+ parser = argparse.ArgumentParser(description='training your network')
40
+ parser.add_argument('--train_engine',
41
+ default='torch_ddp',
42
+ choices=['torch_ddp', 'deepspeed'],
43
+ help='Engine for paralleled training')
44
+ parser.add_argument('--model', required=True, help='model which will be trained')
45
+ parser.add_argument('--config', required=True, help='config file')
46
+ parser.add_argument('--train_data', required=True, help='train data file')
47
+ parser.add_argument('--cv_data', required=True, help='cv data file')
48
+ parser.add_argument('--checkpoint', help='checkpoint model')
49
+ parser.add_argument('--model_dir', required=True, help='save model dir')
50
+ parser.add_argument('--tensorboard_dir',
51
+ default='tensorboard',
52
+ help='tensorboard log dir')
53
+ parser.add_argument('--ddp.dist_backend',
54
+ dest='dist_backend',
55
+ default='nccl',
56
+ choices=['nccl', 'gloo'],
57
+ help='distributed backend')
58
+ parser.add_argument('--num_workers',
59
+ default=0,
60
+ type=int,
61
+ help='num of subprocess workers for reading')
62
+ parser.add_argument('--prefetch',
63
+ default=100,
64
+ type=int,
65
+ help='prefetch number')
66
+ parser.add_argument('--pin_memory',
67
+ action='store_true',
68
+ default=False,
69
+ help='Use pinned memory buffers used for reading')
70
+ parser.add_argument('--deepspeed.save_states',
71
+ dest='save_states',
72
+ default='model_only',
73
+ choices=['model_only', 'model+optimizer'],
74
+ help='save model/optimizer states')
75
+ parser.add_argument('--timeout',
76
+ default=30,
77
+ type=int,
78
+ help='timeout (in seconds) of cosyvoice_join.')
79
+ # parser = deepspeed.add_config_arguments(parser)
80
+ args = parser.parse_args()
81
+ return args
82
+
83
+
84
+ @record
85
+ def main():
86
+ args = get_args()
87
+ logging.basicConfig(level=logging.DEBUG,
88
+ format='%(asctime)s %(levelname)s %(message)s')
89
+
90
+ override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
91
+ with open(args.config, 'r') as f:
92
+ configs = load_hyperpyyaml(f, overrides=override_dict)
93
+ configs['train_conf'].update(vars(args))
94
+
95
+ # Init env for ddp
96
+ init_distributed(args)
97
+
98
+ # Get dataset & dataloader
99
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
100
+ init_dataset_and_dataloader(args, configs)
101
+
102
+ # Do some sanity checks and save config to arsg.model_dir
103
+ configs = check_modify_and_save_config(args, configs)
104
+
105
+ # Tensorboard summary
106
+ writer = init_summarywriter(args)
107
+
108
+ # load checkpoint
109
+ model = configs[args.model]
110
+ if args.checkpoint is not None:
111
+ model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
112
+
113
+ # Dispatch model from cpu to gpu
114
+ model = wrap_cuda_model(args, model)
115
+
116
+ # Get optimizer & scheduler
117
+ model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
118
+ # pdb.set_trace()
119
+ # Save init checkpoints
120
+ info_dict = deepcopy(configs['train_conf'])
121
+ save_model(model, 'init', info_dict)
122
+
123
+ # Get executor
124
+ executor = Executor()
125
+
126
+ # Start training loop
127
+ for epoch in range(info_dict['max_epoch']):
128
+ executor.epoch = epoch
129
+ train_dataset.set_epoch(epoch)
130
+ dist.barrier()
131
+ # try:
132
+ # dist.barrier()
133
+ # except RuntimeError as e:
134
+ # logging.info('except RuntimeError as e: {}'.format(e))
135
+ group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
136
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
137
+ dist.destroy_process_group(group_join)
138
+
139
+ if __name__ == '__main__':
140
+ main()
cosyvoice/cli/__init__.py ADDED
File without changes
cosyvoice/cli/cosyvoice.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import torch
16
+ from hyperpyyaml import load_hyperpyyaml
17
+ from modelscope import snapshot_download
18
+ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
19
+ from cosyvoice.cli.model import CosyVoiceModel
20
+
21
+ class CosyVoice:
22
+
23
+ def __init__(self, model_dir):
24
+ instruct = True if '-Instruct' in model_dir else False
25
+ self.model_dir = model_dir
26
+ if not os.path.exists(model_dir):
27
+ model_dir = snapshot_download(model_dir)
28
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
29
+ configs = load_hyperpyyaml(f)
30
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
31
+ configs['feat_extractor'],
32
+ '{}/campplus.onnx'.format(model_dir),
33
+ '{}/speech_tokenizer_v1.onnx'.format(model_dir),
34
+ '{}/spk2info.pt'.format(model_dir),
35
+ instruct,
36
+ configs['allowed_special'])
37
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
38
+ self.model.load('{}/llm.pt'.format(model_dir),
39
+ '{}/flow.pt'.format(model_dir),
40
+ '{}/hift.pt'.format(model_dir))
41
+ del configs
42
+
43
+ def list_avaliable_spks(self):
44
+ spks = list(self.frontend.spk2info.keys())
45
+ return spks
46
+
47
+ def inference_sft(self, tts_text, spk_id):
48
+ tts_speeches = []
49
+ for i in self.frontend.text_normalize(tts_text, split=True):
50
+ model_input = self.frontend.frontend_sft(i, spk_id)
51
+ model_output = self.model.inference(**model_input)
52
+ tts_speeches.append(model_output['tts_speech'])
53
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
54
+
55
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
56
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False)
57
+ tts_speeches = []
58
+ for i in self.frontend.text_normalize(tts_text, split=True):
59
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
60
+ model_output = self.model.inference(**model_input)
61
+ tts_speeches.append(model_output['tts_speech'])
62
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
63
+
64
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k):
65
+ if self.frontend.instruct is True:
66
+ raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
67
+ tts_speeches = []
68
+ for i in self.frontend.text_normalize(tts_text, split=True):
69
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
70
+ model_output = self.model.inference(**model_input)
71
+ tts_speeches.append(model_output['tts_speech'])
72
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
73
+
74
+ def inference_instruct(self, tts_text, spk_id, instruct_text):
75
+ if self.frontend.instruct is False:
76
+ raise ValueError('{} do not support instruct inference'.format(self.model_dir))
77
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False)
78
+ tts_speeches = []
79
+ for i in self.frontend.text_normalize(tts_text, split=True):
80
+ model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
81
+ model_output = self.model.inference(**model_input)
82
+ tts_speeches.append(model_output['tts_speech'])
83
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
cosyvoice/cli/frontend.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import partial
15
+ import onnxruntime
16
+ import torch
17
+ import numpy as np
18
+ import whisper
19
+ from typing import Callable
20
+ import torchaudio.compliance.kaldi as kaldi
21
+ import torchaudio
22
+ import os
23
+ import re
24
+ import inflect
25
+ try:
26
+ import ttsfrd
27
+ use_ttsfrd = True
28
+ except ImportError:
29
+ print("failed to import ttsfrd, use WeTextProcessing instead")
30
+ from tn.chinese.normalizer import Normalizer as ZhNormalizer
31
+ from tn.english.normalizer import Normalizer as EnNormalizer
32
+ use_ttsfrd = False
33
+ from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
34
+
35
+
36
+ class CosyVoiceFrontEnd:
37
+
38
+ def __init__(self,
39
+ get_tokenizer: Callable,
40
+ feat_extractor: Callable,
41
+ campplus_model: str,
42
+ speech_tokenizer_model: str,
43
+ spk2info: str = '',
44
+ instruct: bool = False,
45
+ allowed_special: str = 'all'):
46
+ self.tokenizer = get_tokenizer()
47
+ self.feat_extractor = feat_extractor
48
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49
+ option = onnxruntime.SessionOptions()
50
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
51
+ option.intra_op_num_threads = 1
52
+ self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
53
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"if torch.cuda.is_available() else "CPUExecutionProvider"])
54
+ if os.path.exists(spk2info):
55
+ self.spk2info = torch.load(spk2info, map_location=self.device)
56
+ self.instruct = instruct
57
+ self.allowed_special = allowed_special
58
+ self.inflect_parser = inflect.engine()
59
+ self.use_ttsfrd = use_ttsfrd
60
+ if self.use_ttsfrd:
61
+ self.frd = ttsfrd.TtsFrontendEngine()
62
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
63
+ assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, 'failed to initialize ttsfrd resource'
64
+ self.frd.set_lang_type('pinyin')
65
+ self.frd.enable_pinyin_mix(True)
66
+ self.frd.set_breakmodel_index(1)
67
+ else:
68
+ self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
69
+ self.en_tn_model = EnNormalizer()
70
+
71
+ def _extract_text_token(self, text):
72
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
73
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
74
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
75
+ return text_token, text_token_len
76
+
77
+ def _extract_speech_token(self, speech):
78
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
79
+ speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
80
+ self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
81
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
82
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
83
+ return speech_token, speech_token_len
84
+
85
+ def _extract_spk_embedding(self, speech):
86
+ feat = kaldi.fbank(speech,
87
+ num_mel_bins=80,
88
+ dither=0,
89
+ sample_frequency=16000)
90
+ feat = feat - feat.mean(dim=0, keepdim=True)
91
+ embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
92
+ embedding = torch.tensor([embedding]).to(self.device)
93
+ return embedding
94
+
95
+ def _extract_speech_feat(self, speech):
96
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
97
+ speech_feat = speech_feat.unsqueeze(dim=0)
98
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
99
+ return speech_feat, speech_feat_len
100
+
101
+ def text_normalize(self, text, split=True):
102
+ text = text.strip()
103
+ if contains_chinese(text):
104
+ if self.use_ttsfrd:
105
+ text = self.frd.get_frd_extra_info(text, 'input')
106
+ else:
107
+ text = self.zh_tn_model.normalize(text)
108
+ text = text.replace("\n", "")
109
+ text = replace_blank(text)
110
+ text = replace_corner_mark(text)
111
+ text = text.replace(".", "、")
112
+ text = text.replace(" - ", ",")
113
+ text = remove_bracket(text)
114
+ text = re.sub(r'[,,]+$', '。', text)
115
+ texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
116
+ token_min_n=60, merge_len=20,
117
+ comma_split=False)]
118
+ else:
119
+ if self.use_ttsfrd:
120
+ text = self.frd.get_frd_extra_info(text, 'input')
121
+ else:
122
+ text = self.en_tn_model.normalize(text)
123
+ text = spell_out_number(text, self.inflect_parser)
124
+ texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
125
+ token_min_n=60, merge_len=20,
126
+ comma_split=False)]
127
+ if split is False:
128
+ return text
129
+ return texts
130
+
131
+ def frontend_sft(self, tts_text, spk_id):
132
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
133
+ embedding = self.spk2info[spk_id]['embedding']
134
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
135
+ return model_input
136
+
137
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
138
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
139
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
140
+ prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
141
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
142
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
143
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
144
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
145
+ 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
146
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
147
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
148
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
149
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
150
+ return model_input
151
+
152
+ def frontend_cross_lingual(self, tts_text, prompt_speech_16k):
153
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k)
154
+ # in cross lingual mode, we remove prompt in llm
155
+ del model_input['prompt_text']
156
+ del model_input['prompt_text_len']
157
+ del model_input['llm_prompt_speech_token']
158
+ del model_input['llm_prompt_speech_token_len']
159
+ return model_input
160
+
161
+ def frontend_instruct(self, tts_text, spk_id, instruct_text):
162
+ model_input = self.frontend_sft(tts_text, spk_id)
163
+ # in instruct mode, we remove spk_embedding in llm due to information leakage
164
+ del model_input['llm_embedding']
165
+ instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
166
+ model_input['prompt_text'] = instruct_text_token
167
+ model_input['prompt_text_len'] = instruct_text_token_len
168
+ return model_input
cosyvoice/cli/model.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+
16
+ class CosyVoiceModel:
17
+
18
+ def __init__(self,
19
+ llm: torch.nn.Module,
20
+ flow: torch.nn.Module,
21
+ hift: torch.nn.Module):
22
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
+ self.llm = llm
24
+ self.flow = flow
25
+ self.hift = hift
26
+
27
+ def load(self, llm_model, flow_model, hift_model):
28
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
29
+ self.llm.to(self.device).eval()
30
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
31
+ self.flow.to(self.device).eval()
32
+ self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
33
+ self.hift.to(self.device).eval()
34
+
35
+ def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
36
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
37
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
38
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
39
+ prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
40
+ tts_speech_token = self.llm.inference(text=text.to(self.device),
41
+ text_len=text_len.to(self.device),
42
+ prompt_text=prompt_text.to(self.device),
43
+ prompt_text_len=prompt_text_len.to(self.device),
44
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
45
+ prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
46
+ embedding=llm_embedding.to(self.device),
47
+ beam_size=1,
48
+ sampling=25,
49
+ max_token_text_ratio=30,
50
+ min_token_text_ratio=3)
51
+ tts_mel = self.flow.inference(token=tts_speech_token,
52
+ token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
53
+ prompt_token=flow_prompt_speech_token.to(self.device),
54
+ prompt_token_len=flow_prompt_speech_token_len.to(self.device),
55
+ prompt_feat=prompt_speech_feat.to(self.device),
56
+ prompt_feat_len=prompt_speech_feat_len.to(self.device),
57
+ embedding=flow_embedding.to(self.device))
58
+ tts_speech = self.hift.inference(mel=tts_mel).cpu()
59
+ torch.cuda.empty_cache()
60
+ return {'tts_speech': tts_speech}
61
+
62
+ def text_to_token(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
63
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
64
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
65
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
66
+ prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
67
+ tts_speech_token = self.llm.inference(text=text.to(self.device),
68
+ text_len=text_len.to(self.device),
69
+ prompt_text=prompt_text.to(self.device),
70
+ prompt_text_len=prompt_text_len.to(self.device),
71
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
72
+ prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
73
+ embedding=llm_embedding.to(self.device),
74
+ beam_size=1,
75
+ sampling=25,
76
+ max_token_text_ratio=30,
77
+ min_token_text_ratio=3)
78
+ return tts_speech_token
79
+
80
+ def token_to_speech(self, tts_speech_token, flow_embedding, llm_embedding=torch.zeros(0, 192),
81
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
82
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
83
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
84
+ prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
85
+
86
+ tts_mel = self.flow.inference(token=tts_speech_token,
87
+ token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
88
+ prompt_token=flow_prompt_speech_token.to(self.device),
89
+ prompt_token_len=flow_prompt_speech_token_len.to(self.device),
90
+ prompt_feat=prompt_speech_feat.to(self.device),
91
+ prompt_feat_len=prompt_speech_feat_len.to(self.device),
92
+ embedding=flow_embedding.to(self.device))
93
+ tts_speech = self.hift.inference(mel=tts_mel).cpu()
94
+ torch.cuda.empty_cache()
95
+ return {'tts_speech': tts_speech}
cosyvoice/dataset/__init__.py ADDED
File without changes
cosyvoice/dataset/dataset.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import random
17
+ import json
18
+ import math
19
+ from functools import partial
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ from torch.utils.data import IterableDataset
24
+ from cosyvoice.utils.file_utils import read_lists, read_json_lists
25
+
26
+
27
+ class Processor(IterableDataset):
28
+
29
+ def __init__(self, source, f, *args, **kw):
30
+ assert callable(f)
31
+ self.source = source
32
+ self.f = f
33
+ self.args = args
34
+ self.kw = kw
35
+
36
+ def set_epoch(self, epoch):
37
+ self.source.set_epoch(epoch)
38
+
39
+ def __iter__(self):
40
+ """ Return an iterator over the source dataset processed by the
41
+ given processor.
42
+ """
43
+ assert self.source is not None
44
+ assert callable(self.f)
45
+ return self.f(iter(self.source), *self.args, **self.kw)
46
+
47
+ def apply(self, f):
48
+ assert callable(f)
49
+ return Processor(self, f, *self.args, **self.kw)
50
+
51
+
52
+ class DistributedSampler:
53
+
54
+ def __init__(self, shuffle=True, partition=True):
55
+ self.epoch = -1
56
+ self.update()
57
+ self.shuffle = shuffle
58
+ self.partition = partition
59
+
60
+ def update(self):
61
+ assert dist.is_available()
62
+ if dist.is_initialized():
63
+ self.rank = dist.get_rank()
64
+ self.world_size = dist.get_world_size()
65
+ else:
66
+ self.rank = 0
67
+ self.world_size = 1
68
+ worker_info = torch.utils.data.get_worker_info()
69
+ if worker_info is None:
70
+ self.worker_id = 0
71
+ self.num_workers = 1
72
+ else:
73
+ self.worker_id = worker_info.id
74
+ self.num_workers = worker_info.num_workers
75
+ return dict(rank=self.rank,
76
+ world_size=self.world_size,
77
+ worker_id=self.worker_id,
78
+ num_workers=self.num_workers)
79
+
80
+ def set_epoch(self, epoch):
81
+ self.epoch = epoch
82
+
83
+ def sample(self, data):
84
+ """ Sample data according to rank/world_size/num_workers
85
+
86
+ Args:
87
+ data(List): input data list
88
+
89
+ Returns:
90
+ List: data list after sample
91
+ """
92
+ data = list(range(len(data)))
93
+ # force datalist even
94
+ if self.partition:
95
+ if self.shuffle:
96
+ random.Random(self.epoch).shuffle(data)
97
+ if len(data) < self.world_size:
98
+ data = data * math.ceil(self.world_size / len(data))
99
+ data = data[:self.world_size]
100
+ data = data[self.rank::self.world_size]
101
+ if len(data) < self.num_workers:
102
+ data = data * math.ceil(self.num_workers / len(data))
103
+ data = data[:self.num_workers]
104
+ data = data[self.worker_id::self.num_workers]
105
+ return data
106
+
107
+
108
+ class DataList(IterableDataset):
109
+
110
+ def __init__(self, lists, shuffle=True, partition=True):
111
+ self.lists = lists
112
+ self.sampler = DistributedSampler(shuffle, partition)
113
+
114
+ def set_epoch(self, epoch):
115
+ self.sampler.set_epoch(epoch)
116
+
117
+ def __iter__(self):
118
+ sampler_info = self.sampler.update()
119
+ indexes = self.sampler.sample(self.lists)
120
+ for index in indexes:
121
+ data = dict(src=self.lists[index])
122
+ data.update(sampler_info)
123
+ yield data
124
+
125
+
126
+ def Dataset(data_list_file,
127
+ data_pipeline,
128
+ mode='train',
129
+ shuffle=True,
130
+ partition=True,
131
+ tts_file='',
132
+ prompt_utt2data=''):
133
+ """ Construct dataset from arguments
134
+
135
+ We have two shuffle stage in the Dataset. The first is global
136
+ shuffle at shards tar/raw file level. The second is global shuffle
137
+ at training samples level.
138
+
139
+ Args:
140
+ data_type(str): raw/shard
141
+ tokenizer (BaseTokenizer): tokenizer to tokenize
142
+ partition(bool): whether to do data partition in terms of rank
143
+ """
144
+ assert mode in ['train', 'inference']
145
+ lists = read_lists(data_list_file)
146
+ # import pdb
147
+ # pdb.set_trace()
148
+ if mode == 'inference':
149
+ with open(tts_file) as f:
150
+ tts_data = json.load(f)
151
+ utt2lists = read_json_lists(prompt_utt2data)
152
+ # filter unnecessary file in inference mode
153
+ lists = list(set([utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists]))
154
+ dataset = DataList(lists,shuffle=shuffle,partition=partition)
155
+ if mode == 'inference':
156
+ # map partial arg tts_data in inference mode
157
+ data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
158
+ for func in data_pipeline:
159
+ dataset = Processor(dataset, func, mode=mode)
160
+ return dataset
cosyvoice/dataset/processor.py ADDED
@@ -0,0 +1,965 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+ import json
17
+ import tarfile
18
+ import json
19
+ import io
20
+ import pyarrow.parquet as pq
21
+ from io import BytesIO
22
+ import torch
23
+ import torchaudio
24
+ from torch.nn.utils.rnn import pad_sequence
25
+ import torch.nn.functional as F
26
+ import tarfile
27
+ import json
28
+ import io
29
+ import wave
30
+ import numpy as np
31
+ import torchaudio
32
+ import os
33
+ import sys
34
+ import json
35
+ import random
36
+ import pickle
37
+ import argparse
38
+ import itertools
39
+ import mmap
40
+ import struct
41
+ import collections
42
+
43
+
44
+
45
+ import shutil
46
+ import multiprocessing as mp
47
+ from pathlib import Path
48
+
49
+ from tqdm import tqdm
50
+ from collections import defaultdict
51
+ from copy import deepcopy
52
+ from datetime import datetime
53
+ import pickle
54
+
55
+ from wids import wids
56
+ import math
57
+
58
+ torchaudio.set_audio_backend('soundfile')
59
+
60
+ AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
61
+
62
+ try:
63
+ MAIN_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/mean_embedding.pt")
64
+ GPT_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/spk_mean_embeddings.pt")
65
+ except:
66
+ MAIN_SPK_EMBEDDING=torch.zeros(1,192)
67
+ GPT_SPK_EMBEDDING=torch.zeros(1,192)
68
+
69
+ def parquet_opener(data, mode='train', tts_data={}):
70
+ """ Give url or local file, return file descriptor
71
+ Inplace operation.
72
+
73
+ Args:
74
+ data(Iterable[str]): url or local file list
75
+
76
+ Returns:
77
+ Iterable[{src, stream}]
78
+ """
79
+ for sample in data:
80
+ assert 'src' in sample
81
+ url = sample['src']
82
+ try:
83
+ df = pq.read_table(url).to_pandas()
84
+ for i in range(len(df)):
85
+ if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
86
+ continue
87
+ sample.update(dict(df.loc[i]))
88
+ if mode == 'train':
89
+ # NOTE do not return sample directly, must initialize a new dict
90
+ yield {**sample}
91
+ else:
92
+ for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
93
+ yield {**sample, 'tts_index': index, 'tts_text': text}
94
+ except Exception as ex:
95
+ logging.warning('Failed to open {}, ex info {}'.format(url, ex))
96
+
97
+
98
+
99
+
100
+ def parse_tar_header(header_bytes):
101
+ header = struct.unpack("!100s8s8s8s12s12s8s1s100s6s2s32s32s8s8s155s", header_bytes)
102
+ return TarHeader(*header)
103
+
104
+ TarHeader = collections.namedtuple(
105
+ "TarHeader",
106
+ [
107
+ "name",
108
+ "mode",
109
+ "uid",
110
+ "gid",
111
+ "size",
112
+ "mtime",
113
+ "chksum",
114
+ "typeflag",
115
+ "linkname",
116
+ "magic",
117
+ "version",
118
+ "uname",
119
+ "gname",
120
+ "devmajor",
121
+ "devminor",
122
+ "prefix",
123
+ ],
124
+ )
125
+
126
+ class MMTar:
127
+ def __init__(self, file_path: Path | str):
128
+ self.stream = open(file_path, "rb")
129
+ self.mmap = mmap.mmap(self.stream.fileno(), 0, access=mmap.ACCESS_READ)
130
+
131
+ def __del__(self):
132
+ try:
133
+ self.mmap.close()
134
+ self.stream.close()
135
+ except: # noqa
136
+ pass
137
+
138
+ def get_at_offset(self, offset) -> tuple[str, bytes]:
139
+ header = parse_tar_header(self.mmap[offset : offset + 500])
140
+ name = header.name.decode("utf-8").strip("\x00")
141
+ start = offset + 512
142
+ end = start + int(header.size.decode("utf-8")[:-1], 8)
143
+ return name, self.mmap[start:end]
144
+
145
+
146
+ class Tar:
147
+ def __init__(self, path: Path):
148
+ self.tar = MMTar(path)
149
+ indices_path = path.with_suffix(".index")
150
+ self.index = pickle.loads(indices_path.read_bytes())
151
+ self.name_mapping = {}
152
+ for name, offset, _ in self.index:
153
+ self.name_mapping[name] = offset
154
+
155
+ def read(self, name: str) -> bytes:
156
+ return self.tar.get_at_offset(self.name_mapping[name])[1]
157
+
158
+ def cosy_jsonl_opener(data, mode='train', tts_data={}):
159
+ """ Give url or local file, return file descriptor
160
+ Inplace operation.
161
+
162
+ Args:
163
+ data(Iterable[str]): url or local file list
164
+
165
+ Returns:
166
+ Iterable[{src, stream}]
167
+ """
168
+ for sample in data:
169
+ assert 'src' in sample
170
+ cosy_jsonl_path = sample['src']
171
+ tar_file_path=cosy_jsonl_path.replace(".vq0907.jsonl",".tar")
172
+ try:
173
+ tar_data=Tar(Path(tar_file_path))
174
+ with open(cosy_jsonl_path, 'r') as f:
175
+ for line in f:
176
+ item=json.loads(line)
177
+ cosy_token = item['cosy_token']
178
+ sample['speech_token']=torch.tensor(cosy_token)
179
+ sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
180
+ # print(item['filename'])
181
+ yield {**sample}
182
+
183
+ except Exception as ex:
184
+ logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
185
+
186
+
187
+ def cosy_jsonl_opener_vq0918_nopool(data, mode='train', tts_data={}):
188
+ """ Give url or local file, return file descriptor
189
+ Inplace operation.
190
+
191
+ Args:
192
+ data(Iterable[str]): url or local file list
193
+
194
+ Returns:
195
+ Iterable[{src, stream}]
196
+ """
197
+ for sample in data:
198
+ assert 'src' in sample
199
+ cosy_jsonl_path = sample['src']
200
+ tar_file_path=cosy_jsonl_path.replace(".vq0918-nopool.jsonl",".tar")
201
+
202
+
203
+ try:
204
+ tar_data=Tar(Path(tar_file_path))
205
+ with open(cosy_jsonl_path, 'r') as f:
206
+ # cosy_data = [json.loads(line) for line in f]
207
+ for line in f:
208
+ item=json.loads(line)
209
+ cosy_token = item['cosy_token']
210
+ sample['speech_token']=torch.tensor(cosy_token)
211
+ sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
212
+ # print(item['filename'])
213
+ yield {**sample}
214
+
215
+ except Exception as ex:
216
+ logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
217
+
218
+
219
+
220
+ def cosy_jsonl_opener_vq0918_pool2(data, mode='train', tts_data={}):
221
+ """ Give url or local file, return file descriptor
222
+ Inplace operation.
223
+
224
+ Args:
225
+ data(Iterable[str]): url or local file list
226
+
227
+ Returns:
228
+ Iterable[{src, stream}]
229
+ """
230
+ for sample in data:
231
+ assert 'src' in sample
232
+ cosy_jsonl_path = sample['src']
233
+ tar_file_path=cosy_jsonl_path.replace(".vq0918-pool2.jsonl",".tar")
234
+
235
+ try:
236
+ tar_data=Tar(Path(tar_file_path))
237
+ with open(cosy_jsonl_path, 'r') as f:
238
+ for line in f:
239
+ item=json.loads(line)
240
+ cosy_token = item['cosy_token']
241
+ sample['speech_token']=torch.tensor(cosy_token)
242
+ sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
243
+
244
+ yield {**sample}
245
+
246
+ except Exception as ex:
247
+ logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
248
+
249
+
250
+ def cosy_jsonl_opener_vq0918_pool4(data, mode='train', tts_data={}):
251
+ """ Give url or local file, return file descriptor
252
+ Inplace operation.
253
+
254
+ Args:
255
+ data(Iterable[str]): url or local file list
256
+
257
+ Returns:
258
+ Iterable[{src, stream}]
259
+ """
260
+ for sample in data:
261
+ assert 'src' in sample
262
+ cosy_jsonl_path = sample['src']
263
+ tar_file_path=cosy_jsonl_path.replace(".vq0918-pool4.jsonl",".tar")
264
+ try:
265
+ tar_data=Tar(Path(tar_file_path))
266
+ with open(cosy_jsonl_path, 'r') as f:
267
+ # cosy_data = [json.loads(line) for line in f]
268
+ for line in f:
269
+ item=json.loads(line)
270
+ cosy_token = item['cosy_token']
271
+ sample['speech_token']=torch.tensor(cosy_token)
272
+ sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
273
+ # print(item['filename'])
274
+ yield {**sample}
275
+
276
+ except Exception as ex:
277
+ logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
278
+
279
+
280
+ def cosy_jsonl_opener_vq0918_pool8(data, mode='train', tts_data={}):
281
+ """ Give url or local file, return file descriptor
282
+ Inplace operation.
283
+
284
+ Args:
285
+ data(Iterable[str]): url or local file list
286
+
287
+ Returns:
288
+ Iterable[{src, stream}]
289
+ """
290
+ for sample in data:
291
+ assert 'src' in sample
292
+ cosy_jsonl_path = sample['src']
293
+ tar_file_path=cosy_jsonl_path.replace(".vq0918-pool8.jsonl",".tar")
294
+
295
+ try:
296
+ tar_data=Tar(Path(tar_file_path))
297
+ with open(cosy_jsonl_path, 'r') as f:
298
+ # cosy_data = [json.loads(line) for line in f]
299
+ for line in f:
300
+ item=json.loads(line)
301
+ cosy_token = item['cosy_token']
302
+ sample['speech_token']=torch.tensor(cosy_token)
303
+ sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
304
+ # print(item['filename'])
305
+ yield {**sample}
306
+
307
+ except Exception as ex:
308
+ logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
309
+
310
+
311
+
312
+ def process_sft_vq0918_pool4(data, mode='train', tts_data={}):
313
+ for sample in data:
314
+ assert 'src' in sample
315
+
316
+ token_npy_path = sample['src']
317
+ wav_path=token_npy_path.replace(".vq0918-pool4.npy","")
318
+
319
+ # wav_path,token_npy_path=sample['src'].split(' ')
320
+ try:
321
+ sample['speech_token']=torch.tensor(np.load(token_npy_path))
322
+ sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
323
+ if sample['speech'].shape[0] > 1:
324
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
325
+ sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
326
+ yield {**sample}
327
+ except Exception as ex:
328
+ logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
329
+ logging.warning('Failed to open {}'.format(wav_path))
330
+
331
+
332
+ def process_sft_vq0918_pool4_split(data, mode='train',split_token=25, tts_data={}):
333
+ for sample in data:
334
+ assert 'src' in sample
335
+
336
+ token_npy_path = sample['src']
337
+ wav_path=token_npy_path.replace(".vq0918-pool4.npy","")
338
+
339
+ # wav_path,token_npy_path=sample['src'].split(' ')
340
+ try:
341
+ # sample['speech_token']=torch.tensor(np.load(token_npy_path))
342
+ # sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
343
+ # if sample['speech'].shape[0] > 1:
344
+ # sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
345
+
346
+
347
+ # sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
348
+
349
+
350
+ speech_token=torch.tensor(np.load(token_npy_path))
351
+ speech,sample_rate= torchaudio.load(wav_path)
352
+ # split_speech=int(split_token / 12.5 * sample_rate)
353
+ if speech.shape[0] > 1:
354
+ speech = speech.mean(dim=0, keepdim=True)
355
+
356
+ sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
357
+ sample['sample_rate']=sample_rate
358
+
359
+ num_splits = (speech_token.size(0) + split_token - 1) // split_token
360
+
361
+ for split_id in range(num_splits):
362
+ end_token_idx = min((split_id + 1) * split_token, speech_token.size(0))
363
+ end_speech_idx=int(np.ceil(end_token_idx / 12.5 * sample_rate))
364
+ sample['speech_token']=speech_token[:end_token_idx]
365
+ sample['speech']=speech[:,:end_speech_idx]
366
+ print(sample['speech_token'].size(),sample['speech'].size())
367
+ yield {**sample}
368
+ except Exception as ex:
369
+ logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
370
+ logging.warning('Failed to open {}'.format(wav_path))
371
+
372
+
373
+ def process_sft_vq0918_pool2(data, mode='train', tts_data={}):
374
+ for sample in data:
375
+ assert 'src' in sample
376
+
377
+ token_npy_path = sample['src'].replace(".vq0918-pool4.npy",".vq0918-pool2.npy")
378
+ wav_path=token_npy_path.replace(".vq0918-pool2.npy","")
379
+
380
+ # wav_path,token_npy_path=sample['src'].split(' ')
381
+ try:
382
+ sample['speech_token']=torch.tensor(np.load(token_npy_path))
383
+ sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
384
+ if sample['speech'].shape[0] > 1:
385
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
386
+
387
+ sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
388
+ yield {**sample}
389
+ except Exception as ex:
390
+ logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
391
+ logging.warning('Failed to open {}'.format(wav_path))
392
+
393
+
394
+ def process_sft_vq0918_pool2_split(data, mode='train',split_token=50, tts_data={}):
395
+ for sample in data:
396
+ assert 'src' in sample
397
+
398
+ token_npy_path = sample['src']
399
+ wav_path=token_npy_path.replace(".vq0918-pool2.npy","")
400
+
401
+ # wav_path,token_npy_path=sample['src'].split(' ')
402
+ try:
403
+ # sample['speech_token']=torch.tensor(np.load(token_npy_path))
404
+ # sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
405
+ # if sample['speech'].shape[0] > 1:
406
+ # sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
407
+
408
+
409
+ # sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
410
+
411
+
412
+ speech_token=torch.tensor(np.load(token_npy_path))
413
+ speech,sample_rate= torchaudio.load(wav_path)
414
+ # split_speech=int(split_token / 12.5 * sample_rate)
415
+ if speech.shape[0] > 1:
416
+ speech = speech.mean(dim=0, keepdim=True)
417
+
418
+ sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
419
+ sample['sample_rate']=sample_rate
420
+
421
+ num_splits = (speech_token.size(0) + split_token - 1) // split_token
422
+
423
+ for split_id in range(num_splits):
424
+ end_token_idx = min((split_id + 1) * split_token, speech_token.size(0))
425
+ end_speech_idx=int(np.ceil(end_token_idx / 25 * sample_rate))
426
+ sample['speech_token']=speech_token[:end_token_idx]
427
+ sample['speech']=speech[:,:end_speech_idx]
428
+ print(sample['speech_token'].size(),sample['speech'].size())
429
+ yield {**sample}
430
+ except Exception as ex:
431
+ logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
432
+ logging.warning('Failed to open {}'.format(wav_path))
433
+
434
+ def process_sft_vq0918_pool4_gpt(data, mode='train', tts_data={}):
435
+ for sample in data:
436
+ assert 'src' in sample
437
+ try:
438
+ entry=json.loads(sample['src'])
439
+ sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
440
+
441
+ for conv in entry["conversations"]:
442
+ if "response_wav" in conv:
443
+ wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}"
444
+ token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy")
445
+ sample['speech_token']=torch.tensor(np.load(token_npy_path))
446
+ sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
447
+ if sample['speech'].shape[0] > 1:
448
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
449
+ sample['spk_embedding']=spk_embedding
450
+ yield {**sample}
451
+ except Exception as ex:
452
+ # logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
453
+ logging.warning('Failed to open {}'.format(wav_path))
454
+
455
+
456
+ def process_sft_vq0918_pool4_gpt_1010(data, mode='train', tts_data={}):
457
+ for sample in data:
458
+ assert 'src' in sample
459
+ try:
460
+ entry=json.loads(sample['src'])
461
+ sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
462
+
463
+ for conv in entry["conversations"]:
464
+ if "response_wav" in conv:
465
+ wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}"
466
+ token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy")
467
+ sample['speech_token']=torch.tensor(np.load(token_npy_path))
468
+ sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
469
+ if sample['speech'].shape[0] > 1:
470
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
471
+ sample['spk_embedding']=spk_embedding
472
+ yield {**sample}
473
+ if "prompt_wav" in conv:
474
+ wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}"
475
+ token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy")
476
+ sample['speech_token']=torch.tensor(np.load(token_npy_path))
477
+ sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
478
+ if sample['speech'].shape[0] > 1:
479
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
480
+ sample['spk_embedding']=spk_embedding
481
+ yield {**sample}
482
+ except Exception as ex:
483
+ # logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
484
+ logging.warning('Failed to open {}'.format(wav_path))
485
+
486
+
487
+ def filter(data,
488
+ max_length=10240,
489
+ min_length=10,
490
+ token_max_length=200,
491
+ token_min_length=1,
492
+ min_output_input_ratio=0.0005,
493
+ max_output_input_ratio=1,
494
+ mode='train'):
495
+ """ Filter sample according to feature and label length
496
+ Inplace operation.
497
+
498
+ Args::
499
+ data: Iterable[{key, wav, label, sample_rate}]
500
+ max_length: drop utterance which is greater than max_length(10ms)
501
+ min_length: drop utterance which is less than min_length(10ms)
502
+ token_max_length: drop utterance which is greater than
503
+ token_max_length, especially when use char unit for
504
+ english modeling
505
+ token_min_length: drop utterance which is
506
+ less than token_max_length
507
+ min_output_input_ratio: minimal ration of
508
+ token_length / feats_length(10ms)
509
+ max_output_input_ratio: maximum ration of
510
+ token_length / feats_length(10ms)
511
+
512
+ Returns:
513
+ Iterable[{key, wav, label, sample_rate}]
514
+ """
515
+ for sample in data:
516
+ # sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
517
+ # del sample['audio_data']
518
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
519
+ num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
520
+ if num_frames < min_length:
521
+ continue
522
+ if num_frames > max_length:
523
+ continue
524
+ if len(sample['text_token']) < token_min_length:
525
+ continue
526
+ if len(sample['text_token']) > token_max_length:
527
+ continue
528
+ if len(sample['speech_token']) == 0:
529
+ continue
530
+ if num_frames != 0:
531
+ if len(sample['text_token']) / num_frames < min_output_input_ratio:
532
+ continue
533
+ if len(sample['text_token']) / num_frames > max_output_input_ratio:
534
+ continue
535
+ yield sample
536
+
537
+
538
+ def filter_speech_token(data,
539
+ max_length=10240,
540
+ min_length=10,
541
+ token_max_length=5000,
542
+ token_min_length=1,
543
+ min_output_input_ratio=0.0005,
544
+ max_output_input_ratio=30,
545
+ mode='train'):
546
+ """ Filter sample according to feature and label length
547
+ Inplace operation.
548
+
549
+ Args::
550
+ data: Iterable[{key, wav, label, sample_rate}]
551
+ max_length: drop utterance which is greater than max_length(10ms)
552
+ min_length: drop utterance which is less than min_length(10ms)
553
+ token_max_length: drop utterance which is greater than
554
+ token_max_length, especially when use char unit for
555
+ english modeling
556
+ token_min_length: drop utterance which is
557
+ less than token_max_length
558
+ min_output_input_ratio: minimal ration of
559
+ token_length / feats_length(10ms)
560
+ max_output_input_ratio: maximum ration of
561
+ token_length / feats_length(10ms)
562
+
563
+ Returns:
564
+ Iterable[{key, wav, label, sample_rate}]
565
+ """
566
+ for sample in data:
567
+ # sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
568
+ # del sample['audio_data']
569
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
570
+ num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
571
+ if num_frames < min_length:
572
+ continue
573
+ if num_frames > max_length:
574
+ continue
575
+ if len(sample['speech_token']) < token_min_length:
576
+ continue
577
+ if len(sample['speech_token']) > token_max_length:
578
+ continue
579
+ if len(sample['speech_token']) == 0:
580
+ continue
581
+ if num_frames != 0:
582
+ if len(sample['speech_token']) / num_frames < min_output_input_ratio:
583
+ continue
584
+ if len(sample['speech_token']) / num_frames > max_output_input_ratio:
585
+ continue
586
+ yield sample
587
+
588
+
589
+ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
590
+ """ Resample data.
591
+ Inplace operation.
592
+
593
+ Args:
594
+ data: Iterable[{key, wav, label, sample_rate}]
595
+ resample_rate: target resample rate
596
+
597
+ Returns:
598
+ Iterable[{key, wav, label, sample_rate}]
599
+ """
600
+ for sample in data:
601
+ assert 'sample_rate' in sample
602
+ assert 'speech' in sample
603
+ sample_rate = sample['sample_rate']
604
+ waveform = sample['speech']
605
+ if sample_rate != resample_rate:
606
+ if sample_rate < min_sample_rate:
607
+ continue
608
+ sample['sample_rate'] = resample_rate
609
+ sample['speech'] = torchaudio.transforms.Resample(
610
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
611
+ max_val = sample['speech'].abs().max()
612
+ if max_val > 1:
613
+ sample['speech'] /= max_val
614
+ yield sample
615
+
616
+
617
+ def compute_fbank(data,
618
+ feat_extractor,
619
+ mode='train'):
620
+ """ Extract fbank
621
+
622
+ Args:
623
+ data: Iterable[{key, wav, label, sample_rate}]
624
+
625
+ Returns:
626
+ Iterable[{key, feat, label}]
627
+ """
628
+ for sample in data:
629
+ assert 'sample_rate' in sample
630
+ assert 'speech' in sample
631
+ # assert 'utt' in sample
632
+ # assert 'text_token' in sample
633
+ waveform = sample['speech']
634
+ mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
635
+ sample['speech_feat'] = mat
636
+ del sample['speech']
637
+ yield sample
638
+
639
+
640
+ def parse_embedding(data, normalize, mode='train'):
641
+ """ Parse utt_embedding/spk_embedding
642
+
643
+ Args:
644
+ data: Iterable[{key, wav, label, sample_rate}]
645
+
646
+ Returns:
647
+ Iterable[{key, feat, label}]
648
+ """
649
+ for sample in data:
650
+ sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
651
+ sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
652
+ if normalize:
653
+ sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
654
+ sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
655
+ yield sample
656
+
657
+
658
+ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
659
+ """ Decode text to chars or BPE
660
+ Inplace operation
661
+
662
+ Args:
663
+ data: Iterable[{key, wav, txt, sample_rate}]
664
+
665
+ Returns:
666
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
667
+ """
668
+ tokenizer = get_tokenizer()
669
+ for sample in data:
670
+ assert 'text' in sample
671
+ sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
672
+ if mode == 'inference':
673
+ sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
674
+ yield sample
675
+
676
+
677
+ def shuffle(data, shuffle_size=10000, mode='train'):
678
+ """ Local shuffle the data
679
+
680
+ Args:
681
+ data: Iterable[{key, feat, label}]
682
+ shuffle_size: buffer size for shuffle
683
+
684
+ Returns:
685
+ Iterable[{key, feat, label}]
686
+ """
687
+ buf = []
688
+ for sample in data:
689
+ buf.append(sample)
690
+ if len(buf) >= shuffle_size:
691
+ random.shuffle(buf)
692
+ for x in buf:
693
+ yield x
694
+ buf = []
695
+ # The sample left over
696
+ random.shuffle(buf)
697
+ for x in buf:
698
+ yield x
699
+
700
+
701
+ def sort(data, sort_size=500, mode='train'):
702
+ """ Sort the data by feature length.
703
+ Sort is used after shuffle and before batch, so we can group
704
+ utts with similar lengths into a batch, and `sort_size` should
705
+ be less than `shuffle_size`
706
+
707
+ Args:
708
+ data: Iterable[{key, feat, label}]
709
+ sort_size: buffer size for sort
710
+
711
+ Returns:
712
+ Iterable[{key, feat, label}]
713
+ """
714
+
715
+ buf = []
716
+ for sample in data:
717
+ buf.append(sample)
718
+ if len(buf) >= sort_size:
719
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
720
+ for x in buf:
721
+ yield x
722
+ buf = []
723
+ # The sample left over
724
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
725
+ for x in buf:
726
+ yield x
727
+
728
+
729
+ def static_batch(data, batch_size=16):
730
+ """ Static batch the data by `batch_size`
731
+
732
+ Args:
733
+ data: Iterable[{key, feat, label}]
734
+ batch_size: batch size
735
+
736
+ Returns:
737
+ Iterable[List[{key, feat, label}]]
738
+ """
739
+ buf = []
740
+ for sample in data:
741
+ buf.append(sample)
742
+ if len(buf) >= batch_size:
743
+ yield buf
744
+ buf = []
745
+ if len(buf) > 0:
746
+ yield buf
747
+
748
+
749
+ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
750
+ """ Dynamic batch the data until the total frames in batch
751
+ reach `max_frames_in_batch`
752
+
753
+ Args:
754
+ data: Iterable[{key, feat, label}]
755
+ max_frames_in_batch: max_frames in one batch
756
+
757
+ Returns:
758
+ Iterable[List[{key, feat, label}]]
759
+ """
760
+ buf = []
761
+ longest_frames = 0
762
+ for sample in data:
763
+ assert 'speech_feat' in sample
764
+ assert isinstance(sample['speech_feat'], torch.Tensor)
765
+ new_sample_frames = sample['speech_feat'].size(0)
766
+ longest_frames = max(longest_frames, new_sample_frames)
767
+ frames_after_padding = longest_frames * (len(buf) + 1)
768
+ if frames_after_padding > max_frames_in_batch:
769
+ yield buf
770
+ buf = [sample]
771
+ longest_frames = new_sample_frames
772
+ else:
773
+ buf.append(sample)
774
+ if len(buf) > 0:
775
+ yield buf
776
+
777
+
778
+ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
779
+ """ Wrapper for static/dynamic batch
780
+ """
781
+ if mode == 'inference':
782
+ return static_batch(data, 1)
783
+ else:
784
+ if batch_type == 'static':
785
+ return static_batch(data, batch_size)
786
+ elif batch_type == 'dynamic':
787
+ return dynamic_batch(data, max_frames_in_batch)
788
+ else:
789
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
790
+
791
+
792
+ def padding(data, use_spk_embedding, mode='train'):
793
+ """ Padding the data into training data
794
+
795
+ Args:
796
+ data: Iterable[List[{key, feat, label}]]
797
+
798
+ Returns:
799
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
800
+ """
801
+ for sample in data:
802
+ assert isinstance(sample, list)
803
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
804
+ dtype=torch.int32)
805
+ order = torch.argsort(speech_feat_len, descending=True)
806
+
807
+ utts = [sample[i]['utt'] for i in order]
808
+ speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
809
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
810
+ speech_token = pad_sequence(speech_token,
811
+ batch_first=True,
812
+ padding_value=0)
813
+ speech_feat = [sample[i]['speech_feat'] for i in order]
814
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
815
+ speech_feat = pad_sequence(speech_feat,
816
+ batch_first=True,
817
+ padding_value=0)
818
+ text = [sample[i]['text'] for i in order]
819
+ text_token = [torch.tensor(sample[i]['text_token']) for i in order]
820
+ text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
821
+ text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
822
+ utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
823
+ spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
824
+ batch = {
825
+ "utts": utts,
826
+ "speech_token": speech_token,
827
+ "speech_token_len": speech_token_len,
828
+ "speech_feat": speech_feat,
829
+ "speech_feat_len": speech_feat_len,
830
+ "text": text,
831
+ "text_token": text_token,
832
+ "text_token_len": text_token_len,
833
+ "utt_embedding": utt_embedding,
834
+ "spk_embedding": spk_embedding,
835
+ }
836
+ if mode == 'inference':
837
+ tts_text = [sample[i]['tts_text'] for i in order]
838
+ tts_index = [sample[i]['tts_index'] for i in order]
839
+ tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
840
+ tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
841
+ tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
842
+ batch.update({'tts_text': tts_text,
843
+ 'tts_index': tts_index,
844
+ 'tts_text_token': tts_text_token,
845
+ 'tts_text_token_len': tts_text_token_len})
846
+ if use_spk_embedding is True:
847
+ batch["embedding"] = batch["spk_embedding"]
848
+ else:
849
+ batch["embedding"] = batch["utt_embedding"]
850
+ yield batch
851
+
852
+
853
+
854
+ def padding_speech_token(data, use_spk_embedding, mode='train'):
855
+ """ Padding the data into training data
856
+
857
+ Args:
858
+ data: Iterable[List[{key, feat, label}]]
859
+
860
+ Returns:
861
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
862
+ """
863
+ for sample in data:
864
+ assert isinstance(sample, list)
865
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
866
+ dtype=torch.int32)
867
+ order = torch.argsort(speech_feat_len, descending=True)
868
+
869
+ # utts = [sample[i]['utt'] for i in order]
870
+ # speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
871
+ try:
872
+ speech_token = [sample[i]['speech_token'].clone().detach() for i in order]
873
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
874
+ speech_token = pad_sequence(speech_token,
875
+ batch_first=True,
876
+ padding_value=0)
877
+ speech_feat = [sample[i]['speech_feat'] for i in order]
878
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
879
+ speech_feat = pad_sequence(speech_feat,
880
+ batch_first=True,
881
+ padding_value=0)
882
+ batch = {
883
+ "speech_token": speech_token,
884
+ "speech_token_len": speech_token_len,
885
+ "speech_feat": speech_feat,
886
+ "speech_feat_len": speech_feat_len,
887
+ }
888
+ if mode == 'inference':
889
+ tts_text = [sample[i]['tts_text'] for i in order]
890
+ tts_index = [sample[i]['tts_index'] for i in order]
891
+ tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
892
+ tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
893
+ tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
894
+ batch.update({'tts_text': tts_text,
895
+ 'tts_index': tts_index,
896
+ 'tts_text_token': tts_text_token,
897
+ 'tts_text_token_len': tts_text_token_len})
898
+ # if use_spk_embedding is True:
899
+ # batch["embedding"] = batch["spk_embedding"]
900
+ # else:
901
+ # batch["embedding"] = batch["utt_embedding"]
902
+ batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device)
903
+ yield batch
904
+ except Exception as ex:
905
+ logging.warning(' ex info {}'.format(ex))
906
+ # assert False
907
+
908
+
909
+
910
+ def padding_speech_token_spk(data, use_spk_embedding, mode='train'):
911
+ """ Padding the data into training data
912
+
913
+ Args:
914
+ data: Iterable[List[{key, feat, label}]]
915
+
916
+ Returns:
917
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
918
+ """
919
+ for sample in data:
920
+ assert isinstance(sample, list)
921
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
922
+ dtype=torch.int32)
923
+ order = torch.argsort(speech_feat_len, descending=True)
924
+
925
+ # utts = [sample[i]['utt'] for i in order]
926
+ # speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
927
+ try:
928
+ speech_token = [sample[i]['speech_token'].clone().detach() for i in order]
929
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
930
+ speech_token = pad_sequence(speech_token,
931
+ batch_first=True,
932
+ padding_value=0)
933
+ speech_feat = [sample[i]['speech_feat'] for i in order]
934
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
935
+ speech_feat = pad_sequence(speech_feat,
936
+ batch_first=True,
937
+ padding_value=0)
938
+ spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
939
+ batch = {
940
+ "speech_token": speech_token,
941
+ "speech_token_len": speech_token_len,
942
+ "speech_feat": speech_feat,
943
+ "speech_feat_len": speech_feat_len,
944
+ "spk_embedding": spk_embedding,
945
+ }
946
+ if mode == 'inference':
947
+ tts_text = [sample[i]['tts_text'] for i in order]
948
+ tts_index = [sample[i]['tts_index'] for i in order]
949
+ tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
950
+ tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
951
+ tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
952
+ batch.update({'tts_text': tts_text,
953
+ 'tts_index': tts_index,
954
+ 'tts_text_token': tts_text_token,
955
+ 'tts_text_token_len': tts_text_token_len})
956
+ # if use_spk_embedding is True:
957
+ # batch["embedding"] = batch["spk_embedding"]
958
+ # else:
959
+ # batch["embedding"] = batch["utt_embedding"]
960
+ # batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device)
961
+ batch["embedding"] = batch["spk_embedding"]
962
+ yield batch
963
+ except Exception as ex:
964
+ logging.warning(' ex info {}'.format(ex))
965
+ # assert False
cosyvoice/flow/__pycache__/decoder.cpython-310.pyc ADDED
Binary file (5.11 kB). View file
 
cosyvoice/flow/__pycache__/flow.cpython-310.pyc ADDED
Binary file (4.12 kB). View file
 
cosyvoice/flow/__pycache__/flow_matching.cpython-310.pyc ADDED
Binary file (4.55 kB). View file
 
cosyvoice/flow/__pycache__/length_regulator.cpython-310.pyc ADDED
Binary file (1.45 kB). View file
 
cosyvoice/flow/decoder.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ from einops import pack, rearrange, repeat
17
+ from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
18
+ from matcha.models.components.transformer import BasicTransformerBlock
19
+
20
+
21
+ class ConditionalDecoder(nn.Module):
22
+ def __init__(
23
+ self,
24
+ in_channels,
25
+ out_channels,
26
+ channels=(256, 256),
27
+ dropout=0.05,
28
+ attention_head_dim=64,
29
+ n_blocks=1,
30
+ num_mid_blocks=2,
31
+ num_heads=4,
32
+ act_fn="snake",
33
+ ):
34
+ """
35
+ This decoder requires an input with the same shape of the target. So, if your text content
36
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
37
+ """
38
+ super().__init__()
39
+ channels = tuple(channels)
40
+ self.in_channels = in_channels
41
+ self.out_channels = out_channels
42
+
43
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
44
+ time_embed_dim = channels[0] * 4
45
+ self.time_mlp = TimestepEmbedding(
46
+ in_channels=in_channels,
47
+ time_embed_dim=time_embed_dim,
48
+ act_fn="silu",
49
+ )
50
+ self.down_blocks = nn.ModuleList([])
51
+ self.mid_blocks = nn.ModuleList([])
52
+ self.up_blocks = nn.ModuleList([])
53
+
54
+ output_channel = in_channels
55
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
56
+ input_channel = output_channel
57
+ output_channel = channels[i]
58
+ is_last = i == len(channels) - 1
59
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
60
+ transformer_blocks = nn.ModuleList(
61
+ [
62
+ BasicTransformerBlock(
63
+ dim=output_channel,
64
+ num_attention_heads=num_heads,
65
+ attention_head_dim=attention_head_dim,
66
+ dropout=dropout,
67
+ activation_fn=act_fn,
68
+ )
69
+ for _ in range(n_blocks)
70
+ ]
71
+ )
72
+ downsample = (
73
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
74
+ )
75
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
76
+
77
+ for i in range(num_mid_blocks):
78
+ input_channel = channels[-1]
79
+ out_channels = channels[-1]
80
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
81
+
82
+ transformer_blocks = nn.ModuleList(
83
+ [
84
+ BasicTransformerBlock(
85
+ dim=output_channel,
86
+ num_attention_heads=num_heads,
87
+ attention_head_dim=attention_head_dim,
88
+ dropout=dropout,
89
+ activation_fn=act_fn,
90
+ )
91
+ for _ in range(n_blocks)
92
+ ]
93
+ )
94
+
95
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
96
+
97
+ channels = channels[::-1] + (channels[0],)
98
+ for i in range(len(channels) - 1):
99
+ input_channel = channels[i] * 2
100
+ output_channel = channels[i + 1]
101
+ is_last = i == len(channels) - 2
102
+ resnet = ResnetBlock1D(
103
+ dim=input_channel,
104
+ dim_out=output_channel,
105
+ time_emb_dim=time_embed_dim,
106
+ )
107
+ transformer_blocks = nn.ModuleList(
108
+ [
109
+ BasicTransformerBlock(
110
+ dim=output_channel,
111
+ num_attention_heads=num_heads,
112
+ attention_head_dim=attention_head_dim,
113
+ dropout=dropout,
114
+ activation_fn=act_fn,
115
+ )
116
+ for _ in range(n_blocks)
117
+ ]
118
+ )
119
+ upsample = (
120
+ Upsample1D(output_channel, use_conv_transpose=True)
121
+ if not is_last
122
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
123
+ )
124
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
125
+ self.final_block = Block1D(channels[-1], channels[-1])
126
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
127
+ self.initialize_weights()
128
+
129
+
130
+ def initialize_weights(self):
131
+ for m in self.modules():
132
+ if isinstance(m, nn.Conv1d):
133
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
134
+ if m.bias is not None:
135
+ nn.init.constant_(m.bias, 0)
136
+ elif isinstance(m, nn.GroupNorm):
137
+ nn.init.constant_(m.weight, 1)
138
+ nn.init.constant_(m.bias, 0)
139
+ elif isinstance(m, nn.Linear):
140
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
141
+ if m.bias is not None:
142
+ nn.init.constant_(m.bias, 0)
143
+
144
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
145
+ """Forward pass of the UNet1DConditional model.
146
+
147
+ Args:
148
+ x (torch.Tensor): shape (batch_size, in_channels, time)
149
+ mask (_type_): shape (batch_size, 1, time)
150
+ t (_type_): shape (batch_size)
151
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
152
+ cond (_type_, optional): placeholder for future use. Defaults to None.
153
+
154
+ Raises:
155
+ ValueError: _description_
156
+ ValueError: _description_
157
+
158
+ Returns:
159
+ _type_: _description_
160
+ """
161
+
162
+ t = self.time_embeddings(t)
163
+ t = self.time_mlp(t)
164
+
165
+ x = pack([x, mu], "b * t")[0]
166
+
167
+ if spks is not None:
168
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
169
+ x = pack([x, spks], "b * t")[0]
170
+ if cond is not None:
171
+ x = pack([x, cond], "b * t")[0]
172
+
173
+ hiddens = []
174
+ masks = [mask]
175
+ for resnet, transformer_blocks, downsample in self.down_blocks:
176
+ mask_down = masks[-1]
177
+ x = resnet(x, mask_down, t)
178
+ x = rearrange(x, "b c t -> b t c").contiguous()
179
+ attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
180
+ for transformer_block in transformer_blocks:
181
+ x = transformer_block(
182
+ hidden_states=x,
183
+ attention_mask=attn_mask,
184
+ timestep=t,
185
+ )
186
+ x = rearrange(x, "b t c -> b c t").contiguous()
187
+ hiddens.append(x) # Save hidden states for skip connections
188
+ x = downsample(x * mask_down)
189
+ masks.append(mask_down[:, :, ::2])
190
+ masks = masks[:-1]
191
+ mask_mid = masks[-1]
192
+
193
+ for resnet, transformer_blocks in self.mid_blocks:
194
+ x = resnet(x, mask_mid, t)
195
+ x = rearrange(x, "b c t -> b t c").contiguous()
196
+ attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
197
+ for transformer_block in transformer_blocks:
198
+ x = transformer_block(
199
+ hidden_states=x,
200
+ attention_mask=attn_mask,
201
+ timestep=t,
202
+ )
203
+ x = rearrange(x, "b t c -> b c t").contiguous()
204
+
205
+ for resnet, transformer_blocks, upsample in self.up_blocks:
206
+ mask_up = masks.pop()
207
+ skip = hiddens.pop()
208
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
209
+ x = resnet(x, mask_up, t)
210
+ x = rearrange(x, "b c t -> b t c").contiguous()
211
+ attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
212
+ for transformer_block in transformer_blocks:
213
+ x = transformer_block(
214
+ hidden_states=x,
215
+ attention_mask=attn_mask,
216
+ timestep=t,
217
+ )
218
+ x = rearrange(x, "b t c -> b c t").contiguous()
219
+ x = upsample(x * mask_up)
220
+ x = self.final_block(x, mask_up)
221
+ output = self.final_proj(x * mask_up)
222
+ return output * mask
cosyvoice/flow/flow.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+ from typing import Dict, Optional
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+ from omegaconf import DictConfig
21
+ from cosyvoice.utils.mask import make_pad_mask
22
+
23
+
24
+ class MaskedDiffWithXvec(torch.nn.Module):
25
+ def __init__(self,
26
+ input_size: int = 512,
27
+ output_size: int = 80,
28
+ spk_embed_dim: int = 192,
29
+ output_type: str = "mel",
30
+ vocab_size: int = 4096,
31
+ input_frame_rate: int = 50,
32
+ only_mask_loss: bool = True,
33
+ encoder: torch.nn.Module = None,
34
+ length_regulator: torch.nn.Module = None,
35
+ decoder: torch.nn.Module = None,
36
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
37
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
38
+ super().__init__()
39
+ self.input_size = input_size
40
+ self.output_size = output_size
41
+ self.decoder_conf = decoder_conf
42
+ self.mel_feat_conf = mel_feat_conf
43
+ self.vocab_size = vocab_size
44
+ self.output_type = output_type
45
+ self.input_frame_rate = input_frame_rate
46
+ logging.info(f"input frame rate={self.input_frame_rate}")
47
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
48
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
49
+ self.encoder = encoder
50
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
51
+ self.decoder = decoder
52
+ self.length_regulator = length_regulator
53
+ self.only_mask_loss = only_mask_loss
54
+
55
+ def forward(
56
+ self,
57
+ batch: dict,
58
+ device: torch.device,
59
+ ) -> Dict[str, Optional[torch.Tensor]]:
60
+ token = batch['speech_token'].to(device)
61
+ token_len = batch['speech_token_len'].to(device)
62
+ feat = batch['speech_feat'].to(device)
63
+ feat_len = batch['speech_feat_len'].to(device)
64
+ embedding = batch['embedding'].to(device)
65
+
66
+ # xvec projection
67
+ embedding = F.normalize(embedding, dim=1)
68
+ embedding = self.spk_embed_affine_layer(embedding)
69
+ # embedding=None
70
+
71
+ # concat text and prompt_text
72
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
73
+ # print(token.max(),self.input_embedding)
74
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
75
+
76
+
77
+ # text encode
78
+ h, h_lengths = self.encoder(token, token_len)
79
+ h = self.encoder_proj(h)
80
+ h, h_lengths = self.length_regulator(h, feat_len)
81
+
82
+ # get conditions
83
+ conds = torch.zeros(feat.shape, device=token.device)
84
+ for i, j in enumerate(feat_len):
85
+ if random.random() < 0.5:
86
+ continue
87
+ index = random.randint(0, int(0.8 * j))
88
+ conds[i, :index] = feat[i, :index]
89
+ conds = conds.transpose(1, 2)
90
+
91
+ mask = (~make_pad_mask(feat_len)).to(h)
92
+ feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
93
+ loss, _ = self.decoder.compute_loss(
94
+ feat.transpose(1, 2).contiguous(),
95
+ mask.unsqueeze(1),
96
+ h.transpose(1, 2).contiguous(),
97
+ embedding,
98
+ cond=conds
99
+ )
100
+ return {'loss': loss}
101
+
102
+ @torch.inference_mode()
103
+ def inference(self,
104
+ token,
105
+ token_len,
106
+ prompt_token,
107
+ prompt_token_len,
108
+ prompt_feat,
109
+ prompt_feat_len,
110
+ embedding):
111
+ assert token.shape[0] == 1
112
+ # xvec projection
113
+ embedding = F.normalize(embedding, dim=1)
114
+ embedding = self.spk_embed_affine_layer(embedding)
115
+
116
+ # concat text and prompt_text
117
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
118
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
119
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
120
+
121
+ # text encode
122
+ h, h_lengths = self.encoder(token, token_len)
123
+ h = self.encoder_proj(h)
124
+ feat_len = (token_len / self.input_frame_rate * 22050 / 256).int()
125
+ h, h_lengths = self.length_regulator(h, feat_len)
126
+
127
+ # get conditions
128
+ conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device)
129
+ if prompt_feat.shape[1] != 0:
130
+ for i, j in enumerate(prompt_feat_len):
131
+ conds[i, :j] = prompt_feat[i]
132
+ conds = conds.transpose(1, 2)
133
+
134
+ mask = (~make_pad_mask(feat_len)).to(h)
135
+ feat = self.decoder(
136
+ mu=h.transpose(1, 2).contiguous(),
137
+ mask=mask.unsqueeze(1),
138
+ spks=embedding,
139
+ cond=conds,
140
+ n_timesteps=10
141
+ )
142
+ if prompt_feat.shape[1] != 0:
143
+ feat = feat[:, :, prompt_feat.shape[1]:]
144
+ return feat
cosyvoice/flow/flow_gradtts.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+ from typing import Dict, Optional
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+ from omegaconf import DictConfig
21
+ from cosyvoice.utils.mask import make_pad_mask
22
+
23
+
24
+ class MaskedDiffWithXvec(torch.nn.Module):
25
+ def __init__(self,
26
+ input_size: int = 512,
27
+ output_size: int = 80,
28
+ spk_embed_dim: int = 192,
29
+ output_type: str = "mel",
30
+ vocab_size: int = 4096,
31
+ input_frame_rate: int = 50,
32
+ only_mask_loss: bool = True,
33
+ encoder: torch.nn.Module = None,
34
+ length_regulator: torch.nn.Module = None,
35
+ decoder: torch.nn.Module = None,
36
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
37
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
38
+ super().__init__()
39
+ self.input_size = input_size
40
+ self.output_size = output_size
41
+ self.decoder_conf = decoder_conf
42
+ self.mel_feat_conf = mel_feat_conf
43
+ self.vocab_size = vocab_size
44
+ self.output_type = output_type
45
+ self.input_frame_rate = input_frame_rate
46
+ logging.info(f"input frame rate={self.input_frame_rate}")
47
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
48
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
49
+ self.encoder = encoder
50
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
51
+ self.decoder = decoder
52
+ self.length_regulator = length_regulator
53
+ self.only_mask_loss = only_mask_loss
54
+
55
+ def forward(
56
+ self,
57
+ batch: dict,
58
+ device: torch.device,
59
+ ) -> Dict[str, Optional[torch.Tensor]]:
60
+ token = batch['speech_token'].to(device)
61
+ token_len = batch['speech_token_len'].to(device)
62
+ feat = batch['speech_feat'].to(device)
63
+ feat_len = batch['speech_feat_len'].to(device)
64
+ embedding = batch['embedding'].to(device)
65
+
66
+ # xvec projection
67
+ embedding = F.normalize(embedding, dim=1)
68
+ embedding = self.spk_embed_affine_layer(embedding)
69
+ # embedding=None
70
+
71
+ # concat text and prompt_text
72
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
73
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
74
+
75
+ # text encode
76
+ h, h_lengths = self.encoder(token, token_len)
77
+ h = self.encoder_proj(h)
78
+ h, h_lengths = self.length_regulator(h, feat_len)
79
+
80
+ # get conditions
81
+ conds = torch.zeros(feat.shape, device=token.device)
82
+ # for i, j in enumerate(feat_len):
83
+ # if random.random() < 0.5:
84
+ # continue
85
+ # index = random.randint(0, int(0.3 * j))
86
+ # conds[i, :index] = feat[i, :index]
87
+ conds = conds.transpose(1, 2)
88
+
89
+ mask = (~make_pad_mask(feat_len)).to(h)
90
+ feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
91
+ loss, _ = self.decoder.compute_loss(
92
+ feat.transpose(1, 2).contiguous(),
93
+ mask.unsqueeze(1),
94
+ h.transpose(1, 2).contiguous(),
95
+ embedding,
96
+ cond=conds
97
+ )
98
+ return {'loss': loss}
99
+
100
+ @torch.inference_mode()
101
+ def inference(self,
102
+ token,
103
+ token_len,
104
+ prompt_token,
105
+ prompt_token_len,
106
+ prompt_feat,
107
+ prompt_feat_len,
108
+ embedding):
109
+ assert token.shape[0] == 1
110
+ # xvec projection
111
+ embedding = F.normalize(embedding, dim=1)
112
+ embedding = self.spk_embed_affine_layer(embedding)
113
+
114
+ # concat text and prompt_text
115
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
116
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
117
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
118
+
119
+ # text encode
120
+ h, h_lengths = self.encoder(token, token_len)
121
+ h = self.encoder_proj(h)
122
+ feat_len = (token_len / self.input_frame_rate * 22050 / 256).int()
123
+ h, h_lengths = self.length_regulator(h, feat_len)
124
+
125
+ # get conditions
126
+ conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device)
127
+ if prompt_feat.shape[1] != 0:
128
+ for i, j in enumerate(prompt_feat_len):
129
+ conds[i, :j] = prompt_feat[i]
130
+ conds = conds.transpose(1, 2)
131
+
132
+ mask = (~make_pad_mask(feat_len)).to(h)
133
+ feat = self.decoder(
134
+ mu=h.transpose(1, 2).contiguous(),
135
+ mask=mask.unsqueeze(1),
136
+ spks=embedding,
137
+ cond=conds,
138
+ n_timesteps=10
139
+ )
140
+ if prompt_feat.shape[1] != 0:
141
+ feat = feat[:, :, prompt_feat.shape[1]:]
142
+ return feat
cosyvoice/flow/flow_matching.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from matcha.models.components.flow_matching import BASECFM
17
+
18
+ class ConditionalCFM(BASECFM):
19
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
20
+ super().__init__(
21
+ n_feats=in_channels,
22
+ cfm_params=cfm_params,
23
+ n_spks=n_spks,
24
+ spk_emb_dim=spk_emb_dim,
25
+ )
26
+ self.t_scheduler = cfm_params.t_scheduler
27
+ self.training_cfg_rate = cfm_params.training_cfg_rate
28
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
29
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
30
+ # Just change the architecture of the estimator here
31
+ self.estimator = estimator
32
+
33
+ @torch.inference_mode()
34
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
35
+ """Forward diffusion
36
+
37
+ Args:
38
+ mu (torch.Tensor): output of encoder
39
+ shape: (batch_size, n_feats, mel_timesteps)
40
+ mask (torch.Tensor): output_mask
41
+ shape: (batch_size, 1, mel_timesteps)
42
+ n_timesteps (int): number of diffusion steps
43
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
44
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
45
+ shape: (batch_size, spk_emb_dim)
46
+ cond: Not used but kept for future purposes
47
+
48
+ Returns:
49
+ sample: generated mel-spectrogram
50
+ shape: (batch_size, n_feats, mel_timesteps)
51
+ """
52
+ torch.manual_seed(42)
53
+
54
+ z = torch.randn_like(mu) * temperature
55
+
56
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
57
+ if self.t_scheduler == 'cosine':
58
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
59
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
60
+
61
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
62
+ """
63
+ Fixed euler solver for ODEs.
64
+ Args:
65
+ x (torch.Tensor): random noise
66
+ t_span (torch.Tensor): n_timesteps interpolated
67
+ shape: (n_timesteps + 1,)
68
+ mu (torch.Tensor): output of encoder
69
+ shape: (batch_size, n_feats, mel_timesteps)
70
+ mask (torch.Tensor): output_mask
71
+ shape: (batch_size, 1, mel_timesteps)
72
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
73
+ shape: (batch_size, spk_emb_dim)
74
+ cond: Not used but kept for future purposes
75
+ """
76
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
77
+
78
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
79
+ # Or in future might add like a return_all_steps flag
80
+ sol = []
81
+
82
+ for step in range(1, len(t_span)):
83
+ dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
84
+ # Classifier-Free Guidance inference introduced in VoiceBox
85
+ if self.inference_cfg_rate > 0:
86
+ cfg_dphi_dt = self.estimator(
87
+ x, mask,
88
+ torch.zeros_like(mu), t,
89
+ torch.zeros_like(spks) if spks is not None else None,
90
+ torch.zeros_like(cond)
91
+ )
92
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
93
+ self.inference_cfg_rate * cfg_dphi_dt)
94
+ x = x + dt * dphi_dt
95
+ t = t + dt
96
+
97
+ sol.append(x)
98
+ if step < len(t_span) - 1:
99
+ dt = t_span[step + 1] - t
100
+
101
+ return sol[-1]
102
+
103
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
104
+ """Computes diffusion loss
105
+
106
+ Args:
107
+ x1 (torch.Tensor): Target
108
+ shape: (batch_size, n_feats, mel_timesteps)
109
+ mask (torch.Tensor): target mask
110
+ shape: (batch_size, 1, mel_timesteps)
111
+ mu (torch.Tensor): output of encoder
112
+ shape: (batch_size, n_feats, mel_timesteps)
113
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
114
+ shape: (batch_size, spk_emb_dim)
115
+
116
+ Returns:
117
+ loss: conditional flow matching loss
118
+ y: conditional flow
119
+ shape: (batch_size, n_feats, mel_timesteps)
120
+ """
121
+ b, _, t = mu.shape
122
+
123
+ # random timestep
124
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
125
+ if self.t_scheduler == 'cosine':
126
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
127
+ # sample noise p(x_0)
128
+ z = torch.randn_like(x1)
129
+
130
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
131
+ u = x1 - (1 - self.sigma_min) * z
132
+
133
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
134
+ if self.training_cfg_rate > 0:
135
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
136
+ mu = mu * cfg_mask.view(-1, 1, 1)
137
+ spks = spks * cfg_mask.view(-1, 1)
138
+ cond = cond * cfg_mask.view(-1, 1, 1)
139
+
140
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
141
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
142
+ return loss, y
cosyvoice/flow/flow_matching_dit.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import pdb
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from matcha.models.components.flow_matching import BASECFM
19
+
20
+
21
+ class ConditionalCFM(BASECFM):
22
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
23
+ super().__init__(
24
+ n_feats=in_channels,
25
+ cfm_params=cfm_params,
26
+ n_spks=n_spks,
27
+ spk_emb_dim=spk_emb_dim,
28
+ )
29
+ self.t_scheduler = cfm_params.t_scheduler
30
+ self.training_cfg_rate = cfm_params.training_cfg_rate
31
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
32
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
33
+ # Just change the architecture of the estimator here
34
+
35
+ io_channels = 80
36
+ input_concat_dim = 80
37
+ embed_dim = 768
38
+ depth = 24
39
+ num_heads = 24
40
+ project_cond_tokens = False
41
+ transformer_type = "continuous_transformer"
42
+ self.estimator = estimator
43
+
44
+ @torch.inference_mode()
45
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
46
+ """Forward diffusion
47
+
48
+ Args:
49
+ mu (torch.Tensor): output of encoder
50
+ shape: (batch_size, n_feats, mel_timesteps)
51
+ mask (torch.Tensor): output_mask
52
+ shape: (batch_size, 1, mel_timesteps)
53
+ n_timesteps (int): number of diffusion steps
54
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
55
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
56
+ shape: (batch_size, spk_emb_dim)
57
+ cond: Not used but kept for future purposes
58
+
59
+ Returns:
60
+ sample: generated mel-spectrogram
61
+ shape: (batch_size, n_feats, mel_timesteps)
62
+ """
63
+ z = torch.randn_like(mu) * temperature
64
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
65
+ if self.t_scheduler == 'cosine':
66
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
67
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
68
+
69
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
70
+ """
71
+ Fixed euler solver for ODEs.
72
+ Args:
73
+ x (torch.Tensor): random noise torch.Size([1, 80, 621])
74
+ t_span (torch.Tensor): n_timesteps interpolated
75
+ shape: (n_timesteps + 1,)
76
+ mu (torch.Tensor): output of encoder
77
+ shape: (batch_size, n_feats, mel_timesteps)
78
+ mask (torch.Tensor): output_mask
79
+ shape: (batch_size, 1, mel_timesteps)
80
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
81
+ shape: (batch_size, spk_emb_dim)
82
+ cond: Not used but kept for future purposes
83
+ """
84
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
85
+
86
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
87
+ # Or in future might add like a return_all_steps flag
88
+ sol = []
89
+
90
+ cfg_dropout_prob = 0.1
91
+ cfg_scale = 1.0
92
+
93
+ # cfg_dropout_prob = 0.0
94
+ # cfg_scale = 3.0
95
+
96
+ for step in range(1, len(t_span)):
97
+ # dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
98
+ # pdb.set_trace()
99
+ dphi_dt = self.estimator(x, # [bs, 80, 229]
100
+ t[None], # (bs,)
101
+ global_embed=spks,
102
+ input_concat_cond=mu,
103
+ mask=mask[0], # [bs, 229]
104
+ cfg_dropout_prob=cfg_dropout_prob, cfg_scale=cfg_scale)
105
+
106
+ # Classifier-Free Guidance inference introduced in VoiceBox
107
+ if self.inference_cfg_rate > 0:
108
+ # cfg_dphi_dt = self.estimator(
109
+ # x, mask,
110
+ # torch.zeros_like(mu), t,
111
+ # torch.zeros_like(spks) if spks is not None else None,
112
+ # torch.zeros_like(cond)
113
+ # )
114
+ cfg_dphi_dt = self.estimator(x, # [bs, 80, 229]
115
+ t[None], # (bs,)
116
+ global_embed=torch.zeros_like(spks) if spks is not None else None,
117
+ input_concat_cond=torch.zeros_like(mu),
118
+ mask=mask[0], # [bs, 229]
119
+ cfg_dropout_prob=cfg_dropout_prob, cfg_scale=cfg_scale)
120
+
121
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
122
+ self.inference_cfg_rate * cfg_dphi_dt)
123
+ x = x + dt * dphi_dt
124
+ t = t + dt
125
+ sol.append(x)
126
+ if step < len(t_span) - 1:
127
+ dt = t_span[step + 1] - t
128
+
129
+ return sol[-1]
130
+
131
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
132
+ """Computes diffusion loss
133
+
134
+ Args:
135
+ x1 (torch.Tensor): Target
136
+ shape: (batch_size, n_feats, mel_timesteps)
137
+ mask (torch.Tensor): target mask
138
+ shape: (batch_size, 1, mel_timesteps)
139
+ mu (torch.Tensor): output of encoder
140
+ shape: (batch_size, n_feats, mel_timesteps)
141
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
142
+ shape: (batch_size, spk_emb_dim)
143
+
144
+ Returns:
145
+ loss: conditional flow matching loss
146
+ y: conditional flow
147
+ shape: (batch_size, n_feats, mel_timesteps)
148
+ """
149
+ b, _, t = mu.shape
150
+
151
+ # random timestep
152
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
153
+ if self.t_scheduler == 'cosine':
154
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
155
+ # sample noise p(x_0)
156
+ z = torch.randn_like(x1)
157
+
158
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
159
+ u = x1 - (1 - self.sigma_min) * z
160
+
161
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
162
+ if self.training_cfg_rate > 0:
163
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
164
+ mu = mu * cfg_mask.view(-1, 1, 1)
165
+ spks = spks * cfg_mask.view(-1, 1)
166
+ cond = cond * cfg_mask.view(-1, 1, 1)
167
+
168
+ # pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
169
+ pred = self.estimator(y, # [bs, 80, 229]
170
+ t.squeeze(1, 2), # (bs,)
171
+ global_embed=spks,
172
+ input_concat_cond=mu,
173
+ mask=mask.squeeze(1), # [bs, 229]
174
+ cfg_dropout_prob=0.1)
175
+
176
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
177
+ return loss, y
178
+
179
+ # def estimator_trans(self):
180
+ # pass
cosyvoice/flow/length_regulator.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Tuple
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+ from cosyvoice.utils.mask import make_pad_mask
18
+
19
+
20
+ class InterpolateRegulator(nn.Module):
21
+ def __init__(
22
+ self,
23
+ channels: int,
24
+ sampling_ratios: Tuple,
25
+ out_channels: int = None,
26
+ groups: int = 1,
27
+ ):
28
+ super().__init__()
29
+ self.sampling_ratios = sampling_ratios
30
+ out_channels = out_channels or channels
31
+ model = nn.ModuleList([])
32
+ if len(sampling_ratios) > 0:
33
+ for _ in sampling_ratios:
34
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
35
+ norm = nn.GroupNorm(groups, channels)
36
+ act = nn.Mish()
37
+ model.extend([module, norm, act])
38
+ model.append(
39
+ nn.Conv1d(channels, out_channels, 1, 1)
40
+ )
41
+ self.model = nn.Sequential(*model)
42
+
43
+ def forward(self, x, ylens=None):
44
+ # x in (B, T, D)
45
+ mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
46
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
47
+ out = self.model(x).transpose(1, 2).contiguous()
48
+ olens = ylens
49
+ return out * mask, olens
cosyvoice/flow/stable/adp.py ADDED
@@ -0,0 +1,1591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License
2
+ # License can be found in LICENSES/LICENSE_ADP.txt
3
+
4
+ import math
5
+ from inspect import isfunction
6
+ from math import ceil, floor, log, pi, log2
7
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
8
+ from packaging import version
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange, reduce, repeat
13
+ from einops.layers.torch import Rearrange
14
+ from einops_exts import rearrange_many
15
+ from torch import Tensor, einsum
16
+ from torch.backends.cuda import sdp_kernel
17
+ from torch.nn import functional as F
18
+ from dac.nn.layers import Snake1d
19
+ import pdb
20
+ """
21
+ Utils
22
+ """
23
+
24
+
25
+ class ConditionedSequential(nn.Module):
26
+ def __init__(self, *modules):
27
+ super().__init__()
28
+ self.module_list = nn.ModuleList(*modules)
29
+
30
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None):
31
+ for module in self.module_list:
32
+ x = module(x, mapping)
33
+ return x
34
+
35
+ T = TypeVar("T")
36
+
37
+ def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
38
+ if exists(val):
39
+ return val
40
+ return d() if isfunction(d) else d
41
+
42
+ def exists(val: Optional[T]) -> T:
43
+ return val is not None
44
+
45
+ def closest_power_2(x: float) -> int:
46
+ exponent = log2(x)
47
+ distance_fn = lambda z: abs(x - 2 ** z) # noqa
48
+ exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
49
+ return 2 ** int(exponent_closest)
50
+
51
+ def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
52
+ return_dicts: Tuple[Dict, Dict] = ({}, {})
53
+ for key in d.keys():
54
+ no_prefix = int(not key.startswith(prefix))
55
+ return_dicts[no_prefix][key] = d[key]
56
+ return return_dicts
57
+
58
+ def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
59
+ kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
60
+ if keep_prefix:
61
+ return kwargs_with_prefix, kwargs
62
+ kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
63
+ return kwargs_no_prefix, kwargs
64
+
65
+ """
66
+ Convolutional Blocks
67
+ """
68
+ import typing as tp
69
+
70
+ # Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License
71
+ # License available in LICENSES/LICENSE_META.txt
72
+
73
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
74
+ padding_total: int = 0) -> int:
75
+ """See `pad_for_conv1d`."""
76
+ length = x.shape[-1]
77
+ n_frames = (length - kernel_size + padding_total) / stride + 1
78
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
79
+ return ideal_length - length
80
+
81
+
82
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
83
+ """Pad for a convolution to make sure that the last window is full.
84
+ Extra padding is added at the end. This is required to ensure that we can rebuild
85
+ an output of the same length, as otherwise, even with padding, some time steps
86
+ might get removed.
87
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
88
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
89
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
90
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
91
+ 1 2 3 4 # once you removed padding, we are missing one time step !
92
+ """
93
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
94
+ return F.pad(x, (0, extra_padding))
95
+
96
+
97
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
98
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
99
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
100
+ """
101
+ length = x.shape[-1]
102
+ padding_left, padding_right = paddings
103
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
104
+ if mode == 'reflect':
105
+ max_pad = max(padding_left, padding_right)
106
+ extra_pad = 0
107
+ if length <= max_pad:
108
+ extra_pad = max_pad - length + 1
109
+ x = F.pad(x, (0, extra_pad))
110
+ padded = F.pad(x, paddings, mode, value)
111
+ end = padded.shape[-1] - extra_pad
112
+ return padded[..., :end]
113
+ else:
114
+ return F.pad(x, paddings, mode, value)
115
+
116
+
117
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
118
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
119
+ padding_left, padding_right = paddings
120
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
121
+ assert (padding_left + padding_right) <= x.shape[-1]
122
+ end = x.shape[-1] - padding_right
123
+ return x[..., padding_left: end]
124
+
125
+
126
+ class Conv1d(nn.Conv1d):
127
+ def __init__(self, *args, **kwargs):
128
+ super().__init__(*args, **kwargs)
129
+
130
+ def forward(self, x: Tensor, causal=False) -> Tensor:
131
+ kernel_size = self.kernel_size[0]
132
+ stride = self.stride[0]
133
+ dilation = self.dilation[0]
134
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
135
+ padding_total = kernel_size - stride
136
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
137
+ if causal:
138
+ # Left padding for causal
139
+ x = pad1d(x, (padding_total, extra_padding))
140
+ else:
141
+ # Asymmetric padding required for odd strides
142
+ padding_right = padding_total // 2
143
+ padding_left = padding_total - padding_right
144
+ x = pad1d(x, (padding_left, padding_right + extra_padding))
145
+ return super().forward(x)
146
+
147
+ class ConvTranspose1d(nn.ConvTranspose1d):
148
+ def __init__(self, *args, **kwargs):
149
+ super().__init__(*args, **kwargs)
150
+
151
+ def forward(self, x: Tensor, causal=False) -> Tensor:
152
+ kernel_size = self.kernel_size[0]
153
+ stride = self.stride[0]
154
+ padding_total = kernel_size - stride
155
+
156
+ y = super().forward(x)
157
+
158
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
159
+ # removed at the very end, when keeping only the right length for the output,
160
+ # as removing it here would require also passing the length at the matching layer
161
+ # in the encoder.
162
+ if causal:
163
+ padding_right = ceil(padding_total)
164
+ padding_left = padding_total - padding_right
165
+ y = unpad1d(y, (padding_left, padding_right))
166
+ else:
167
+ # Asymmetric padding required for odd strides
168
+ padding_right = padding_total // 2
169
+ padding_left = padding_total - padding_right
170
+ y = unpad1d(y, (padding_left, padding_right))
171
+ return y
172
+
173
+
174
+ def Downsample1d(
175
+ in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
176
+ ) -> nn.Module:
177
+ assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
178
+
179
+ return Conv1d(
180
+ in_channels=in_channels,
181
+ out_channels=out_channels,
182
+ kernel_size=factor * kernel_multiplier + 1,
183
+ stride=factor
184
+ )
185
+
186
+
187
+ def Upsample1d(
188
+ in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
189
+ ) -> nn.Module:
190
+
191
+ if factor == 1:
192
+ return Conv1d(
193
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3
194
+ )
195
+
196
+ if use_nearest:
197
+ return nn.Sequential(
198
+ nn.Upsample(scale_factor=factor, mode="nearest"),
199
+ Conv1d(
200
+ in_channels=in_channels,
201
+ out_channels=out_channels,
202
+ kernel_size=3
203
+ ),
204
+ )
205
+ else:
206
+ return ConvTranspose1d(
207
+ in_channels=in_channels,
208
+ out_channels=out_channels,
209
+ kernel_size=factor * 2,
210
+ stride=factor
211
+ )
212
+
213
+
214
+ class ConvBlock1d(nn.Module):
215
+ def __init__(
216
+ self,
217
+ in_channels: int,
218
+ out_channels: int,
219
+ *,
220
+ kernel_size: int = 3,
221
+ stride: int = 1,
222
+ dilation: int = 1,
223
+ num_groups: int = 8,
224
+ use_norm: bool = True,
225
+ use_snake: bool = False
226
+ ) -> None:
227
+ super().__init__()
228
+
229
+ self.groupnorm = (
230
+ nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
231
+ if use_norm
232
+ else nn.Identity()
233
+ )
234
+
235
+ if use_snake:
236
+ self.activation = Snake1d(in_channels)
237
+ else:
238
+ self.activation = nn.SiLU()
239
+
240
+ self.project = Conv1d(
241
+ in_channels=in_channels,
242
+ out_channels=out_channels,
243
+ kernel_size=kernel_size,
244
+ stride=stride,
245
+ dilation=dilation,
246
+ )
247
+
248
+ def forward(
249
+ self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False
250
+ ) -> Tensor:
251
+ x = self.groupnorm(x)
252
+ if exists(scale_shift):
253
+ scale, shift = scale_shift
254
+ x = x * (scale + 1) + shift
255
+ x = self.activation(x)
256
+ return self.project(x, causal=causal)
257
+
258
+
259
+ class MappingToScaleShift(nn.Module):
260
+ def __init__(
261
+ self,
262
+ features: int,
263
+ channels: int,
264
+ ):
265
+ super().__init__()
266
+
267
+ self.to_scale_shift = nn.Sequential(
268
+ nn.SiLU(),
269
+ nn.Linear(in_features=features, out_features=channels * 2),
270
+ )
271
+
272
+ def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]:
273
+ scale_shift = self.to_scale_shift(mapping)
274
+ scale_shift = rearrange(scale_shift, "b c -> b c 1")
275
+ scale, shift = scale_shift.chunk(2, dim=1)
276
+ return scale, shift
277
+
278
+
279
+ class ResnetBlock1d(nn.Module):
280
+ def __init__(
281
+ self,
282
+ in_channels: int,
283
+ out_channels: int,
284
+ *,
285
+ kernel_size: int = 3,
286
+ stride: int = 1,
287
+ dilation: int = 1,
288
+ use_norm: bool = True,
289
+ use_snake: bool = False,
290
+ num_groups: int = 8,
291
+ context_mapping_features: Optional[int] = None,
292
+ ) -> None:
293
+ super().__init__()
294
+
295
+ self.use_mapping = exists(context_mapping_features)
296
+
297
+ self.block1 = ConvBlock1d(
298
+ in_channels=in_channels,
299
+ out_channels=out_channels,
300
+ kernel_size=kernel_size,
301
+ stride=stride,
302
+ dilation=dilation,
303
+ use_norm=use_norm,
304
+ num_groups=num_groups,
305
+ use_snake=use_snake
306
+ )
307
+
308
+ if self.use_mapping:
309
+ assert exists(context_mapping_features)
310
+ self.to_scale_shift = MappingToScaleShift(
311
+ features=context_mapping_features, channels=out_channels
312
+ )
313
+
314
+ self.block2 = ConvBlock1d(
315
+ in_channels=out_channels,
316
+ out_channels=out_channels,
317
+ use_norm=use_norm,
318
+ num_groups=num_groups,
319
+ use_snake=use_snake
320
+ )
321
+
322
+ self.to_out = (
323
+ Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
324
+ if in_channels != out_channels
325
+ else nn.Identity()
326
+ )
327
+
328
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
329
+ assert_message = "context mapping required if context_mapping_features > 0"
330
+ assert not (self.use_mapping ^ exists(mapping)), assert_message
331
+
332
+ h = self.block1(x, causal=causal)
333
+
334
+ scale_shift = None
335
+ if self.use_mapping:
336
+ scale_shift = self.to_scale_shift(mapping)
337
+
338
+ h = self.block2(h, scale_shift=scale_shift, causal=causal)
339
+
340
+ return h + self.to_out(x)
341
+
342
+
343
+ class Patcher(nn.Module):
344
+ def __init__(
345
+ self,
346
+ in_channels: int,
347
+ out_channels: int,
348
+ patch_size: int,
349
+ context_mapping_features: Optional[int] = None,
350
+ use_snake: bool = False,
351
+ ):
352
+ super().__init__()
353
+ assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
354
+ assert out_channels % patch_size == 0, assert_message
355
+ self.patch_size = patch_size
356
+
357
+ self.block = ResnetBlock1d(
358
+ in_channels=in_channels,
359
+ out_channels=out_channels // patch_size,
360
+ num_groups=1,
361
+ context_mapping_features=context_mapping_features,
362
+ use_snake=use_snake
363
+ )
364
+
365
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
366
+ x = self.block(x, mapping, causal=causal)
367
+ x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size)
368
+ return x
369
+
370
+
371
+ class Unpatcher(nn.Module):
372
+ def __init__(
373
+ self,
374
+ in_channels: int,
375
+ out_channels: int,
376
+ patch_size: int,
377
+ context_mapping_features: Optional[int] = None,
378
+ use_snake: bool = False
379
+ ):
380
+ super().__init__()
381
+ assert_message = f"in_channels must be divisible by patch_size ({patch_size})"
382
+ assert in_channels % patch_size == 0, assert_message
383
+ self.patch_size = patch_size
384
+
385
+ self.block = ResnetBlock1d(
386
+ in_channels=in_channels // patch_size,
387
+ out_channels=out_channels,
388
+ num_groups=1,
389
+ context_mapping_features=context_mapping_features,
390
+ use_snake=use_snake
391
+ )
392
+
393
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
394
+ x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size)
395
+ x = self.block(x, mapping, causal=causal)
396
+ return x
397
+
398
+
399
+ """
400
+ Attention Components
401
+ """
402
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
403
+ mid_features = features * multiplier
404
+ return nn.Sequential(
405
+ nn.Linear(in_features=features, out_features=mid_features),
406
+ nn.GELU(),
407
+ nn.Linear(in_features=mid_features, out_features=features),
408
+ )
409
+
410
+ def add_mask(sim: Tensor, mask: Tensor) -> Tensor:
411
+ b, ndim = sim.shape[0], mask.ndim
412
+ if ndim == 3:
413
+ mask = rearrange(mask, "b n m -> b 1 n m")
414
+ if ndim == 2:
415
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
416
+ max_neg_value = -torch.finfo(sim.dtype).max
417
+ sim = sim.masked_fill(~mask, max_neg_value)
418
+ return sim
419
+
420
+ def causal_mask(q: Tensor, k: Tensor) -> Tensor:
421
+ b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
422
+ mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1)
423
+ mask = repeat(mask, "n m -> b n m", b=b)
424
+ return mask
425
+
426
+ class AttentionBase(nn.Module):
427
+ def __init__(
428
+ self,
429
+ features: int,
430
+ *,
431
+ head_features: int,
432
+ num_heads: int,
433
+ out_features: Optional[int] = None,
434
+ ):
435
+ super().__init__()
436
+ self.scale = head_features**-0.5
437
+ self.num_heads = num_heads
438
+ mid_features = head_features * num_heads
439
+ out_features = default(out_features, features)
440
+
441
+ self.to_out = nn.Linear(
442
+ in_features=mid_features, out_features=out_features
443
+ )
444
+
445
+ self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
446
+
447
+ if not self.use_flash:
448
+ return
449
+
450
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
451
+
452
+ if device_properties.major == 8 and device_properties.minor == 0:
453
+ # Use flash attention for A100 GPUs
454
+ self.sdp_kernel_config = (True, False, False)
455
+ else:
456
+ # Don't use flash attention for other GPUs
457
+ self.sdp_kernel_config = (False, True, True)
458
+
459
+ def forward(
460
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False
461
+ ) -> Tensor:
462
+ # Split heads
463
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
464
+
465
+ if not self.use_flash:
466
+ if is_causal and not mask:
467
+ # Mask out future tokens for causal attention
468
+ mask = causal_mask(q, k)
469
+
470
+ # Compute similarity matrix and add eventual mask
471
+ sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale
472
+ sim = add_mask(sim, mask) if exists(mask) else sim
473
+
474
+ # Get attention matrix with softmax
475
+ attn = sim.softmax(dim=-1, dtype=torch.float32)
476
+
477
+ # Compute values
478
+ out = einsum("... n m, ... m d -> ... n d", attn, v)
479
+ else:
480
+ with sdp_kernel(*self.sdp_kernel_config):
481
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal)
482
+
483
+ out = rearrange(out, "b h n d -> b n (h d)")
484
+ return self.to_out(out)
485
+
486
+ class Attention(nn.Module):
487
+ def __init__(
488
+ self,
489
+ features: int,
490
+ *,
491
+ head_features: int,
492
+ num_heads: int,
493
+ out_features: Optional[int] = None,
494
+ context_features: Optional[int] = None,
495
+ causal: bool = False,
496
+ ):
497
+ super().__init__()
498
+ self.context_features = context_features
499
+ self.causal = causal
500
+ mid_features = head_features * num_heads
501
+ context_features = default(context_features, features)
502
+
503
+ self.norm = nn.LayerNorm(features)
504
+ self.norm_context = nn.LayerNorm(context_features)
505
+ self.to_q = nn.Linear(
506
+ in_features=features, out_features=mid_features, bias=False
507
+ )
508
+ self.to_kv = nn.Linear(
509
+ in_features=context_features, out_features=mid_features * 2, bias=False
510
+ )
511
+ self.attention = AttentionBase(
512
+ features,
513
+ num_heads=num_heads,
514
+ head_features=head_features,
515
+ out_features=out_features,
516
+ )
517
+
518
+ def forward(
519
+ self,
520
+ x: Tensor, # [b, n, c]
521
+ context: Optional[Tensor] = None, # [b, m, d]
522
+ context_mask: Optional[Tensor] = None, # [b, m], false is masked,
523
+ causal: Optional[bool] = False,
524
+ ) -> Tensor:
525
+ assert_message = "You must provide a context when using context_features"
526
+ assert not self.context_features or exists(context), assert_message
527
+ # Use context if provided
528
+ context = default(context, x)
529
+ # Normalize then compute q from input and k,v from context
530
+ x, context = self.norm(x), self.norm_context(context)
531
+
532
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
533
+
534
+ if exists(context_mask):
535
+ # Mask out cross-attention for padding tokens
536
+ mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1])
537
+ k, v = k * mask, v * mask
538
+
539
+ # Compute and return attention
540
+ return self.attention(q, k, v, is_causal=self.causal or causal)
541
+
542
+
543
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
544
+ mid_features = features * multiplier
545
+ return nn.Sequential(
546
+ nn.Linear(in_features=features, out_features=mid_features),
547
+ nn.GELU(),
548
+ nn.Linear(in_features=mid_features, out_features=features),
549
+ )
550
+
551
+ """
552
+ Transformer Blocks
553
+ """
554
+
555
+
556
+ class TransformerBlock(nn.Module):
557
+ def __init__(
558
+ self,
559
+ features: int,
560
+ num_heads: int,
561
+ head_features: int,
562
+ multiplier: int,
563
+ context_features: Optional[int] = None,
564
+ ):
565
+ super().__init__()
566
+
567
+ self.use_cross_attention = exists(context_features) and context_features > 0
568
+
569
+ self.attention = Attention(
570
+ features=features,
571
+ num_heads=num_heads,
572
+ head_features=head_features
573
+ )
574
+
575
+ if self.use_cross_attention:
576
+ self.cross_attention = Attention(
577
+ features=features,
578
+ num_heads=num_heads,
579
+ head_features=head_features,
580
+ context_features=context_features
581
+ )
582
+
583
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
584
+
585
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor:
586
+ x = self.attention(x, causal=causal) + x
587
+ if self.use_cross_attention:
588
+ x = self.cross_attention(x, context=context, context_mask=context_mask) + x
589
+ x = self.feed_forward(x) + x
590
+ return x
591
+
592
+
593
+ """
594
+ Transformers
595
+ """
596
+
597
+
598
+ class Transformer1d(nn.Module):
599
+ def __init__(
600
+ self,
601
+ num_layers: int,
602
+ channels: int,
603
+ num_heads: int,
604
+ head_features: int,
605
+ multiplier: int,
606
+ context_features: Optional[int] = None,
607
+ ):
608
+ super().__init__()
609
+
610
+ self.to_in = nn.Sequential(
611
+ nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True),
612
+ Conv1d(
613
+ in_channels=channels,
614
+ out_channels=channels,
615
+ kernel_size=1,
616
+ ),
617
+ Rearrange("b c t -> b t c"),
618
+ )
619
+
620
+ self.blocks = nn.ModuleList(
621
+ [
622
+ TransformerBlock(
623
+ features=channels,
624
+ head_features=head_features,
625
+ num_heads=num_heads,
626
+ multiplier=multiplier,
627
+ context_features=context_features,
628
+ )
629
+ for i in range(num_layers)
630
+ ]
631
+ )
632
+
633
+ self.to_out = nn.Sequential(
634
+ Rearrange("b t c -> b c t"),
635
+ Conv1d(
636
+ in_channels=channels,
637
+ out_channels=channels,
638
+ kernel_size=1,
639
+ ),
640
+ )
641
+
642
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor:
643
+ x = self.to_in(x)
644
+ for block in self.blocks:
645
+ x = block(x, context=context, context_mask=context_mask, causal=causal)
646
+ x = self.to_out(x)
647
+ return x
648
+
649
+
650
+ """
651
+ Time Embeddings
652
+ """
653
+
654
+
655
+ class SinusoidalEmbedding(nn.Module):
656
+ def __init__(self, dim: int):
657
+ super().__init__()
658
+ self.dim = dim
659
+
660
+ def forward(self, x: Tensor) -> Tensor:
661
+ device, half_dim = x.device, self.dim // 2
662
+ emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
663
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
664
+ emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
665
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
666
+
667
+
668
+ class LearnedPositionalEmbedding(nn.Module):
669
+ """Used for continuous time"""
670
+
671
+ def __init__(self, dim: int):
672
+ super().__init__()
673
+ assert (dim % 2) == 0
674
+ half_dim = dim // 2
675
+ self.weights = nn.Parameter(torch.randn(half_dim))
676
+
677
+ def forward(self, x: Tensor) -> Tensor:
678
+ x = rearrange(x, "b -> b 1")
679
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
680
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
681
+ fouriered = torch.cat((x, fouriered), dim=-1)
682
+ return fouriered
683
+
684
+
685
+ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
686
+ return nn.Sequential(
687
+ LearnedPositionalEmbedding(dim),
688
+ nn.Linear(in_features=dim + 1, out_features=out_features),
689
+ )
690
+
691
+
692
+ """
693
+ Encoder/Decoder Components
694
+ """
695
+
696
+
697
+ class DownsampleBlock1d(nn.Module):
698
+ def __init__(
699
+ self,
700
+ in_channels: int,
701
+ out_channels: int,
702
+ *,
703
+ factor: int,
704
+ num_groups: int,
705
+ num_layers: int,
706
+ kernel_multiplier: int = 2,
707
+ use_pre_downsample: bool = True,
708
+ use_skip: bool = False,
709
+ use_snake: bool = False,
710
+ extract_channels: int = 0,
711
+ context_channels: int = 0,
712
+ num_transformer_blocks: int = 0,
713
+ attention_heads: Optional[int] = None,
714
+ attention_features: Optional[int] = None,
715
+ attention_multiplier: Optional[int] = None,
716
+ context_mapping_features: Optional[int] = None,
717
+ context_embedding_features: Optional[int] = None,
718
+ ):
719
+ super().__init__()
720
+ self.use_pre_downsample = use_pre_downsample
721
+ self.use_skip = use_skip
722
+ self.use_transformer = num_transformer_blocks > 0
723
+ self.use_extract = extract_channels > 0
724
+ self.use_context = context_channels > 0
725
+
726
+ channels = out_channels if use_pre_downsample else in_channels
727
+
728
+ self.downsample = Downsample1d(
729
+ in_channels=in_channels,
730
+ out_channels=out_channels,
731
+ factor=factor,
732
+ kernel_multiplier=kernel_multiplier,
733
+ )
734
+
735
+ self.blocks = nn.ModuleList(
736
+ [
737
+ ResnetBlock1d(
738
+ in_channels=channels + context_channels if i == 0 else channels,
739
+ out_channels=channels,
740
+ num_groups=num_groups,
741
+ context_mapping_features=context_mapping_features,
742
+ use_snake=use_snake
743
+ )
744
+ for i in range(num_layers)
745
+ ]
746
+ )
747
+
748
+ if self.use_transformer:
749
+ assert (
750
+ (exists(attention_heads) or exists(attention_features))
751
+ and exists(attention_multiplier)
752
+ )
753
+
754
+ if attention_features is None and attention_heads is not None:
755
+ attention_features = channels // attention_heads
756
+
757
+ if attention_heads is None and attention_features is not None:
758
+ attention_heads = channels // attention_features
759
+
760
+ self.transformer = Transformer1d(
761
+ num_layers=num_transformer_blocks,
762
+ channels=channels,
763
+ num_heads=attention_heads,
764
+ head_features=attention_features,
765
+ multiplier=attention_multiplier,
766
+ context_features=context_embedding_features
767
+ )
768
+
769
+ if self.use_extract:
770
+ num_extract_groups = min(num_groups, extract_channels)
771
+ self.to_extracted = ResnetBlock1d(
772
+ in_channels=out_channels,
773
+ out_channels=extract_channels,
774
+ num_groups=num_extract_groups,
775
+ use_snake=use_snake
776
+ )
777
+
778
+ def forward(
779
+ self,
780
+ x: Tensor,
781
+ *,
782
+ mapping: Optional[Tensor] = None,
783
+ channels: Optional[Tensor] = None,
784
+ embedding: Optional[Tensor] = None,
785
+ embedding_mask: Optional[Tensor] = None,
786
+ causal: Optional[bool] = False
787
+ ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]:
788
+
789
+ if self.use_pre_downsample:
790
+ x = self.downsample(x)
791
+
792
+ if self.use_context and exists(channels):
793
+ x = torch.cat([x, channels], dim=1)
794
+
795
+ skips = []
796
+ for block in self.blocks:
797
+ x = block(x, mapping=mapping, causal=causal)
798
+ skips += [x] if self.use_skip else []
799
+
800
+ if self.use_transformer:
801
+ x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
802
+ skips += [x] if self.use_skip else []
803
+
804
+ if not self.use_pre_downsample:
805
+ x = self.downsample(x)
806
+
807
+ if self.use_extract:
808
+ extracted = self.to_extracted(x)
809
+ return x, extracted
810
+
811
+ return (x, skips) if self.use_skip else x
812
+
813
+
814
+ class UpsampleBlock1d(nn.Module):
815
+ def __init__(
816
+ self,
817
+ in_channels: int,
818
+ out_channels: int,
819
+ *,
820
+ factor: int,
821
+ num_layers: int,
822
+ num_groups: int,
823
+ use_nearest: bool = False,
824
+ use_pre_upsample: bool = False,
825
+ use_skip: bool = False,
826
+ use_snake: bool = False,
827
+ skip_channels: int = 0,
828
+ use_skip_scale: bool = False,
829
+ extract_channels: int = 0,
830
+ num_transformer_blocks: int = 0,
831
+ attention_heads: Optional[int] = None,
832
+ attention_features: Optional[int] = None,
833
+ attention_multiplier: Optional[int] = None,
834
+ context_mapping_features: Optional[int] = None,
835
+ context_embedding_features: Optional[int] = None,
836
+ ):
837
+ super().__init__()
838
+
839
+ self.use_extract = extract_channels > 0
840
+ self.use_pre_upsample = use_pre_upsample
841
+ self.use_transformer = num_transformer_blocks > 0
842
+ self.use_skip = use_skip
843
+ self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0
844
+
845
+ channels = out_channels if use_pre_upsample else in_channels
846
+
847
+ self.blocks = nn.ModuleList(
848
+ [
849
+ ResnetBlock1d(
850
+ in_channels=channels + skip_channels,
851
+ out_channels=channels,
852
+ num_groups=num_groups,
853
+ context_mapping_features=context_mapping_features,
854
+ use_snake=use_snake
855
+ )
856
+ for _ in range(num_layers)
857
+ ]
858
+ )
859
+
860
+ if self.use_transformer:
861
+ assert (
862
+ (exists(attention_heads) or exists(attention_features))
863
+ and exists(attention_multiplier)
864
+ )
865
+
866
+ if attention_features is None and attention_heads is not None:
867
+ attention_features = channels // attention_heads
868
+
869
+ if attention_heads is None and attention_features is not None:
870
+ attention_heads = channels // attention_features
871
+
872
+ self.transformer = Transformer1d(
873
+ num_layers=num_transformer_blocks,
874
+ channels=channels,
875
+ num_heads=attention_heads,
876
+ head_features=attention_features,
877
+ multiplier=attention_multiplier,
878
+ context_features=context_embedding_features,
879
+ )
880
+
881
+ self.upsample = Upsample1d(
882
+ in_channels=in_channels,
883
+ out_channels=out_channels,
884
+ factor=factor,
885
+ use_nearest=use_nearest,
886
+ )
887
+
888
+ if self.use_extract:
889
+ num_extract_groups = min(num_groups, extract_channels)
890
+ self.to_extracted = ResnetBlock1d(
891
+ in_channels=out_channels,
892
+ out_channels=extract_channels,
893
+ num_groups=num_extract_groups,
894
+ use_snake=use_snake
895
+ )
896
+
897
+ def add_skip(self, x: Tensor, skip: Tensor) -> Tensor:
898
+ return torch.cat([x, skip * self.skip_scale], dim=1)
899
+
900
+ def forward(
901
+ self,
902
+ x: Tensor,
903
+ *,
904
+ skips: Optional[List[Tensor]] = None,
905
+ mapping: Optional[Tensor] = None,
906
+ embedding: Optional[Tensor] = None,
907
+ embedding_mask: Optional[Tensor] = None,
908
+ causal: Optional[bool] = False
909
+ ) -> Union[Tuple[Tensor, Tensor], Tensor]:
910
+
911
+ if self.use_pre_upsample:
912
+ x = self.upsample(x)
913
+
914
+ for block in self.blocks:
915
+ x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x
916
+ x = block(x, mapping=mapping, causal=causal)
917
+
918
+ if self.use_transformer:
919
+ x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
920
+
921
+ if not self.use_pre_upsample:
922
+ x = self.upsample(x)
923
+
924
+ if self.use_extract:
925
+ extracted = self.to_extracted(x)
926
+ return x, extracted
927
+
928
+ return x
929
+
930
+
931
+ class BottleneckBlock1d(nn.Module):
932
+ def __init__(
933
+ self,
934
+ channels: int,
935
+ *,
936
+ num_groups: int,
937
+ num_transformer_blocks: int = 0,
938
+ attention_heads: Optional[int] = None,
939
+ attention_features: Optional[int] = None,
940
+ attention_multiplier: Optional[int] = None,
941
+ context_mapping_features: Optional[int] = None,
942
+ context_embedding_features: Optional[int] = None,
943
+ use_snake: bool = False,
944
+ ):
945
+ super().__init__()
946
+ self.use_transformer = num_transformer_blocks > 0
947
+
948
+ self.pre_block = ResnetBlock1d(
949
+ in_channels=channels,
950
+ out_channels=channels,
951
+ num_groups=num_groups,
952
+ context_mapping_features=context_mapping_features,
953
+ use_snake=use_snake
954
+ )
955
+
956
+ if self.use_transformer:
957
+ assert (
958
+ (exists(attention_heads) or exists(attention_features))
959
+ and exists(attention_multiplier)
960
+ )
961
+
962
+ if attention_features is None and attention_heads is not None:
963
+ attention_features = channels // attention_heads
964
+
965
+ if attention_heads is None and attention_features is not None:
966
+ attention_heads = channels // attention_features
967
+
968
+ self.transformer = Transformer1d(
969
+ num_layers=num_transformer_blocks,
970
+ channels=channels,
971
+ num_heads=attention_heads,
972
+ head_features=attention_features,
973
+ multiplier=attention_multiplier,
974
+ context_features=context_embedding_features,
975
+ )
976
+
977
+ self.post_block = ResnetBlock1d(
978
+ in_channels=channels,
979
+ out_channels=channels,
980
+ num_groups=num_groups,
981
+ context_mapping_features=context_mapping_features,
982
+ use_snake=use_snake
983
+ )
984
+
985
+ def forward(
986
+ self,
987
+ x: Tensor,
988
+ *,
989
+ mapping: Optional[Tensor] = None,
990
+ embedding: Optional[Tensor] = None,
991
+ embedding_mask: Optional[Tensor] = None,
992
+ causal: Optional[bool] = False
993
+ ) -> Tensor:
994
+ x = self.pre_block(x, mapping=mapping, causal=causal)
995
+ if self.use_transformer:
996
+ x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
997
+ x = self.post_block(x, mapping=mapping, causal=causal)
998
+ return x
999
+
1000
+
1001
+ """
1002
+ UNet
1003
+ """
1004
+
1005
+
1006
+ class UNet1d(nn.Module):
1007
+ def __init__(
1008
+ self,
1009
+ in_channels: int,
1010
+ channels: int,
1011
+ multipliers: Sequence[int],
1012
+ factors: Sequence[int],
1013
+ num_blocks: Sequence[int],
1014
+ attentions: Sequence[int],
1015
+ patch_size: int = 1,
1016
+ resnet_groups: int = 8,
1017
+ use_context_time: bool = True,
1018
+ kernel_multiplier_downsample: int = 2,
1019
+ use_nearest_upsample: bool = False,
1020
+ use_skip_scale: bool = True,
1021
+ use_snake: bool = False,
1022
+ use_stft: bool = False,
1023
+ use_stft_context: bool = False,
1024
+ out_channels: Optional[int] = None,
1025
+ context_features: Optional[int] = None,
1026
+ context_features_multiplier: int = 4,
1027
+ context_channels: Optional[Sequence[int]] = None,
1028
+ context_embedding_features: Optional[int] = None,
1029
+ **kwargs,
1030
+ ):
1031
+ super().__init__()
1032
+ out_channels = default(out_channels, in_channels)
1033
+ context_channels = list(default(context_channels, []))
1034
+ num_layers = len(multipliers) - 1
1035
+ use_context_features = exists(context_features)
1036
+ use_context_channels = len(context_channels) > 0
1037
+ context_mapping_features = None
1038
+
1039
+ attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True)
1040
+
1041
+ self.num_layers = num_layers
1042
+ self.use_context_time = use_context_time
1043
+ self.use_context_features = use_context_features
1044
+ self.use_context_channels = use_context_channels
1045
+ self.use_stft = use_stft
1046
+ self.use_stft_context = use_stft_context
1047
+
1048
+ self.context_features = context_features
1049
+ context_channels_pad_length = num_layers + 1 - len(context_channels)
1050
+ context_channels = context_channels + [0] * context_channels_pad_length
1051
+ self.context_channels = context_channels
1052
+ self.context_embedding_features = context_embedding_features
1053
+
1054
+ if use_context_channels:
1055
+ has_context = [c > 0 for c in context_channels]
1056
+ self.has_context = has_context
1057
+ self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))]
1058
+
1059
+ assert (
1060
+ len(factors) == num_layers
1061
+ and len(attentions) >= num_layers
1062
+ and len(num_blocks) == num_layers
1063
+ )
1064
+
1065
+ if use_context_time or use_context_features:
1066
+ context_mapping_features = channels * context_features_multiplier
1067
+
1068
+ self.to_mapping = nn.Sequential(
1069
+ nn.Linear(context_mapping_features, context_mapping_features),
1070
+ nn.GELU(),
1071
+ nn.Linear(context_mapping_features, context_mapping_features),
1072
+ nn.GELU(),
1073
+ )
1074
+
1075
+ if use_context_time:
1076
+ assert exists(context_mapping_features)
1077
+ self.to_time = nn.Sequential(
1078
+ TimePositionalEmbedding(
1079
+ dim=channels, out_features=context_mapping_features
1080
+ ),
1081
+ nn.GELU(),
1082
+ )
1083
+
1084
+ if use_context_features:
1085
+ assert exists(context_features) and exists(context_mapping_features)
1086
+ self.to_features = nn.Sequential(
1087
+ nn.Linear(
1088
+ in_features=context_features, out_features=context_mapping_features
1089
+ ),
1090
+ nn.GELU(),
1091
+ )
1092
+
1093
+ if use_stft:
1094
+ stft_kwargs, kwargs = groupby("stft_", kwargs)
1095
+ assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True"
1096
+ stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2
1097
+ in_channels *= stft_channels
1098
+ out_channels *= stft_channels
1099
+ context_channels[0] *= stft_channels if use_stft_context else 1
1100
+ assert exists(in_channels) and exists(out_channels)
1101
+ self.stft = STFT(**stft_kwargs)
1102
+
1103
+ assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}"
1104
+
1105
+ self.to_in = Patcher(
1106
+ in_channels=in_channels + context_channels[0],
1107
+ out_channels=channels * multipliers[0],
1108
+ patch_size=patch_size,
1109
+ context_mapping_features=context_mapping_features,
1110
+ use_snake=use_snake
1111
+ )
1112
+
1113
+ self.downsamples = nn.ModuleList(
1114
+ [
1115
+ DownsampleBlock1d(
1116
+ in_channels=channels * multipliers[i],
1117
+ out_channels=channels * multipliers[i + 1],
1118
+ context_mapping_features=context_mapping_features,
1119
+ context_channels=context_channels[i + 1],
1120
+ context_embedding_features=context_embedding_features,
1121
+ num_layers=num_blocks[i],
1122
+ factor=factors[i],
1123
+ kernel_multiplier=kernel_multiplier_downsample,
1124
+ num_groups=resnet_groups,
1125
+ use_pre_downsample=True,
1126
+ use_skip=True,
1127
+ use_snake=use_snake,
1128
+ num_transformer_blocks=attentions[i],
1129
+ **attention_kwargs,
1130
+ )
1131
+ for i in range(num_layers)
1132
+ ]
1133
+ )
1134
+
1135
+ self.bottleneck = BottleneckBlock1d(
1136
+ channels=channels * multipliers[-1],
1137
+ context_mapping_features=context_mapping_features,
1138
+ context_embedding_features=context_embedding_features,
1139
+ num_groups=resnet_groups,
1140
+ num_transformer_blocks=attentions[-1],
1141
+ use_snake=use_snake,
1142
+ **attention_kwargs,
1143
+ )
1144
+
1145
+ self.upsamples = nn.ModuleList(
1146
+ [
1147
+ UpsampleBlock1d(
1148
+ in_channels=channels * multipliers[i + 1],
1149
+ out_channels=channels * multipliers[i],
1150
+ context_mapping_features=context_mapping_features,
1151
+ context_embedding_features=context_embedding_features,
1152
+ num_layers=num_blocks[i] + (1 if attentions[i] else 0),
1153
+ factor=factors[i],
1154
+ use_nearest=use_nearest_upsample,
1155
+ num_groups=resnet_groups,
1156
+ use_skip_scale=use_skip_scale,
1157
+ use_pre_upsample=False,
1158
+ use_skip=True,
1159
+ use_snake=use_snake,
1160
+ skip_channels=channels * multipliers[i + 1],
1161
+ num_transformer_blocks=attentions[i],
1162
+ **attention_kwargs,
1163
+ )
1164
+ for i in reversed(range(num_layers))
1165
+ ]
1166
+ )
1167
+
1168
+ self.to_out = Unpatcher(
1169
+ in_channels=channels * multipliers[0],
1170
+ out_channels=out_channels,
1171
+ patch_size=patch_size,
1172
+ context_mapping_features=context_mapping_features,
1173
+ use_snake=use_snake
1174
+ )
1175
+
1176
+ def get_channels(
1177
+ self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0
1178
+ ) -> Optional[Tensor]:
1179
+ """Gets context channels at `layer` and checks that shape is correct"""
1180
+ use_context_channels = self.use_context_channels and self.has_context[layer]
1181
+ if not use_context_channels:
1182
+ return None
1183
+ assert exists(channels_list), "Missing context"
1184
+ # Get channels index (skipping zero channel contexts)
1185
+ channels_id = self.channels_ids[layer]
1186
+ # Get channels
1187
+ channels = channels_list[channels_id]
1188
+ message = f"Missing context for layer {layer} at index {channels_id}"
1189
+ assert exists(channels), message
1190
+ # Check channels
1191
+ num_channels = self.context_channels[layer]
1192
+ message = f"Expected context with {num_channels} channels at idx {channels_id}"
1193
+ assert channels.shape[1] == num_channels, message
1194
+ # STFT channels if requested
1195
+ channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa
1196
+ return channels
1197
+
1198
+ def get_mapping(
1199
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
1200
+ ) -> Optional[Tensor]:
1201
+ """Combines context time features and features into mapping"""
1202
+ items, mapping = [], None
1203
+ # Compute time features
1204
+ if self.use_context_time:
1205
+ assert_message = "use_context_time=True but no time features provided"
1206
+ assert exists(time), assert_message
1207
+ items += [self.to_time(time)]
1208
+ # Compute features
1209
+ if self.use_context_features:
1210
+ assert_message = "context_features exists but no features provided"
1211
+ assert exists(features), assert_message
1212
+ items += [self.to_features(features)]
1213
+ # Compute joint mapping
1214
+ if self.use_context_time or self.use_context_features:
1215
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
1216
+ mapping = self.to_mapping(mapping)
1217
+ return mapping
1218
+
1219
+ def forward(
1220
+ self,
1221
+ x: Tensor,
1222
+ time: Optional[Tensor] = None,
1223
+ *,
1224
+ features: Optional[Tensor] = None,
1225
+ channels_list: Optional[Sequence[Tensor]] = None,
1226
+ embedding: Optional[Tensor] = None,
1227
+ embedding_mask: Optional[Tensor] = None,
1228
+ causal: Optional[bool] = False,
1229
+ ) -> Tensor:
1230
+ channels = self.get_channels(channels_list, layer=0)
1231
+ # Apply stft if required
1232
+ print(x.shape)
1233
+ x = self.stft.encode1d(x) if self.use_stft else x # type: ignore
1234
+ print(x.shape)
1235
+ # Concat context channels at layer 0 if provided
1236
+ x = torch.cat([x, channels], dim=1) if exists(channels) else x
1237
+ print(x.shape)
1238
+ # Compute mapping from time and features
1239
+ mapping = self.get_mapping(time, features)
1240
+ x = self.to_in(x, mapping, causal=causal)
1241
+ print(x.shape)
1242
+ skips_list = [x]
1243
+
1244
+ for i, downsample in enumerate(self.downsamples):
1245
+ channels = self.get_channels(channels_list, layer=i + 1)
1246
+ x, skips = downsample(
1247
+ x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal
1248
+ )
1249
+ skips_list += [skips]
1250
+
1251
+ x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
1252
+ for i, upsample in enumerate(self.upsamples):
1253
+ skips = skips_list.pop()
1254
+ x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
1255
+
1256
+ x += skips_list.pop()
1257
+ x = self.to_out(x, mapping, causal=causal)
1258
+ x = self.stft.decode1d(x) if self.use_stft else x
1259
+
1260
+ return x
1261
+
1262
+
1263
+ """ Conditioning Modules """
1264
+
1265
+
1266
+ class FixedEmbedding(nn.Module):
1267
+ def __init__(self, max_length: int, features: int):
1268
+ super().__init__()
1269
+ self.max_length = max_length
1270
+ self.embedding = nn.Embedding(max_length, features)
1271
+
1272
+ def forward(self, x: Tensor) -> Tensor:
1273
+ batch_size, length, device = *x.shape[0:2], x.device
1274
+ assert_message = "Input sequence length must be <= max_length"
1275
+ assert length <= self.max_length, assert_message
1276
+ position = torch.arange(length, device=device)
1277
+ fixed_embedding = self.embedding(position)
1278
+ fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
1279
+ return fixed_embedding
1280
+
1281
+
1282
+ def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
1283
+ if proba == 1:
1284
+ return torch.ones(shape, device=device, dtype=torch.bool)
1285
+ elif proba == 0:
1286
+ return torch.zeros(shape, device=device, dtype=torch.bool)
1287
+ else:
1288
+ return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
1289
+
1290
+
1291
+ class UNetCFG1d(UNet1d):
1292
+
1293
+ """UNet1d with Classifier-Free Guidance"""
1294
+
1295
+ def __init__(
1296
+ self,
1297
+ context_embedding_max_length: int,
1298
+ context_embedding_features: int,
1299
+ use_xattn_time: bool = False,
1300
+ **kwargs,
1301
+ ):
1302
+ super().__init__(
1303
+ context_embedding_features=context_embedding_features, **kwargs
1304
+ )
1305
+
1306
+ self.use_xattn_time = use_xattn_time
1307
+
1308
+ if use_xattn_time:
1309
+ assert exists(context_embedding_features)
1310
+ self.to_time_embedding = nn.Sequential(
1311
+ TimePositionalEmbedding(
1312
+ dim=kwargs["channels"], out_features=context_embedding_features
1313
+ ),
1314
+ nn.GELU(),
1315
+ )
1316
+
1317
+ context_embedding_max_length += 1 # Add one for time embedding
1318
+
1319
+ self.fixed_embedding = FixedEmbedding(
1320
+ max_length=context_embedding_max_length, features=context_embedding_features
1321
+ )
1322
+
1323
+ def forward( # type: ignore
1324
+ self,
1325
+ x: Tensor,
1326
+ time: Tensor,
1327
+ *,
1328
+ embedding: Tensor,
1329
+ embedding_mask: Optional[Tensor] = None,
1330
+ embedding_scale: float = 1.0,
1331
+ embedding_mask_proba: float = 0.0,
1332
+ batch_cfg: bool = False,
1333
+ rescale_cfg: bool = False,
1334
+ scale_phi: float = 0.4,
1335
+ negative_embedding: Optional[Tensor] = None,
1336
+ negative_embedding_mask: Optional[Tensor] = None,
1337
+ **kwargs,
1338
+ ) -> Tensor:
1339
+ b, device = embedding.shape[0], embedding.device
1340
+
1341
+ if self.use_xattn_time:
1342
+ embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1)
1343
+
1344
+ if embedding_mask is not None:
1345
+ embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1)
1346
+
1347
+ fixed_embedding = self.fixed_embedding(embedding)
1348
+
1349
+ if embedding_mask_proba > 0.0:
1350
+ # Randomly mask embedding
1351
+ batch_mask = rand_bool(
1352
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
1353
+ )
1354
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
1355
+
1356
+ if embedding_scale != 1.0:
1357
+ if batch_cfg:
1358
+ batch_x = torch.cat([x, x], dim=0)
1359
+ batch_time = torch.cat([time, time], dim=0)
1360
+
1361
+ if negative_embedding is not None:
1362
+ if negative_embedding_mask is not None:
1363
+ negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2)
1364
+
1365
+ negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding)
1366
+
1367
+ batch_embed = torch.cat([embedding, negative_embedding], dim=0)
1368
+
1369
+ else:
1370
+ batch_embed = torch.cat([embedding, fixed_embedding], dim=0)
1371
+
1372
+ batch_mask = None
1373
+ if embedding_mask is not None:
1374
+ batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0)
1375
+
1376
+ batch_features = None
1377
+ features = kwargs.pop("features", None)
1378
+ if self.use_context_features:
1379
+ batch_features = torch.cat([features, features], dim=0)
1380
+
1381
+ batch_channels = None
1382
+ channels_list = kwargs.pop("channels_list", None)
1383
+ if self.use_context_channels:
1384
+ batch_channels = []
1385
+ for channels in channels_list:
1386
+ batch_channels += [torch.cat([channels, channels], dim=0)]
1387
+
1388
+ # Compute both normal and fixed embedding outputs
1389
+ batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs)
1390
+ out, out_masked = batch_out.chunk(2, dim=0)
1391
+
1392
+ else:
1393
+ # Compute both normal and fixed embedding outputs
1394
+ out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
1395
+ out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs)
1396
+
1397
+ out_cfg = out_masked + (out - out_masked) * embedding_scale
1398
+
1399
+ if rescale_cfg:
1400
+
1401
+ out_std = out.std(dim=1, keepdim=True)
1402
+ out_cfg_std = out_cfg.std(dim=1, keepdim=True)
1403
+
1404
+ return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg
1405
+
1406
+ else:
1407
+
1408
+ return out_cfg
1409
+
1410
+ else:
1411
+ return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
1412
+
1413
+
1414
+ class UNetNCCA1d(UNet1d):
1415
+
1416
+ """UNet1d with Noise Channel Conditioning Augmentation"""
1417
+
1418
+ def __init__(self, context_features: int, **kwargs):
1419
+ super().__init__(context_features=context_features, **kwargs)
1420
+ self.embedder = NumberEmbedder(features=context_features)
1421
+
1422
+ def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor:
1423
+ x = x if torch.is_tensor(x) else torch.tensor(x)
1424
+ return x.expand(shape)
1425
+
1426
+ def forward( # type: ignore
1427
+ self,
1428
+ x: Tensor,
1429
+ time: Tensor,
1430
+ *,
1431
+ channels_list: Sequence[Tensor],
1432
+ channels_augmentation: Union[
1433
+ bool, Sequence[bool], Sequence[Sequence[bool]], Tensor
1434
+ ] = False,
1435
+ channels_scale: Union[
1436
+ float, Sequence[float], Sequence[Sequence[float]], Tensor
1437
+ ] = 0,
1438
+ **kwargs,
1439
+ ) -> Tensor:
1440
+ b, n = x.shape[0], len(channels_list)
1441
+ channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x)
1442
+ channels_scale = self.expand(channels_scale, shape=(b, n)).to(x)
1443
+
1444
+ # Augmentation (for each channel list item)
1445
+ for i in range(n):
1446
+ scale = channels_scale[:, i] * channels_augmentation[:, i]
1447
+ scale = rearrange(scale, "b -> b 1 1")
1448
+ item = channels_list[i]
1449
+ channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa
1450
+
1451
+ # Scale embedding (sum reduction if more than one channel list item)
1452
+ channels_scale_emb = self.embedder(channels_scale)
1453
+ channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum")
1454
+
1455
+ return super().forward(
1456
+ x=x,
1457
+ time=time,
1458
+ channels_list=channels_list,
1459
+ features=channels_scale_emb,
1460
+ **kwargs,
1461
+ )
1462
+
1463
+
1464
+ class UNetAll1d(UNetCFG1d, UNetNCCA1d):
1465
+ def __init__(self, *args, **kwargs):
1466
+ super().__init__(*args, **kwargs)
1467
+
1468
+ def forward(self, *args, **kwargs): # type: ignore
1469
+ return UNetCFG1d.forward(self, *args, **kwargs)
1470
+
1471
+
1472
+ def XUNet1d(type: str = "base", **kwargs) -> UNet1d:
1473
+ if type == "base":
1474
+ return UNet1d(**kwargs)
1475
+ elif type == "all":
1476
+ return UNetAll1d(**kwargs)
1477
+ elif type == "cfg":
1478
+ return UNetCFG1d(**kwargs)
1479
+ elif type == "ncca":
1480
+ return UNetNCCA1d(**kwargs)
1481
+ else:
1482
+ raise ValueError(f"Unknown XUNet1d type: {type}")
1483
+
1484
+ class NumberEmbedder(nn.Module):
1485
+ def __init__(
1486
+ self,
1487
+ features: int,
1488
+ dim: int = 256,
1489
+ ):
1490
+ super().__init__()
1491
+ self.features = features
1492
+ self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
1493
+
1494
+ def forward(self, x: Union[List[float], Tensor]) -> Tensor:
1495
+ if not torch.is_tensor(x):
1496
+ device = next(self.embedding.parameters()).device
1497
+ x = torch.tensor(x, device=device)
1498
+ assert isinstance(x, Tensor)
1499
+ shape = x.shape
1500
+ x = rearrange(x, "... -> (...)")
1501
+ embedding = self.embedding(x)
1502
+ x = embedding.view(*shape, self.features)
1503
+ return x # type: ignore
1504
+
1505
+
1506
+ """
1507
+ Audio Transforms
1508
+ """
1509
+
1510
+
1511
+ class STFT(nn.Module):
1512
+ """Helper for torch stft and istft"""
1513
+
1514
+ def __init__(
1515
+ self,
1516
+ num_fft: int = 1023,
1517
+ hop_length: int = 256,
1518
+ window_length: Optional[int] = None,
1519
+ length: Optional[int] = None,
1520
+ use_complex: bool = False,
1521
+ ):
1522
+ super().__init__()
1523
+ self.num_fft = num_fft
1524
+ self.hop_length = default(hop_length, floor(num_fft // 4))
1525
+ self.window_length = default(window_length, num_fft)
1526
+ self.length = length
1527
+ self.register_buffer("window", torch.hann_window(self.window_length))
1528
+ self.use_complex = use_complex
1529
+
1530
+ def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
1531
+ b = wave.shape[0]
1532
+ wave = rearrange(wave, "b c t -> (b c) t")
1533
+
1534
+ stft = torch.stft(
1535
+ wave,
1536
+ n_fft=self.num_fft,
1537
+ hop_length=self.hop_length,
1538
+ win_length=self.window_length,
1539
+ window=self.window, # type: ignore
1540
+ return_complex=True,
1541
+ normalized=True,
1542
+ )
1543
+
1544
+ if self.use_complex:
1545
+ # Returns real and imaginary
1546
+ stft_a, stft_b = stft.real, stft.imag
1547
+ else:
1548
+ # Returns magnitude and phase matrices
1549
+ magnitude, phase = torch.abs(stft), torch.angle(stft)
1550
+ stft_a, stft_b = magnitude, phase
1551
+
1552
+ return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b)
1553
+
1554
+ def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor:
1555
+ b, l = stft_a.shape[0], stft_a.shape[-1] # noqa
1556
+ length = closest_power_2(l * self.hop_length)
1557
+
1558
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l")
1559
+
1560
+ if self.use_complex:
1561
+ real, imag = stft_a, stft_b
1562
+ else:
1563
+ magnitude, phase = stft_a, stft_b
1564
+ real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase)
1565
+
1566
+ stft = torch.stack([real, imag], dim=-1)
1567
+
1568
+ wave = torch.istft(
1569
+ stft,
1570
+ n_fft=self.num_fft,
1571
+ hop_length=self.hop_length,
1572
+ win_length=self.window_length,
1573
+ window=self.window, # type: ignore
1574
+ length=default(self.length, length),
1575
+ normalized=True,
1576
+ )
1577
+
1578
+ return rearrange(wave, "(b c) t -> b c t", b=b)
1579
+
1580
+ def encode1d(
1581
+ self, wave: Tensor, stacked: bool = True
1582
+ ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
1583
+ stft_a, stft_b = self.encode(wave)
1584
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l")
1585
+ return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b)
1586
+
1587
+ def decode1d(self, stft_pair: Tensor) -> Tensor:
1588
+ f = self.num_fft // 2 + 1
1589
+ stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1)
1590
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f)
1591
+ return self.decode(stft_a, stft_b)
cosyvoice/flow/stable/blocks.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from torch.backends.cuda import sdp_kernel
9
+ from packaging import version
10
+
11
+ from dac.nn.layers import Snake1d
12
+
13
+ class ResidualBlock(nn.Module):
14
+ def __init__(self, main, skip=None):
15
+ super().__init__()
16
+ self.main = nn.Sequential(*main)
17
+ self.skip = skip if skip else nn.Identity()
18
+
19
+ def forward(self, input):
20
+ return self.main(input) + self.skip(input)
21
+
22
+ class ResConvBlock(ResidualBlock):
23
+ def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
24
+ skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
25
+ super().__init__([
26
+ nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
27
+ nn.GroupNorm(1, c_mid),
28
+ Snake1d(c_mid) if use_snake else nn.GELU(),
29
+ nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
30
+ nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
31
+ (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
32
+ ], skip)
33
+
34
+ class SelfAttention1d(nn.Module):
35
+ def __init__(self, c_in, n_head=1, dropout_rate=0.):
36
+ super().__init__()
37
+ assert c_in % n_head == 0
38
+ self.norm = nn.GroupNorm(1, c_in)
39
+ self.n_head = n_head
40
+ self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
41
+ self.out_proj = nn.Conv1d(c_in, c_in, 1)
42
+ self.dropout = nn.Dropout(dropout_rate, inplace=True)
43
+
44
+ self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
45
+
46
+ if not self.use_flash:
47
+ return
48
+
49
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
50
+
51
+ if device_properties.major == 8 and device_properties.minor == 0:
52
+ # Use flash attention for A100 GPUs
53
+ self.sdp_kernel_config = (True, False, False)
54
+ else:
55
+ # Don't use flash attention for other GPUs
56
+ self.sdp_kernel_config = (False, True, True)
57
+
58
+ def forward(self, input):
59
+ n, c, s = input.shape
60
+ qkv = self.qkv_proj(self.norm(input))
61
+ qkv = qkv.view(
62
+ [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
63
+ q, k, v = qkv.chunk(3, dim=1)
64
+ scale = k.shape[3]**-0.25
65
+
66
+ if self.use_flash:
67
+ with sdp_kernel(*self.sdp_kernel_config):
68
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
69
+ else:
70
+ att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
71
+ y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
72
+
73
+
74
+ return input + self.dropout(self.out_proj(y))
75
+
76
+ class SkipBlock(nn.Module):
77
+ def __init__(self, *main):
78
+ super().__init__()
79
+ self.main = nn.Sequential(*main)
80
+
81
+ def forward(self, input):
82
+ return torch.cat([self.main(input), input], dim=1)
83
+
84
+ class FourierFeatures(nn.Module):
85
+ def __init__(self, in_features, out_features, std=1.):
86
+ super().__init__()
87
+ assert out_features % 2 == 0
88
+ self.weight = nn.Parameter(torch.randn(
89
+ [out_features // 2, in_features]) * std)
90
+
91
+ def forward(self, input):
92
+ f = 2 * math.pi * input @ self.weight.T
93
+ return torch.cat([f.cos(), f.sin()], dim=-1)
94
+
95
+ def expand_to_planes(input, shape):
96
+ return input[..., None].repeat([1, 1, shape[2]])
97
+
98
+ _kernels = {
99
+ 'linear':
100
+ [1 / 8, 3 / 8, 3 / 8, 1 / 8],
101
+ 'cubic':
102
+ [-0.01171875, -0.03515625, 0.11328125, 0.43359375,
103
+ 0.43359375, 0.11328125, -0.03515625, -0.01171875],
104
+ 'lanczos3':
105
+ [0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
106
+ -0.066637322306633, 0.13550527393817902, 0.44638532400131226,
107
+ 0.44638532400131226, 0.13550527393817902, -0.066637322306633,
108
+ -0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
109
+ }
110
+
111
+ class Downsample1d(nn.Module):
112
+ def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
113
+ super().__init__()
114
+ self.pad_mode = pad_mode
115
+ kernel_1d = torch.tensor(_kernels[kernel])
116
+ self.pad = kernel_1d.shape[0] // 2 - 1
117
+ self.register_buffer('kernel', kernel_1d)
118
+ self.channels_last = channels_last
119
+
120
+ def forward(self, x):
121
+ if self.channels_last:
122
+ x = x.permute(0, 2, 1)
123
+ x = F.pad(x, (self.pad,) * 2, self.pad_mode)
124
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
125
+ indices = torch.arange(x.shape[1], device=x.device)
126
+ weight[indices, indices] = self.kernel.to(weight)
127
+ x = F.conv1d(x, weight, stride=2)
128
+ if self.channels_last:
129
+ x = x.permute(0, 2, 1)
130
+ return x
131
+
132
+
133
+ class Upsample1d(nn.Module):
134
+ def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
135
+ super().__init__()
136
+ self.pad_mode = pad_mode
137
+ kernel_1d = torch.tensor(_kernels[kernel]) * 2
138
+ self.pad = kernel_1d.shape[0] // 2 - 1
139
+ self.register_buffer('kernel', kernel_1d)
140
+ self.channels_last = channels_last
141
+
142
+ def forward(self, x):
143
+ if self.channels_last:
144
+ x = x.permute(0, 2, 1)
145
+ x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
146
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
147
+ indices = torch.arange(x.shape[1], device=x.device)
148
+ weight[indices, indices] = self.kernel.to(weight)
149
+ x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
150
+ if self.channels_last:
151
+ x = x.permute(0, 2, 1)
152
+ return x
153
+
154
+ def Downsample1d_2(
155
+ in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
156
+ ) -> nn.Module:
157
+ assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
158
+
159
+ return nn.Conv1d(
160
+ in_channels=in_channels,
161
+ out_channels=out_channels,
162
+ kernel_size=factor * kernel_multiplier + 1,
163
+ stride=factor,
164
+ padding=factor * (kernel_multiplier // 2),
165
+ )
166
+
167
+
168
+ def Upsample1d_2(
169
+ in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
170
+ ) -> nn.Module:
171
+
172
+ if factor == 1:
173
+ return nn.Conv1d(
174
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
175
+ )
176
+
177
+ if use_nearest:
178
+ return nn.Sequential(
179
+ nn.Upsample(scale_factor=factor, mode="nearest"),
180
+ nn.Conv1d(
181
+ in_channels=in_channels,
182
+ out_channels=out_channels,
183
+ kernel_size=3,
184
+ padding=1,
185
+ ),
186
+ )
187
+ else:
188
+ return nn.ConvTranspose1d(
189
+ in_channels=in_channels,
190
+ out_channels=out_channels,
191
+ kernel_size=factor * 2,
192
+ stride=factor,
193
+ padding=factor // 2 + factor % 2,
194
+ output_padding=factor % 2,
195
+ )
196
+
197
+ def zero_init(layer):
198
+ nn.init.zeros_(layer.weight)
199
+ if layer.bias is not None:
200
+ nn.init.zeros_(layer.bias)
201
+ return layer
202
+
203
+ def rms_norm(x, scale, eps):
204
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
205
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
206
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
207
+ return x * scale.to(x.dtype)
208
+
209
+ #rms_norm = torch.compile(rms_norm)
210
+
211
+ class AdaRMSNorm(nn.Module):
212
+ def __init__(self, features, cond_features, eps=1e-6):
213
+ super().__init__()
214
+ self.eps = eps
215
+ self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
216
+
217
+ def extra_repr(self):
218
+ return f"eps={self.eps},"
219
+
220
+ def forward(self, x, cond):
221
+ return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
222
+
223
+ def normalize(x, eps=1e-4):
224
+ dim = list(range(1, x.ndim))
225
+ n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
226
+ alpha = np.sqrt(n.numel() / x.numel())
227
+ return x / torch.add(eps, n, alpha=alpha)
228
+
229
+ class ForcedWNConv1d(nn.Module):
230
+ def __init__(self, in_channels, out_channels, kernel_size=1):
231
+ super().__init__()
232
+ self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
233
+
234
+ def forward(self, x):
235
+ if self.training:
236
+ with torch.no_grad():
237
+ self.weight.copy_(normalize(self.weight))
238
+
239
+ fan_in = self.weight[0].numel()
240
+
241
+ w = normalize(self.weight) / math.sqrt(fan_in)
242
+
243
+ return F.conv1d(x, w, padding='same')
244
+
245
+ # Kernels
246
+
247
+ use_compile = True
248
+
249
+ def compile(function, *args, **kwargs):
250
+ if not use_compile:
251
+ return function
252
+ try:
253
+ return torch.compile(function, *args, **kwargs)
254
+ except RuntimeError:
255
+ return function
256
+
257
+
258
+ @compile
259
+ def linear_geglu(x, weight, bias=None):
260
+ x = x @ weight.mT
261
+ if bias is not None:
262
+ x = x + bias
263
+ x, gate = x.chunk(2, dim=-1)
264
+ return x * F.gelu(gate)
265
+
266
+
267
+ @compile
268
+ def rms_norm(x, scale, eps):
269
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
270
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
271
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
272
+ return x * scale.to(x.dtype)
273
+
274
+ # Layers
275
+
276
+ class LinearGEGLU(nn.Linear):
277
+ def __init__(self, in_features, out_features, bias=True):
278
+ super().__init__(in_features, out_features * 2, bias=bias)
279
+ self.out_features = out_features
280
+
281
+ def forward(self, x):
282
+ return linear_geglu(x, self.weight, self.bias)
283
+
284
+
285
+ class RMSNorm(nn.Module):
286
+ def __init__(self, shape, fix_scale = False, eps=1e-6):
287
+ super().__init__()
288
+ self.eps = eps
289
+
290
+ if fix_scale:
291
+ self.register_buffer("scale", torch.ones(shape))
292
+ else:
293
+ self.scale = nn.Parameter(torch.ones(shape))
294
+
295
+ def extra_repr(self):
296
+ return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
297
+
298
+ def forward(self, x):
299
+ return rms_norm(x, self.scale, self.eps)
300
+
301
+ def snake_beta(x, alpha, beta):
302
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
303
+
304
+ # try:
305
+ # snake_beta = torch.compile(snake_beta)
306
+ # except RuntimeError:
307
+ # pass
308
+
309
+ # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
310
+ # License available in LICENSES/LICENSE_NVIDIA.txt
311
+ class SnakeBeta(nn.Module):
312
+
313
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
314
+ super(SnakeBeta, self).__init__()
315
+ self.in_features = in_features
316
+
317
+ # initialize alpha
318
+ self.alpha_logscale = alpha_logscale
319
+ if self.alpha_logscale: # log scale alphas initialized to zeros
320
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
321
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
322
+ else: # linear scale alphas initialized to ones
323
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
324
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
325
+
326
+ self.alpha.requires_grad = alpha_trainable
327
+ self.beta.requires_grad = alpha_trainable
328
+
329
+ self.no_div_by_zero = 0.000000001
330
+
331
+ def forward(self, x):
332
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
333
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
334
+ if self.alpha_logscale:
335
+ alpha = torch.exp(alpha)
336
+ beta = torch.exp(beta)
337
+ x = snake_beta(x, alpha, beta)
338
+
339
+ return x
cosyvoice/flow/stable/dit.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+
3
+ import torch
4
+
5
+ from einops import rearrange
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from x_transformers import ContinuousTransformerWrapper, Encoder
9
+
10
+ from .blocks import FourierFeatures
11
+ from .transformer import ContinuousTransformer
12
+ from .transformer_use_mask import ContinuousTransformer as ContinuousTransformer_mask
13
+
14
+
15
+ class DiffusionTransformer(nn.Module):
16
+ def __init__(self,
17
+ io_channels=32,
18
+ patch_size=1,
19
+ embed_dim=768,
20
+ cond_token_dim=0,
21
+ project_cond_tokens=True,
22
+ global_cond_dim=0,
23
+ project_global_cond=True,
24
+ input_concat_dim=0,
25
+ prepend_cond_dim=0,
26
+ depth=12,
27
+ num_heads=8,
28
+ transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers",
29
+ global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
30
+ **kwargs):
31
+
32
+ super().__init__()
33
+
34
+ self.cond_token_dim = cond_token_dim
35
+
36
+ # Timestep embeddings
37
+ timestep_features_dim = 256
38
+
39
+ self.timestep_features = FourierFeatures(1, timestep_features_dim)
40
+
41
+ self.to_timestep_embed = nn.Sequential(
42
+ nn.Linear(timestep_features_dim, embed_dim, bias=True),
43
+ nn.SiLU(),
44
+ nn.Linear(embed_dim, embed_dim, bias=True),
45
+ )
46
+
47
+ if cond_token_dim > 0:
48
+ # Conditioning tokens
49
+
50
+ cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
51
+ self.to_cond_embed = nn.Sequential(
52
+ nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
53
+ nn.SiLU(),
54
+ nn.Linear(cond_embed_dim, cond_embed_dim, bias=False)
55
+ )
56
+ else:
57
+ cond_embed_dim = 0
58
+ self.to_cond_embed = nn.Identity()
59
+
60
+ if global_cond_dim > 0:
61
+ # Global conditioning
62
+ global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
63
+ self.to_global_embed = nn.Sequential(
64
+ nn.Linear(global_cond_dim, global_embed_dim, bias=False),
65
+ nn.SiLU(),
66
+ nn.Linear(global_embed_dim, global_embed_dim, bias=False)
67
+ )
68
+
69
+ if prepend_cond_dim > 0:
70
+ # Prepend conditioning
71
+ self.to_prepend_embed = nn.Sequential(
72
+ nn.Linear(prepend_cond_dim, embed_dim, bias=False),
73
+ nn.SiLU(),
74
+ nn.Linear(embed_dim, embed_dim, bias=False)
75
+ )
76
+
77
+ self.input_concat_dim = input_concat_dim
78
+
79
+ dim_in = io_channels + self.input_concat_dim
80
+
81
+ self.patch_size = patch_size
82
+
83
+ # Transformer
84
+
85
+ self.transformer_type = transformer_type
86
+
87
+ self.global_cond_type = global_cond_type
88
+
89
+ if self.transformer_type == "x-transformers":
90
+ self.transformer = ContinuousTransformerWrapper(
91
+ dim_in=dim_in * patch_size,
92
+ dim_out=io_channels * patch_size,
93
+ max_seq_len=0, # Not relevant without absolute positional embeds
94
+ attn_layers=Encoder(
95
+ dim=embed_dim,
96
+ depth=depth,
97
+ heads=num_heads,
98
+ attn_flash=True,
99
+ cross_attend=cond_token_dim > 0,
100
+ dim_context=None if cond_embed_dim == 0 else cond_embed_dim,
101
+ zero_init_branch_output=True,
102
+ use_abs_pos_emb=False,
103
+ rotary_pos_emb=True,
104
+ ff_swish=True,
105
+ ff_glu=True,
106
+ **kwargs
107
+ )
108
+ )
109
+
110
+ elif self.transformer_type == "continuous_transformer":
111
+
112
+ global_dim = None
113
+
114
+ if self.global_cond_type == "adaLN":
115
+ # The global conditioning is projected to the embed_dim already at this point
116
+ global_dim = embed_dim
117
+
118
+ self.transformer = ContinuousTransformer(
119
+ dim=embed_dim,
120
+ depth=depth,
121
+ dim_heads=embed_dim // num_heads,
122
+ dim_in=dim_in * patch_size,
123
+ dim_out=io_channels * patch_size,
124
+ cross_attend=cond_token_dim > 0,
125
+ cond_token_dim=cond_embed_dim,
126
+ global_cond_dim=global_dim,
127
+ **kwargs
128
+ )
129
+ elif self.transformer_type == "continuous_transformer_with_mask":
130
+
131
+ global_dim = None
132
+
133
+ if self.global_cond_type == "adaLN":
134
+ # The global conditioning is projected to the embed_dim already at this point
135
+ global_dim = embed_dim
136
+
137
+ self.transformer = ContinuousTransformer_mask(
138
+ dim=embed_dim,
139
+ depth=depth,
140
+ dim_heads=embed_dim // num_heads,
141
+ dim_in=dim_in * patch_size,
142
+ dim_out=io_channels * patch_size,
143
+ cross_attend=cond_token_dim > 0,
144
+ cond_token_dim=cond_embed_dim,
145
+ global_cond_dim=global_dim,
146
+ **kwargs
147
+ )
148
+
149
+ else:
150
+ raise ValueError(f"Unknown transformer type: {self.transformer_type}")
151
+
152
+ self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
153
+ nn.init.zeros_(self.preprocess_conv.weight)
154
+ self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
155
+ nn.init.zeros_(self.postprocess_conv.weight)
156
+
157
+ def _forward(
158
+ self,
159
+ x,
160
+ t,
161
+ mask=None,
162
+ cross_attn_cond=None,
163
+ cross_attn_cond_mask=None,
164
+ input_concat_cond=None,
165
+ global_embed=None,
166
+ prepend_cond=None,
167
+ prepend_cond_mask=None,
168
+ return_info=False,
169
+ **kwargs):
170
+ ### 1. 需要重新写过以适应不同长度的con
171
+ if cross_attn_cond is not None:
172
+ cross_attn_cond = self.to_cond_embed(cross_attn_cond)
173
+
174
+ if global_embed is not None:
175
+ # Project the global conditioning to the embedding dimension
176
+ global_embed = self.to_global_embed(global_embed)
177
+
178
+ prepend_inputs = None
179
+ prepend_mask = None
180
+ prepend_length = 0
181
+ if prepend_cond is not None:
182
+ # Project the prepend conditioning to the embedding dimension
183
+ prepend_cond = self.to_prepend_embed(prepend_cond)
184
+
185
+ prepend_inputs = prepend_cond
186
+ if prepend_cond_mask is not None:
187
+ prepend_mask = prepend_cond_mask
188
+
189
+ if input_concat_cond is not None:
190
+
191
+ # Interpolate input_concat_cond to the same length as x
192
+ if input_concat_cond.shape[2] != x.shape[2]:
193
+ input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2],), mode='nearest')
194
+
195
+ x = torch.cat([x, input_concat_cond], dim=1)
196
+
197
+ # Get the batch of timestep embeddings
198
+ try:
199
+ timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
200
+ except Exception as e:
201
+ print("t.shape:", t.shape, "x.shape", x.shape)
202
+ print("t:", t)
203
+ raise e
204
+
205
+ # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
206
+ if global_embed is not None:
207
+ global_embed = global_embed + timestep_embed
208
+ else:
209
+ global_embed = timestep_embed
210
+
211
+ # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
212
+ if self.global_cond_type == "prepend":
213
+ if prepend_inputs is None:
214
+ # Prepend inputs are just the global embed, and the mask is all ones
215
+ prepend_inputs = global_embed.unsqueeze(1)
216
+ prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
217
+ else:
218
+ # Prepend inputs are the prepend conditioning + the global embed
219
+ prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
220
+ prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)],
221
+ dim=1)
222
+
223
+ prepend_length = prepend_inputs.shape[1]
224
+
225
+ x = self.preprocess_conv(x) + x
226
+
227
+ x = rearrange(x, "b c t -> b t c")
228
+
229
+ extra_args = {}
230
+
231
+ if self.global_cond_type == "adaLN":
232
+ extra_args["global_cond"] = global_embed
233
+
234
+ if self.patch_size > 1:
235
+ x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
236
+
237
+ if self.transformer_type == "x-transformers":
238
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond,
239
+ context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask,
240
+ **extra_args, **kwargs)
241
+ elif self.transformer_type in ["continuous_transformer","continuous_transformer_with_mask"] :
242
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond,
243
+ context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask,
244
+ return_info=return_info, **extra_args, **kwargs)
245
+
246
+ if return_info:
247
+ output, info = output
248
+ elif self.transformer_type == "mm_transformer":
249
+ output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask,
250
+ **extra_args, **kwargs)
251
+
252
+ output = rearrange(output, "b t c -> b c t")[:, :, prepend_length:]
253
+
254
+ if self.patch_size > 1:
255
+ output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
256
+
257
+ output = self.postprocess_conv(output) + output
258
+
259
+ if return_info:
260
+ return output, info
261
+
262
+ return output
263
+
264
+ def forward(
265
+ self,
266
+ x,
267
+ t,
268
+ cross_attn_cond=None,
269
+ cross_attn_cond_mask=None,
270
+ negative_cross_attn_cond=None,
271
+ negative_cross_attn_mask=None,
272
+ input_concat_cond=None,
273
+ global_embed=None,
274
+ negative_global_embed=None,
275
+ prepend_cond=None,
276
+ prepend_cond_mask=None,
277
+ cfg_scale=1.0,
278
+ cfg_dropout_prob=0.0,
279
+ causal=False,
280
+ scale_phi=0.0,
281
+ mask=None,
282
+ return_info=False,
283
+ **kwargs):
284
+
285
+ assert causal == False, "Causal mode is not supported for DiffusionTransformer"
286
+
287
+ if cross_attn_cond_mask is not None:
288
+ cross_attn_cond_mask = cross_attn_cond_mask.bool()
289
+
290
+ cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
291
+
292
+ if prepend_cond_mask is not None:
293
+ prepend_cond_mask = prepend_cond_mask.bool()
294
+
295
+ # CFG dropout
296
+ if cfg_dropout_prob > 0.0:
297
+ if cross_attn_cond is not None:
298
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
299
+ dropout_mask = torch.bernoulli(
300
+ torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(
301
+ torch.bool)
302
+ cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
303
+
304
+ if prepend_cond is not None:
305
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
306
+ dropout_mask = torch.bernoulli(
307
+ torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(
308
+ torch.bool)
309
+ prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
310
+
311
+ if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None):
312
+ # Classifier-free guidance
313
+ # Concatenate conditioned and unconditioned inputs on the batch dimension
314
+ batch_inputs = torch.cat([x, x], dim=0)
315
+ batch_timestep = torch.cat([t, t], dim=0)
316
+
317
+ if global_embed is not None:
318
+ batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
319
+ else:
320
+ batch_global_cond = None
321
+
322
+ if input_concat_cond is not None:
323
+ batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
324
+ else:
325
+ batch_input_concat_cond = None
326
+
327
+ batch_cond = None
328
+ batch_cond_masks = None
329
+
330
+ # Handle CFG for cross-attention conditioning
331
+ if cross_attn_cond is not None:
332
+
333
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
334
+
335
+ # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
336
+ if negative_cross_attn_cond is not None:
337
+
338
+ # If there's a negative cross-attention mask, set the masked tokens to the null embed
339
+ if negative_cross_attn_mask is not None:
340
+ negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
341
+
342
+ negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond,
343
+ null_embed)
344
+
345
+ batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
346
+
347
+ else:
348
+ batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
349
+
350
+ if cross_attn_cond_mask is not None:
351
+ batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
352
+
353
+ batch_prepend_cond = None
354
+ batch_prepend_cond_mask = None
355
+
356
+ if prepend_cond is not None:
357
+
358
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
359
+
360
+ batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
361
+
362
+ if prepend_cond_mask is not None:
363
+ batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
364
+
365
+ if mask is not None:
366
+ batch_masks = torch.cat([mask, mask], dim=0)
367
+ else:
368
+ batch_masks = None
369
+
370
+ batch_output = self._forward(
371
+ batch_inputs,
372
+ batch_timestep,
373
+ cross_attn_cond=batch_cond,
374
+ cross_attn_cond_mask=batch_cond_masks,
375
+ mask=batch_masks,
376
+ input_concat_cond=batch_input_concat_cond,
377
+ global_embed=batch_global_cond,
378
+ prepend_cond=batch_prepend_cond,
379
+ prepend_cond_mask=batch_prepend_cond_mask,
380
+ return_info=return_info,
381
+ **kwargs)
382
+
383
+ if return_info:
384
+ batch_output, info = batch_output
385
+
386
+ cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
387
+ cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
388
+
389
+ # CFG Rescale
390
+ if scale_phi != 0.0:
391
+ cond_out_std = cond_output.std(dim=1, keepdim=True)
392
+ out_cfg_std = cfg_output.std(dim=1, keepdim=True)
393
+ output = scale_phi * (cfg_output * (cond_out_std / out_cfg_std)) + (1 - scale_phi) * cfg_output
394
+ else:
395
+ output = cfg_output
396
+
397
+ if return_info:
398
+ return output, info
399
+
400
+ return output
401
+
402
+ else:
403
+ return self._forward(
404
+ x,
405
+ t,
406
+ cross_attn_cond=cross_attn_cond,
407
+ cross_attn_cond_mask=cross_attn_cond_mask,
408
+ input_concat_cond=input_concat_cond,
409
+ global_embed=global_embed,
410
+ prepend_cond=prepend_cond,
411
+ prepend_cond_mask=prepend_cond_mask,
412
+ mask=mask,
413
+ return_info=return_info,
414
+ **kwargs
415
+ )
cosyvoice/flow/stable/dit_v2.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+
3
+ import torch
4
+
5
+ from einops import rearrange
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from x_transformers import ContinuousTransformerWrapper, Encoder
9
+
10
+ from .blocks import FourierFeatures
11
+ from .transformer import ContinuousTransformer
12
+ from model.stable import transformer_use_mask
13
+
14
+
15
+ class DiffusionTransformerV2(nn.Module):
16
+ def __init__(self,
17
+ io_channels=32,
18
+ patch_size=1,
19
+ embed_dim=768,
20
+ cond_token_dim=0,
21
+ project_cond_tokens=True,
22
+ global_cond_dim=0,
23
+ project_global_cond=True,
24
+ input_concat_dim=0,
25
+ prepend_cond_dim=0,
26
+ depth=12,
27
+ num_heads=8,
28
+ transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers",
29
+ global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
30
+ **kwargs):
31
+
32
+ super().__init__()
33
+ d_model = embed_dim
34
+ n_head = num_heads
35
+ n_layers = depth
36
+ encoder_layer = torch.nn.TransformerEncoderLayer(batch_first=True,
37
+ norm_first=True,
38
+ d_model=d_model,
39
+ nhead=n_head)
40
+ self.transformer = torch.nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
41
+
42
+ # ===================================== timestep embedding
43
+ timestep_features_dim = 256
44
+ self.timestep_features = FourierFeatures(1, timestep_features_dim)
45
+ self.to_timestep_embed = nn.Sequential(
46
+ nn.Linear(timestep_features_dim, embed_dim, bias=True),
47
+ nn.SiLU(),
48
+ nn.Linear(embed_dim, embed_dim, bias=True),
49
+ )
50
+
51
+
52
+ def _forward(
53
+ self,
54
+ Xt_btd,
55
+ t, #(1d)
56
+ mu_btd,
57
+ ):
58
+
59
+ timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
60
+ cated_input = torch.cat([t,mu,x_t])
61
+
62
+ ### 1. 需要重新写过以适应不同长度的con
63
+ if cross_attn_cond is not None:
64
+ cross_attn_cond = self.to_cond_embed(cross_attn_cond)
65
+
66
+ if global_embed is not None:
67
+ # Project the global conditioning to the embedding dimension
68
+ global_embed = self.to_global_embed(global_embed)
69
+
70
+ prepend_inputs = None
71
+ prepend_mask = None
72
+ prepend_length = 0
73
+ if prepend_cond is not None:
74
+ # Project the prepend conditioning to the embedding dimension
75
+ prepend_cond = self.to_prepend_embed(prepend_cond)
76
+
77
+ prepend_inputs = prepend_cond
78
+ if prepend_cond_mask is not None:
79
+ prepend_mask = prepend_cond_mask
80
+
81
+ if input_concat_cond is not None:
82
+
83
+ # Interpolate input_concat_cond to the same length as x
84
+ if input_concat_cond.shape[2] != x.shape[2]:
85
+ input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2],), mode='nearest')
86
+
87
+ x = torch.cat([x, input_concat_cond], dim=1)
88
+
89
+ # Get the batch of timestep embeddings
90
+ try:
91
+ timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
92
+ except Exception as e:
93
+ print("t.shape:", t.shape, "x.shape", x.shape)
94
+ print("t:", t)
95
+ raise e
96
+
97
+ # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
98
+ if global_embed is not None:
99
+ global_embed = global_embed + timestep_embed
100
+ else:
101
+ global_embed = timestep_embed
102
+
103
+ # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
104
+ if self.global_cond_type == "prepend":
105
+ if prepend_inputs is None:
106
+ # Prepend inputs are just the global embed, and the mask is all ones
107
+ prepend_inputs = global_embed.unsqueeze(1)
108
+ prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
109
+ else:
110
+ # Prepend inputs are the prepend conditioning + the global embed
111
+ prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
112
+ prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)],
113
+ dim=1)
114
+
115
+ prepend_length = prepend_inputs.shape[1]
116
+
117
+ x = self.preprocess_conv(x) + x
118
+
119
+ x = rearrange(x, "b c t -> b t c")
120
+
121
+ extra_args = {}
122
+
123
+ if self.global_cond_type == "adaLN":
124
+ extra_args["global_cond"] = global_embed
125
+
126
+ if self.patch_size > 1:
127
+ x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
128
+
129
+ if self.transformer_type == "x-transformers":
130
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond,
131
+ context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask,
132
+ **extra_args, **kwargs)
133
+ elif self.transformer_type in ["continuous_transformer", "continuous_transformer_with_mask"]:
134
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond,
135
+ context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask,
136
+ return_info=return_info, **extra_args, **kwargs)
137
+
138
+ if return_info:
139
+ output, info = output
140
+ elif self.transformer_type == "mm_transformer":
141
+ output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask,
142
+ **extra_args, **kwargs)
143
+
144
+ output = rearrange(output, "b t c -> b c t")[:, :, prepend_length:]
145
+
146
+ if self.patch_size > 1:
147
+ output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
148
+
149
+ output = self.postprocess_conv(output) + output
150
+
151
+ if return_info:
152
+ return output, info
153
+
154
+ return output
155
+
156
+ def forward(
157
+ self,
158
+ x,
159
+ t,
160
+ cross_attn_cond=None,
161
+ cross_attn_cond_mask=None,
162
+ negative_cross_attn_cond=None,
163
+ negative_cross_attn_mask=None,
164
+ input_concat_cond=None,
165
+ global_embed=None,
166
+ negative_global_embed=None,
167
+ prepend_cond=None,
168
+ prepend_cond_mask=None,
169
+ cfg_scale=1.0,
170
+ cfg_dropout_prob=0.0,
171
+ causal=False,
172
+ scale_phi=0.0,
173
+ mask=None,
174
+ return_info=False,
175
+ **kwargs):
176
+
177
+ assert causal == False, "Causal mode is not supported for DiffusionTransformer"
178
+
179
+ if cross_attn_cond_mask is not None:
180
+ cross_attn_cond_mask = cross_attn_cond_mask.bool()
181
+
182
+ cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
183
+
184
+ if prepend_cond_mask is not None:
185
+ prepend_cond_mask = prepend_cond_mask.bool()
186
+
187
+ # CFG dropout
188
+ if cfg_dropout_prob > 0.0:
189
+ if cross_attn_cond is not None:
190
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
191
+ dropout_mask = torch.bernoulli(
192
+ torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(
193
+ torch.bool)
194
+ cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
195
+
196
+ if prepend_cond is not None:
197
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
198
+ dropout_mask = torch.bernoulli(
199
+ torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(
200
+ torch.bool)
201
+ prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
202
+
203
+ if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None):
204
+ # Classifier-free guidance
205
+ # Concatenate conditioned and unconditioned inputs on the batch dimension
206
+ batch_inputs = torch.cat([x, x], dim=0)
207
+ batch_timestep = torch.cat([t, t], dim=0)
208
+
209
+ if global_embed is not None:
210
+ batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
211
+ else:
212
+ batch_global_cond = None
213
+
214
+ if input_concat_cond is not None:
215
+ batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
216
+ else:
217
+ batch_input_concat_cond = None
218
+
219
+ batch_cond = None
220
+ batch_cond_masks = None
221
+
222
+ # Handle CFG for cross-attention conditioning
223
+ if cross_attn_cond is not None:
224
+
225
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
226
+
227
+ # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
228
+ if negative_cross_attn_cond is not None:
229
+
230
+ # If there's a negative cross-attention mask, set the masked tokens to the null embed
231
+ if negative_cross_attn_mask is not None:
232
+ negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
233
+
234
+ negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond,
235
+ null_embed)
236
+
237
+ batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
238
+
239
+ else:
240
+ batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
241
+
242
+ if cross_attn_cond_mask is not None:
243
+ batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
244
+
245
+ batch_prepend_cond = None
246
+ batch_prepend_cond_mask = None
247
+
248
+ if prepend_cond is not None:
249
+
250
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
251
+
252
+ batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
253
+
254
+ if prepend_cond_mask is not None:
255
+ batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
256
+
257
+ if mask is not None:
258
+ batch_masks = torch.cat([mask, mask], dim=0)
259
+ else:
260
+ batch_masks = None
261
+
262
+ batch_output = self._forward(
263
+ batch_inputs,
264
+ batch_timestep,
265
+ cross_attn_cond=batch_cond,
266
+ cross_attn_cond_mask=batch_cond_masks,
267
+ mask=batch_masks,
268
+ input_concat_cond=batch_input_concat_cond,
269
+ global_embed=batch_global_cond,
270
+ prepend_cond=batch_prepend_cond,
271
+ prepend_cond_mask=batch_prepend_cond_mask,
272
+ return_info=return_info,
273
+ **kwargs)
274
+
275
+ if return_info:
276
+ batch_output, info = batch_output
277
+
278
+ cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
279
+ cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
280
+
281
+ # CFG Rescale
282
+ if scale_phi != 0.0:
283
+ cond_out_std = cond_output.std(dim=1, keepdim=True)
284
+ out_cfg_std = cfg_output.std(dim=1, keepdim=True)
285
+ output = scale_phi * (cfg_output * (cond_out_std / out_cfg_std)) + (1 - scale_phi) * cfg_output
286
+ else:
287
+ output = cfg_output
288
+
289
+ if return_info:
290
+ return output, info
291
+
292
+ return output
293
+
294
+ else:
295
+ return self._forward(
296
+ x,
297
+ t,
298
+ cross_attn_cond=cross_attn_cond,
299
+ cross_attn_cond_mask=cross_attn_cond_mask,
300
+ input_concat_cond=input_concat_cond,
301
+ global_embed=global_embed,
302
+ prepend_cond=prepend_cond,
303
+ prepend_cond_mask=prepend_cond_mask,
304
+ mask=mask,
305
+ return_info=return_info,
306
+ **kwargs
307
+ )
cosyvoice/flow/stable/sampling.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ from tqdm import trange, tqdm
4
+
5
+ import k_diffusion as K
6
+
7
+ # Define the noise schedule and sampling loop
8
+ def get_alphas_sigmas(t):
9
+ """Returns the scaling factors for the clean image (alpha) and for the
10
+ noise (sigma), given a timestep."""
11
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
12
+
13
+ def alpha_sigma_to_t(alpha, sigma):
14
+ """Returns a timestep, given the scaling factors for the clean image and for
15
+ the noise."""
16
+ return torch.atan2(sigma, alpha) / math.pi * 2
17
+
18
+ def t_to_alpha_sigma(t):
19
+ """Returns the scaling factors for the clean image and for the noise, given
20
+ a timestep."""
21
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
22
+
23
+
24
+ @torch.no_grad()
25
+ def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args):
26
+ """Draws samples from a model given starting noise. Euler method"""
27
+
28
+ # Make tensor of ones to broadcast the single t values
29
+ ts = x.new_ones([x.shape[0]])
30
+
31
+ # Create the noise schedule
32
+ t = torch.linspace(sigma_max, 0, steps + 1)
33
+
34
+ #alphas, sigmas = 1-t, t
35
+
36
+ for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])):
37
+ # Broadcast the current timestep to the correct shape
38
+ t_curr_tensor = t_curr * torch.ones(
39
+ (x.shape[0],), dtype=x.dtype, device=x.device
40
+ )
41
+ dt = t_prev - t_curr # we solve backwards in our formulation
42
+ x = x + dt * model(x, t_curr_tensor, **extra_args) #.denoise(x, denoiser, t_curr_tensor, cond, uc)
43
+
44
+ # If we are on the last timestep, output the denoised image
45
+ return x
46
+
47
+ @torch.no_grad()
48
+ def sample(model, x, steps, eta, **extra_args):
49
+ """Draws samples from a model given starting noise. v-diffusion"""
50
+ ts = x.new_ones([x.shape[0]])
51
+
52
+ # Create the noise schedule
53
+ t = torch.linspace(1, 0, steps + 1)[:-1]
54
+
55
+ alphas, sigmas = get_alphas_sigmas(t)
56
+
57
+ # The sampling loop
58
+ for i in trange(steps):
59
+
60
+ # Get the model output (v, the predicted velocity)
61
+ with torch.cuda.amp.autocast():
62
+ v = model(x, ts * t[i], **extra_args).float()
63
+
64
+ # Predict the noise and the denoised image
65
+ pred = x * alphas[i] - v * sigmas[i]
66
+ eps = x * sigmas[i] + v * alphas[i]
67
+
68
+ # If we are not on the last timestep, compute the noisy image for the
69
+ # next timestep.
70
+ if i < steps - 1:
71
+ # If eta > 0, adjust the scaling factor for the predicted noise
72
+ # downward according to the amount of additional noise to add
73
+ ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
74
+ (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
75
+ adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
76
+
77
+ # Recombine the predicted noise and predicted denoised image in the
78
+ # correct proportions for the next step
79
+ x = pred * alphas[i + 1] + eps * adjusted_sigma
80
+
81
+ # Add the correct amount of fresh noise
82
+ if eta:
83
+ x += torch.randn_like(x) * ddim_sigma
84
+
85
+ # If we are on the last timestep, output the denoised image
86
+ return pred
87
+
88
+ # Soft mask inpainting is just shrinking hard (binary) mask inpainting
89
+ # Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
90
+ def get_bmask(i, steps, mask):
91
+ strength = (i+1)/(steps)
92
+ # convert to binary mask
93
+ bmask = torch.where(mask<=strength,1,0)
94
+ return bmask
95
+
96
+ def make_cond_model_fn(model, cond_fn):
97
+ def cond_model_fn(x, sigma, **kwargs):
98
+ with torch.enable_grad():
99
+ x = x.detach().requires_grad_()
100
+ denoised = model(x, sigma, **kwargs)
101
+ cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
102
+ cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
103
+ return cond_denoised
104
+ return cond_model_fn
105
+
106
+ # Uses k-diffusion from https://github.com/crowsonkb/k-diffusion
107
+ # init_data is init_audio as latents (if this is latent diffusion)
108
+ # For sampling, set both init_data and mask to None
109
+ # For variations, set init_data
110
+ # For inpainting, set both init_data & mask
111
+ def sample_k(
112
+ model_fn,
113
+ noise,
114
+ init_data=None,
115
+ mask=None,
116
+ steps=100,
117
+ sampler_type="dpmpp-2m-sde",
118
+ sigma_min=0.5,
119
+ sigma_max=50,
120
+ rho=1.0, device="cuda",
121
+ callback=None,
122
+ cond_fn=None,
123
+ **extra_args
124
+ ):
125
+
126
+ denoiser = K.external.VDenoiser(model_fn)
127
+
128
+ if cond_fn is not None:
129
+ denoiser = make_cond_model_fn(denoiser, cond_fn)
130
+
131
+ # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has
132
+ sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device)
133
+ # Scale the initial noise by sigma
134
+ noise = noise * sigmas[0]
135
+
136
+ wrapped_callback = callback
137
+
138
+ if mask is None and init_data is not None:
139
+ # VARIATION (no inpainting)
140
+ # set the initial latent to the init_data, and noise it with initial sigma
141
+ x = init_data + noise
142
+ elif mask is not None and init_data is not None:
143
+ # INPAINTING
144
+ bmask = get_bmask(0, steps, mask)
145
+ # initial noising
146
+ input_noised = init_data + noise
147
+ # set the initial latent to a mix of init_data and noise, based on step 0's binary mask
148
+ x = input_noised * bmask + noise * (1-bmask)
149
+ # define the inpainting callback function (Note: side effects, it mutates x)
150
+ # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105
151
+ # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
152
+ # This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)`
153
+ def inpainting_callback(args):
154
+ i = args["i"]
155
+ x = args["x"]
156
+ sigma = args["sigma"]
157
+ #denoised = args["denoised"]
158
+ # noise the init_data input with this step's appropriate amount of noise
159
+ input_noised = init_data + torch.randn_like(init_data) * sigma
160
+ # shrinking hard mask
161
+ bmask = get_bmask(i, steps, mask)
162
+ # mix input_noise with x, using binary mask
163
+ new_x = input_noised * bmask + x * (1-bmask)
164
+ # mutate x
165
+ x[:,:,:] = new_x[:,:,:]
166
+ # wrap together the inpainting callback and the user-submitted callback.
167
+ if callback is None:
168
+ wrapped_callback = inpainting_callback
169
+ else:
170
+ wrapped_callback = lambda args: (inpainting_callback(args), callback(args))
171
+ else:
172
+ # SAMPLING
173
+ # set the initial latent to noise
174
+ x = noise
175
+
176
+
177
+ with torch.cuda.amp.autocast():
178
+ if sampler_type == "k-heun":
179
+ return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
180
+ elif sampler_type == "k-lms":
181
+ return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
182
+ elif sampler_type == "k-dpmpp-2s-ancestral":
183
+ return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
184
+ elif sampler_type == "k-dpm-2":
185
+ return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
186
+ elif sampler_type == "k-dpm-fast":
187
+ return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args)
188
+ elif sampler_type == "k-dpm-adaptive":
189
+ return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args)
190
+ elif sampler_type == "dpmpp-2m-sde":
191
+ return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
192
+ elif sampler_type == "dpmpp-3m-sde":
193
+ return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
194
+
195
+ # Uses discrete Euler sampling for rectified flow models
196
+ # init_data is init_audio as latents (if this is latent diffusion)
197
+ # For sampling, set both init_data and mask to None
198
+ # For variations, set init_data
199
+ # For inpainting, set both init_data & mask
200
+ def sample_rf(
201
+ model_fn,
202
+ noise,
203
+ init_data=None,
204
+ steps=100,
205
+ sigma_max=1,
206
+ device="cuda",
207
+ callback=None,
208
+ cond_fn=None,
209
+ **extra_args
210
+ ):
211
+
212
+ if sigma_max > 1:
213
+ sigma_max = 1
214
+
215
+ if cond_fn is not None:
216
+ denoiser = make_cond_model_fn(denoiser, cond_fn)
217
+
218
+ wrapped_callback = callback
219
+
220
+ if init_data is not None:
221
+ # VARIATION (no inpainting)
222
+ # Interpolate the init data and the noise for init audio
223
+ x = init_data * (1 - sigma_max) + noise * sigma_max
224
+ else:
225
+ # SAMPLING
226
+ # set the initial latent to noise
227
+ x = noise
228
+
229
+ with torch.cuda.amp.autocast():
230
+ # TODO: Add callback support
231
+ #return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args)
232
+ return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args)
cosyvoice/flow/stable/stable_diffusion.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+ from .dit import DiffusionTransformer
4
+ from .adp import UNet1d
5
+ from .sampling import sample
6
+ import math
7
+ from model.base import BaseModule
8
+ import pdb
9
+
10
+ target_length = 1536
11
+
12
+
13
+ def pad_and_create_mask(matrix, target_length):
14
+ T = matrix.shape[2]
15
+ if T > target_length:
16
+ raise ValueError("The third dimension length %s should not exceed %s" % (T, target_length))
17
+
18
+ padding_size = target_length - T
19
+
20
+ padded_matrix = F.pad(matrix, (0, padding_size), "constant", 0)
21
+
22
+ mask = torch.ones((1, target_length))
23
+ mask[:, T:] = 0 # Set the padding part to 0
24
+
25
+ return padded_matrix.to(matrix.device), mask.to(matrix.device)
26
+
27
+
28
+ class Stable_Diffusion(BaseModule):
29
+ def __init__(self, io_channels, input_concat_dim=None, embed_dim=768, depth=24, num_heads=24,
30
+ project_cond_tokens=False, transformer_type="continuous_transformer"):
31
+ super(Stable_Diffusion, self).__init__()
32
+ self.diffusion = DiffusionTransformer(
33
+ io_channels=io_channels,
34
+ input_concat_dim=input_concat_dim,
35
+ embed_dim=embed_dim,
36
+ # cond_token_dim=target_length,
37
+ depth=depth,
38
+ num_heads=num_heads,
39
+ project_cond_tokens=project_cond_tokens,
40
+ transformer_type=transformer_type,
41
+ )
42
+ # self.diffusion = UNet1d(
43
+ # in_channels=80,
44
+ # channels=256,
45
+ # resnet_groups=16,
46
+ # kernel_multiplier_downsample=2,
47
+ # multipliers=[4, 4, 4, 5, 5],
48
+ # factors=[1, 2, 2, 4], # 输入长度不一致卷积缩短
49
+ # num_blocks=[2, 2, 2, 2],
50
+ # attentions=[1, 3, 3, 3, 3],
51
+ # attention_heads=16,
52
+ # attention_multiplier=4,
53
+ # use_nearest_upsample=False,
54
+ # use_skip_scale=True,
55
+ # use_context_time=True
56
+ # )
57
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
58
+
59
+ @torch.no_grad()
60
+ def forward(self, mu, mask, n_timesteps):
61
+ # pdb.set_trace()
62
+ mask = mask.squeeze(1)
63
+ noise = torch.randn_like(mu).to(mu.device)
64
+ # mu_pad, mu_pad_mask = pad_and_create_mask(mu, target_length)
65
+ # extra_args = {"cross_attn_cond": mu, "cross_attn_cond_mask": mask, "mask": mask}
66
+ extra_args = {"input_concat_cond": mu, "mask": mask}
67
+ fakes = sample(self.diffusion, noise, n_timesteps, 0, **extra_args)
68
+
69
+ return fakes
70
+
71
+ def compute_loss(self, x0, mask, mu):
72
+
73
+ # pdb.set_trace()
74
+ t = self.rng.draw(x0.shape[0])[:, 0].to(x0.device)
75
+ alphas, sigmas = torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
76
+
77
+ alphas = alphas[:, None, None]
78
+ sigmas = sigmas[:, None, None]
79
+ noise = torch.randn_like(x0)
80
+ noised_inputs = x0 * alphas + noise * sigmas
81
+ targets = noise * alphas - x0 * sigmas
82
+ mask = mask.squeeze(1)
83
+ # mu_pad, mu_pad_mask = pad_and_create_mask(mu, target_length)
84
+ # output = self.diffusion(noised_inputs, t, cross_attn_cond=mu,
85
+ # cross_attn_cond_mask=mask, mask=mask, cfg_dropout_prob=0.1)
86
+ # pdb.set_trace()
87
+ output = self.diffusion(noised_inputs, # [bs, 80, 229]
88
+ t, # (bs,)
89
+ input_concat_cond=mu,
90
+ mask=mask, # [bs, 229]
91
+ cfg_dropout_prob=0.1)
92
+
93
+ return self.mse_loss(output, targets, mask), output
94
+
95
+ def mse_loss(self, output, targets, mask):
96
+
97
+ mse_loss = F.mse_loss(output, targets, reduction='none')
98
+
99
+ if mask.ndim == 2 and mse_loss.ndim == 3:
100
+ mask = mask.unsqueeze(1)
101
+
102
+ if mask.shape[1] != mse_loss.shape[1]:
103
+ mask = mask.repeat(1, mse_loss.shape[1], 1)
104
+
105
+ mse_loss = mse_loss * mask
106
+
107
+ mse_loss = mse_loss.mean()
108
+
109
+ return mse_loss
cosyvoice/flow/stable/stable_diffusion_test.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+ from .dit import DiffusionTransformer
4
+ from .adp import UNet1d
5
+ from .sampling import sample
6
+ import math
7
+ from model.base import BaseModule
8
+ import pdb
9
+
10
+ target_length = 1536
11
+ def pad_and_create_mask(matrix, target_length):
12
+
13
+ T = matrix.shape[2]
14
+ if T > target_length:
15
+ raise ValueError("The third dimension length %s should not exceed %s"%(T, target_length))
16
+
17
+ padding_size = target_length - T
18
+
19
+ padded_matrix = F.pad(matrix, (0, padding_size), "constant", 0)
20
+
21
+ mask = torch.ones((1, target_length))
22
+ mask[:, T:] = 0 # Set the padding part to 0
23
+
24
+ return padded_matrix.to(matrix.device), mask.to(matrix.device)
25
+
26
+
27
+ class Stable_Diffusion(BaseModule):
28
+ def __init__(self):
29
+ super(Stable_Diffusion, self).__init__()
30
+ self.diffusion = DiffusionTransformer(
31
+ io_channels=80,
32
+ # input_concat_dim=80,
33
+ embed_dim=768,
34
+ # cond_token_dim=target_length,
35
+ depth=24,
36
+ num_heads=24,
37
+ project_cond_tokens=False,
38
+ transformer_type="continuous_transformer",
39
+ )
40
+ # self.diffusion = UNet1d(
41
+ # in_channels=80,
42
+ # channels=256,
43
+ # resnet_groups=16,
44
+ # kernel_multiplier_downsample=2,
45
+ # multipliers=[4, 4, 4, 5, 5],
46
+ # factors=[1, 2, 2, 4], # 输入长度不一致卷积缩短
47
+ # num_blocks=[2, 2, 2, 2],
48
+ # attentions=[1, 3, 3, 3, 3],
49
+ # attention_heads=16,
50
+ # attention_multiplier=4,
51
+ # use_nearest_upsample=False,
52
+ # use_skip_scale=True,
53
+ # use_context_time=True
54
+ # )
55
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
56
+
57
+ @torch.no_grad()
58
+ def forward(self, mu, mask, n_timesteps):
59
+ # pdb.set_trace()
60
+ mask = mask.squeeze(1)
61
+ # noise = torch.randn_like(mu).to(mu.device)
62
+ # mu_pad, mu_pad_mask = pad_and_create_mask(mu, target_length)
63
+ # extra_args = {"cross_attn_cond": mu, "cross_attn_cond_mask": mask, "mask": mask}
64
+ extra_args = {"mask": mask}
65
+ fakes = sample(self.diffusion, mu, n_timesteps, 0, **extra_args)
66
+
67
+ return fakes
68
+
69
+
70
+ def compute_loss(self, x0, mask, mu):
71
+
72
+ # pdb.set_trace()
73
+ t = self.rng.draw(x0.shape[0])[:, 0].to(x0.device)
74
+ alphas, sigmas = torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
75
+
76
+ alphas = alphas[:, None, None]
77
+ sigmas = sigmas[:, None, None]
78
+ noise = torch.randn_like(x0)
79
+ noised_inputs = x0 * alphas + noise * sigmas
80
+ targets = mu * alphas - x0 * sigmas
81
+ mask = mask.squeeze(1)
82
+ # mu_pad, mu_pad_mask = pad_and_create_mask(mu, target_length)
83
+ # output = self.diffusion(noised_inputs, t, cross_attn_cond=mu,
84
+ # cross_attn_cond_mask=mask, mask=mask, cfg_dropout_prob=0.1)
85
+ output = self.diffusion(noised_inputs, t, mask=mask, cfg_dropout_prob=0.1)
86
+
87
+ return self.mse_loss(output, targets, mask), output
88
+
89
+
90
+ def mse_loss(self, output, targets, mask):
91
+
92
+ mse_loss = F.mse_loss(output, targets, reduction='none')
93
+
94
+ if mask.ndim == 2 and mse_loss.ndim == 3:
95
+ mask = mask.unsqueeze(1)
96
+
97
+ if mask.shape[1] != mse_loss.shape[1]:
98
+ mask = mask.repeat(1, mse_loss.shape[1], 1)
99
+
100
+ mse_loss = mse_loss[mask]
101
+
102
+ mse_loss = mse_loss.mean()
103
+
104
+ return mse_loss
cosyvoice/flow/stable/transformer.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ from functools import reduce, partial
3
+ from packaging import version
4
+
5
+ from einops import rearrange, repeat
6
+ from einops.layers.torch import Rearrange
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn, einsum
10
+ from torch.cuda.amp import autocast
11
+ from typing import Callable, Literal
12
+
13
+ try:
14
+ from flash_attn import flash_attn_func, flash_attn_kvpacked_func
15
+ except ImportError as e:
16
+ print(e)
17
+ print('flash_attn not installed, disabling Flash Attention')
18
+ flash_attn_kvpacked_func = None
19
+ flash_attn_func = None
20
+
21
+ try:
22
+ import natten
23
+ except ImportError:
24
+ natten = None
25
+
26
+ def checkpoint(function, *args, **kwargs):
27
+ kwargs.setdefault("use_reentrant", False)
28
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
29
+
30
+
31
+ # Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
32
+ # License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
33
+
34
+ def create_causal_mask(i, j, device):
35
+ return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
36
+
37
+ def or_reduce(masks):
38
+ head, *body = masks
39
+ for rest in body:
40
+ head = head | rest
41
+ return head
42
+
43
+ # positional embeddings
44
+
45
+ class AbsolutePositionalEmbedding(nn.Module):
46
+ def __init__(self, dim, max_seq_len):
47
+ super().__init__()
48
+ self.scale = dim ** -0.5
49
+ self.max_seq_len = max_seq_len
50
+ self.emb = nn.Embedding(max_seq_len, dim)
51
+
52
+ def forward(self, x, pos = None, seq_start_pos = None):
53
+ seq_len, device = x.shape[1], x.device
54
+ assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
55
+
56
+ if pos is None:
57
+ pos = torch.arange(seq_len, device = device)
58
+
59
+ if seq_start_pos is not None:
60
+ pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
61
+
62
+ pos_emb = self.emb(pos)
63
+ pos_emb = pos_emb * self.scale
64
+ return pos_emb
65
+
66
+ class ScaledSinusoidalEmbedding(nn.Module):
67
+ def __init__(self, dim, theta = 10000):
68
+ super().__init__()
69
+ assert (dim % 2) == 0, 'dimension must be divisible by 2'
70
+ self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
71
+
72
+ half_dim = dim // 2
73
+ freq_seq = torch.arange(half_dim).float() / half_dim
74
+ inv_freq = theta ** -freq_seq
75
+ self.register_buffer('inv_freq', inv_freq, persistent = False)
76
+
77
+ def forward(self, x, pos = None, seq_start_pos = None):
78
+ seq_len, device = x.shape[1], x.device
79
+
80
+ if pos is None:
81
+ pos = torch.arange(seq_len, device = device)
82
+
83
+ if seq_start_pos is not None:
84
+ pos = pos - seq_start_pos[..., None]
85
+
86
+ emb = einsum('i, j -> i j', pos, self.inv_freq)
87
+ emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
88
+ return emb * self.scale
89
+
90
+ class RotaryEmbedding(nn.Module):
91
+ def __init__(
92
+ self,
93
+ dim,
94
+ use_xpos = False,
95
+ scale_base = 512,
96
+ interpolation_factor = 1.,
97
+ base = 10000,
98
+ base_rescale_factor = 1.
99
+ ):
100
+ super().__init__()
101
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
102
+ # has some connection to NTK literature
103
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
104
+ base *= base_rescale_factor ** (dim / (dim - 2))
105
+
106
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
107
+ self.register_buffer('inv_freq', inv_freq)
108
+
109
+ assert interpolation_factor >= 1.
110
+ self.interpolation_factor = interpolation_factor
111
+
112
+ if not use_xpos:
113
+ self.register_buffer('scale', None)
114
+ return
115
+
116
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
117
+
118
+ self.scale_base = scale_base
119
+ self.register_buffer('scale', scale)
120
+
121
+ def forward_from_seq_len(self, seq_len):
122
+ device = self.inv_freq.device
123
+
124
+ t = torch.arange(seq_len, device = device)
125
+ return self.forward(t)
126
+
127
+ @autocast(enabled = False)
128
+ def forward(self, t):
129
+ device = self.inv_freq.device
130
+
131
+ t = t.to(torch.float32)
132
+
133
+ t = t / self.interpolation_factor
134
+
135
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
136
+ freqs = torch.cat((freqs, freqs), dim = -1)
137
+
138
+ if self.scale is None:
139
+ return freqs, 1.
140
+
141
+ power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
142
+ scale = self.scale ** rearrange(power, 'n -> n 1')
143
+ scale = torch.cat((scale, scale), dim = -1)
144
+
145
+ return freqs, scale
146
+
147
+ def rotate_half(x):
148
+ x = rearrange(x, '... (j d) -> ... j d', j = 2)
149
+ x1, x2 = x.unbind(dim = -2)
150
+ return torch.cat((-x2, x1), dim = -1)
151
+
152
+ @autocast(enabled = False)
153
+ def apply_rotary_pos_emb(t, freqs, scale = 1):
154
+ out_dtype = t.dtype
155
+
156
+ # cast to float32 if necessary for numerical stability
157
+ dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
158
+ rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
159
+ freqs, t = freqs.to(dtype), t.to(dtype)
160
+ freqs = freqs[-seq_len:, :]
161
+
162
+ if t.ndim == 4 and freqs.ndim == 3:
163
+ freqs = rearrange(freqs, 'b n d -> b 1 n d')
164
+
165
+ # partial rotary embeddings, Wang et al. GPT-J
166
+ t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
167
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
168
+
169
+ t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
170
+
171
+ return torch.cat((t, t_unrotated), dim = -1)
172
+
173
+ # norms
174
+ class LayerNorm(nn.Module):
175
+ def __init__(self, dim, bias=False, fix_scale=False):
176
+ """
177
+ bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
178
+ """
179
+ super().__init__()
180
+
181
+ if fix_scale:
182
+ self.register_buffer("gamma", torch.ones(dim))
183
+ else:
184
+ self.gamma = nn.Parameter(torch.ones(dim))
185
+
186
+ if bias:
187
+ self.beta = nn.Parameter(torch.zeros(dim))
188
+ else:
189
+ self.register_buffer("beta", torch.zeros(dim))
190
+
191
+
192
+ def forward(self, x):
193
+ return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta)
194
+
195
+ # feedforward
196
+
197
+ class GLU(nn.Module):
198
+ def __init__(
199
+ self,
200
+ dim_in,
201
+ dim_out,
202
+ activation: Callable,
203
+ use_conv = False,
204
+ conv_kernel_size = 3,
205
+ ):
206
+ super().__init__()
207
+ self.act = activation
208
+ self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2))
209
+ self.use_conv = use_conv
210
+
211
+ def forward(self, x):
212
+ if self.use_conv:
213
+ x = rearrange(x, 'b n d -> b d n')
214
+ x = self.proj(x)
215
+ x = rearrange(x, 'b d n -> b n d')
216
+ else:
217
+ x = self.proj(x)
218
+
219
+ x, gate = x.chunk(2, dim = -1)
220
+ return x * self.act(gate)
221
+
222
+ class FeedForward(nn.Module):
223
+ def __init__(
224
+ self,
225
+ dim,
226
+ dim_out = None,
227
+ mult = 4,
228
+ no_bias = False,
229
+ glu = True,
230
+ use_conv = False,
231
+ conv_kernel_size = 3,
232
+ zero_init_output = True,
233
+ ):
234
+ super().__init__()
235
+ inner_dim = int(dim * mult)
236
+
237
+ # Default to SwiGLU
238
+
239
+ activation = nn.SiLU()
240
+
241
+ dim_out = dim if dim_out is None else dim_out
242
+
243
+ if glu:
244
+ linear_in = GLU(dim, inner_dim, activation)
245
+ else:
246
+ linear_in = nn.Sequential(
247
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
248
+ nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias),
249
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
250
+ activation
251
+ )
252
+
253
+ linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias)
254
+
255
+ # init last linear layer to 0
256
+ if zero_init_output:
257
+ nn.init.zeros_(linear_out.weight)
258
+ if not no_bias:
259
+ nn.init.zeros_(linear_out.bias)
260
+
261
+
262
+ self.ff = nn.Sequential(
263
+ linear_in,
264
+ Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
265
+ linear_out,
266
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
267
+ )
268
+
269
+ def forward(self, x):
270
+ return self.ff(x)
271
+
272
+ class Attention(nn.Module):
273
+ def __init__(
274
+ self,
275
+ dim,
276
+ dim_heads = 64,
277
+ dim_context = None,
278
+ causal = False,
279
+ zero_init_output=True,
280
+ qk_norm: Literal['l2', 'ln', 'none'] = 'none',
281
+ natten_kernel_size = None
282
+ ):
283
+ super().__init__()
284
+ self.dim = dim
285
+ self.dim_heads = dim_heads
286
+ self.causal = causal
287
+
288
+ dim_kv = dim_context if dim_context is not None else dim
289
+
290
+ self.num_heads = dim // dim_heads
291
+ self.kv_heads = dim_kv // dim_heads
292
+
293
+ if dim_context is not None:
294
+ self.to_q = nn.Linear(dim, dim, bias=False)
295
+ self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
296
+ else:
297
+ self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
298
+
299
+ self.to_out = nn.Linear(dim, dim, bias=False)
300
+
301
+ if zero_init_output:
302
+ nn.init.zeros_(self.to_out.weight)
303
+
304
+ self.qk_norm = qk_norm
305
+
306
+ if self.qk_norm == "ln":
307
+ self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
308
+ self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
309
+
310
+ # Using 1d neighborhood attention
311
+ self.natten_kernel_size = natten_kernel_size
312
+ if natten_kernel_size is not None:
313
+ return
314
+
315
+ self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
316
+
317
+ self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None
318
+ # pdb.set_trace()
319
+ self.use_fa_flash = False
320
+
321
+ self.sdp_kwargs = dict(
322
+ enable_flash = True,
323
+ enable_math = True,
324
+ enable_mem_efficient = True
325
+ )
326
+
327
+ def flash_attn(
328
+ self,
329
+ q,
330
+ k,
331
+ v,
332
+ mask = None,
333
+ causal = None
334
+ ):
335
+ batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device
336
+ kv_heads = k.shape[1]
337
+ # Recommended for multi-query single-key-value attention by Tri Dao
338
+ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
339
+
340
+ if heads != kv_heads:
341
+ # Repeat interleave kv_heads to match q_heads
342
+ heads_per_kv_head = heads // kv_heads
343
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
344
+
345
+ if k.ndim == 3:
346
+ k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
347
+
348
+ if v.ndim == 3:
349
+ v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
350
+
351
+ causal = self.causal if causal is None else causal
352
+
353
+ if q_len == 1 and causal:
354
+ causal = False
355
+
356
+ if mask is not None:
357
+ assert mask.ndim == 4
358
+ mask = mask.expand(batch, heads, q_len, k_len)
359
+
360
+ # handle kv cache - this should be bypassable in updated flash attention 2
361
+
362
+ if k_len > q_len and causal:
363
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
364
+ if mask is None:
365
+ mask = ~causal_mask
366
+ else:
367
+ mask = mask & ~causal_mask
368
+ causal = False
369
+
370
+ # manually handle causal mask, if another mask was given
371
+
372
+ row_is_entirely_masked = None
373
+
374
+ if mask is not None and causal:
375
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
376
+ mask = mask & ~causal_mask
377
+
378
+ # protect against an entire row being masked out
379
+
380
+ row_is_entirely_masked = ~mask.any(dim = -1)
381
+ mask[..., 0] = mask[..., 0] | row_is_entirely_masked
382
+
383
+ causal = False
384
+
385
+ with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
386
+ out = F.scaled_dot_product_attention(
387
+ q, k, v,
388
+ attn_mask = mask,
389
+ is_causal = causal
390
+ )
391
+
392
+ # for a row that is entirely masked out, should zero out the output of that row token
393
+
394
+ if row_is_entirely_masked is not None:
395
+ out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
396
+
397
+ return out
398
+
399
+ def forward(
400
+ self,
401
+ x,
402
+ context = None,
403
+ mask = None,
404
+ context_mask = None,
405
+ rotary_pos_emb = None,
406
+ causal = None
407
+ ):
408
+ h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
409
+
410
+ kv_input = context if has_context else x
411
+
412
+ if hasattr(self, 'to_q'):
413
+ # Use separate linear projections for q and k/v
414
+ q = self.to_q(x)
415
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h)
416
+
417
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
418
+
419
+ k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
420
+ else:
421
+ # Use fused linear projection
422
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
423
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
424
+
425
+ # Normalize q and k for cosine sim attention
426
+ if self.qk_norm == "l2":
427
+ q = F.normalize(q, dim=-1)
428
+ k = F.normalize(k, dim=-1)
429
+ elif self.qk_norm == "ln":
430
+ q = self.q_norm(q)
431
+ k = self.k_norm(k)
432
+
433
+ if rotary_pos_emb is not None and not has_context:
434
+ freqs, _ = rotary_pos_emb
435
+
436
+ q_dtype = q.dtype
437
+ k_dtype = k.dtype
438
+
439
+ q = q.to(torch.float32)
440
+ k = k.to(torch.float32)
441
+ freqs = freqs.to(torch.float32)
442
+
443
+ q = apply_rotary_pos_emb(q, freqs)
444
+ k = apply_rotary_pos_emb(k, freqs)
445
+
446
+ q = q.to(q_dtype)
447
+ k = k.to(k_dtype)
448
+
449
+ input_mask = context_mask
450
+
451
+ if input_mask is None and not has_context:
452
+ input_mask = mask
453
+
454
+ # determine masking
455
+ masks = []
456
+ final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
457
+
458
+ if input_mask is not None:
459
+ input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
460
+ masks.append(~input_mask)
461
+
462
+ # Other masks will be added here later
463
+
464
+ if len(masks) > 0:
465
+ final_attn_mask = ~or_reduce(masks)
466
+
467
+ n, device = q.shape[-2], q.device
468
+
469
+ causal = self.causal if causal is None else causal
470
+
471
+ if n == 1 and causal:
472
+ causal = False
473
+
474
+ if self.natten_kernel_size is not None:
475
+ if natten is None:
476
+ raise ImportError('natten not installed, please install natten to use neighborhood attention')
477
+
478
+ dtype_in = q.dtype
479
+ q, k, v = map(lambda t: t.to(torch.float32), (q, k, v))
480
+
481
+ attn = natten.functional.natten1dqk(q, k, kernel_size = self.natten_kernel_size, dilation=1)
482
+
483
+ if final_attn_mask is not None:
484
+ attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max)
485
+
486
+ attn = F.softmax(attn, dim=-1, dtype=torch.float32)
487
+
488
+ out = natten.functional.natten1dav(attn, v, kernel_size = self.natten_kernel_size, dilation=1).to(dtype_in)
489
+
490
+ # Prioritize Flash Attention 2
491
+ elif self.use_fa_flash:
492
+ # pdb.set_trace()
493
+ assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2'
494
+ # Flash Attention 2 requires FP16 inputs
495
+ fa_dtype_in = q.dtype
496
+ q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v))
497
+
498
+ out = flash_attn_func(q, k, v, causal = causal)
499
+
500
+ out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
501
+
502
+ # Fall back to PyTorch implementation
503
+ elif self.use_pt_flash:
504
+ out = self.flash_attn(q, k, v, causal = causal, mask = final_attn_mask)
505
+
506
+ else:
507
+ # Fall back to custom implementation
508
+
509
+ if h != kv_h:
510
+ # Repeat interleave kv_heads to match q_heads
511
+ heads_per_kv_head = h // kv_h
512
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
513
+
514
+ scale = 1. / (q.shape[-1] ** 0.5)
515
+
516
+ kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
517
+
518
+ dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
519
+
520
+ i, j, dtype = *dots.shape[-2:], dots.dtype
521
+
522
+ mask_value = -torch.finfo(dots.dtype).max
523
+
524
+ if final_attn_mask is not None:
525
+ dots = dots.masked_fill(~final_attn_mask, mask_value)
526
+
527
+ if causal:
528
+ causal_mask = self.create_causal_mask(i, j, device = device)
529
+ dots = dots.masked_fill(causal_mask, mask_value)
530
+
531
+ attn = F.softmax(dots, dim=-1, dtype=torch.float32)
532
+ attn = attn.type(dtype)
533
+
534
+ out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
535
+
536
+ # merge heads
537
+ out = rearrange(out, ' b h n d -> b n (h d)')
538
+
539
+ # Communicate between heads
540
+
541
+ # with autocast(enabled = False):
542
+ # out_dtype = out.dtype
543
+ # out = out.to(torch.float32)
544
+ # out = self.to_out(out).to(out_dtype)
545
+ out = self.to_out(out)
546
+
547
+ if mask is not None:
548
+ mask = rearrange(mask, 'b n -> b n 1')
549
+ out = out.masked_fill(~mask, 0.)
550
+
551
+ return out
552
+
553
+ class ConformerModule(nn.Module):
554
+ def __init__(
555
+ self,
556
+ dim,
557
+ norm_kwargs = {},
558
+ ):
559
+
560
+ super().__init__()
561
+
562
+ self.dim = dim
563
+
564
+ self.in_norm = LayerNorm(dim, **norm_kwargs)
565
+ self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
566
+ self.glu = GLU(dim, dim, nn.SiLU())
567
+ self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
568
+ self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
569
+ self.swish = nn.SiLU()
570
+ self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
571
+
572
+ def forward(self, x):
573
+ x = self.in_norm(x)
574
+ x = rearrange(x, 'b n d -> b d n')
575
+ x = self.pointwise_conv(x)
576
+ x = rearrange(x, 'b d n -> b n d')
577
+ x = self.glu(x)
578
+ x = rearrange(x, 'b n d -> b d n')
579
+ x = self.depthwise_conv(x)
580
+ x = rearrange(x, 'b d n -> b n d')
581
+ x = self.mid_norm(x)
582
+ x = self.swish(x)
583
+ x = rearrange(x, 'b n d -> b d n')
584
+ x = self.pointwise_conv_2(x)
585
+ x = rearrange(x, 'b d n -> b n d')
586
+
587
+ return x
588
+
589
+ class TransformerBlock(nn.Module):
590
+ def __init__(
591
+ self,
592
+ dim,
593
+ dim_heads = 64,
594
+ cross_attend = False,
595
+ dim_context = None,
596
+ global_cond_dim = None,
597
+ causal = False,
598
+ zero_init_branch_outputs = True,
599
+ conformer = False,
600
+ layer_ix = -1,
601
+ remove_norms = False,
602
+ attn_kwargs = {},
603
+ ff_kwargs = {},
604
+ norm_kwargs = {}
605
+ ):
606
+
607
+ super().__init__()
608
+ self.dim = dim
609
+ self.dim_heads = dim_heads
610
+ self.cross_attend = cross_attend
611
+ self.dim_context = dim_context
612
+ self.causal = causal
613
+
614
+ self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
615
+
616
+ self.self_attn = Attention(
617
+ dim,
618
+ dim_heads = dim_heads,
619
+ causal = causal,
620
+ zero_init_output=zero_init_branch_outputs,
621
+ **attn_kwargs
622
+ )
623
+ ### 2. 主要是这边需要修改
624
+ if cross_attend:
625
+ self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
626
+ self.cross_attn = Attention(
627
+ dim,
628
+ dim_heads = dim_heads,
629
+ dim_context=dim_context,
630
+ causal = causal,
631
+ zero_init_output=zero_init_branch_outputs,
632
+ **attn_kwargs
633
+ )
634
+
635
+ self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
636
+ self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
637
+
638
+ self.layer_ix = layer_ix
639
+
640
+ self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
641
+
642
+ self.global_cond_dim = global_cond_dim
643
+
644
+ if global_cond_dim is not None:
645
+ self.to_scale_shift_gate = nn.Sequential(
646
+ nn.SiLU(),
647
+ nn.Linear(global_cond_dim, dim * 6, bias=False)
648
+ )
649
+
650
+ nn.init.zeros_(self.to_scale_shift_gate[1].weight)
651
+ #nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
652
+
653
+ def forward(
654
+ self,
655
+ x,
656
+ context = None,
657
+ global_cond=None,
658
+ mask = None,
659
+ context_mask = None,
660
+ rotary_pos_emb = None
661
+ ):
662
+ if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
663
+
664
+ scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
665
+
666
+ # self-attention with adaLN
667
+ residual = x
668
+ x = self.pre_norm(x)
669
+ x = x * (1 + scale_self) + shift_self
670
+ x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
671
+ x = x * torch.sigmoid(1 - gate_self)
672
+ x = x + residual
673
+
674
+ if context is not None:
675
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
676
+
677
+ if self.conformer is not None:
678
+ x = x + self.conformer(x)
679
+
680
+ # feedforward with adaLN
681
+ residual = x
682
+ x = self.ff_norm(x)
683
+ x = x * (1 + scale_ff) + shift_ff
684
+ x = self.ff(x)
685
+ x = x * torch.sigmoid(1 - gate_ff)
686
+ x = x + residual
687
+
688
+ else:
689
+ x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
690
+
691
+ if context is not None:
692
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
693
+
694
+ if self.conformer is not None:
695
+ x = x + self.conformer(x)
696
+
697
+ x = x + self.ff(self.ff_norm(x))
698
+
699
+ return x
700
+
701
+ class ContinuousTransformer(nn.Module):
702
+ def __init__(
703
+ self,
704
+ dim,
705
+ depth,
706
+ *,
707
+ dim_in = None,
708
+ dim_out = None,
709
+ dim_heads = 64,
710
+ cross_attend=False,
711
+ cond_token_dim=None,
712
+ global_cond_dim=None,
713
+ causal=False,
714
+ rotary_pos_emb=True,
715
+ zero_init_branch_outputs=True,
716
+ conformer=False,
717
+ use_sinusoidal_emb=False,
718
+ use_abs_pos_emb=False,
719
+ abs_pos_emb_max_length=10000,
720
+ **kwargs
721
+ ):
722
+
723
+ super().__init__()
724
+
725
+ self.dim = dim
726
+ self.depth = depth
727
+ self.causal = causal
728
+ self.layers = nn.ModuleList([])
729
+
730
+ self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
731
+ self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
732
+
733
+ if rotary_pos_emb:
734
+ self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
735
+ else:
736
+ self.rotary_pos_emb = None
737
+
738
+ self.use_sinusoidal_emb = use_sinusoidal_emb
739
+ if use_sinusoidal_emb:
740
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
741
+
742
+ self.use_abs_pos_emb = use_abs_pos_emb
743
+ if use_abs_pos_emb:
744
+ self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
745
+
746
+ for i in range(depth):
747
+ self.layers.append(
748
+ TransformerBlock(
749
+ dim,
750
+ dim_heads = dim_heads,
751
+ cross_attend = cross_attend,
752
+ dim_context = cond_token_dim,
753
+ global_cond_dim = global_cond_dim,
754
+ causal = causal,
755
+ zero_init_branch_outputs = zero_init_branch_outputs,
756
+ conformer=conformer,
757
+ layer_ix=i,
758
+ **kwargs
759
+ )
760
+ )
761
+
762
+ def forward(
763
+ self,
764
+ x,
765
+ mask = None,
766
+ prepend_embeds = None,
767
+ prepend_mask = None,
768
+ global_cond = None,
769
+ return_info = False,
770
+ **kwargs
771
+ ):
772
+ batch, seq, device = *x.shape[:2], x.device
773
+
774
+ info = {
775
+ "hidden_states": [],
776
+ }
777
+
778
+ x = self.project_in(x)
779
+ if prepend_embeds is not None:
780
+ prepend_length, prepend_dim = prepend_embeds.shape[1:]
781
+
782
+ assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
783
+
784
+ x = torch.cat((prepend_embeds, x), dim = -2)
785
+
786
+ if prepend_mask is not None or mask is not None:
787
+ mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
788
+ prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
789
+
790
+ mask = torch.cat((prepend_mask, mask), dim = -1)
791
+
792
+ # Attention layers
793
+
794
+ if self.rotary_pos_emb is not None:
795
+ rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
796
+ else:
797
+ rotary_pos_emb = None
798
+
799
+ if self.use_sinusoidal_emb or self.use_abs_pos_emb:
800
+ x = x + self.pos_emb(x)
801
+
802
+ # Iterate over the transformer layers
803
+ for layer in self.layers:
804
+ #x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
805
+ # pdb.set_trace()
806
+ x = checkpoint(layer, x, mask=mask.bool(),rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
807
+
808
+ if return_info:
809
+ info["hidden_states"].append(x)
810
+
811
+ x = self.project_out(x)
812
+
813
+ if return_info:
814
+ return x, info
815
+
816
+ return x
cosyvoice/flow/stable/transformer_use_mask.py ADDED
@@ -0,0 +1,845 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ from functools import reduce, partial
3
+ from packaging import version
4
+
5
+ from einops import rearrange, repeat
6
+ from einops.layers.torch import Rearrange
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn, einsum
10
+ from torch.cuda.amp import autocast
11
+ from typing import Callable, Literal
12
+
13
+ try:
14
+ from flash_attn import flash_attn_func, flash_attn_kvpacked_func
15
+ except ImportError as e:
16
+ print(e)
17
+ print('flash_attn not installed, disabling Flash Attention')
18
+ flash_attn_kvpacked_func = None
19
+ flash_attn_func = None
20
+
21
+ try:
22
+ import natten
23
+ except ImportError:
24
+ natten = None
25
+
26
+
27
+ def checkpoint(function, *args, **kwargs):
28
+ kwargs.setdefault("use_reentrant", False)
29
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
30
+
31
+
32
+ # Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
33
+ # License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
34
+
35
+ def create_causal_mask(i, j, device):
36
+ return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)
37
+
38
+
39
+ def or_reduce(masks):
40
+ head, *body = masks
41
+ for rest in body:
42
+ head = head | rest
43
+ return head
44
+
45
+
46
+ # positional embeddings
47
+
48
+ class AbsolutePositionalEmbedding(nn.Module):
49
+ def __init__(self, dim, max_seq_len):
50
+ super().__init__()
51
+ self.scale = dim ** -0.5
52
+ self.max_seq_len = max_seq_len
53
+ self.emb = nn.Embedding(max_seq_len, dim)
54
+
55
+ def forward(self, x, pos=None, seq_start_pos=None):
56
+ seq_len, device = x.shape[1], x.device
57
+ assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
58
+
59
+ if pos is None:
60
+ pos = torch.arange(seq_len, device=device)
61
+
62
+ if seq_start_pos is not None:
63
+ pos = (pos - seq_start_pos[..., None]).clamp(min=0)
64
+
65
+ pos_emb = self.emb(pos)
66
+ pos_emb = pos_emb * self.scale
67
+ return pos_emb
68
+
69
+
70
+ class ScaledSinusoidalEmbedding(nn.Module):
71
+ def __init__(self, dim, theta=10000):
72
+ super().__init__()
73
+ assert (dim % 2) == 0, 'dimension must be divisible by 2'
74
+ self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
75
+
76
+ half_dim = dim // 2
77
+ freq_seq = torch.arange(half_dim).float() / half_dim
78
+ inv_freq = theta ** -freq_seq
79
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
80
+
81
+ def forward(self, x, pos=None, seq_start_pos=None):
82
+ seq_len, device = x.shape[1], x.device
83
+
84
+ if pos is None:
85
+ pos = torch.arange(seq_len, device=device)
86
+
87
+ if seq_start_pos is not None:
88
+ pos = pos - seq_start_pos[..., None]
89
+
90
+ emb = einsum('i, j -> i j', pos, self.inv_freq)
91
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
92
+ return emb * self.scale
93
+
94
+
95
+ class RotaryEmbedding(nn.Module):
96
+ def __init__(
97
+ self,
98
+ dim,
99
+ use_xpos=False,
100
+ scale_base=512,
101
+ interpolation_factor=1.,
102
+ base=10000,
103
+ base_rescale_factor=1.
104
+ ):
105
+ super().__init__()
106
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
107
+ # has some connection to NTK literature
108
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
109
+ base *= base_rescale_factor ** (dim / (dim - 2))
110
+
111
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
112
+ self.register_buffer('inv_freq', inv_freq)
113
+
114
+ assert interpolation_factor >= 1.
115
+ self.interpolation_factor = interpolation_factor
116
+
117
+ if not use_xpos:
118
+ self.register_buffer('scale', None)
119
+ return
120
+
121
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
122
+
123
+ self.scale_base = scale_base
124
+ self.register_buffer('scale', scale)
125
+
126
+ def forward_from_seq_len(self, seq_len):
127
+ device = self.inv_freq.device
128
+
129
+ t = torch.arange(seq_len, device=device)
130
+ return self.forward(t)
131
+
132
+ @autocast(enabled=False)
133
+ def forward(self, t):
134
+ device = self.inv_freq.device
135
+
136
+ t = t.to(torch.float32)
137
+
138
+ t = t / self.interpolation_factor
139
+
140
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
141
+ freqs = torch.cat((freqs, freqs), dim=-1)
142
+
143
+ if self.scale is None:
144
+ return freqs, 1.
145
+
146
+ power = (torch.arange(seq_len, device=device) - (seq_len // 2)) / self.scale_base
147
+ scale = self.scale ** rearrange(power, 'n -> n 1')
148
+ scale = torch.cat((scale, scale), dim=-1)
149
+
150
+ return freqs, scale
151
+
152
+
153
+ def rotate_half(x):
154
+ x = rearrange(x, '... (j d) -> ... j d', j=2)
155
+ x1, x2 = x.unbind(dim=-2)
156
+ return torch.cat((-x2, x1), dim=-1)
157
+
158
+
159
+ @autocast(enabled=False)
160
+ def apply_rotary_pos_emb(t, freqs, scale=1):
161
+ out_dtype = t.dtype
162
+
163
+ # cast to float32 if necessary for numerical stability
164
+ dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
165
+ rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
166
+ freqs, t = freqs.to(dtype), t.to(dtype)
167
+ freqs = freqs[-seq_len:, :]
168
+
169
+ if t.ndim == 4 and freqs.ndim == 3:
170
+ freqs = rearrange(freqs, 'b n d -> b 1 n d')
171
+
172
+ # partial rotary embeddings, Wang et al. GPT-J
173
+ t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
174
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
175
+
176
+ t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
177
+
178
+ return torch.cat((t, t_unrotated), dim=-1)
179
+
180
+
181
+ # norms
182
+ class LayerNorm(nn.Module):
183
+ def __init__(self, dim, bias=False, fix_scale=False):
184
+ """
185
+ bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
186
+ """
187
+ super().__init__()
188
+
189
+ if fix_scale:
190
+ self.register_buffer("gamma", torch.ones(dim))
191
+ else:
192
+ self.gamma = nn.Parameter(torch.ones(dim))
193
+
194
+ if bias:
195
+ self.beta = nn.Parameter(torch.zeros(dim))
196
+ else:
197
+ self.register_buffer("beta", torch.zeros(dim))
198
+
199
+ def forward(self, x):
200
+ return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta)
201
+
202
+
203
+ # feedforward
204
+
205
+ class GLU(nn.Module):
206
+ def __init__(
207
+ self,
208
+ dim_in,
209
+ dim_out,
210
+ activation: Callable,
211
+ use_conv=False,
212
+ conv_kernel_size=3,
213
+ ):
214
+ super().__init__()
215
+ self.act = activation
216
+ self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size,
217
+ padding=(conv_kernel_size // 2))
218
+ self.use_conv = use_conv
219
+
220
+ def forward(self, x):
221
+ if self.use_conv:
222
+ x = rearrange(x, 'b n d -> b d n')
223
+ x = self.proj(x)
224
+ x = rearrange(x, 'b d n -> b n d')
225
+ else:
226
+ x = self.proj(x)
227
+
228
+ x, gate = x.chunk(2, dim=-1)
229
+ return x * self.act(gate)
230
+
231
+
232
+ class FeedForward(nn.Module):
233
+ def __init__(
234
+ self,
235
+ dim,
236
+ dim_out=None,
237
+ mult=4,
238
+ no_bias=False,
239
+ glu=True,
240
+ use_conv=False,
241
+ conv_kernel_size=3,
242
+ zero_init_output=True,
243
+ ):
244
+ super().__init__()
245
+ inner_dim = int(dim * mult)
246
+
247
+ # Default to SwiGLU
248
+
249
+ activation = nn.SiLU()
250
+
251
+ dim_out = dim if dim_out is None else dim_out
252
+
253
+ if glu:
254
+ linear_in = GLU(dim, inner_dim, activation)
255
+ else:
256
+ linear_in = nn.Sequential(
257
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
258
+ nn.Linear(dim, inner_dim, bias=not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim,
259
+ conv_kernel_size, padding=(
260
+ conv_kernel_size // 2), bias=not no_bias),
261
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
262
+ activation
263
+ )
264
+
265
+ linear_out = nn.Linear(inner_dim, dim_out, bias=not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out,
266
+ conv_kernel_size,
267
+ padding=(
268
+ conv_kernel_size // 2),
269
+ bias=not no_bias)
270
+
271
+ # init last linear layer to 0
272
+ if zero_init_output:
273
+ nn.init.zeros_(linear_out.weight)
274
+ if not no_bias:
275
+ nn.init.zeros_(linear_out.bias)
276
+
277
+ self.ff = nn.Sequential(
278
+ linear_in,
279
+ Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
280
+ linear_out,
281
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
282
+ )
283
+
284
+ def forward(self, x):
285
+ return self.ff(x)
286
+
287
+
288
+ class Attention(nn.Module):
289
+ def __init__(
290
+ self,
291
+ dim,
292
+ dim_heads=64,
293
+ dim_context=None,
294
+ causal=False,
295
+ zero_init_output=True,
296
+ qk_norm: Literal['l2', 'ln', 'none'] = 'none',
297
+ natten_kernel_size=None
298
+ ):
299
+ super().__init__()
300
+ self.dim = dim
301
+ self.dim_heads = dim_heads
302
+ self.causal = causal
303
+
304
+ dim_kv = dim_context if dim_context is not None else dim
305
+
306
+ self.num_heads = dim // dim_heads
307
+ self.kv_heads = dim_kv // dim_heads
308
+
309
+ if dim_context is not None:
310
+ self.to_q = nn.Linear(dim, dim, bias=False)
311
+ self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
312
+ else:
313
+ self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
314
+
315
+ self.to_out = nn.Linear(dim, dim, bias=False)
316
+
317
+ if zero_init_output:
318
+ nn.init.zeros_(self.to_out.weight)
319
+
320
+ self.qk_norm = qk_norm
321
+
322
+ if self.qk_norm == "ln":
323
+ self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
324
+ self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
325
+
326
+ # Using 1d neighborhood attention
327
+ self.natten_kernel_size = natten_kernel_size
328
+ if natten_kernel_size is not None:
329
+ return
330
+
331
+ self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
332
+
333
+ self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None
334
+ # pdb.set_trace()
335
+ self.use_fa_flash = False
336
+
337
+ self.sdp_kwargs = dict(
338
+ enable_flash=True,
339
+ enable_math=True,
340
+ enable_mem_efficient=True
341
+ )
342
+
343
+ def flash_attn(
344
+ self,
345
+ q,
346
+ k,
347
+ v,
348
+ mask=None,
349
+ causal=None
350
+ ):
351
+ batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device
352
+ kv_heads = k.shape[1]
353
+ # Recommended for multi-query single-key-value attention by Tri Dao
354
+ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
355
+
356
+ if heads != kv_heads:
357
+ # Repeat interleave kv_heads to match q_heads
358
+ heads_per_kv_head = heads // kv_heads
359
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim=1), (k, v))
360
+
361
+ if k.ndim == 3:
362
+ k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
363
+
364
+ if v.ndim == 3:
365
+ v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
366
+
367
+ causal = self.causal if causal is None else causal
368
+
369
+ if q_len == 1 and causal:
370
+ causal = False
371
+
372
+ if mask is not None:
373
+ assert mask.ndim == 4
374
+ mask = mask.expand(batch, heads, q_len, k_len)
375
+
376
+ assert causal
377
+ # handle kv cache - this should be bypassable in updated flash attention 2
378
+ if k_len > q_len and causal:
379
+ causal_mask = create_causal_mask(q_len, k_len, device=device)
380
+ if mask is None:
381
+ mask = ~causal_mask
382
+ else:
383
+ mask = mask & ~causal_mask
384
+ causal = False
385
+
386
+ # manually handle causal mask, if another mask was given
387
+
388
+ row_is_entirely_masked = None
389
+
390
+ if mask is not None and causal:
391
+ causal_mask = create_causal_mask(q_len, k_len, device=device)
392
+ mask = mask & ~causal_mask
393
+
394
+ # protect against an entire row being masked out
395
+
396
+ row_is_entirely_masked = ~mask.any(dim=-1)
397
+ mask[..., 0] = mask[..., 0] | row_is_entirely_masked
398
+
399
+ causal = False
400
+
401
+ with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
402
+ out = F.scaled_dot_product_attention(
403
+ q, k, v,
404
+ attn_mask=mask,
405
+ is_causal=causal
406
+ )
407
+
408
+ # for a row that is entirely masked out, should zero out the output of that row token
409
+
410
+ if row_is_entirely_masked is not None:
411
+ out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
412
+
413
+ return out
414
+
415
+ def forward(
416
+ self,
417
+ x,
418
+ context=None,
419
+ mask=None,
420
+ context_mask=None,
421
+ rotary_pos_emb=None,
422
+ causal=None
423
+ ):
424
+ h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
425
+
426
+ kv_input = context if has_context else x
427
+
428
+ if hasattr(self, 'to_q'):
429
+ # Use separate linear projections for q and k/v
430
+ q = self.to_q(x)
431
+ q = rearrange(q, 'b n (h d) -> b h n d', h=h)
432
+
433
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
434
+
435
+ k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=kv_h), (k, v))
436
+ else:
437
+ # Use fused linear projection
438
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
439
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
440
+
441
+ # Normalize q and k for cosine sim attention
442
+ if self.qk_norm == "l2":
443
+ q = F.normalize(q, dim=-1)
444
+ k = F.normalize(k, dim=-1)
445
+ elif self.qk_norm == "ln":
446
+ q = self.q_norm(q)
447
+ k = self.k_norm(k)
448
+
449
+ if rotary_pos_emb is not None and not has_context:
450
+ freqs, _ = rotary_pos_emb
451
+
452
+ q_dtype = q.dtype
453
+ k_dtype = k.dtype
454
+
455
+ q = q.to(torch.float32)
456
+ k = k.to(torch.float32)
457
+ freqs = freqs.to(torch.float32)
458
+
459
+ q = apply_rotary_pos_emb(q, freqs)
460
+ k = apply_rotary_pos_emb(k, freqs)
461
+
462
+ q = q.to(q_dtype)
463
+ k = k.to(k_dtype)
464
+
465
+ input_mask = context_mask
466
+
467
+ if input_mask is None and not has_context:
468
+ input_mask = mask
469
+
470
+ # determine masking
471
+ masks = []
472
+ final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
473
+
474
+ if input_mask is not None:
475
+ input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
476
+ masks.append(~input_mask)
477
+
478
+ # Other masks will be added here later
479
+
480
+ if len(masks) > 0:
481
+ final_attn_mask = ~or_reduce(masks)
482
+
483
+ n, device = q.shape[-2], q.device
484
+
485
+ causal = self.causal if causal is None else causal
486
+
487
+ if n == 1 and causal:
488
+ causal = False
489
+
490
+ if self.natten_kernel_size is not None:
491
+ if natten is None:
492
+ raise ImportError('natten not installed, please install natten to use neighborhood attention')
493
+
494
+ dtype_in = q.dtype
495
+ q, k, v = map(lambda t: t.to(torch.float32), (q, k, v))
496
+
497
+ attn = natten.functional.natten1dqk(q, k, kernel_size=self.natten_kernel_size, dilation=1)
498
+
499
+ if final_attn_mask is not None:
500
+ attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max)
501
+
502
+ attn = F.softmax(attn, dim=-1, dtype=torch.float32)
503
+
504
+ out = natten.functional.natten1dav(attn, v, kernel_size=self.natten_kernel_size, dilation=1).to(dtype_in)
505
+
506
+ # Prioritize Flash Attention 2
507
+ elif self.use_fa_flash:
508
+ assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2'
509
+ # Flash Attention 2 requires FP16 inputs
510
+ fa_dtype_in = q.dtype
511
+ q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v))
512
+
513
+ out = flash_attn_func(q, k, v, causal=causal)
514
+
515
+ out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
516
+
517
+ # Fall back to PyTorch implementation
518
+ elif self.use_pt_flash:
519
+ # causal=False
520
+ # final_attn_mask:[64, 1, 1, 348]
521
+ out = self.flash_attn(q, k, v, causal=True, mask=final_attn_mask)
522
+
523
+ else:
524
+ # Fall back to custom implementation
525
+
526
+ if h != kv_h:
527
+ # Repeat interleave kv_heads to match q_heads
528
+ heads_per_kv_head = h // kv_h
529
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim=1), (k, v))
530
+
531
+ scale = 1. / (q.shape[-1] ** 0.5)
532
+
533
+ kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
534
+
535
+ dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
536
+
537
+ i, j, dtype = *dots.shape[-2:], dots.dtype
538
+
539
+ mask_value = -torch.finfo(dots.dtype).max
540
+
541
+ if final_attn_mask is not None:
542
+ dots = dots.masked_fill(~final_attn_mask, mask_value)
543
+
544
+ if causal:
545
+ causal_mask = create_causal_mask(i, j, device=device)
546
+ dots = dots.masked_fill(causal_mask, mask_value)
547
+
548
+ attn = F.softmax(dots, dim=-1, dtype=torch.float32)
549
+ attn = attn.type(dtype)
550
+
551
+ out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
552
+
553
+ # merge heads
554
+ out = rearrange(out, ' b h n d -> b n (h d)')
555
+
556
+ # Communicate between heads
557
+
558
+ # with autocast(enabled = False):
559
+ # out_dtype = out.dtype
560
+ # out = out.to(torch.float32)
561
+ # out = self.to_out(out).to(out_dtype)
562
+ out = self.to_out(out)
563
+
564
+ if mask is not None:
565
+ mask = rearrange(mask, 'b n -> b n 1')
566
+ out = out.masked_fill(~mask, 0.)
567
+
568
+ return out
569
+
570
+
571
+ class ConformerModule(nn.Module):
572
+ def __init__(
573
+ self,
574
+ dim,
575
+ norm_kwargs={},
576
+ ):
577
+ super().__init__()
578
+
579
+ self.dim = dim
580
+
581
+ self.in_norm = LayerNorm(dim, **norm_kwargs)
582
+ self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
583
+ self.glu = GLU(dim, dim, nn.SiLU())
584
+ self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
585
+ self.mid_norm = LayerNorm(dim,
586
+ **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
587
+ self.swish = nn.SiLU()
588
+ self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
589
+
590
+ def forward(self, x):
591
+ x = self.in_norm(x)
592
+ x = rearrange(x, 'b n d -> b d n')
593
+ x = self.pointwise_conv(x)
594
+ x = rearrange(x, 'b d n -> b n d')
595
+ x = self.glu(x)
596
+ x = rearrange(x, 'b n d -> b d n')
597
+ x = self.depthwise_conv(x)
598
+ x = rearrange(x, 'b d n -> b n d')
599
+ x = self.mid_norm(x)
600
+ x = self.swish(x)
601
+ x = rearrange(x, 'b n d -> b d n')
602
+ x = self.pointwise_conv_2(x)
603
+ x = rearrange(x, 'b d n -> b n d')
604
+
605
+ return x
606
+
607
+
608
+ class TransformerBlock(nn.Module):
609
+ def __init__(
610
+ self,
611
+ dim,
612
+ dim_heads=64,
613
+ cross_attend=False,
614
+ dim_context=None,
615
+ global_cond_dim=None,
616
+ causal=False,
617
+ zero_init_branch_outputs=True,
618
+ conformer=False,
619
+ layer_ix=-1,
620
+ remove_norms=False,
621
+ attn_kwargs={},
622
+ ff_kwargs={},
623
+ norm_kwargs={}
624
+ ):
625
+
626
+ super().__init__()
627
+ self.dim = dim
628
+ self.dim_heads = dim_heads
629
+ self.cross_attend = cross_attend
630
+ self.dim_context = dim_context
631
+ self.causal = causal
632
+
633
+ self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
634
+
635
+ self.self_attn = Attention(
636
+ dim,
637
+ dim_heads=dim_heads,
638
+ causal=causal,
639
+ zero_init_output=zero_init_branch_outputs,
640
+ **attn_kwargs
641
+ )
642
+ ### 2. 主要是这边需要修改
643
+ if cross_attend:
644
+ self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
645
+ self.cross_attn = Attention(
646
+ dim,
647
+ dim_heads=dim_heads,
648
+ dim_context=dim_context,
649
+ causal=causal,
650
+ zero_init_output=zero_init_branch_outputs,
651
+ **attn_kwargs
652
+ )
653
+
654
+ self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
655
+ self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
656
+
657
+ self.layer_ix = layer_ix
658
+
659
+ self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
660
+
661
+ self.global_cond_dim = global_cond_dim
662
+
663
+ if global_cond_dim is not None:
664
+ self.to_scale_shift_gate = nn.Sequential(
665
+ nn.SiLU(),
666
+ nn.Linear(global_cond_dim, dim * 6, bias=False)
667
+ )
668
+
669
+ nn.init.zeros_(self.to_scale_shift_gate[1].weight)
670
+ # nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
671
+
672
+ def forward(
673
+ self,
674
+ x,
675
+ context=None,
676
+ global_cond=None,
677
+ mask=None,
678
+ context_mask=None,
679
+ rotary_pos_emb=None
680
+ ):
681
+ if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
682
+
683
+ scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(
684
+ global_cond).unsqueeze(1).chunk(6, dim=-1)
685
+
686
+ # self-attention with adaLN
687
+ residual = x
688
+ x = self.pre_norm(x)
689
+ x = x * (1 + scale_self) + shift_self
690
+ x = self.self_attn(x, mask=mask, rotary_pos_emb=rotary_pos_emb)
691
+ x = x * torch.sigmoid(1 - gate_self)
692
+ x = x + residual
693
+
694
+ if context is not None:
695
+ x = x + self.cross_attn(self.cross_attend_norm(x), context=context, context_mask=context_mask)
696
+
697
+ if self.conformer is not None:
698
+ x = x + self.conformer(x)
699
+
700
+ # feedforward with adaLN
701
+ residual = x
702
+ x = self.ff_norm(x)
703
+ x = x * (1 + scale_ff) + shift_ff
704
+ x = self.ff(x)
705
+ x = x * torch.sigmoid(1 - gate_ff)
706
+ x = x + residual
707
+
708
+ else:
709
+ x = x + self.self_attn(self.pre_norm(x), mask=mask, rotary_pos_emb=rotary_pos_emb)
710
+
711
+ if context is not None:
712
+ x = x + self.cross_attn(self.cross_attend_norm(x), context=context, context_mask=context_mask)
713
+
714
+ if self.conformer is not None:
715
+ x = x + self.conformer(x)
716
+
717
+ x = x + self.ff(self.ff_norm(x))
718
+
719
+ return x
720
+
721
+
722
+ class ContinuousTransformer(nn.Module):
723
+ def __init__(
724
+ self,
725
+ dim,
726
+ depth,
727
+ *,
728
+ dim_in=None,
729
+ dim_out=None,
730
+ dim_heads=64,
731
+ cross_attend=False,
732
+ cond_token_dim=None,
733
+ global_cond_dim=None,
734
+ causal=False,
735
+ rotary_pos_emb=True,
736
+ zero_init_branch_outputs=True,
737
+ conformer=False,
738
+ use_sinusoidal_emb=False,
739
+ use_abs_pos_emb=False,
740
+ abs_pos_emb_max_length=10000,
741
+ **kwargs
742
+ ):
743
+
744
+ super().__init__()
745
+
746
+ self.dim = dim
747
+ self.depth = depth
748
+ self.causal = causal
749
+ self.layers = nn.ModuleList([])
750
+
751
+ self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
752
+ self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
753
+
754
+ if rotary_pos_emb:
755
+ self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
756
+ else:
757
+ self.rotary_pos_emb = None
758
+
759
+ self.use_sinusoidal_emb = use_sinusoidal_emb
760
+ if use_sinusoidal_emb:
761
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
762
+
763
+ self.use_abs_pos_emb = use_abs_pos_emb
764
+ if use_abs_pos_emb:
765
+ self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
766
+
767
+ for i in range(depth):
768
+ self.layers.append(
769
+ TransformerBlock(
770
+ dim,
771
+ dim_heads=dim_heads,
772
+ cross_attend=cross_attend,
773
+ dim_context=cond_token_dim,
774
+ global_cond_dim=global_cond_dim,
775
+ causal=causal,
776
+ zero_init_branch_outputs=zero_init_branch_outputs,
777
+ conformer=conformer,
778
+ layer_ix=i,
779
+ **kwargs
780
+ )
781
+ )
782
+
783
+ def forward(
784
+ self,
785
+ x,
786
+ mask=None,
787
+ prepend_embeds=None,
788
+ prepend_mask=None,
789
+ global_cond=None,
790
+ return_info=False,
791
+ **kwargs
792
+ ):
793
+ batch, seq, device = *x.shape[:2], x.device
794
+
795
+ info = {
796
+ "hidden_states": [],
797
+ }
798
+
799
+ x = self.project_in(x)
800
+ if prepend_embeds is not None:
801
+ prepend_length, prepend_dim = prepend_embeds.shape[1:]
802
+
803
+ assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
804
+
805
+ x = torch.cat((prepend_embeds, x), dim=-2)
806
+
807
+ if prepend_mask is not None or mask is not None:
808
+ mask = mask if mask is not None else torch.ones((batch, seq), device=device, dtype=torch.bool)
809
+ prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length),
810
+ device=device, dtype=torch.bool)
811
+
812
+ mask = torch.cat((prepend_mask, mask), dim=-1)
813
+
814
+ # Attention layers
815
+
816
+ if self.rotary_pos_emb is not None:
817
+ rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
818
+ else:
819
+ rotary_pos_emb = None
820
+
821
+ if self.use_sinusoidal_emb or self.use_abs_pos_emb:
822
+ x = x + self.pos_emb(x)
823
+
824
+ # Iterate over the transformer layers
825
+ mask = self.refine_mask(mask)
826
+ for layer in self.layers:
827
+ # x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
828
+ # pdb.set_trace()
829
+ x = checkpoint(layer, x, mask=mask.bool(), rotary_pos_emb=rotary_pos_emb, global_cond=global_cond, **kwargs)
830
+
831
+ if return_info:
832
+ info["hidden_states"].append(x)
833
+
834
+ x = self.project_out(x)
835
+
836
+ if return_info:
837
+ return x, info
838
+
839
+ return x
840
+
841
+ def refine_mask(self, mask):
842
+ return mask
843
+ # pdb.set_trace()
844
+ # mask = 1 - torch.triu(torch.ones(seq_length, seq_length), diagonal=1)
845
+ # return mask
cosyvoice/hifigan/__pycache__/f0_predictor.cpython-310.pyc ADDED
Binary file (1.34 kB). View file
 
cosyvoice/hifigan/__pycache__/generator.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
cosyvoice/hifigan/f0_predictor.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn.utils import weight_norm
17
+
18
+
19
+ class ConvRNNF0Predictor(nn.Module):
20
+ def __init__(self,
21
+ num_class: int = 1,
22
+ in_channels: int = 80,
23
+ cond_channels: int = 512
24
+ ):
25
+ super().__init__()
26
+
27
+ self.num_class = num_class
28
+ self.condnet = nn.Sequential(
29
+ weight_norm(
30
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
31
+ ),
32
+ nn.ELU(),
33
+ weight_norm(
34
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
35
+ ),
36
+ nn.ELU(),
37
+ weight_norm(
38
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
39
+ ),
40
+ nn.ELU(),
41
+ weight_norm(
42
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
43
+ ),
44
+ nn.ELU(),
45
+ weight_norm(
46
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
47
+ ),
48
+ nn.ELU(),
49
+ )
50
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ x = self.condnet(x)
54
+ x = x.transpose(1, 2)
55
+ return torch.abs(self.classifier(x).squeeze(-1))
cosyvoice/hifigan/generator.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HIFI-GAN"""
16
+
17
+ import typing as tp
18
+ import numpy as np
19
+ from scipy.signal import get_window
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.nn import Conv1d
24
+ from torch.nn import ConvTranspose1d
25
+ from torch.nn.utils import remove_weight_norm
26
+ from torch.nn.utils import weight_norm
27
+ from torch.distributions.uniform import Uniform
28
+
29
+ from cosyvoice.transformer.activation import Snake
30
+ from cosyvoice.utils.common import get_padding
31
+ from cosyvoice.utils.common import init_weights
32
+
33
+
34
+ """hifigan based generator implementation.
35
+
36
+ This code is modified from https://github.com/jik876/hifi-gan
37
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
38
+ https://github.com/NVIDIA/BigVGAN
39
+
40
+ """
41
+
42
+
43
+ class ResBlock(torch.nn.Module):
44
+ """Residual block module in HiFiGAN/BigVGAN."""
45
+ def __init__(
46
+ self,
47
+ channels: int = 512,
48
+ kernel_size: int = 3,
49
+ dilations: tp.List[int] = [1, 3, 5],
50
+ ):
51
+ super(ResBlock, self).__init__()
52
+ self.convs1 = nn.ModuleList()
53
+ self.convs2 = nn.ModuleList()
54
+
55
+ for dilation in dilations:
56
+ self.convs1.append(
57
+ weight_norm(
58
+ Conv1d(
59
+ channels,
60
+ channels,
61
+ kernel_size,
62
+ 1,
63
+ dilation=dilation,
64
+ padding=get_padding(kernel_size, dilation)
65
+ )
66
+ )
67
+ )
68
+ self.convs2.append(
69
+ weight_norm(
70
+ Conv1d(
71
+ channels,
72
+ channels,
73
+ kernel_size,
74
+ 1,
75
+ dilation=1,
76
+ padding=get_padding(kernel_size, 1)
77
+ )
78
+ )
79
+ )
80
+ self.convs1.apply(init_weights)
81
+ self.convs2.apply(init_weights)
82
+ self.activations1 = nn.ModuleList([
83
+ Snake(channels, alpha_logscale=False)
84
+ for _ in range(len(self.convs1))
85
+ ])
86
+ self.activations2 = nn.ModuleList([
87
+ Snake(channels, alpha_logscale=False)
88
+ for _ in range(len(self.convs2))
89
+ ])
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ for idx in range(len(self.convs1)):
93
+ xt = self.activations1[idx](x)
94
+ xt = self.convs1[idx](xt)
95
+ xt = self.activations2[idx](xt)
96
+ xt = self.convs2[idx](xt)
97
+ x = xt + x
98
+ return x
99
+
100
+ def remove_weight_norm(self):
101
+ for idx in range(len(self.convs1)):
102
+ remove_weight_norm(self.convs1[idx])
103
+ remove_weight_norm(self.convs2[idx])
104
+
105
+
106
+ class SineGen(torch.nn.Module):
107
+ """ Definition of sine generator
108
+ SineGen(samp_rate, harmonic_num = 0,
109
+ sine_amp = 0.1, noise_std = 0.003,
110
+ voiced_threshold = 0,
111
+ flag_for_pulse=False)
112
+ samp_rate: sampling rate in Hz
113
+ harmonic_num: number of harmonic overtones (default 0)
114
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
115
+ noise_std: std of Gaussian noise (default 0.003)
116
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
117
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
118
+ Note: when flag_for_pulse is True, the first time step of a voiced
119
+ segment is always sin(np.pi) or cos(0)
120
+ """
121
+
122
+ def __init__(self, samp_rate, harmonic_num=0,
123
+ sine_amp=0.1, noise_std=0.003,
124
+ voiced_threshold=0):
125
+ super(SineGen, self).__init__()
126
+ self.sine_amp = sine_amp
127
+ self.noise_std = noise_std
128
+ self.harmonic_num = harmonic_num
129
+ self.sampling_rate = samp_rate
130
+ self.voiced_threshold = voiced_threshold
131
+
132
+ def _f02uv(self, f0):
133
+ # generate uv signal
134
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
135
+ return uv
136
+
137
+ @torch.no_grad()
138
+ def forward(self, f0):
139
+ """
140
+ :param f0: [B, 1, sample_len], Hz
141
+ :return: [B, 1, sample_len]
142
+ """
143
+
144
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
145
+ for i in range(self.harmonic_num + 1):
146
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
147
+
148
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
149
+ u_dist = Uniform(low=-np.pi, high=np.pi)
150
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
151
+ phase_vec[:, 0, :] = 0
152
+
153
+ # generate sine waveforms
154
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
155
+
156
+ # generate uv signal
157
+ uv = self._f02uv(f0)
158
+
159
+ # noise: for unvoiced should be similar to sine_amp
160
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
161
+ # . for voiced regions is self.noise_std
162
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
163
+ noise = noise_amp * torch.randn_like(sine_waves)
164
+
165
+ # first: set the unvoiced part to 0 by uv
166
+ # then: additive noise
167
+ sine_waves = sine_waves * uv + noise
168
+ return sine_waves, uv, noise
169
+
170
+
171
+ class SourceModuleHnNSF(torch.nn.Module):
172
+ """ SourceModule for hn-nsf
173
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
174
+ add_noise_std=0.003, voiced_threshod=0)
175
+ sampling_rate: sampling_rate in Hz
176
+ harmonic_num: number of harmonic above F0 (default: 0)
177
+ sine_amp: amplitude of sine source signal (default: 0.1)
178
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
179
+ note that amplitude of noise in unvoiced is decided
180
+ by sine_amp
181
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
182
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
183
+ F0_sampled (batchsize, length, 1)
184
+ Sine_source (batchsize, length, 1)
185
+ noise_source (batchsize, length 1)
186
+ uv (batchsize, length, 1)
187
+ """
188
+
189
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
190
+ add_noise_std=0.003, voiced_threshod=0):
191
+ super(SourceModuleHnNSF, self).__init__()
192
+
193
+ self.sine_amp = sine_amp
194
+ self.noise_std = add_noise_std
195
+
196
+ # to produce sine waveforms
197
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
198
+ sine_amp, add_noise_std, voiced_threshod)
199
+
200
+ # to merge source harmonics into a single excitation
201
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
202
+ self.l_tanh = torch.nn.Tanh()
203
+
204
+ def forward(self, x):
205
+ """
206
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
207
+ F0_sampled (batchsize, length, 1)
208
+ Sine_source (batchsize, length, 1)
209
+ noise_source (batchsize, length 1)
210
+ """
211
+ # source for harmonic branch
212
+ with torch.no_grad():
213
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
214
+ sine_wavs = sine_wavs.transpose(1, 2)
215
+ uv = uv.transpose(1, 2)
216
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
217
+
218
+ # source for noise branch, in the same shape as uv
219
+ noise = torch.randn_like(uv) * self.sine_amp / 3
220
+ return sine_merge, noise, uv
221
+
222
+
223
+ class HiFTGenerator(nn.Module):
224
+ """
225
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
226
+ https://arxiv.org/abs/2309.09493
227
+ """
228
+ def __init__(
229
+ self,
230
+ in_channels: int = 80,
231
+ base_channels: int = 512,
232
+ nb_harmonics: int = 8,
233
+ sampling_rate: int = 22050,
234
+ nsf_alpha: float = 0.1,
235
+ nsf_sigma: float = 0.003,
236
+ nsf_voiced_threshold: float = 10,
237
+ upsample_rates: tp.List[int] = [8, 8],
238
+ upsample_kernel_sizes: tp.List[int] = [16, 16],
239
+ istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
240
+ resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
241
+ resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
242
+ source_resblock_kernel_sizes: tp.List[int] = [7, 11],
243
+ source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
244
+ lrelu_slope: float = 0.1,
245
+ audio_limit: float = 0.99,
246
+ f0_predictor: torch.nn.Module = None,
247
+ ):
248
+ super(HiFTGenerator, self).__init__()
249
+
250
+ self.out_channels = 1
251
+ self.nb_harmonics = nb_harmonics
252
+ self.sampling_rate = sampling_rate
253
+ self.istft_params = istft_params
254
+ self.lrelu_slope = lrelu_slope
255
+ self.audio_limit = audio_limit
256
+
257
+ self.num_kernels = len(resblock_kernel_sizes)
258
+ self.num_upsamples = len(upsample_rates)
259
+ self.m_source = SourceModuleHnNSF(
260
+ sampling_rate=sampling_rate,
261
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
262
+ harmonic_num=nb_harmonics,
263
+ sine_amp=nsf_alpha,
264
+ add_noise_std=nsf_sigma,
265
+ voiced_threshod=nsf_voiced_threshold)
266
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
267
+
268
+ self.conv_pre = weight_norm(
269
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
270
+ )
271
+
272
+ # Up
273
+ self.ups = nn.ModuleList()
274
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
275
+ self.ups.append(
276
+ weight_norm(
277
+ ConvTranspose1d(
278
+ base_channels // (2**i),
279
+ base_channels // (2**(i + 1)),
280
+ k,
281
+ u,
282
+ padding=(k - u) // 2,
283
+ )
284
+ )
285
+ )
286
+
287
+ # Down
288
+ self.source_downs = nn.ModuleList()
289
+ self.source_resblocks = nn.ModuleList()
290
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
291
+ downsample_cum_rates = np.cumprod(downsample_rates)
292
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
293
+ if u == 1:
294
+ self.source_downs.append(
295
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
296
+ )
297
+ else:
298
+ self.source_downs.append(
299
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
300
+ )
301
+
302
+ self.source_resblocks.append(
303
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
304
+ )
305
+
306
+ self.resblocks = nn.ModuleList()
307
+ for i in range(len(self.ups)):
308
+ ch = base_channels // (2**(i + 1))
309
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
310
+ self.resblocks.append(ResBlock(ch, k, d))
311
+
312
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
313
+ self.ups.apply(init_weights)
314
+ self.conv_post.apply(init_weights)
315
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
316
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
317
+ self.f0_predictor = f0_predictor
318
+
319
+ def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
320
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
321
+
322
+ har_source, _, _ = self.m_source(f0)
323
+ return har_source.transpose(1, 2)
324
+
325
+ def _stft(self, x):
326
+ spec = torch.stft(
327
+ x,
328
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
329
+ return_complex=True)
330
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
331
+ return spec[..., 0], spec[..., 1]
332
+
333
+ def _istft(self, magnitude, phase):
334
+ magnitude = torch.clip(magnitude, max=1e2)
335
+ real = magnitude * torch.cos(phase)
336
+ img = magnitude * torch.sin(phase)
337
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
338
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
339
+ return inverse_transform
340
+
341
+ def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
342
+ f0 = self.f0_predictor(x)
343
+ s = self._f02source(f0)
344
+
345
+ # use cache_source to avoid glitch
346
+ if cache_source.shape[2] != 0:
347
+ s[:, :, :cache_source.shape[2]] = cache_source
348
+
349
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
350
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
351
+
352
+ x = self.conv_pre(x)
353
+ for i in range(self.num_upsamples):
354
+ x = F.leaky_relu(x, self.lrelu_slope)
355
+ x = self.ups[i](x)
356
+
357
+ if i == self.num_upsamples - 1:
358
+ x = self.reflection_pad(x)
359
+
360
+ # fusion
361
+ si = self.source_downs[i](s_stft)
362
+ si = self.source_resblocks[i](si)
363
+ x = x + si
364
+
365
+ xs = None
366
+ for j in range(self.num_kernels):
367
+ if xs is None:
368
+ xs = self.resblocks[i * self.num_kernels + j](x)
369
+ else:
370
+ xs += self.resblocks[i * self.num_kernels + j](x)
371
+ x = xs / self.num_kernels
372
+
373
+ x = F.leaky_relu(x)
374
+ x = self.conv_post(x)
375
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
376
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
377
+
378
+ x = self._istft(magnitude, phase)
379
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
380
+ return x, s
381
+
382
+ def remove_weight_norm(self):
383
+ print('Removing weight norm...')
384
+ for l in self.ups:
385
+ remove_weight_norm(l)
386
+ for l in self.resblocks:
387
+ l.remove_weight_norm()
388
+ remove_weight_norm(self.conv_pre)
389
+ remove_weight_norm(self.conv_post)
390
+ self.source_module.remove_weight_norm()
391
+ for l in self.source_downs:
392
+ remove_weight_norm(l)
393
+ for l in self.source_resblocks:
394
+ l.remove_weight_norm()
395
+
396
+ @torch.inference_mode()
397
+ def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
398
+ return self.forward(x=mel, cache_source=cache_source)
cosyvoice/llm/__pycache__/llm.cpython-310.pyc ADDED
Binary file (6.29 kB). View file
 
cosyvoice/llm/llm.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Union
15
+ import torch
16
+ from torch import nn
17
+ import torch.nn.functional as F
18
+ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
19
+ from cosyvoice.utils.common import IGNORE_ID
20
+ from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
21
+ from cosyvoice.utils.common import th_accuracy
22
+
23
+
24
+ class TransformerLM(torch.nn.Module):
25
+ def __init__(
26
+ self,
27
+ text_encoder_input_size: int,
28
+ llm_input_size: int,
29
+ llm_output_size: int,
30
+ text_token_size: int,
31
+ speech_token_size: int,
32
+ text_encoder: torch.nn.Module,
33
+ llm: torch.nn.Module,
34
+ length_normalized_loss: bool = True,
35
+ lsm_weight: float = 0.0,
36
+ spk_embed_dim: int = 192,
37
+ ):
38
+ super().__init__()
39
+ self.llm_input_size = llm_input_size
40
+ self.speech_token_size = speech_token_size
41
+ # 1. build text token inputs related modules
42
+ self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
43
+ self.text_encoder = text_encoder
44
+ self.text_encoder_affine_layer = nn.Linear(
45
+ self.text_encoder.output_size(),
46
+ llm_input_size
47
+ )
48
+
49
+ # 2. build speech token language model related modules
50
+ self.sos_eos = 0
51
+ self.task_id = 1
52
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
53
+ self.llm = llm
54
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
55
+ self.criterion_ce = LabelSmoothingLoss(
56
+ size=speech_token_size + 1,
57
+ padding_idx=IGNORE_ID,
58
+ smoothing=lsm_weight,
59
+ normalize_length=length_normalized_loss,
60
+ )
61
+
62
+ # 3. [Optional] build speech token related modules
63
+ self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
64
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
65
+
66
+ def encode(
67
+ self,
68
+ text: torch.Tensor,
69
+ text_lengths: torch.Tensor,
70
+ ):
71
+ encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
72
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
73
+ encoder_out = self.text_encoder_affine_layer(encoder_out)
74
+ return encoder_out, encoder_out_lens
75
+
76
+ def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
77
+ text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
78
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
79
+ lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) for i in range(len(text_token))]
80
+ lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
81
+ lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
82
+ return lm_input, lm_input_len
83
+
84
+ def forward(
85
+ self,
86
+ batch: dict,
87
+ device: torch.device,
88
+ ) -> Dict[str, Optional[torch.Tensor]]:
89
+ """
90
+ Args:
91
+ text: (B, L, D)
92
+ text_lengths: (B,)
93
+ audio: (B, T, N) or (B, T)
94
+ audio_lengths: (B,)
95
+ """
96
+ text_token = batch['text_token'].to(device)
97
+ text_token_len = batch['text_token_len'].to(device)
98
+ speech_token = batch['speech_token'].to(device)
99
+ speech_token_len = batch['speech_token_len'].to(device)
100
+ embedding = batch['embedding'].to(device)
101
+
102
+ # 1. prepare llm_target
103
+ lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))]
104
+ lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
105
+
106
+ # 1. encode text_token
107
+ text_token = self.text_embedding(text_token)
108
+ text_token, text_token_len = self.encode(text_token, text_token_len)
109
+
110
+ # 2. embedding projection
111
+ embedding = F.normalize(embedding, dim=1)
112
+ embedding = self.spk_embed_affine_layer(embedding)
113
+ embedding = embedding.unsqueeze(1)
114
+
115
+ # 3. eos and task_id
116
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
117
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
118
+
119
+ # 4. encode speech_token
120
+ speech_token = self.speech_embedding(speech_token)
121
+
122
+ # 5. unpad and pad
123
+ lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len)
124
+
125
+ # 6. run lm forward
126
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
127
+ logits = self.llm_decoder(lm_output)
128
+ loss = self.criterion_ce(logits, lm_target)
129
+ acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
130
+ return {'loss': loss, 'acc': acc}
131
+
132
+ def sampling_ids(
133
+ self,
134
+ weighted_scores: torch.Tensor,
135
+ sampling: Union[bool, int, float] = True,
136
+ beam_size: int = 1,
137
+ ignore_eos: bool = True,
138
+ ):
139
+ while True:
140
+ prob, indices = weighted_scores.softmax(dim=-1).topk(sampling)
141
+ top_ids = prob.multinomial(beam_size, replacement=True)
142
+ top_ids = indices[top_ids]
143
+ if (not ignore_eos) or (self.speech_token_size not in top_ids):
144
+ break
145
+ return top_ids
146
+
147
+ @torch.inference_mode()
148
+ def inference(
149
+ self,
150
+ text: torch.Tensor,
151
+ text_len: torch.Tensor,
152
+ prompt_text: torch.Tensor,
153
+ prompt_text_len: torch.Tensor,
154
+ prompt_speech_token: torch.Tensor,
155
+ prompt_speech_token_len: torch.Tensor,
156
+ embedding: torch.Tensor,
157
+ beam_size: int = 1,
158
+ sampling: int = 25,
159
+ max_token_text_ratio: float = 20,
160
+ min_token_text_ratio: float = 2,
161
+ ) -> torch.Tensor:
162
+ device = text.device
163
+ text = torch.concat([prompt_text, text], dim=1)
164
+ text_len += prompt_text_len
165
+ text = self.text_embedding(text)
166
+
167
+ # 1. encode text
168
+ text, text_len = self.encode(text, text_len)
169
+
170
+ # 2. encode embedding
171
+ if embedding.shape[0] != 0:
172
+ embedding = F.normalize(embedding, dim=1)
173
+ embedding = self.spk_embed_affine_layer(embedding)
174
+ embedding = embedding.unsqueeze(dim=1)
175
+ else:
176
+ embedding = torch.zeros(1, 0, self.llm_input_size).to(device)
177
+
178
+ # 3. concat llm_input
179
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
180
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
181
+ if prompt_speech_token_len != 0:
182
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
183
+ else:
184
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size).to(device)
185
+ lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
186
+
187
+ # 4. cal min/max_length
188
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
189
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
190
+
191
+ # 5. step by step decode
192
+ out_tokens = []
193
+ offset = 0
194
+ att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
195
+ for i in range(max_len):
196
+ y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
197
+ att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
198
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
199
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
200
+ if top_ids == self.speech_token_size:
201
+ break
202
+ out_tokens.append(top_ids)
203
+ offset += lm_input.size(1)
204
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
205
+
206
+ return torch.tensor([out_tokens], dtype=torch.int64, device=device)
cosyvoice/transformer/__init__.py ADDED
File without changes
cosyvoice/transformer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (144 Bytes). View file
 
cosyvoice/transformer/__pycache__/activation.cpython-310.pyc ADDED
Binary file (2.47 kB). View file
 
cosyvoice/transformer/__pycache__/attention.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
cosyvoice/transformer/__pycache__/convolution.cpython-310.pyc ADDED
Binary file (3.05 kB). View file
 
cosyvoice/transformer/__pycache__/embedding.cpython-310.pyc ADDED
Binary file (9.44 kB). View file
 
cosyvoice/transformer/__pycache__/encoder.cpython-310.pyc ADDED
Binary file (16.7 kB). View file
 
cosyvoice/transformer/__pycache__/encoder_layer.cpython-310.pyc ADDED
Binary file (7.32 kB). View file
 
cosyvoice/transformer/__pycache__/label_smoothing_loss.cpython-310.pyc ADDED
Binary file (2.87 kB). View file
 
cosyvoice/transformer/__pycache__/positionwise_feed_forward.cpython-310.pyc ADDED
Binary file (3.78 kB). View file
 
cosyvoice/transformer/__pycache__/subsampling.cpython-310.pyc ADDED
Binary file (9.82 kB). View file
 
cosyvoice/transformer/activation.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # 2020 Northwestern Polytechnical University (Pengcheng Guo)
3
+ # 2020 Mobvoi Inc (Binbin Zhang)
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Swish() activation function for Conformer."""
18
+
19
+ import torch
20
+ from torch import nn, sin, pow
21
+ from torch.nn import Parameter
22
+
23
+
24
+ class Swish(torch.nn.Module):
25
+ """Construct an Swish object."""
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ """Return Swish activation function."""
29
+ return x * torch.sigmoid(x)
30
+
31
+
32
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
33
+ # LICENSE is in incl_licenses directory.
34
+ class Snake(nn.Module):
35
+ '''
36
+ Implementation of a sine-based periodic activation function
37
+ Shape:
38
+ - Input: (B, C, T)
39
+ - Output: (B, C, T), same shape as the input
40
+ Parameters:
41
+ - alpha - trainable parameter
42
+ References:
43
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
44
+ https://arxiv.org/abs/2006.08195
45
+ Examples:
46
+ >>> a1 = snake(256)
47
+ >>> x = torch.randn(256)
48
+ >>> x = a1(x)
49
+ '''
50
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
51
+ '''
52
+ Initialization.
53
+ INPUT:
54
+ - in_features: shape of the input
55
+ - alpha: trainable parameter
56
+ alpha is initialized to 1 by default, higher values = higher-frequency.
57
+ alpha will be trained along with the rest of your model.
58
+ '''
59
+ super(Snake, self).__init__()
60
+ self.in_features = in_features
61
+
62
+ # initialize alpha
63
+ self.alpha_logscale = alpha_logscale
64
+ if self.alpha_logscale: # log scale alphas initialized to zeros
65
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
66
+ else: # linear scale alphas initialized to ones
67
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
68
+
69
+ self.alpha.requires_grad = alpha_trainable
70
+
71
+ self.no_div_by_zero = 0.000000001
72
+
73
+ def forward(self, x):
74
+ '''
75
+ Forward pass of the function.
76
+ Applies the function to the input elementwise.
77
+ Snake ∶= x + 1/a * sin^2 (xa)
78
+ '''
79
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
80
+ if self.alpha_logscale:
81
+ alpha = torch.exp(alpha)
82
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
83
+
84
+ return x
cosyvoice/transformer/attention.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Multi-Head Attention layer definition."""
18
+
19
+ import math
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+
26
+ class MultiHeadedAttention(nn.Module):
27
+ """Multi-Head Attention layer.
28
+
29
+ Args:
30
+ n_head (int): The number of heads.
31
+ n_feat (int): The number of features.
32
+ dropout_rate (float): Dropout rate.
33
+
34
+ """
35
+
36
+ def __init__(self,
37
+ n_head: int,
38
+ n_feat: int,
39
+ dropout_rate: float,
40
+ key_bias: bool = True):
41
+ """Construct an MultiHeadedAttention object."""
42
+ super().__init__()
43
+ assert n_feat % n_head == 0
44
+ # We assume d_v always equals d_k
45
+ self.d_k = n_feat // n_head
46
+ self.h = n_head
47
+ self.linear_q = nn.Linear(n_feat, n_feat)
48
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
49
+ self.linear_v = nn.Linear(n_feat, n_feat)
50
+ self.linear_out = nn.Linear(n_feat, n_feat)
51
+ self.dropout = nn.Dropout(p=dropout_rate)
52
+
53
+ def forward_qkv(
54
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
55
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
56
+ """Transform query, key and value.
57
+
58
+ Args:
59
+ query (torch.Tensor): Query tensor (#batch, time1, size).
60
+ key (torch.Tensor): Key tensor (#batch, time2, size).
61
+ value (torch.Tensor): Value tensor (#batch, time2, size).
62
+
63
+ Returns:
64
+ torch.Tensor: Transformed query tensor, size
65
+ (#batch, n_head, time1, d_k).
66
+ torch.Tensor: Transformed key tensor, size
67
+ (#batch, n_head, time2, d_k).
68
+ torch.Tensor: Transformed value tensor, size
69
+ (#batch, n_head, time2, d_k).
70
+
71
+ """
72
+ n_batch = query.size(0)
73
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
74
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
75
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
76
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
77
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
78
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
79
+
80
+ return q, k, v
81
+
82
+ def forward_attention(
83
+ self,
84
+ value: torch.Tensor,
85
+ scores: torch.Tensor,
86
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
87
+ ) -> torch.Tensor:
88
+ """Compute attention context vector.
89
+
90
+ Args:
91
+ value (torch.Tensor): Transformed value, size
92
+ (#batch, n_head, time2, d_k).
93
+ scores (torch.Tensor): Attention score, size
94
+ (#batch, n_head, time1, time2).
95
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
96
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
97
+
98
+ Returns:
99
+ torch.Tensor: Transformed value (#batch, time1, d_model)
100
+ weighted by the attention score (#batch, time1, time2).
101
+
102
+ """
103
+ n_batch = value.size(0)
104
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
105
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
106
+ # 1st chunk to ease the onnx export.]
107
+ # 2. pytorch training
108
+ if mask.size(2) > 0: # time2 > 0
109
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
110
+ # For last chunk, time2 might be larger than scores.size(-1)
111
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
112
+ scores = scores.masked_fill(mask, -float('inf'))
113
+ attn = torch.softmax(scores, dim=-1).masked_fill(
114
+ mask, 0.0) # (batch, head, time1, time2)
115
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
116
+ # 1. onnx(16/-1, -1/-1, 16/0)
117
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
118
+ else:
119
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
120
+
121
+ p_attn = self.dropout(attn)
122
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
123
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
124
+ self.h * self.d_k)
125
+ ) # (batch, time1, d_model)
126
+
127
+ return self.linear_out(x) # (batch, time1, d_model)
128
+
129
+ def forward(
130
+ self,
131
+ query: torch.Tensor,
132
+ key: torch.Tensor,
133
+ value: torch.Tensor,
134
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
135
+ pos_emb: torch.Tensor = torch.empty(0),
136
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
137
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
138
+ """Compute scaled dot product attention.
139
+
140
+ Args:
141
+ query (torch.Tensor): Query tensor (#batch, time1, size).
142
+ key (torch.Tensor): Key tensor (#batch, time2, size).
143
+ value (torch.Tensor): Value tensor (#batch, time2, size).
144
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
145
+ (#batch, time1, time2).
146
+ 1.When applying cross attention between decoder and encoder,
147
+ the batch padding mask for input is in (#batch, 1, T) shape.
148
+ 2.When applying self attention of encoder,
149
+ the mask is in (#batch, T, T) shape.
150
+ 3.When applying self attention of decoder,
151
+ the mask is in (#batch, L, L) shape.
152
+ 4.If the different position in decoder see different block
153
+ of the encoder, such as Mocha, the passed in mask could be
154
+ in (#batch, L, T) shape. But there is no such case in current
155
+ CosyVoice.
156
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
157
+ where `cache_t == chunk_size * num_decoding_left_chunks`
158
+ and `head * d_k == size`
159
+
160
+
161
+ Returns:
162
+ torch.Tensor: Output tensor (#batch, time1, d_model).
163
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
164
+ where `cache_t == chunk_size * num_decoding_left_chunks`
165
+ and `head * d_k == size`
166
+
167
+ """
168
+ q, k, v = self.forward_qkv(query, key, value)
169
+
170
+ # NOTE(xcsong):
171
+ # when export onnx model, for 1st chunk, we feed
172
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
173
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
174
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
175
+ # and we will always do splitting and
176
+ # concatnation(this will simplify onnx export). Note that
177
+ # it's OK to concat & split zero-shaped tensors(see code below).
178
+ # when export jit model, for 1st chunk, we always feed
179
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
180
+ # >>> a = torch.ones((1, 2, 0, 4))
181
+ # >>> b = torch.ones((1, 2, 3, 4))
182
+ # >>> c = torch.cat((a, b), dim=2)
183
+ # >>> torch.equal(b, c) # True
184
+ # >>> d = torch.split(a, 2, dim=-1)
185
+ # >>> torch.equal(d[0], d[1]) # True
186
+ if cache.size(0) > 0:
187
+ key_cache, value_cache = torch.split(cache,
188
+ cache.size(-1) // 2,
189
+ dim=-1)
190
+ k = torch.cat([key_cache, k], dim=2)
191
+ v = torch.cat([value_cache, v], dim=2)
192
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
193
+ # non-trivial to calculate `next_cache_start` here.
194
+ new_cache = torch.cat((k, v), dim=-1)
195
+
196
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
197
+ return self.forward_attention(v, scores, mask), new_cache
198
+
199
+
200
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
201
+ """Multi-Head Attention layer with relative position encoding.
202
+ Paper: https://arxiv.org/abs/1901.02860
203
+ Args:
204
+ n_head (int): The number of heads.
205
+ n_feat (int): The number of features.
206
+ dropout_rate (float): Dropout rate.
207
+ """
208
+
209
+ def __init__(self,
210
+ n_head: int,
211
+ n_feat: int,
212
+ dropout_rate: float,
213
+ key_bias: bool = True):
214
+ """Construct an RelPositionMultiHeadedAttention object."""
215
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
216
+ # linear transformation for positional encoding
217
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
218
+ # these two learnable bias are used in matrix c and matrix d
219
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
220
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
221
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
222
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
223
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
224
+
225
+ def rel_shift(self, x):
226
+ """Compute relative positional encoding.
227
+
228
+ Args:
229
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
230
+ time1 means the length of query vector.
231
+
232
+ Returns:
233
+ torch.Tensor: Output tensor.
234
+
235
+ """
236
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
237
+ x_padded = torch.cat([zero_pad, x], dim=-1)
238
+
239
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
240
+ x = x_padded[:, :, 1:].view_as(x)[
241
+ :, :, :, : x.size(-1) // 2 + 1
242
+ ] # only keep the positions from 0 to time2
243
+ return x
244
+
245
+ def forward(
246
+ self,
247
+ query: torch.Tensor,
248
+ key: torch.Tensor,
249
+ value: torch.Tensor,
250
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
251
+ pos_emb: torch.Tensor = torch.empty(0),
252
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
253
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
254
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
255
+ Args:
256
+ query (torch.Tensor): Query tensor (#batch, time1, size).
257
+ key (torch.Tensor): Key tensor (#batch, time2, size).
258
+ value (torch.Tensor): Value tensor (#batch, time2, size).
259
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
260
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
261
+ pos_emb (torch.Tensor): Positional embedding tensor
262
+ (#batch, time2, size).
263
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
264
+ where `cache_t == chunk_size * num_decoding_left_chunks`
265
+ and `head * d_k == size`
266
+ Returns:
267
+ torch.Tensor: Output tensor (#batch, time1, d_model).
268
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
269
+ where `cache_t == chunk_size * num_decoding_left_chunks`
270
+ and `head * d_k == size`
271
+ """
272
+ q, k, v = self.forward_qkv(query, key, value)
273
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
274
+
275
+ # NOTE(xcsong):
276
+ # when export onnx model, for 1st chunk, we feed
277
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
278
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
279
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
280
+ # and we will always do splitting and
281
+ # concatnation(this will simplify onnx export). Note that
282
+ # it's OK to concat & split zero-shaped tensors(see code below).
283
+ # when export jit model, for 1st chunk, we always feed
284
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
285
+ # >>> a = torch.ones((1, 2, 0, 4))
286
+ # >>> b = torch.ones((1, 2, 3, 4))
287
+ # >>> c = torch.cat((a, b), dim=2)
288
+ # >>> torch.equal(b, c) # True
289
+ # >>> d = torch.split(a, 2, dim=-1)
290
+ # >>> torch.equal(d[0], d[1]) # True
291
+ if cache.size(0) > 0:
292
+ key_cache, value_cache = torch.split(cache,
293
+ cache.size(-1) // 2,
294
+ dim=-1)
295
+ k = torch.cat([key_cache, k], dim=2)
296
+ v = torch.cat([value_cache, v], dim=2)
297
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
298
+ # non-trivial to calculate `next_cache_start` here.
299
+ new_cache = torch.cat((k, v), dim=-1)
300
+
301
+ n_batch_pos = pos_emb.size(0)
302
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
303
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
304
+
305
+ # (batch, head, time1, d_k)
306
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
307
+ # (batch, head, time1, d_k)
308
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
309
+
310
+ # compute attention score
311
+ # first compute matrix a and matrix c
312
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
313
+ # (batch, head, time1, time2)
314
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
315
+
316
+ # compute matrix b and matrix d
317
+ # (batch, head, time1, time2)
318
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
319
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
320
+ if matrix_ac.shape != matrix_bd.shape:
321
+ matrix_bd = self.rel_shift(matrix_bd)
322
+
323
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
324
+ self.d_k) # (batch, head, time1, time2)
325
+
326
+ return self.forward_attention(v, scores, mask), new_cache
327
+
328
+
329
+
330
+
331
+ # class BlockRelPositionMultiHeadedAttention(MultiHeadedAttention):
332
+ # """Multi-Head Attention layer with relative position encoding.
333
+ # Paper: https://arxiv.org/abs/1901.02860
334
+ # Args:
335
+ # n_head (int): The number of heads.
336
+ # n_feat (int): The number of features.
337
+ # dropout_rate (float): Dropout rate.
338
+ # """
339
+
340
+ # def __init__(self,
341
+ # n_head: int,
342
+ # n_feat: int,
343
+ # dropout_rate: float,
344
+ # key_bias: bool = True,
345
+ # block_size=25):
346
+ # """Construct an RelPositionMultiHeadedAttention object."""
347
+ # super().__init__(n_head, n_feat, dropout_rate, key_bias)
348
+ # # linear transformation for positional encoding
349
+ # self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
350
+ # # these two learnable bias are used in matrix c and matrix d
351
+ # # as described in https://arxiv.org/abs/1901.02860 Section 3.3
352
+ # self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
353
+ # self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
354
+ # torch.nn.init.xavier_uniform_(self.pos_bias_u)
355
+ # torch.nn.init.xavier_uniform_(self.pos_bias_v)
356
+ # self.block_size=block_size
357
+
358
+ # def rel_shift(self, x):
359
+ # """Compute relative positional encoding.
360
+
361
+ # Args:
362
+ # x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
363
+ # time1 means the length of query vector.
364
+
365
+ # Returns:
366
+ # torch.Tensor: Output tensor.
367
+
368
+ # """
369
+ # zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
370
+ # x_padded = torch.cat([zero_pad, x], dim=-1)
371
+
372
+ # x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
373
+ # x = x_padded[:, :, 1:].view_as(x)[
374
+ # :, :, :, : x.size(-1) // 2 + 1
375
+ # ] # only keep the positions from 0 to time2
376
+ # return x
377
+
378
+ # def forward(
379
+ # self,
380
+ # query: torch.Tensor,
381
+ # key: torch.Tensor,
382
+ # value: torch.Tensor,
383
+ # mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
384
+ # pos_emb: torch.Tensor = torch.empty(0),
385
+ # cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
386
+ # ) -> Tuple[torch.Tensor, torch.Tensor]:
387
+ # """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
388
+ # Args:
389
+ # query (torch.Tensor): Query tensor (#batch, time1, size).
390
+ # key (torch.Tensor): Key tensor (#batch, time2, size).
391
+ # value (torch.Tensor): Value tensor (#batch, time2, size).
392
+ # mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
393
+ # (#batch, time1, time2), (0, 0, 0) means fake mask.
394
+ # pos_emb (torch.Tensor): Positional embedding tensor
395
+ # (#batch, time2, size).
396
+ # cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
397
+ # where `cache_t == chunk_size * num_decoding_left_chunks`
398
+ # and `head * d_k == size`
399
+ # Returns:
400
+ # torch.Tensor: Output tensor (#batch, time1, d_model).
401
+ # torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
402
+ # where `cache_t == chunk_size * num_decoding_left_chunks`
403
+ # and `head * d_k == size`
404
+ # """
405
+ # q, k, v = self.forward_qkv(query, key, value)
406
+ # q = q.transpose(1, 2) # (batch, time1, head, d_k)
407
+
408
+ # # NOTE(xcsong):
409
+ # # when export onnx model, for 1st chunk, we feed
410
+ # # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
411
+ # # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
412
+ # # In all modes, `if cache.size(0) > 0` will alwayse be `True`
413
+ # # and we will always do splitting and
414
+ # # concatnation(this will simplify onnx export). Note that
415
+ # # it's OK to concat & split zero-shaped tensors(see code below).
416
+ # # when export jit model, for 1st chunk, we always feed
417
+ # # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
418
+ # # >>> a = torch.ones((1, 2, 0, 4))
419
+ # # >>> b = torch.ones((1, 2, 3, 4))
420
+ # # >>> c = torch.cat((a, b), dim=2)
421
+ # # >>> torch.equal(b, c) # True
422
+ # # >>> d = torch.split(a, 2, dim=-1)
423
+ # # >>> torch.equal(d[0], d[1]) # True
424
+ # if cache.size(0) > 0:
425
+ # key_cache, value_cache = torch.split(cache,
426
+ # cache.size(-1) // 2,
427
+ # dim=-1)
428
+ # k = torch.cat([key_cache, k], dim=2)
429
+ # v = torch.cat([value_cache, v], dim=2)
430
+ # # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
431
+ # # non-trivial to calculate `next_cache_start` here.
432
+ # new_cache = torch.cat((k, v), dim=-1)
433
+
434
+ # n_batch_pos = pos_emb.size(0)
435
+ # p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
436
+ # p = p.transpose(1, 2) # (batch, head, time1, d_k)
437
+
438
+ # # (batch, head, time1, d_k)
439
+ # q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
440
+ # # (batch, head, time1, d_k)
441
+ # q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
442
+
443
+ # # compute attention score
444
+ # # first compute matrix a and matrix c
445
+ # # as described in https://arxiv.org/abs/1901.02860 Section 3.3
446
+ # # (batch, head, time1, time2)
447
+
448
+ # # Compute matrix ac and bd
449
+ # matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) # (batch, head, time1, time2)
450
+ # matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) # (batch, head, time1, time2)
451
+
452
+ # batch_size, num_heads, seq_len, _ = matrix_ac.shape
453
+
454
+ # # Create block causal mask
455
+ # block_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=self.block_size).to(matrix_ac.device).bool()
456
+ # # mask = mask.masked_fill(mask == 1, float('-inf')) # mask upper triangular matrix beyond block
457
+
458
+ # # Apply relative shift if necessary
459
+ # if matrix_ac.shape != matrix_bd.shape:
460
+ # matrix_bd = self.rel_shift(matrix_bd)
461
+
462
+ # # Combine ac and bd and apply the block causal mask
463
+ # scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
464
+ # scores = scores.masked_fill(block_mask.unsqueeze(0).unsqueeze(0), float('-inf')) # apply the block mask
465
+
466
+ # # Forward attention
467
+ # return self.forward_attention(v, scores, mask), new_cache
468
+
469
+
470
+
471
+ from cosyvoice.utils import block_mask_util
472
+ class BlockRelPositionMultiHeadedAttention(MultiHeadedAttention):
473
+ """Multi-Head Attention layer with relative position encoding.
474
+ Paper: https://arxiv.org/abs/1901.02860
475
+ Args:
476
+ n_head (int): The number of heads.
477
+ n_feat (int): The number of features.
478
+ dropout_rate (float): Dropout rate.
479
+ """
480
+
481
+ def __init__(self,
482
+ n_head: int,
483
+ n_feat: int,
484
+ dropout_rate: float,
485
+ key_bias: bool = True, block_size=25):
486
+ """Construct an RelPositionMultiHeadedAttention object."""
487
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
488
+ # linear transformation for positional encoding
489
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
490
+ # these two learnable bias are used in matrix c and matrix d
491
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
492
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
493
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
494
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
495
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
496
+ self.block_size = block_size
497
+
498
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
499
+ """Compute relative positional encoding.
500
+
501
+ Args:
502
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
503
+ time1 means the length of query vector.
504
+
505
+ Returns:
506
+ torch.Tensor: Output tensor.
507
+
508
+ """
509
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
510
+ device=x.device,
511
+ dtype=x.dtype)
512
+ x_padded = torch.cat([zero_pad, x], dim=-1)
513
+
514
+ x_padded = x_padded.view(x.size()[0],
515
+ x.size()[1],
516
+ x.size(3) + 1, x.size(2))
517
+ x = x_padded[:, :, 1:].view_as(x)[
518
+ :, :, :, : x.size(-1) // 2 + 1
519
+ ] # only keep the positions from 0 to time2
520
+ return x
521
+
522
+ def forward(
523
+ self,
524
+ query: torch.Tensor,
525
+ key: torch.Tensor,
526
+ value: torch.Tensor,
527
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
528
+ pos_emb: torch.Tensor = torch.empty(0),
529
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
530
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
531
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
532
+ Args:
533
+ query (torch.Tensor): Query tensor (#batch, time1, size).
534
+ key (torch.Tensor): Key tensor (#batch, time2, size).
535
+ value (torch.Tensor): Value tensor (#batch, time2, size).
536
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
537
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
538
+ pos_emb (torch.Tensor): Positional embedding tensor
539
+ (#batch, time2, size).
540
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
541
+ where `cache_t == chunk_size * num_decoding_left_chunks`
542
+ and `head * d_k == size`
543
+ Returns:
544
+ torch.Tensor: Output tensor (#batch, time1, d_model).
545
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
546
+ where `cache_t == chunk_size * num_decoding_left_chunks`
547
+ and `head * d_k == size`
548
+ """
549
+ q, k, v = self.forward_qkv(query, key, value)
550
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
551
+
552
+ # 0代表被mask的位置
553
+ bs, time_len, _ = query.shape
554
+ # mask = torch.tril(torch.ones(time_len, time_len).to(mask), diagonal=0).int()
555
+ # block_size = self.block_size
556
+ # mask[:, 0:block_size] = 1
557
+ block_mask = block_mask_util.create_grid_mask(time_len,self.block_size,fill_triangle=True).to(query).int()
558
+ block_mask = block_mask[None].repeat(bs, 1, 1)
559
+ mask=mask*block_mask
560
+
561
+ # NOTE(xcsong):
562
+ # when export onnx model, for 1st chunk, we feed
563
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
564
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
565
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
566
+ # and we will always do splitting and
567
+ # concatnation(this will simplify onnx export). Note that
568
+ # it's OK to concat & split zero-shaped tensors(see code below).
569
+ # when export jit model, for 1st chunk, we always feed
570
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
571
+ # >>> a = torch.ones((1, 2, 0, 4))
572
+ # >>> b = torch.ones((1, 2, 3, 4))
573
+ # >>> c = torch.cat((a, b), dim=2)
574
+ # >>> torch.equal(b, c) # True
575
+ # >>> d = torch.split(a, 2, dim=-1)
576
+ # >>> torch.equal(d[0], d[1]) # True
577
+ if cache.size(0) > 0:
578
+ key_cache, value_cache = torch.split(cache,
579
+ cache.size(-1) // 2,
580
+ dim=-1)
581
+ k = torch.cat([key_cache, k], dim=2)
582
+ v = torch.cat([value_cache, v], dim=2)
583
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
584
+ # non-trivial to calculate `next_cache_start` here.
585
+ new_cache = torch.cat((k, v), dim=-1)
586
+
587
+ n_batch_pos = pos_emb.size(0)
588
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
589
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
590
+
591
+ # (batch, head, time1, d_k)
592
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
593
+ # (batch, head, time1, d_k)
594
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
595
+
596
+ # compute attention score
597
+ # first compute matrix a and matrix c
598
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
599
+ # (batch, head, time1, time2)
600
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
601
+
602
+ # compute matrix b and matrix d
603
+ # (batch, head, time1, time2)
604
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
605
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
606
+ if matrix_ac.shape != matrix_bd.shape:
607
+ matrix_bd = self.rel_shift(matrix_bd)
608
+
609
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
610
+ self.d_k) # (batch, head, time1, time2)
611
+
612
+ return self.forward_attention(v, scores, mask), new_cache
cosyvoice/transformer/convolution.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """ConvolutionModule definition."""
17
+
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class ConvolutionModule(nn.Module):
25
+ """ConvolutionModule in Conformer model."""
26
+
27
+ def __init__(self,
28
+ channels: int,
29
+ kernel_size: int = 15,
30
+ activation: nn.Module = nn.ReLU(),
31
+ norm: str = "batch_norm",
32
+ causal: bool = False,
33
+ bias: bool = True):
34
+ """Construct an ConvolutionModule object.
35
+ Args:
36
+ channels (int): The number of channels of conv layers.
37
+ kernel_size (int): Kernel size of conv layers.
38
+ causal (int): Whether use causal convolution or not
39
+ """
40
+ super().__init__()
41
+
42
+ self.pointwise_conv1 = nn.Conv1d(
43
+ channels,
44
+ 2 * channels,
45
+ kernel_size=1,
46
+ stride=1,
47
+ padding=0,
48
+ bias=bias,
49
+ )
50
+ # self.lorder is used to distinguish if it's a causal convolution,
51
+ # if self.lorder > 0: it's a causal convolution, the input will be
52
+ # padded with self.lorder frames on the left in forward.
53
+ # else: it's a symmetrical convolution
54
+ if causal:
55
+ padding = 0
56
+ self.lorder = kernel_size - 1
57
+ else:
58
+ # kernel_size should be an odd number for none causal convolution
59
+ assert (kernel_size - 1) % 2 == 0
60
+ padding = (kernel_size - 1) // 2
61
+ self.lorder = 0
62
+ self.depthwise_conv = nn.Conv1d(
63
+ channels,
64
+ channels,
65
+ kernel_size,
66
+ stride=1,
67
+ padding=padding,
68
+ groups=channels,
69
+ bias=bias,
70
+ )
71
+
72
+ assert norm in ['batch_norm', 'layer_norm']
73
+ if norm == "batch_norm":
74
+ self.use_layer_norm = False
75
+ self.norm = nn.BatchNorm1d(channels)
76
+ else:
77
+ self.use_layer_norm = True
78
+ self.norm = nn.LayerNorm(channels)
79
+
80
+ self.pointwise_conv2 = nn.Conv1d(
81
+ channels,
82
+ channels,
83
+ kernel_size=1,
84
+ stride=1,
85
+ padding=0,
86
+ bias=bias,
87
+ )
88
+ self.activation = activation
89
+
90
+ def forward(
91
+ self,
92
+ x: torch.Tensor,
93
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
94
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
95
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
96
+ """Compute convolution module.
97
+ Args:
98
+ x (torch.Tensor): Input tensor (#batch, time, channels).
99
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
100
+ (0, 0, 0) means fake mask.
101
+ cache (torch.Tensor): left context cache, it is only
102
+ used in causal convolution (#batch, channels, cache_t),
103
+ (0, 0, 0) meas fake cache.
104
+ Returns:
105
+ torch.Tensor: Output tensor (#batch, time, channels).
106
+ """
107
+ # exchange the temporal dimension and the feature dimension
108
+ x = x.transpose(1, 2) # (#batch, channels, time)
109
+
110
+ # mask batch padding
111
+ if mask_pad.size(2) > 0: # time > 0
112
+ x.masked_fill_(~mask_pad, 0.0)
113
+
114
+ if self.lorder > 0:
115
+ if cache.size(2) == 0: # cache_t == 0
116
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
117
+ else:
118
+ assert cache.size(0) == x.size(0) # equal batch
119
+ assert cache.size(1) == x.size(1) # equal channel
120
+ x = torch.cat((cache, x), dim=2)
121
+ assert (x.size(2) > self.lorder)
122
+ new_cache = x[:, :, -self.lorder:]
123
+ else:
124
+ # It's better we just return None if no cache is required,
125
+ # However, for JIT export, here we just fake one tensor instead of
126
+ # None.
127
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
128
+
129
+ # GLU mechanism
130
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
131
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
132
+
133
+ # 1D Depthwise Conv
134
+ x = self.depthwise_conv(x)
135
+ if self.use_layer_norm:
136
+ x = x.transpose(1, 2)
137
+ x = self.activation(self.norm(x))
138
+ if self.use_layer_norm:
139
+ x = x.transpose(1, 2)
140
+ x = self.pointwise_conv2(x)
141
+ # mask batch padding
142
+ if mask_pad.size(2) > 0: # time > 0
143
+ x.masked_fill_(~mask_pad, 0.0)
144
+
145
+ return x.transpose(1, 2), new_cache