hhguo commited on
Commit
37ced70
1 Parent(s): 3ec218a
Files changed (42) hide show
  1. fireredtts/fireredtts.py +163 -0
  2. fireredtts/modules/__init__.py +42 -0
  3. fireredtts/modules/bigvgan/__init__.py +6 -0
  4. fireredtts/modules/bigvgan/activations.py +126 -0
  5. fireredtts/modules/bigvgan/alias_free_cuda/__init__.py +0 -0
  6. fireredtts/modules/bigvgan/alias_free_cuda/activation1d.py +75 -0
  7. fireredtts/modules/bigvgan/alias_free_cuda/anti_alias_activation.cpp +48 -0
  8. fireredtts/modules/bigvgan/alias_free_cuda/anti_alias_activation_cuda.cu +314 -0
  9. fireredtts/modules/bigvgan/alias_free_cuda/compat.h +31 -0
  10. fireredtts/modules/bigvgan/alias_free_cuda/load.py +85 -0
  11. fireredtts/modules/bigvgan/alias_free_cuda/test_activation.py +64 -0
  12. fireredtts/modules/bigvgan/alias_free_cuda/test_activation_snake_beta.py +64 -0
  13. fireredtts/modules/bigvgan/alias_free_cuda/type_shim.h +97 -0
  14. fireredtts/modules/bigvgan/alias_free_torch/__init__.py +5 -0
  15. fireredtts/modules/bigvgan/alias_free_torch/act.py +29 -0
  16. fireredtts/modules/bigvgan/alias_free_torch/filter.py +98 -0
  17. fireredtts/modules/bigvgan/alias_free_torch/resample.py +57 -0
  18. fireredtts/modules/bigvgan/bigvgan.py +399 -0
  19. fireredtts/modules/codec/speaker.py +1052 -0
  20. fireredtts/modules/flow/__init__.py +24 -0
  21. fireredtts/modules/flow/codebook.npy +3 -0
  22. fireredtts/modules/flow/codec_embedding.py +30 -0
  23. fireredtts/modules/flow/conformer.py +730 -0
  24. fireredtts/modules/flow/decoder.py +396 -0
  25. fireredtts/modules/flow/flow_model.py +89 -0
  26. fireredtts/modules/flow/mel_encoder.py +170 -0
  27. fireredtts/modules/flow/mel_spectrogram.py +132 -0
  28. fireredtts/modules/flow/transformer.py +249 -0
  29. fireredtts/modules/flow/utils.py +30 -0
  30. fireredtts/modules/gpt/__init__.py +0 -0
  31. fireredtts/modules/gpt/gpt.py +356 -0
  32. fireredtts/modules/text_normalizer/__init__.py +0 -0
  33. fireredtts/modules/text_normalizer/normalize.py +178 -0
  34. fireredtts/modules/text_normalizer/regex_common.py +23 -0
  35. fireredtts/modules/text_normalizer/utils.py +121 -0
  36. fireredtts/modules/tokenizer/__init__.py +0 -0
  37. fireredtts/modules/tokenizer/assets/multilingual.tiktoken +0 -0
  38. fireredtts/modules/tokenizer/tokenizer.py +46 -0
  39. fireredtts/modules/tokenizer/whisper_tokenizer.py +456 -0
  40. fireredtts/utils/__init__.py +0 -0
  41. fireredtts/utils/utils.py +37 -0
  42. pretrained_models/README.md +3 -0
fireredtts/fireredtts.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from fireredtts.modules.gpt.gpt import GPT
5
+ from fireredtts.modules import Token2Wav, MelSpectrogramExtractor
6
+ from fireredtts.modules.tokenizer.tokenizer import VoiceBpeTokenizer
7
+ from fireredtts.modules.codec.speaker import SpeakerEmbedddingExtractor
8
+ from fireredtts.utils.utils import load_audio
9
+
10
+ import time
11
+
12
+
13
+ class FireRedTTS:
14
+ def __init__(self, config_path, pretrained_path, device="cuda"):
15
+ self.device = device
16
+ self.config = json.load(open(config_path))
17
+ self.gpt_path = os.path.join(pretrained_path, "fireredtts_gpt.pt")
18
+ self.token2wav_path = os.path.join(pretrained_path, "fireredtts_token2wav.pt")
19
+ self.speaker_extractor_path = os.path.join(
20
+ pretrained_path, "fireredtts_speaker.bin"
21
+ )
22
+ assert os.path.exists(self.token2wav_path)
23
+ assert os.path.exists(self.gpt_path)
24
+ assert os.path.exists(self.speaker_extractor_path)
25
+
26
+ # tokenizer;
27
+ self.text_tokenizer = VoiceBpeTokenizer()
28
+
29
+ # speaker ectractor
30
+ self.speaker_extractor = SpeakerEmbedddingExtractor(
31
+ ckpt_path=self.speaker_extractor_path, device=device
32
+ )
33
+
34
+ # load gpt model
35
+ self.gpt = GPT(
36
+ start_text_token=self.config["gpt"]["gpt_start_text_token"],
37
+ stop_text_token=self.config["gpt"]["gpt_stop_text_token"],
38
+ layers=self.config["gpt"]["gpt_layers"],
39
+ model_dim=self.config["gpt"]["gpt_n_model_channels"],
40
+ heads=self.config["gpt"]["gpt_n_heads"],
41
+ max_text_tokens=self.config["gpt"]["gpt_max_text_tokens"],
42
+ max_mel_tokens=self.config["gpt"]["gpt_max_audio_tokens"],
43
+ max_prompt_tokens=self.config["gpt"]["gpt_max_prompt_tokens"],
44
+ code_stride_len=self.config["gpt"]["gpt_code_stride_len"],
45
+ number_text_tokens=self.config["gpt"]["gpt_number_text_tokens"],
46
+ num_audio_tokens=self.config["gpt"]["gpt_num_audio_tokens"],
47
+ start_audio_token=self.config["gpt"]["gpt_start_audio_token"],
48
+ stop_audio_token=self.config["gpt"]["gpt_stop_audio_token"],
49
+ )
50
+
51
+ sd = torch.load(self.gpt_path, map_location=device)["model"]
52
+ self.gpt.load_state_dict(sd, strict=True)
53
+ self.gpt = self.gpt.to(device=device)
54
+ self.gpt.eval()
55
+ self.gpt.init_gpt_for_inference(kv_cache=True)
56
+
57
+ # mel-spectrogram extractor
58
+ self.mel_extractor = MelSpectrogramExtractor()
59
+
60
+ # load token2wav model
61
+ self.token2wav = Token2Wav.init_from_config(self.config)
62
+ sd = torch.load(self.token2wav_path, map_location="cpu")
63
+ self.token2wav.load_state_dict(sd, strict=True)
64
+ self.token2wav.generator.remove_weight_norm()
65
+ self.token2wav.eval()
66
+ self.token2wav = self.token2wav.to(device)
67
+
68
+ def extract_spk_embeddings(self, prompt_wav):
69
+ _, _, audio_resampled = load_audio(audiopath=prompt_wav, sampling_rate=16000)
70
+ audio_len = torch.tensor(
71
+ data=[audio_resampled.shape[1]], dtype=torch.long, requires_grad=False
72
+ )
73
+
74
+ # speaker embeddings [1,512]
75
+ spk_embeddings = self.speaker_extractor(
76
+ audio_resampled.to(device="cuda")
77
+ ).unsqueeze(0)
78
+
79
+ return spk_embeddings
80
+
81
+ def do_gpt_inference(self, spk_gpt, text_tokens):
82
+ """_summary_
83
+
84
+ Args:
85
+ spk_gpt (_type_): speaker embeddidng in gpt
86
+ text_tokens (_type_): text tokens
87
+ """
88
+ with torch.no_grad():
89
+ gpt_codes = self.gpt.generate(
90
+ cond_latents=spk_gpt,
91
+ text_inputs=text_tokens,
92
+ input_tokens=None,
93
+ do_sample=True,
94
+ top_p=0.85,
95
+ top_k=30,
96
+ temperature=0.75,
97
+ num_return_sequences=9,
98
+ num_beams=1,
99
+ length_penalty=1.0,
100
+ repetition_penalty=2.0,
101
+ output_attentions=False,
102
+ )
103
+
104
+ seqs = []
105
+ EOS_TOKEN = self.config["gpt"]["gpt_stop_audio_token"]
106
+ for seq in gpt_codes:
107
+ index = (seq == EOS_TOKEN).nonzero(as_tuple=True)[0][0]
108
+ seq = seq[:index]
109
+ seqs.append(seq)
110
+
111
+ sorted_seqs = sorted(seqs, key=lambda i: len(i), reverse=False)
112
+ gpt_codes = sorted_seqs[2].unsqueeze(0) # [1, len]
113
+ # sorted_len = [len(l) for l in sorted_seqs]
114
+ # print("---sorted_len:", sorted_len)
115
+
116
+ return gpt_codes
117
+
118
+ def synthesize(self, prompt_wav, text, lang="auto"):
119
+ """_summary_
120
+
121
+ Args:
122
+ prompts_wav (_type_): prompts_wav path
123
+ text (_type_): text
124
+ lang (_type_): language of text
125
+ """
126
+ # Currently only supports Chinese and English
127
+ assert lang in ["zh", "en", "auto"]
128
+ assert os.path.exists(prompt_wav)
129
+
130
+ # text to tokens
131
+ text_tokens = self.text_tokenizer.encode(text=text, lang=lang)
132
+ text_tokens = torch.IntTensor(text_tokens).unsqueeze(0).to(self.device)
133
+ assert text_tokens.shape[-1] < 400
134
+
135
+ # extract speaker embedding
136
+ spk_embeddings = self.extract_spk_embeddings(prompt_wav=prompt_wav).unsqueeze(0)
137
+ with torch.no_grad():
138
+ spk_gpt = self.gpt.reference_embedding(spk_embeddings)
139
+
140
+ # gpt inference
141
+ gpt_start_time = time.time()
142
+ gpt_codes = self.do_gpt_inference(spk_gpt=spk_gpt, text_tokens=text_tokens)
143
+ gpt_end_time = time.time()
144
+ gpt_dur = gpt_end_time - gpt_start_time
145
+
146
+ # prompt mel-spectrogram compute
147
+ prompt_mel = (
148
+ self.mel_extractor(wav_path=prompt_wav).unsqueeze(0).to(self.device)
149
+ )
150
+ # convert token to waveform (b=1, t)
151
+ voc_start_time = time.time()
152
+ rec_wavs = self.token2wav.inference(gpt_codes, prompt_mel, n_timesteps=10)
153
+ voc_end_time = time.time()
154
+ voc_dur = voc_end_time - voc_start_time
155
+ all_dur = voc_end_time - gpt_start_time
156
+
157
+ # rtf compute
158
+ # audio_dur = rec_wavs.shape[-1] / 24000
159
+ # rtf_gpt = gpt_dur / audio_dur
160
+ # rtf_voc = voc_dur / audio_dur
161
+ # rtf_all = all_dur / audio_dur
162
+
163
+ return rec_wavs
fireredtts/modules/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import torch.nn as nn
4
+ from fireredtts.modules.bigvgan import get_bigvgan_backend
5
+ from fireredtts.modules.flow import get_flow_frontend, MelSpectrogramExtractor
6
+
7
+
8
+ class Token2Wav(nn.Module):
9
+ def __init__(
10
+ self,
11
+ flow: nn.Module,
12
+ generator: nn.Module,
13
+ ):
14
+ super().__init__()
15
+ self.flow = flow
16
+ self.generator = generator
17
+
18
+ @torch.no_grad()
19
+ def inference(
20
+ self, tokens: torch.Tensor, prompt_mel: torch.Tensor, n_timesteps: int = 10
21
+ ) -> torch.Tensor:
22
+ token_len = torch.tensor([tokens.shape[1]], dtype=torch.long).to(tokens.device)
23
+ prompt_mel_len = torch.tensor([prompt_mel.shape[1]], dtype=torch.long).to(
24
+ prompt_mel.device
25
+ )
26
+ # flow
27
+ mel = self.flow.inference(
28
+ token=tokens,
29
+ token_len=token_len,
30
+ prompt_mel=prompt_mel,
31
+ prompt_mel_len=prompt_mel_len,
32
+ n_timesteps=n_timesteps,
33
+ )
34
+ # bigvgan
35
+ audio = self.generator(mel) # (b=1, 1, t)
36
+ return audio.squeeze(1)
37
+
38
+ @classmethod
39
+ def init_from_config(cls, config) -> "Token2Wav":
40
+ flow = get_flow_frontend(config["flow"])
41
+ bigvgan = get_bigvgan_backend(config["bigvgan"])
42
+ return cls(flow, bigvgan)
fireredtts/modules/bigvgan/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from fireredtts.modules.bigvgan.bigvgan import BigVGAN
2
+
3
+
4
+ def get_bigvgan_backend(bigvgan_config):
5
+ generator = BigVGAN(**bigvgan_config)
6
+ return generator
fireredtts/modules/bigvgan/activations.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ """
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ """
25
+
26
+ def __init__(
27
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
28
+ ):
29
+ """
30
+ Initialization.
31
+ INPUT:
32
+ - in_features: shape of the input
33
+ - alpha: trainable parameter
34
+ alpha is initialized to 1 by default, higher values = higher-frequency.
35
+ alpha will be trained along with the rest of your model.
36
+ """
37
+ super(Snake, self).__init__()
38
+ self.in_features = in_features
39
+
40
+ # initialize alpha
41
+ self.alpha_logscale = alpha_logscale
42
+ if self.alpha_logscale: # log scale alphas initialized to zeros
43
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
44
+ else: # linear scale alphas initialized to ones
45
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
46
+
47
+ self.alpha.requires_grad = alpha_trainable
48
+
49
+ self.no_div_by_zero = 0.000000001
50
+
51
+ def forward(self, x):
52
+ """
53
+ Forward pass of the function.
54
+ Applies the function to the input elementwise.
55
+ Snake ∶= x + 1/a * sin^2 (xa)
56
+ """
57
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
58
+ if self.alpha_logscale:
59
+ alpha = torch.exp(alpha)
60
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
61
+
62
+ return x
63
+
64
+
65
+ class SnakeBeta(nn.Module):
66
+ """
67
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
68
+ Shape:
69
+ - Input: (B, C, T)
70
+ - Output: (B, C, T), same shape as the input
71
+ Parameters:
72
+ - alpha - trainable parameter that controls frequency
73
+ - beta - trainable parameter that controls magnitude
74
+ References:
75
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
76
+ https://arxiv.org/abs/2006.08195
77
+ Examples:
78
+ >>> a1 = snakebeta(256)
79
+ >>> x = torch.randn(256)
80
+ >>> x = a1(x)
81
+ """
82
+
83
+ def __init__(
84
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
85
+ ):
86
+ """
87
+ Initialization.
88
+ INPUT:
89
+ - in_features: shape of the input
90
+ - alpha - trainable parameter that controls frequency
91
+ - beta - trainable parameter that controls magnitude
92
+ alpha is initialized to 1 by default, higher values = higher-frequency.
93
+ beta is initialized to 1 by default, higher values = higher-magnitude.
94
+ alpha will be trained along with the rest of your model.
95
+ """
96
+ super(SnakeBeta, self).__init__()
97
+ self.in_features = in_features
98
+
99
+ # initialize alpha
100
+ self.alpha_logscale = alpha_logscale
101
+ if self.alpha_logscale: # log scale alphas initialized to zeros
102
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
103
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
104
+ else: # linear scale alphas initialized to ones
105
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
106
+ self.beta = Parameter(torch.ones(in_features) * alpha)
107
+
108
+ self.alpha.requires_grad = alpha_trainable
109
+ self.beta.requires_grad = alpha_trainable
110
+
111
+ self.no_div_by_zero = 0.000000001
112
+
113
+ def forward(self, x):
114
+ """
115
+ Forward pass of the function.
116
+ Applies the function to the input elementwise.
117
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
118
+ """
119
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
120
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
121
+ if self.alpha_logscale:
122
+ alpha = torch.exp(alpha)
123
+ beta = torch.exp(beta)
124
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
125
+
126
+ return x
fireredtts/modules/bigvgan/alias_free_cuda/__init__.py ADDED
File without changes
fireredtts/modules/bigvgan/alias_free_cuda/activation1d.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from token2wav.alias_free_torch.resample import UpSample1d, DownSample1d
7
+
8
+ # load fused CUDA kernel: this enables importing anti_alias_activation_cuda
9
+ from token2wav.alias_free_cuda import load
10
+
11
+ load.load()
12
+
13
+
14
+ class FusedAntiAliasActivation(torch.autograd.Function):
15
+ """
16
+ Assumes filter size 12, replication padding on upsampling, and logscale alpha/beta parameters as inputs
17
+ """
18
+
19
+ @staticmethod
20
+ def forward(ctx, inputs, ftr, alpha, beta):
21
+ import anti_alias_activation_cuda
22
+
23
+ activation_results = anti_alias_activation_cuda.forward(
24
+ inputs, ftr, alpha, beta
25
+ )
26
+ return activation_results
27
+
28
+ @staticmethod
29
+ def backward(ctx, output_grads):
30
+ # TODO: implement bwd pass
31
+ raise NotImplementedError
32
+ return output_grads, None, None
33
+
34
+
35
+ class Activation1d(nn.Module):
36
+ def __init__(
37
+ self,
38
+ activation,
39
+ up_ratio: int = 2,
40
+ down_ratio: int = 2,
41
+ up_kernel_size: int = 12,
42
+ down_kernel_size: int = 12,
43
+ fused: bool = True,
44
+ ):
45
+ super().__init__()
46
+ self.up_ratio = up_ratio
47
+ self.down_ratio = down_ratio
48
+ self.act = activation
49
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
50
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
51
+
52
+ self.fused = fused # whether to use fused CUDA kernel or not
53
+
54
+ def forward(self, x):
55
+ if not self.fused:
56
+ x = self.upsample(x)
57
+ x = self.act(x)
58
+ x = self.downsample(x)
59
+ return x
60
+ else:
61
+ if self.act.__class__.__name__ == "Snake":
62
+ beta = self.act.alpha.data # snake uses same params for alpha and beta
63
+ else:
64
+ beta = (
65
+ self.act.beta.data
66
+ ) # snakebeta uses different params for alpha and beta
67
+ alpha = self.act.alpha.data
68
+ if (
69
+ not self.act.alpha_logscale
70
+ ): # exp baked into cuda kernel, cancel it out with a log
71
+ alpha = torch.log(alpha)
72
+ beta = torch.log(beta)
73
+ x = FusedAntiAliasActivation.apply(x, self.upsample.filter, alpha, beta)
74
+ x = self.downsample(x)
75
+ return x
fireredtts/modules/bigvgan/alias_free_cuda/anti_alias_activation.cpp ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <cuda_fp16.h>
18
+ #include <torch/extension.h>
19
+ #include <vector>
20
+
21
+ namespace anti_alias_activation {
22
+
23
+ torch::Tensor fwd_cuda(torch::Tensor const& input,
24
+ torch::Tensor const& filter,
25
+ torch::Tensor const& alpha,
26
+ torch::Tensor const& beta
27
+ );
28
+
29
+ torch::Tensor fwd(torch::Tensor const& input,
30
+ torch::Tensor const& filter,
31
+ torch::Tensor const& alpha,
32
+ torch::Tensor const& beta
33
+ ) {
34
+ AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
35
+ //AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
36
+ // (input.scalar_type() == at::ScalarType::BFloat16),
37
+ // "Only fp16 and bf16 are supported");
38
+
39
+ return fwd_cuda(input, filter, alpha, beta);
40
+ }
41
+
42
+ } // end namespace anti_alias_activation
43
+
44
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
45
+ m.def("forward",
46
+ &anti_alias_activation::fwd,
47
+ "Anti Alias Activation -- Forward.");
48
+ }
fireredtts/modules/bigvgan/alias_free_cuda/anti_alias_activation_cuda.cu ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include <cuda.h>
19
+ #include <cuda_runtime.h>
20
+ #include <cuda_fp16.h>
21
+ #include <cuda_profiler_api.h>
22
+ #include <ATen/cuda/CUDAContext.h>
23
+ #include <torch/extension.h>
24
+ #include "type_shim.h"
25
+ #include <assert.h>
26
+ #include <cfloat>
27
+ #include <limits>
28
+ #include <stdint.h>
29
+ #include <c10/macros/Macros.h>
30
+
31
+ namespace {
32
+
33
+ /*
34
+ template <typename Datatype, int ELEMENTS_PER_LDG>
35
+ __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
36
+
37
+ template <>
38
+ __device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
39
+
40
+ template <>
41
+ __device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
42
+
43
+ template <>
44
+ __device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
45
+
46
+ template <>
47
+ __device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
48
+
49
+ template <>
50
+ __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
51
+
52
+ template <>
53
+ __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
54
+
55
+ int log2_ceil(int value) {
56
+ int log2_value = 0;
57
+ while ((1 << log2_value) < value) ++log2_value;
58
+ return log2_value;
59
+ }
60
+
61
+ template<typename T>
62
+ struct Add {
63
+ __device__ __forceinline__ T operator()(T a, T b) const {
64
+ return a + b;
65
+ }
66
+ };
67
+
68
+ template<typename T>
69
+ struct Max {
70
+ __device__ __forceinline__ T operator()(T a, T b) const {
71
+ return a < b ? b : a;
72
+ }
73
+ };
74
+
75
+ template <typename T>
76
+ __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
77
+ {
78
+ #if CUDA_VERSION >= 9000
79
+ return __shfl_xor_sync(mask, value, laneMask, width);
80
+ #else
81
+ return __shfl_xor(value, laneMask, width);
82
+ #endif
83
+ }
84
+
85
+ template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
86
+ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
87
+ ReduceOp<acc_t> r;
88
+ #pragma unroll
89
+ for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
90
+ #pragma unroll
91
+ for (int i = 0; i < WARP_BATCH; ++i) {
92
+ acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
93
+ sum[i] = r(sum[i], b);
94
+ }
95
+ }
96
+ }
97
+ */
98
+
99
+ template <typename input_t, typename output_t, typename acc_t>
100
+ __global__ void anti_alias_activation_forward(
101
+ output_t *dst,
102
+ const input_t *src,
103
+ const input_t *ftr,
104
+ const input_t *alpha,
105
+ const input_t *beta,
106
+ int batch_size,
107
+ int channels,
108
+ int seq_len)
109
+ {
110
+ // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
111
+ constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
112
+ constexpr int BUFFER_SIZE = 32;
113
+ constexpr int FILTER_SIZE = 12;
114
+ constexpr int HALF_FILTER_SIZE = 6;
115
+ constexpr int REPLICATION_PAD = 5; // 5 on each side
116
+
117
+ // blockDim/threadIdx = (128, 1, 1)
118
+ // gridDim/blockIdx = (seq_blocks, channels, batches)
119
+ int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
120
+ int local_offset = threadIdx.x * BUFFER_SIZE;
121
+ int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
122
+
123
+
124
+ //int intermediate_seq_len = seq_len * 2 - 1 + 4 * REPLICATION_PAD;
125
+ //int intermediate_block_offset = (blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
126
+ //int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
127
+
128
+ int output_seq_len = seq_len * 2 ; //
129
+ int output_block_offset = (blockIdx.x * 128 * BUFFER_SIZE * 2 + output_seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
130
+ int output_local_offset = threadIdx.x * BUFFER_SIZE * 2;
131
+ int output_seq_offset = blockIdx.x * 128 * BUFFER_SIZE *2 + output_local_offset;
132
+ // get values needed for replication padding before moving pointer
133
+ const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
134
+ input_t seq_left_most_value = right_most_pntr[0];
135
+ input_t seq_right_most_value = right_most_pntr[seq_len - 1];
136
+
137
+ src += block_offset + local_offset;
138
+ dst += output_block_offset + output_local_offset ;
139
+ alpha = alpha + blockIdx.y;
140
+ input_t alpha_val = expf(alpha[0]);
141
+ beta = beta + blockIdx.y;
142
+ input_t beta_val = expf(beta[0]);
143
+ // load data from global memory
144
+ input_t elements[2*FILTER_SIZE+2*BUFFER_SIZE] = {0};
145
+ input_t intermediates[2*FILTER_SIZE+2*BUFFER_SIZE] = {0};
146
+ //output_t output[2*BUFFER_SIZE];
147
+ input_t filter[FILTER_SIZE];
148
+ //input_t temp_data[ELEMENTS_PER_LDG_STG];
149
+ //uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
150
+
151
+ #pragma unroll
152
+ for (int it = 0; it < FILTER_SIZE; it+=1) {
153
+ filter[it] = ftr[it];
154
+ }
155
+
156
+
157
+ #pragma unroll
158
+ for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE ; it+=1) {
159
+ int element_index = seq_offset + it;
160
+ if ((element_index < 0) && (element_index >= -REPLICATION_PAD)) {
161
+ elements[2*(HALF_FILTER_SIZE+it)] = 2*seq_left_most_value;
162
+ }
163
+ if ((element_index >= seq_len) && (element_index < seq_len + REPLICATION_PAD)) {
164
+ elements[2*(HALF_FILTER_SIZE+it)] = 2*seq_right_most_value;
165
+ }
166
+ if ((element_index >= 0) && (element_index < seq_len)) {
167
+ elements[2*(HALF_FILTER_SIZE+it)] = 2*src[it];
168
+ }
169
+ }
170
+
171
+
172
+
173
+ // apply filter
174
+ #pragma unroll
175
+ for (int it = 0; it < (2 * BUFFER_SIZE + 2*FILTER_SIZE); it+=1) {
176
+ input_t acc = 0.0;
177
+
178
+ int element_index = output_seq_offset + it; // index for output
179
+ #pragma unroll
180
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx+=1){
181
+ if ((element_index + f_idx) >= 0){
182
+ acc += filter[f_idx] * elements[it+f_idx];
183
+ }
184
+ }
185
+ intermediates[it] = acc;
186
+ }
187
+
188
+ double no_div_by_zero = 0.000000001;
189
+ #pragma unroll
190
+ for (int it = 0; it < 12 + 2 * BUFFER_SIZE; it++) {
191
+ intermediates[it] += (1.0/(beta_val + no_div_by_zero)) * sinf(intermediates[it] * alpha_val) * sinf(intermediates[it] * alpha_val);
192
+ }
193
+
194
+
195
+ // now copy to output
196
+ #pragma unroll
197
+ for (int it = 0; it < 2*BUFFER_SIZE; it+=1){
198
+ int element_index = output_seq_offset + it;
199
+ if (element_index < output_seq_len) {
200
+ dst[it] = intermediates[it+6];
201
+ }
202
+ }
203
+
204
+
205
+
206
+ // for (int it = 0; it < BUFFER_SIZE; it+=ELEMENTS_PER_LDG_STG) {
207
+ // int element_index = seq_offset + it;
208
+ // if (element_index < seq_len) {
209
+ // dst[it] = output[it];
210
+ // }
211
+ // }
212
+
213
+
214
+ // // Upsample convolution
215
+ // for (int it = 0; it < 2 * BUFFER_SIZE + 12; it+=1) {
216
+ // input_t acc = 0.0;
217
+
218
+ // for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx+=1){
219
+ // acc += filter[f_idx] * elements[it+f_idx];
220
+ // }
221
+ // intermediates[it] = acc;
222
+ // }
223
+
224
+ // // correct the corners of intermediates
225
+ // if (seq_offset == 0) {
226
+ // for (int it = 0; it < 6; it+=1)
227
+ // intermediates[it] = 0;
228
+ // }
229
+
230
+ // if (seq_offset + 32 >= seq_len) {
231
+ // int offset = seq_len % 32 == 0 ? 32 : seq_len % 32;
232
+
233
+ // for (int it = 0; it < 6; it++) {
234
+ // intermediates[6+2*offset+it] = 0;
235
+ // }
236
+ // }
237
+
238
+
239
+
240
+
241
+ // for (int it = 0; it < BUFFER_SIZE; it+=ELEMENTS_PER_LDG_STG) {
242
+ // int element_index = seq_offset + it;
243
+ // if (element_index < seq_len) {
244
+ // dst[it] = output[it];
245
+ // }
246
+ // }
247
+ }
248
+
249
+ template<typename input_t, typename output_t, typename acc_t>
250
+ void dispatch_anti_alias_activation_forward(
251
+ output_t *dst,
252
+ const input_t *src,
253
+ const input_t *ftr,
254
+ const input_t *alpha,
255
+ const input_t *beta,
256
+ int batch_size,
257
+ int channels,
258
+ int seq_len)
259
+ {
260
+ if (seq_len == 0) {
261
+ return;
262
+ } else {
263
+ // use 128 threads per block to maximimize gpu utilization
264
+ constexpr int threads_per_block = 128;
265
+ constexpr int seq_len_per_block = 4096;
266
+ int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
267
+ dim3 blocks(blocks_per_seq_len, channels, batch_size);
268
+ dim3 threads(threads_per_block, 1, 1);
269
+
270
+ anti_alias_activation_forward<input_t, output_t, acc_t>
271
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, ftr, alpha, beta, batch_size, channels, seq_len);
272
+ }
273
+ }
274
+ }
275
+
276
+ namespace anti_alias_activation {
277
+
278
+ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& filter, torch::Tensor const& alpha, torch::Tensor const& beta)
279
+ {
280
+ // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
281
+ const int batches = input.size(0);
282
+ const int channels = input.size(1);
283
+ const int seq_len = input.size(2);
284
+
285
+ // Output
286
+ auto act_options = input.options().requires_grad(false);
287
+ int output_seq_len = seq_len*2; // we'll be dilating between each element by interspersing with zeros
288
+
289
+ torch::Tensor anti_alias_activation_results =
290
+ torch::empty({batches, channels, output_seq_len}, act_options);
291
+
292
+ // Softmax Intermediate Result Ptr
293
+ void* input_ptr = static_cast<void*>(input.data_ptr());
294
+ void* filter_ptr = static_cast<void*>(filter.data_ptr());
295
+ void* alpha_ptr = static_cast<void*>(alpha.data_ptr());
296
+ void* beta_ptr = static_cast<void*>(beta.data_ptr());
297
+ void* anti_alias_activation_results_ptr = static_cast<void*>(anti_alias_activation_results.data_ptr());
298
+
299
+ DISPATCH_FLOAT_HALF_AND_BFLOAT(
300
+ input.scalar_type(),
301
+ "dispatch anti alias activation_forward",
302
+ dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
303
+ reinterpret_cast<scalar_t*>(anti_alias_activation_results_ptr),
304
+ reinterpret_cast<const scalar_t*>(input_ptr),
305
+ reinterpret_cast<const scalar_t*>(filter_ptr),
306
+ reinterpret_cast<const scalar_t*>(alpha_ptr),
307
+ reinterpret_cast<const scalar_t*>(beta_ptr),
308
+ batches,
309
+ channels,
310
+ seq_len);
311
+ );
312
+ return anti_alias_activation_results;
313
+ }
314
+ }
fireredtts/modules/bigvgan/alias_free_cuda/compat.h ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ /*This code is copied fron NVIDIA apex:
18
+ * https://github.com/NVIDIA/apex
19
+ * with minor changes. */
20
+
21
+
22
+
23
+ #ifndef TORCH_CHECK
24
+ #define TORCH_CHECK AT_CHECK
25
+ #endif
26
+
27
+ #ifdef VERSION_GE_1_3
28
+ #define DATA_PTR data_ptr
29
+ #else
30
+ #define DATA_PTR data
31
+ #endif
fireredtts/modules/bigvgan/alias_free_cuda/load.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import os
5
+ import pathlib
6
+ import subprocess
7
+
8
+ from torch.utils import cpp_extension
9
+
10
+ # Setting this param to a list has a problem of generating different
11
+ # compilation commands (with diferent order of architectures) and
12
+ # leading to recompilation of fused kernels. Set it to empty string
13
+ # to avoid recompilation and assign arch flags explicity in
14
+ # extra_cuda_cflags below
15
+ os.environ["TORCH_CUDA_ARCH_LIST"] = ""
16
+
17
+
18
+ def load():
19
+ # Check if cuda 11 is installed for compute capability 8.0
20
+ cc_flag = []
21
+ _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
22
+ if int(bare_metal_major) >= 11:
23
+ cc_flag.append("-gencode")
24
+ cc_flag.append("arch=compute_80,code=sm_80")
25
+
26
+ # Build path
27
+ srcpath = pathlib.Path(__file__).parent.absolute()
28
+ buildpath = srcpath / "build"
29
+ _create_build_dir(buildpath)
30
+
31
+ # Helper function to build the kernels.
32
+ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
33
+ return cpp_extension.load(
34
+ name=name,
35
+ sources=sources,
36
+ build_directory=buildpath,
37
+ extra_cflags=[
38
+ "-O3",
39
+ ],
40
+ extra_cuda_cflags=[
41
+ "-O3",
42
+ "-gencode",
43
+ "arch=compute_70,code=sm_70",
44
+ "--use_fast_math",
45
+ ]
46
+ + extra_cuda_flags
47
+ + cc_flag,
48
+ verbose=True,
49
+ )
50
+
51
+ extra_cuda_flags = [
52
+ "-U__CUDA_NO_HALF_OPERATORS__",
53
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
54
+ "--expt-relaxed-constexpr",
55
+ "--expt-extended-lambda",
56
+ ]
57
+
58
+ sources = [
59
+ srcpath / "anti_alias_activation.cpp",
60
+ srcpath / "anti_alias_activation_cuda.cu",
61
+ ]
62
+ anti_alias_activation_cuda = _cpp_extention_load_helper(
63
+ "anti_alias_activation_cuda", sources, extra_cuda_flags
64
+ )
65
+
66
+
67
+ def _get_cuda_bare_metal_version(cuda_dir):
68
+ raw_output = subprocess.check_output(
69
+ [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
70
+ )
71
+ output = raw_output.split()
72
+ release_idx = output.index("release") + 1
73
+ release = output[release_idx].split(".")
74
+ bare_metal_major = release[0]
75
+ bare_metal_minor = release[1][0]
76
+
77
+ return raw_output, bare_metal_major, bare_metal_minor
78
+
79
+
80
+ def _create_build_dir(buildpath):
81
+ try:
82
+ os.mkdir(buildpath)
83
+ except OSError:
84
+ if not os.path.isdir(buildpath):
85
+ print(f"Creation of the build directory {buildpath} failed")
fireredtts/modules/bigvgan/alias_free_cuda/test_activation.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import math
5
+ import torch
6
+ import alias_free_cuda
7
+ from alias_free_cuda import activation1d
8
+ from activations import Snake, SnakeBeta
9
+
10
+
11
+ def test_load_fused_kernels():
12
+ try:
13
+ import alias_free_cuda
14
+ import torch
15
+
16
+ print("[Success] load_fused_kernels")
17
+ except ImportError as e:
18
+ print("[Fail] load_fused_kernels")
19
+ raise e
20
+
21
+
22
+ def test_anti_alias_activation():
23
+ data = torch.rand((10, 10, 50000), device="cuda")
24
+
25
+ # check activations.Snake cuda vs. torch
26
+ fused_anti_alias_activation = activation1d.Activation1d(
27
+ activation=Snake(10), fused=True
28
+ ).cuda()
29
+ fused_activation_output = fused_anti_alias_activation(data)
30
+
31
+ torch_anti_alias_activation = activation1d.Activation1d(
32
+ activation=Snake(10), fused=False
33
+ ).cuda()
34
+ torch_activation_output = torch_anti_alias_activation(data)
35
+
36
+ test_result = (fused_activation_output - torch_activation_output).abs()
37
+
38
+ while test_result.dim() != 1:
39
+ test_result = test_result.mean(dim=-1)
40
+
41
+ diff = test_result.mean(dim=-1)
42
+
43
+ if diff <= 1e-3:
44
+ print(
45
+ f"\n[Success] test_fused_anti_alias_activation"
46
+ f"\n > mean_difference={diff}"
47
+ f"\n > fused_values={fused_activation_output[-1][-1][-100:].tolist()}"
48
+ f"\n > torch_values={torch_activation_output[-1][-1][-100:].tolist()}"
49
+ )
50
+ else:
51
+ print(
52
+ f"\n[Fail] test_fused_anti_alias_activation"
53
+ f"\n > mean_difference={diff}, "
54
+ f"\n > fused_values={fused_activation_output[-1][-1][-30:].tolist()}, "
55
+ f"\n > torch_values={torch_activation_output[-1][-1][-30:].tolist()}"
56
+ )
57
+
58
+
59
+ if __name__ == "__main__":
60
+ from alias_free_cuda import load
61
+
62
+ load.load()
63
+ test_load_fused_kernels()
64
+ test_anti_alias_activation()
fireredtts/modules/bigvgan/alias_free_cuda/test_activation_snake_beta.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import math
5
+ import torch
6
+ import alias_free_cuda
7
+ from alias_free_cuda import activation1d
8
+ from activations import Snake, SnakeBeta
9
+
10
+
11
+ def test_load_fused_kernels():
12
+ try:
13
+ import alias_free_cuda
14
+ import torch
15
+
16
+ print("[Success] load_fused_kernels")
17
+ except ImportError as e:
18
+ print("[Fail] load_fused_kernels")
19
+ raise e
20
+
21
+
22
+ def test_anti_alias_activation():
23
+ data = torch.rand((10, 10, 50000), device="cuda")
24
+
25
+ # check activations.Snake cuda vs. torch
26
+ fused_anti_alias_activation = activation1d.Activation1d(
27
+ activation=SnakeBeta(10), fused=True
28
+ ).cuda()
29
+ fused_activation_output = fused_anti_alias_activation(data)
30
+
31
+ torch_anti_alias_activation = activation1d.Activation1d(
32
+ activation=SnakeBeta(10), fused=False
33
+ ).cuda()
34
+ torch_activation_output = torch_anti_alias_activation(data)
35
+
36
+ test_result = (fused_activation_output - torch_activation_output).abs()
37
+
38
+ while test_result.dim() != 1:
39
+ test_result = test_result.mean(dim=-1)
40
+
41
+ diff = test_result.mean(dim=-1)
42
+
43
+ if diff <= 1e-3:
44
+ print(
45
+ f"\n[Success] test_fused_anti_alias_activation"
46
+ f"\n > mean_difference={diff}"
47
+ f"\n > fused_values={fused_activation_output[-1][-1][-100:].tolist()}"
48
+ f"\n > torch_values={torch_activation_output[-1][-1][-100:].tolist()}"
49
+ )
50
+ else:
51
+ print(
52
+ f"\n[Fail] test_fused_anti_alias_activation"
53
+ f"\n > mean_difference={diff}, "
54
+ f"\n > fused_values={fused_activation_output[-1][-1][-30:].tolist()}, "
55
+ f"\n > torch_values={torch_activation_output[-1][-1][-30:].tolist()}"
56
+ )
57
+
58
+
59
+ if __name__ == "__main__":
60
+ from alias_free_cuda import load
61
+
62
+ load.load()
63
+ test_load_fused_kernels()
64
+ test_anti_alias_activation()
fireredtts/modules/bigvgan/alias_free_cuda/type_shim.h ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+
18
+ #include <ATen/ATen.h>
19
+ #include "compat.h"
20
+
21
+
22
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
23
+ switch(TYPE) \
24
+ { \
25
+ case at::ScalarType::Float: \
26
+ { \
27
+ using scalar_t = float; \
28
+ __VA_ARGS__; \
29
+ break; \
30
+ } \
31
+ case at::ScalarType::Half: \
32
+ { \
33
+ using scalar_t = at::Half; \
34
+ __VA_ARGS__; \
35
+ break; \
36
+ } \
37
+ case at::ScalarType::BFloat16: \
38
+ { \
39
+ using scalar_t = at::BFloat16; \
40
+ __VA_ARGS__; \
41
+ break; \
42
+ } \
43
+ default: \
44
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
45
+ }
46
+
47
+
48
+
49
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
50
+ switch(TYPEIN) \
51
+ { \
52
+ case at::ScalarType::Float: \
53
+ { \
54
+ using scalar_t_in = float; \
55
+ switch(TYPEOUT) \
56
+ { \
57
+ case at::ScalarType::Float: \
58
+ { \
59
+ using scalar_t_out = float; \
60
+ __VA_ARGS__; \
61
+ break; \
62
+ } \
63
+ case at::ScalarType::Half: \
64
+ { \
65
+ using scalar_t_out = at::Half; \
66
+ __VA_ARGS__; \
67
+ break; \
68
+ } \
69
+ case at::ScalarType::BFloat16: \
70
+ { \
71
+ using scalar_t_out = at::BFloat16; \
72
+ __VA_ARGS__; \
73
+ break; \
74
+ } \
75
+ default: \
76
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
77
+ } \
78
+ break; \
79
+ } \
80
+ case at::ScalarType::Half: \
81
+ { \
82
+ using scalar_t_in = at::Half; \
83
+ using scalar_t_out = at::Half; \
84
+ __VA_ARGS__; \
85
+ break; \
86
+ } \
87
+ case at::ScalarType::BFloat16: \
88
+ { \
89
+ using scalar_t_in = at::BFloat16; \
90
+ using scalar_t_out = at::BFloat16; \
91
+ __VA_ARGS__; \
92
+ break; \
93
+ } \
94
+ default: \
95
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
96
+ }
97
+
fireredtts/modules/bigvgan/alias_free_torch/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ from .filter import *
4
+ from .resample import *
5
+ from .act import *
fireredtts/modules/bigvgan/alias_free_torch/act.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch.nn as nn
4
+ from .resample import UpSample1d, DownSample1d
5
+
6
+
7
+ class Activation1d(nn.Module):
8
+ def __init__(
9
+ self,
10
+ activation,
11
+ up_ratio: int = 2,
12
+ down_ratio: int = 2,
13
+ up_kernel_size: int = 12,
14
+ down_kernel_size: int = 12,
15
+ ):
16
+ super().__init__()
17
+ self.up_ratio = up_ratio
18
+ self.down_ratio = down_ratio
19
+ self.act = activation
20
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
21
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
22
+
23
+ # x: [B,C,T]
24
+ def forward(self, x):
25
+ x = self.upsample(x)
26
+ x = self.act(x)
27
+ x = self.downsample(x)
28
+
29
+ return x
fireredtts/modules/bigvgan/alias_free_torch/filter.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+
8
+ if "sinc" in dir(torch):
9
+ sinc = torch.sinc
10
+ else:
11
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
12
+ # https://adefossez.github.io/julius/julius/core.html
13
+ # LICENSE is in incl_licenses directory.
14
+ def sinc(x: torch.Tensor):
15
+ """
16
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
17
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
18
+ """
19
+ return torch.where(
20
+ x == 0,
21
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
22
+ torch.sin(math.pi * x) / math.pi / x,
23
+ )
24
+
25
+
26
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
27
+ # https://adefossez.github.io/julius/julius/lowpass.html
28
+ # LICENSE is in incl_licenses directory.
29
+ def kaiser_sinc_filter1d(
30
+ cutoff, half_width, kernel_size
31
+ ): # return filter [1,1,kernel_size]
32
+ even = kernel_size % 2 == 0
33
+ half_size = kernel_size // 2
34
+
35
+ # For kaiser window
36
+ delta_f = 4 * half_width
37
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
38
+ if A > 50.0:
39
+ beta = 0.1102 * (A - 8.7)
40
+ elif A >= 21.0:
41
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
42
+ else:
43
+ beta = 0.0
44
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
45
+
46
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
47
+ if even:
48
+ time = torch.arange(-half_size, half_size) + 0.5
49
+ else:
50
+ time = torch.arange(kernel_size) - half_size
51
+ if cutoff == 0:
52
+ filter_ = torch.zeros_like(time)
53
+ else:
54
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
55
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
56
+ # of the constant component in the input signal.
57
+ filter_ /= filter_.sum()
58
+ filter = filter_.view(1, 1, kernel_size)
59
+
60
+ return filter
61
+
62
+
63
+ class LowPassFilter1d(nn.Module):
64
+ def __init__(
65
+ self,
66
+ cutoff=0.5,
67
+ half_width=0.6,
68
+ stride: int = 1,
69
+ padding: bool = True,
70
+ padding_mode: str = "replicate",
71
+ kernel_size: int = 12,
72
+ ):
73
+ # kernel_size should be even number for stylegan3 setup,
74
+ # in this implementation, odd number is also possible.
75
+ super().__init__()
76
+ if cutoff < -0.0:
77
+ raise ValueError("Minimum cutoff must be larger than zero.")
78
+ if cutoff > 0.5:
79
+ raise ValueError("A cutoff above 0.5 does not make sense.")
80
+ self.kernel_size = kernel_size
81
+ self.even = kernel_size % 2 == 0
82
+ self.pad_left = kernel_size // 2 - int(self.even)
83
+ self.pad_right = kernel_size // 2
84
+ self.stride = stride
85
+ self.padding = padding
86
+ self.padding_mode = padding_mode
87
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
88
+ self.register_buffer("filter", filter)
89
+
90
+ # input [B, C, T]
91
+ def forward(self, x):
92
+ _, C, _ = x.shape
93
+
94
+ if self.padding:
95
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
96
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
97
+
98
+ return out
fireredtts/modules/bigvgan/alias_free_torch/resample.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from .filter import LowPassFilter1d
6
+ from .filter import kaiser_sinc_filter1d
7
+
8
+
9
+ class UpSample1d(nn.Module):
10
+ def __init__(self, ratio=2, kernel_size=None):
11
+ super().__init__()
12
+ self.ratio = ratio
13
+ self.kernel_size = (
14
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
+ )
16
+ self.stride = ratio
17
+ self.pad = self.kernel_size // ratio - 1
18
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
19
+ self.pad_right = (
20
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
21
+ )
22
+ filter = kaiser_sinc_filter1d(
23
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
24
+ )
25
+ self.register_buffer("filter", filter)
26
+
27
+ # x: [B, C, T]
28
+ def forward(self, x):
29
+ _, C, _ = x.shape
30
+
31
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
32
+ x = self.ratio * F.conv_transpose1d(
33
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
34
+ )
35
+ x = x[..., self.pad_left : -self.pad_right]
36
+
37
+ return x
38
+
39
+
40
+ class DownSample1d(nn.Module):
41
+ def __init__(self, ratio=2, kernel_size=None):
42
+ super().__init__()
43
+ self.ratio = ratio
44
+ self.kernel_size = (
45
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
46
+ )
47
+ self.lowpass = LowPassFilter1d(
48
+ cutoff=0.5 / ratio,
49
+ half_width=0.6 / ratio,
50
+ stride=ratio,
51
+ kernel_size=self.kernel_size,
52
+ )
53
+
54
+ def forward(self, x):
55
+ xx = self.lowpass(x)
56
+
57
+ return xx
fireredtts/modules/bigvgan/bigvgan.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm
6
+
7
+ from fireredtts.modules.bigvgan.alias_free_torch import (
8
+ Activation1d as TorchActivation1d,
9
+ )
10
+ from fireredtts.modules.bigvgan.activations import Snake, SnakeBeta
11
+
12
+
13
+ def init_weights(m, mean=0.0, std=0.01):
14
+ classname = m.__class__.__name__
15
+ if classname.find("Conv") != -1:
16
+ m.weight.data.normal_(mean, std)
17
+
18
+
19
+ def get_padding(kernel_size, dilation=1):
20
+ return int((kernel_size * dilation - dilation) / 2)
21
+
22
+
23
+ class AMPBlock1(torch.nn.Module):
24
+ def __init__(
25
+ self,
26
+ channels,
27
+ kernel_size=3,
28
+ dilation=(1, 3, 5),
29
+ activation=None,
30
+ snake_logscale=True,
31
+ use_cuda_kernel=False,
32
+ ):
33
+ super(AMPBlock1, self).__init__()
34
+
35
+ self.convs1 = nn.ModuleList(
36
+ [
37
+ weight_norm(
38
+ Conv1d(
39
+ channels,
40
+ channels,
41
+ kernel_size,
42
+ 1,
43
+ dilation=dilation[0],
44
+ padding=get_padding(kernel_size, dilation[0]),
45
+ )
46
+ ),
47
+ weight_norm(
48
+ Conv1d(
49
+ channels,
50
+ channels,
51
+ kernel_size,
52
+ 1,
53
+ dilation=dilation[1],
54
+ padding=get_padding(kernel_size, dilation[1]),
55
+ )
56
+ ),
57
+ weight_norm(
58
+ Conv1d(
59
+ channels,
60
+ channels,
61
+ kernel_size,
62
+ 1,
63
+ dilation=dilation[2],
64
+ padding=get_padding(kernel_size, dilation[2]),
65
+ )
66
+ ),
67
+ ]
68
+ )
69
+ self.convs1.apply(init_weights)
70
+
71
+ self.convs2 = nn.ModuleList(
72
+ [
73
+ weight_norm(
74
+ Conv1d(
75
+ channels,
76
+ channels,
77
+ kernel_size,
78
+ 1,
79
+ dilation=1,
80
+ padding=get_padding(kernel_size, 1),
81
+ )
82
+ ),
83
+ weight_norm(
84
+ Conv1d(
85
+ channels,
86
+ channels,
87
+ kernel_size,
88
+ 1,
89
+ dilation=1,
90
+ padding=get_padding(kernel_size, 1),
91
+ )
92
+ ),
93
+ weight_norm(
94
+ Conv1d(
95
+ channels,
96
+ channels,
97
+ kernel_size,
98
+ 1,
99
+ dilation=1,
100
+ padding=get_padding(kernel_size, 1),
101
+ )
102
+ ),
103
+ ]
104
+ )
105
+ self.convs2.apply(init_weights)
106
+
107
+ self.num_layers = len(self.convs1) + len(
108
+ self.convs2
109
+ ) # total number of conv layers
110
+
111
+ # select which Activation1d, lazy-load cuda version to ensure backward compatibility
112
+ if use_cuda_kernel:
113
+ from modules.bigvgan.alias_free_cuda.activation1d import (
114
+ Activation1d as CudaActivation1d,
115
+ )
116
+
117
+ Activation1d = CudaActivation1d
118
+ else:
119
+ Activation1d = TorchActivation1d
120
+
121
+ if (
122
+ activation == "snake"
123
+ ): # periodic nonlinearity with snake function and anti-aliasing
124
+ self.activations = nn.ModuleList(
125
+ [
126
+ Activation1d(
127
+ activation=Snake(channels, alpha_logscale=snake_logscale)
128
+ )
129
+ for _ in range(self.num_layers)
130
+ ]
131
+ )
132
+ elif (
133
+ activation == "snakebeta"
134
+ ): # periodic nonlinearity with snakebeta function and anti-aliasing
135
+ self.activations = nn.ModuleList(
136
+ [
137
+ Activation1d(
138
+ activation=SnakeBeta(channels, alpha_logscale=snake_logscale)
139
+ )
140
+ for _ in range(self.num_layers)
141
+ ]
142
+ )
143
+ else:
144
+ raise NotImplementedError(
145
+ "activation incorrectly specified. check the config file and look for 'activation'."
146
+ )
147
+
148
+ def forward(self, x):
149
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
150
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
151
+ xt = a1(x)
152
+ xt = c1(xt)
153
+ xt = a2(xt)
154
+ xt = c2(xt)
155
+ x = xt + x
156
+
157
+ return x
158
+
159
+ def remove_weight_norm(self):
160
+ for l in self.convs1:
161
+ remove_weight_norm(l)
162
+ for l in self.convs2:
163
+ remove_weight_norm(l)
164
+
165
+
166
+ class AMPBlock2(torch.nn.Module):
167
+ def __init__(
168
+ self,
169
+ channels,
170
+ kernel_size=3,
171
+ dilation=(1, 3),
172
+ activation=None,
173
+ snake_logscale=True,
174
+ use_cuda_kernel=False,
175
+ ):
176
+ super(AMPBlock2, self).__init__()
177
+
178
+ self.convs = nn.ModuleList(
179
+ [
180
+ weight_norm(
181
+ Conv1d(
182
+ channels,
183
+ channels,
184
+ kernel_size,
185
+ 1,
186
+ dilation=dilation[0],
187
+ padding=get_padding(kernel_size, dilation[0]),
188
+ )
189
+ ),
190
+ weight_norm(
191
+ Conv1d(
192
+ channels,
193
+ channels,
194
+ kernel_size,
195
+ 1,
196
+ dilation=dilation[1],
197
+ padding=get_padding(kernel_size, dilation[1]),
198
+ )
199
+ ),
200
+ ]
201
+ )
202
+ self.convs.apply(init_weights)
203
+
204
+ self.num_layers = len(self.convs) # total number of conv layers
205
+
206
+ # select which Activation1d, lazy-load cuda version to ensure backward compatibility
207
+ if use_cuda_kernel:
208
+ from modules.bigvgan.alias_free_cuda.activation1d import (
209
+ Activation1d as CudaActivation1d,
210
+ )
211
+
212
+ Activation1d = CudaActivation1d
213
+ else:
214
+ Activation1d = TorchActivation1d
215
+
216
+ if (
217
+ activation == "snake"
218
+ ): # periodic nonlinearity with snake function and anti-aliasing
219
+ self.activations = nn.ModuleList(
220
+ [
221
+ Activation1d(
222
+ activation=Snake(channels, alpha_logscale=snake_logscale)
223
+ )
224
+ for _ in range(self.num_layers)
225
+ ]
226
+ )
227
+ elif (
228
+ activation == "snakebeta"
229
+ ): # periodic nonlinearity with snakebeta function and anti-aliasing
230
+ self.activations = nn.ModuleList(
231
+ [
232
+ Activation1d(
233
+ activation=SnakeBeta(channels, alpha_logscale=snake_logscale)
234
+ )
235
+ for _ in range(self.num_layers)
236
+ ]
237
+ )
238
+ else:
239
+ raise NotImplementedError(
240
+ "activation incorrectly specified. check the config file and look for 'activation'."
241
+ )
242
+
243
+ def forward(self, x):
244
+ for c, a in zip(self.convs, self.activations):
245
+ xt = a(x)
246
+ xt = c(xt)
247
+ x = xt + x
248
+
249
+ return x
250
+
251
+ def remove_weight_norm(self):
252
+ for l in self.convs:
253
+ remove_weight_norm(l)
254
+
255
+
256
+ class BigVGAN(torch.nn.Module):
257
+ # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
258
+ def __init__(
259
+ self,
260
+ num_mels: int,
261
+ upsample_initial_channel: int,
262
+ resblock_kernel_sizes: tp.List[int],
263
+ resblock_dilation_sizes: tp.List[tp.List[int]],
264
+ upsample_rates: tp.List[int],
265
+ upsample_kernel_sizes: tp.List[int],
266
+ resblock_type: str = "1",
267
+ snake_logscale: bool = True,
268
+ activation: str = "snakebeta",
269
+ use_tanh_at_final: bool = False,
270
+ use_bias_at_final: bool = False,
271
+ use_cuda_kernel: bool = False,
272
+ ):
273
+ super(BigVGAN, self).__init__()
274
+
275
+ self.num_kernels = len(resblock_kernel_sizes)
276
+ self.num_upsamples = len(upsample_rates)
277
+
278
+ # pre conv
279
+ self.conv_pre = weight_norm(
280
+ Conv1d(num_mels, upsample_initial_channel, 7, 1, padding=3)
281
+ )
282
+
283
+ # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
284
+ resblock = AMPBlock1 if resblock_type == "1" else AMPBlock2
285
+
286
+ # transposed conv-based upsamplers. does not apply anti-aliasing
287
+ self.ups = nn.ModuleList()
288
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
289
+ self.ups.append(
290
+ nn.ModuleList(
291
+ [
292
+ weight_norm(
293
+ ConvTranspose1d(
294
+ upsample_initial_channel // (2**i),
295
+ upsample_initial_channel // (2 ** (i + 1)),
296
+ k,
297
+ u,
298
+ padding=(k - u) // 2,
299
+ )
300
+ )
301
+ ]
302
+ )
303
+ )
304
+
305
+ # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
306
+ self.resblocks = nn.ModuleList()
307
+ for i in range(len(self.ups)):
308
+ ch = upsample_initial_channel // (2 ** (i + 1))
309
+ for j, (k, d) in enumerate(
310
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
311
+ ):
312
+ self.resblocks.append(
313
+ resblock(
314
+ ch,
315
+ k,
316
+ d,
317
+ activation=activation,
318
+ snake_logscale=snake_logscale,
319
+ use_cuda_kernel=use_cuda_kernel,
320
+ )
321
+ )
322
+
323
+ # select which Activation1d, lazy-load cuda version to ensure backward compatibility
324
+ if use_cuda_kernel:
325
+ from modules.bigvgan.alias_free_cuda.activation1d import (
326
+ Activation1d as CudaActivation1d,
327
+ )
328
+
329
+ Activation1d = CudaActivation1d
330
+ else:
331
+ Activation1d = TorchActivation1d
332
+
333
+ # post conv
334
+ if (
335
+ activation == "snake"
336
+ ): # periodic nonlinearity with snake function and anti-aliasing
337
+ activation_post = Snake(ch, alpha_logscale=snake_logscale)
338
+ self.activation_post = Activation1d(activation=activation_post)
339
+ elif (
340
+ activation == "snakebeta"
341
+ ): # periodic nonlinearity with snakebeta function and anti-aliasing
342
+ activation_post = SnakeBeta(ch, alpha_logscale=snake_logscale)
343
+ self.activation_post = Activation1d(activation=activation_post)
344
+ else:
345
+ raise NotImplementedError(
346
+ "activation incorrectly specified. check the config file and look for 'activation'."
347
+ )
348
+
349
+ # whether to use bias for the final conv_post. Defaults to True for backward compatibility
350
+ self.use_bias_at_final = use_bias_at_final
351
+ self.conv_post = weight_norm(
352
+ Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
353
+ )
354
+
355
+ # weight initialization
356
+ for i in range(len(self.ups)):
357
+ self.ups[i].apply(init_weights)
358
+ self.conv_post.apply(init_weights)
359
+
360
+ # final tanh activation. Defaults to True for backward compatibility
361
+ self.use_tanh_at_final = use_tanh_at_final
362
+
363
+ def forward(self, x):
364
+ # pre conv
365
+ x = self.conv_pre(x)
366
+
367
+ for i in range(self.num_upsamples):
368
+ # upsampling
369
+ for i_up in range(len(self.ups[i])):
370
+ x = self.ups[i][i_up](x)
371
+ # AMP blocks
372
+ xs = None
373
+ for j in range(self.num_kernels):
374
+ if xs is None:
375
+ xs = self.resblocks[i * self.num_kernels + j](x)
376
+ else:
377
+ xs += self.resblocks[i * self.num_kernels + j](x)
378
+ x = xs / self.num_kernels
379
+
380
+ # post conv
381
+ x = self.activation_post(x)
382
+ x = self.conv_post(x)
383
+ # final tanh activation
384
+ if self.use_tanh_at_final:
385
+ x = torch.tanh(x)
386
+ else:
387
+ x = torch.clamp(x, min=-1.0, max=1.0) # bound the output to [-1, 1]
388
+
389
+ return x
390
+
391
+ def remove_weight_norm(self):
392
+ print("Removing weight norm...")
393
+ for l in self.ups:
394
+ for l_i in l:
395
+ remove_weight_norm(l_i)
396
+ for l in self.resblocks:
397
+ l.remove_weight_norm()
398
+ remove_weight_norm(self.conv_pre)
399
+ remove_weight_norm(self.conv_post)
fireredtts/modules/codec/speaker.py ADDED
@@ -0,0 +1,1052 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from librosa.filters import mel as librosa_mel_fn
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ import math
6
+ import numpy as np
7
+ import torch
8
+ import torchaudio
9
+
10
+
11
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
12
+ return torch.log(torch.clamp(x, min=clip_val) * C)
13
+
14
+
15
+ def spectral_normalize_torch(magnitudes):
16
+ output = dynamic_range_compression_torch(magnitudes)
17
+ return output
18
+
19
+
20
+ class TorchMelSpectrogram(nn.Module):
21
+
22
+ def __init__(
23
+ self,
24
+ filter_length=1024,
25
+ hop_length=160,
26
+ win_length=640,
27
+ n_mel_channels=80,
28
+ mel_fmin=0,
29
+ mel_fmax=8000,
30
+ sampling_rate=16000,
31
+ ):
32
+
33
+ super().__init__()
34
+ self.filter_length = filter_length
35
+ self.hop_length = hop_length
36
+ self.win_length = win_length
37
+ self.n_mel_channels = n_mel_channels
38
+ self.mel_fmin = mel_fmin
39
+ self.mel_fmax = mel_fmax
40
+ self.sampling_rate = sampling_rate
41
+
42
+ self.mel_basis = {}
43
+ self.hann_window = {}
44
+
45
+ def forward(self, inp, length=None):
46
+ if len(inp.shape) == 3:
47
+ inp = inp.squeeze(1) if inp.shape[1] == 1 else inp.squeeze(2)
48
+ assert len(inp.shape) == 2
49
+
50
+ y = inp
51
+ if len(list(self.mel_basis.keys())) == 0:
52
+ mel = librosa_mel_fn(
53
+ sr=self.sampling_rate,
54
+ n_fft=self.filter_length,
55
+ n_mels=self.n_mel_channels,
56
+ fmin=self.mel_fmin,
57
+ fmax=self.mel_fmax,
58
+ )
59
+ self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)] = (
60
+ torch.from_numpy(mel).float().to(y.device)
61
+ )
62
+ self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to(
63
+ y.device
64
+ )
65
+
66
+ y = torch.nn.functional.pad(
67
+ y.unsqueeze(1),
68
+ (
69
+ int((self.filter_length - self.hop_length) / 2),
70
+ int((self.filter_length - self.hop_length) / 2),
71
+ ),
72
+ mode="reflect",
73
+ )
74
+ y = y.squeeze(1)
75
+
76
+ # complex tensor as default, then use view_as_real for future pytorch compatibility
77
+ spec = torch.stft(
78
+ y,
79
+ self.filter_length,
80
+ hop_length=self.hop_length,
81
+ win_length=self.win_length,
82
+ window=self.hann_window[str(y.device)],
83
+ center=False,
84
+ pad_mode="reflect",
85
+ normalized=False,
86
+ onesided=True,
87
+ return_complex=True,
88
+ )
89
+ spec = torch.view_as_real(spec)
90
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
91
+
92
+ spec = torch.matmul(
93
+ self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)], spec
94
+ )
95
+ spec = spectral_normalize_torch(spec)
96
+
97
+ max_mel_length = math.ceil(y.shape[-1] / self.hop_length)
98
+ spec = spec[..., :max_mel_length].transpose(1, 2)
99
+
100
+ if length is None:
101
+ return spec
102
+ else:
103
+ spec_len = torch.ceil(length / self.hop_length).clamp(max=spec.shape[1])
104
+ return spec, spec_len
105
+
106
+
107
+ def length_to_mask(length, max_len=None, dtype=None, device=None):
108
+ """Creates a binary mask for each sequence.
109
+
110
+ Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
111
+
112
+ Arguments
113
+ ---------
114
+ length : torch.LongTensor
115
+ Containing the length of each sequence in the batch. Must be 1D.
116
+ max_len : int
117
+ Max length for the mask, also the size of the second dimension.
118
+ dtype : torch.dtype, default: None
119
+ The dtype of the generated mask.
120
+ device: torch.device, default: None
121
+ The device to put the mask variable.
122
+
123
+ Returns
124
+ -------
125
+ mask : tensor
126
+ The binary mask.
127
+
128
+ Example
129
+ -------
130
+ >>> length=torch.Tensor([1,2,3])
131
+ >>> mask=length_to_mask(length)
132
+ >>> mask
133
+ tensor([[1., 0., 0.],
134
+ [1., 1., 0.],
135
+ [1., 1., 1.]])
136
+ """
137
+ assert len(length.shape) == 1
138
+
139
+ if max_len is None:
140
+ max_len = length.max().long().item() # using arange to generate mask
141
+ mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
142
+ len(length), max_len
143
+ ) < length.unsqueeze(1)
144
+
145
+ if dtype is None:
146
+ dtype = length.dtype
147
+
148
+ if device is None:
149
+ device = length.device
150
+
151
+ mask = torch.as_tensor(mask, dtype=dtype, device=device)
152
+ return mask
153
+
154
+
155
+ def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
156
+ """This function computes the number of elements to add for zero-padding.
157
+
158
+ Arguments
159
+ ---------
160
+ L_in : int
161
+ stride: int
162
+ kernel_size : int
163
+ dilation : int
164
+ """
165
+ if stride > 1:
166
+ n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
167
+ L_out = stride * (n_steps - 1) + kernel_size * dilation
168
+ padding = [kernel_size // 2, kernel_size // 2]
169
+
170
+ else:
171
+ L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
172
+
173
+ padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
174
+ return padding
175
+
176
+
177
+ class Conv1d(nn.Module):
178
+ """This function implements 1d convolution.
179
+
180
+ Arguments
181
+ ---------
182
+ out_channels : int
183
+ It is the number of output channels.
184
+ kernel_size : int
185
+ Kernel size of the convolutional filters.
186
+ input_shape : tuple
187
+ The shape of the input. Alternatively use ``in_channels``.
188
+ in_channels : int
189
+ The number of input channels. Alternatively use ``input_shape``.
190
+ stride : int
191
+ Stride factor of the convolutional filters. When the stride factor > 1,
192
+ a decimation in time is performed.
193
+ dilation : int
194
+ Dilation factor of the convolutional filters.
195
+ padding : str
196
+ (same, valid, causal). If "valid", no padding is performed.
197
+ If "same" and stride is 1, output shape is the same as the input shape.
198
+ "causal" results in causal (dilated) convolutions.
199
+ padding_mode : str
200
+ This flag specifies the type of padding. See torch.nn documentation
201
+ for more information.
202
+ skip_transpose : bool
203
+ If False, uses batch x time x channel convention of speechbrain.
204
+ If True, uses batch x channel x time convention.
205
+
206
+ Example
207
+ -------
208
+ >>> inp_tensor = torch.rand([10, 40, 16])
209
+ >>> cnn_1d = Conv1d(
210
+ ... input_shape=inp_tensor.shape, out_channels=8, kernel_size=5
211
+ ... )
212
+ >>> out_tensor = cnn_1d(inp_tensor)
213
+ >>> out_tensor.shape
214
+ torch.Size([10, 40, 8])
215
+ """
216
+
217
+ def __init__(
218
+ self,
219
+ out_channels,
220
+ kernel_size,
221
+ input_shape=None,
222
+ in_channels=None,
223
+ stride=1,
224
+ dilation=1,
225
+ padding="same",
226
+ groups=1,
227
+ bias=True,
228
+ padding_mode="reflect",
229
+ skip_transpose=True,
230
+ ):
231
+ super().__init__()
232
+ self.kernel_size = kernel_size
233
+ self.stride = stride
234
+ self.dilation = dilation
235
+ self.padding = padding
236
+ self.padding_mode = padding_mode
237
+ self.unsqueeze = False
238
+ self.skip_transpose = skip_transpose
239
+
240
+ if input_shape is None and in_channels is None:
241
+ raise ValueError("Must provide one of input_shape or in_channels")
242
+
243
+ if in_channels is None:
244
+ in_channels = self._check_input_shape(input_shape)
245
+
246
+ self.conv = nn.Conv1d(
247
+ in_channels,
248
+ out_channels,
249
+ self.kernel_size,
250
+ stride=self.stride,
251
+ dilation=self.dilation,
252
+ padding=0,
253
+ groups=groups,
254
+ bias=bias,
255
+ )
256
+
257
+ def forward(self, x):
258
+ """Returns the output of the convolution.
259
+
260
+ Arguments
261
+ ---------
262
+ x : torch.Tensor (batch, time, channel)
263
+ input to convolve. 2d or 4d tensors are expected.
264
+ """
265
+
266
+ if not self.skip_transpose:
267
+ x = x.transpose(1, -1)
268
+
269
+ if self.unsqueeze:
270
+ x = x.unsqueeze(1)
271
+
272
+ if self.padding == "same":
273
+ x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
274
+
275
+ elif self.padding == "causal":
276
+ num_pad = (self.kernel_size - 1) * self.dilation
277
+ x = F.pad(x, (num_pad, 0))
278
+
279
+ elif self.padding == "valid":
280
+ pass
281
+
282
+ else:
283
+ raise ValueError(
284
+ "Padding must be 'same', 'valid' or 'causal'. Got " + self.padding
285
+ )
286
+
287
+ wx = self.conv(x)
288
+
289
+ if self.unsqueeze:
290
+ wx = wx.squeeze(1)
291
+
292
+ if not self.skip_transpose:
293
+ wx = wx.transpose(1, -1)
294
+
295
+ return wx
296
+
297
+ def _manage_padding(
298
+ self,
299
+ x,
300
+ kernel_size: int,
301
+ dilation: int,
302
+ stride: int,
303
+ ):
304
+ """This function performs zero-padding on the time axis
305
+ such that their lengths is unchanged after the convolution.
306
+
307
+ Arguments
308
+ ---------
309
+ x : torch.Tensor
310
+ Input tensor.
311
+ kernel_size : int
312
+ Size of kernel.
313
+ dilation : int
314
+ Dilation used.
315
+ stride : int
316
+ Stride.
317
+ """
318
+
319
+ # Detecting input shape
320
+ L_in = x.shape[-1]
321
+
322
+ # Time padding
323
+ padding = get_padding_elem(L_in, stride, kernel_size, dilation)
324
+
325
+ # Applying padding
326
+ x = F.pad(x, padding, mode=self.padding_mode)
327
+
328
+ return x
329
+
330
+ def _check_input_shape(self, shape):
331
+ """Checks the input shape and returns the number of input channels."""
332
+
333
+ if len(shape) == 2:
334
+ self.unsqueeze = True
335
+ in_channels = 1
336
+ elif self.skip_transpose:
337
+ in_channels = shape[1]
338
+ elif len(shape) == 3:
339
+ in_channels = shape[2]
340
+ else:
341
+ raise ValueError("conv1d expects 2d, 3d inputs. Got " + str(len(shape)))
342
+
343
+ # Kernel size must be odd
344
+ if self.kernel_size % 2 == 0:
345
+ raise ValueError(
346
+ "The field kernel size must be an odd number. Got %s."
347
+ % (self.kernel_size)
348
+ )
349
+ return in_channels
350
+
351
+
352
+ class Fp32BatchNorm(nn.Module):
353
+ def __init__(self, sync=True, *args, **kwargs):
354
+ super().__init__()
355
+
356
+ if (
357
+ not torch.distributed.is_initialized()
358
+ or torch.distributed.get_world_size() == 1
359
+ ):
360
+ sync = False
361
+
362
+ if sync:
363
+ self.bn = nn.SyncBatchNorm(*args, **kwargs)
364
+ else:
365
+ self.bn = nn.BatchNorm1d(*args, **kwargs)
366
+
367
+ self.sync = sync
368
+
369
+ def forward(self, input):
370
+ if self.bn.running_mean.dtype != torch.float:
371
+ if self.sync:
372
+ self.bn.running_mean = self.bn.running_mean.float()
373
+ self.bn.running_var = self.bn.running_var.float()
374
+ if self.bn.affine:
375
+ try:
376
+ self.bn.weight = self.bn.weight.float()
377
+ self.bn.bias = self.bn.bias.float()
378
+ except:
379
+ self.bn.float()
380
+ else:
381
+ self.bn.float()
382
+
383
+ output = self.bn(input.float())
384
+ return output.type_as(input)
385
+
386
+
387
+ class BatchNorm1d(nn.Module):
388
+ """Applies 1d batch normalization to the input tensor.
389
+
390
+ Arguments
391
+ ---------
392
+ input_shape : tuple
393
+ The expected shape of the input. Alternatively, use ``input_size``.
394
+ input_size : int
395
+ The expected size of the input. Alternatively, use ``input_shape``.
396
+ eps : float
397
+ This value is added to std deviation estimation to improve the numerical
398
+ stability.
399
+ momentum : float
400
+ It is a value used for the running_mean and running_var computation.
401
+ affine : bool
402
+ When set to True, the affine parameters are learned.
403
+ track_running_stats : bool
404
+ When set to True, this module tracks the running mean and variance,
405
+ and when set to False, this module does not track such statistics.
406
+ combine_batch_time : bool
407
+ When true, it combines batch an time axis.
408
+
409
+
410
+ Example
411
+ -------
412
+ >>> input = torch.randn(100, 10)
413
+ >>> norm = BatchNorm1d(input_shape=input.shape)
414
+ >>> output = norm(input)
415
+ >>> output.shape
416
+ torch.Size([100, 10])
417
+ """
418
+
419
+ def __init__(
420
+ self,
421
+ input_shape=None,
422
+ input_size=None,
423
+ eps=1e-05,
424
+ momentum=0.1,
425
+ affine=True,
426
+ track_running_stats=True,
427
+ combine_batch_time=False,
428
+ skip_transpose=True,
429
+ enabled=True,
430
+ ):
431
+ super().__init__()
432
+ self.combine_batch_time = combine_batch_time
433
+ self.skip_transpose = skip_transpose
434
+
435
+ if input_size is None and skip_transpose:
436
+ input_size = input_shape[1]
437
+ elif input_size is None:
438
+ input_size = input_shape[-1]
439
+
440
+ if enabled:
441
+ self.norm = Fp32BatchNorm(
442
+ num_features=input_size,
443
+ eps=eps,
444
+ momentum=momentum,
445
+ affine=affine,
446
+ track_running_stats=track_running_stats,
447
+ )
448
+ else:
449
+ self.norm = nn.Identity()
450
+
451
+ def forward(self, x):
452
+ """Returns the normalized input tensor.
453
+
454
+ Arguments
455
+ ---------
456
+ x : torch.Tensor (batch, time, [channels])
457
+ input to normalize. 2d or 3d tensors are expected in input
458
+ 4d tensors can be used when combine_dims=True.
459
+ """
460
+ shape_or = x.shape
461
+ if self.combine_batch_time:
462
+ if x.ndim == 3:
463
+ x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
464
+ else:
465
+ x = x.reshape(shape_or[0] * shape_or[1], shape_or[3], shape_or[2])
466
+
467
+ elif not self.skip_transpose:
468
+ x = x.transpose(-1, 1)
469
+
470
+ x_n = self.norm(x)
471
+
472
+ if self.combine_batch_time:
473
+ x_n = x_n.reshape(shape_or)
474
+ elif not self.skip_transpose:
475
+ x_n = x_n.transpose(1, -1)
476
+
477
+ return x_n
478
+
479
+
480
+ class Linear(torch.nn.Module):
481
+ """Computes a linear transformation y = wx + b.
482
+
483
+ Arguments
484
+ ---------
485
+ n_neurons : int
486
+ It is the number of output neurons (i.e, the dimensionality of the
487
+ output).
488
+ bias : bool
489
+ If True, the additive bias b is adopted.
490
+ combine_dims : bool
491
+ If True and the input is 4D, combine 3rd and 4th dimensions of input.
492
+
493
+ Example
494
+ -------
495
+ >>> inputs = torch.rand(10, 50, 40)
496
+ >>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100)
497
+ >>> output = lin_t(inputs)
498
+ >>> output.shape
499
+ torch.Size([10, 50, 100])
500
+ """
501
+
502
+ def __init__(
503
+ self,
504
+ n_neurons,
505
+ input_shape=None,
506
+ input_size=None,
507
+ bias=True,
508
+ combine_dims=False,
509
+ ):
510
+ super().__init__()
511
+ self.combine_dims = combine_dims
512
+
513
+ if input_shape is None and input_size is None:
514
+ raise ValueError("Expected one of input_shape or input_size")
515
+
516
+ if input_size is None:
517
+ input_size = input_shape[-1]
518
+ if len(input_shape) == 4 and self.combine_dims:
519
+ input_size = input_shape[2] * input_shape[3]
520
+
521
+ # Weights are initialized following pytorch approach
522
+ self.w = nn.Linear(input_size, n_neurons, bias=bias)
523
+
524
+ def forward(self, x):
525
+ """Returns the linear transformation of input tensor.
526
+
527
+ Arguments
528
+ ---------
529
+ x : torch.Tensor
530
+ Input to transform linearly.
531
+ """
532
+ if x.ndim == 4 and self.combine_dims:
533
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
534
+
535
+ wx = self.w(x)
536
+
537
+ return wx
538
+
539
+
540
+ class TDNNBlock(nn.Module):
541
+ """An implementation of TDNN.
542
+
543
+ Arguments
544
+ ----------
545
+ in_channels : int
546
+ Number of input channels.
547
+ out_channels : int
548
+ The number of output channels.
549
+ kernel_size : int
550
+ The kernel size of the TDNN blocks.
551
+ dilation : int
552
+ The dilation of the Res2Net block.
553
+ activation : torch class
554
+ A class for constructing the activation layers.
555
+
556
+ Example
557
+ -------
558
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
559
+ >>> layer = TDNNBlock(64, 64, kernel_size=3, dilation=1)
560
+ >>> out_tensor = layer(inp_tensor).transpose(1, 2)
561
+ >>> out_tensor.shape
562
+ torch.Size([8, 120, 64])
563
+ """
564
+
565
+ def __init__(
566
+ self,
567
+ in_channels,
568
+ out_channels,
569
+ kernel_size,
570
+ dilation,
571
+ activation=nn.ReLU,
572
+ batch_norm=True,
573
+ ):
574
+ super(TDNNBlock, self).__init__()
575
+ self.conv = Conv1d(
576
+ in_channels=in_channels,
577
+ out_channels=out_channels,
578
+ kernel_size=kernel_size,
579
+ dilation=dilation,
580
+ )
581
+ self.activation = activation()
582
+ self.norm = BatchNorm1d(input_size=out_channels, enabled=batch_norm)
583
+
584
+ def forward(self, x):
585
+ return self.norm(self.activation(self.conv(x)))
586
+
587
+
588
+ class Res2NetBlock(torch.nn.Module):
589
+ """An implementation of Res2NetBlock w/ dilation.
590
+
591
+ Arguments
592
+ ---------
593
+ in_channels : int
594
+ The number of channels expected in the input.
595
+ out_channels : int
596
+ The number of output channels.
597
+ scale : int
598
+ The scale of the Res2Net block.
599
+ kernel_size: int
600
+ The kernel size of the Res2Net block.
601
+ dilation : int
602
+ The dilation of the Res2Net block.
603
+
604
+ Example
605
+ -------
606
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
607
+ >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
608
+ >>> out_tensor = layer(inp_tensor).transpose(1, 2)
609
+ >>> out_tensor.shape
610
+ torch.Size([8, 120, 64])
611
+ """
612
+
613
+ def __init__(
614
+ self,
615
+ in_channels,
616
+ out_channels,
617
+ scale=8,
618
+ kernel_size=3,
619
+ dilation=1,
620
+ batch_norm=True,
621
+ ):
622
+ super(Res2NetBlock, self).__init__()
623
+ assert in_channels % scale == 0
624
+ assert out_channels % scale == 0
625
+
626
+ in_channel = in_channels // scale
627
+ hidden_channel = out_channels // scale
628
+
629
+ self.blocks = nn.ModuleList(
630
+ [
631
+ TDNNBlock(
632
+ in_channel,
633
+ hidden_channel,
634
+ kernel_size=kernel_size,
635
+ dilation=dilation,
636
+ batch_norm=batch_norm,
637
+ )
638
+ for i in range(scale - 1)
639
+ ]
640
+ )
641
+ self.scale = scale
642
+
643
+ def forward(self, x):
644
+ y = []
645
+ for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
646
+ if i == 0:
647
+ y_i = x_i
648
+ elif i == 1:
649
+ y_i = self.blocks[i - 1](x_i)
650
+ else:
651
+ y_i = self.blocks[i - 1](x_i + y_i)
652
+ y.append(y_i)
653
+ y = torch.cat(y, dim=1)
654
+ return y
655
+
656
+
657
+ class SEBlock(nn.Module):
658
+ """An implementation of squeeze-and-excitation block.
659
+
660
+ Arguments
661
+ ---------
662
+ in_channels : int
663
+ The number of input channels.
664
+ se_channels : int
665
+ The number of output channels after squeeze.
666
+ out_channels : int
667
+ The number of output channels.
668
+
669
+ Example
670
+ -------
671
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
672
+ >>> se_layer = SEBlock(64, 16, 64)
673
+ >>> lengths = torch.rand((8,))
674
+ >>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2)
675
+ >>> out_tensor.shape
676
+ torch.Size([8, 120, 64])
677
+ """
678
+
679
+ def __init__(self, in_channels, se_channels, out_channels):
680
+ super(SEBlock, self).__init__()
681
+
682
+ self.conv1 = Conv1d(
683
+ in_channels=in_channels, out_channels=se_channels, kernel_size=1
684
+ )
685
+ self.relu = torch.nn.ReLU(inplace=True)
686
+ self.conv2 = Conv1d(
687
+ in_channels=se_channels, out_channels=out_channels, kernel_size=1
688
+ )
689
+ self.sigmoid = torch.nn.Sigmoid()
690
+
691
+ def forward(self, x, lengths=None):
692
+ L = x.shape[-1]
693
+ if lengths is not None:
694
+ mask = length_to_mask(lengths * L, max_len=L, device=x.device)
695
+ mask = mask.unsqueeze(1)
696
+ total = mask.sum(dim=2, keepdim=True)
697
+ s = (x * mask).sum(dim=2, keepdim=True) / total
698
+ else:
699
+ s = x.mean(dim=2, keepdim=True)
700
+
701
+ s = self.relu(self.conv1(s))
702
+ s = self.sigmoid(self.conv2(s))
703
+
704
+ return s * x
705
+
706
+
707
+ class AttentiveStatisticsPooling(nn.Module):
708
+ """This class implements an attentive statistic pooling layer for each channel.
709
+ It returns the concatenated mean and std of the input tensor.
710
+
711
+ Arguments
712
+ ---------
713
+ channels: int
714
+ The number of input channels.
715
+ attention_channels: int
716
+ The number of attention channels.
717
+
718
+ Example
719
+ -------
720
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
721
+ >>> asp_layer = AttentiveStatisticsPooling(64)
722
+ >>> lengths = torch.rand((8,))
723
+ >>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2)
724
+ >>> out_tensor.shape
725
+ torch.Size([8, 1, 128])
726
+ """
727
+
728
+ def __init__(
729
+ self, channels, attention_channels=128, global_context=True, batch_norm=True
730
+ ):
731
+ super().__init__()
732
+
733
+ self.eps = 1e-12
734
+ self.global_context = global_context
735
+ if global_context:
736
+ self.tdnn = TDNNBlock(
737
+ channels * 3, attention_channels, 1, 1, batch_norm=batch_norm
738
+ )
739
+ else:
740
+ self.tdnn = TDNNBlock(
741
+ channels, attention_channels, 1, 1, batch_norm, batch_norm
742
+ )
743
+ self.tanh = nn.Tanh()
744
+ self.conv = Conv1d(
745
+ in_channels=attention_channels, out_channels=channels, kernel_size=1
746
+ )
747
+
748
+ def forward(self, x, lengths=None):
749
+ """Calculates mean and std for a batch (input tensor).
750
+
751
+ Arguments
752
+ ---------
753
+ x : torch.Tensor
754
+ Tensor of shape [N, C, L].
755
+ """
756
+ L = x.shape[-1]
757
+
758
+ def _compute_statistics(x, m, dim=2, eps=self.eps):
759
+ mean = (m * x).sum(dim)
760
+ std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
761
+ return mean, std
762
+
763
+ if lengths is None:
764
+ lengths = torch.ones(x.shape[0], device=x.device)
765
+
766
+ # Make binary mask of shape [N, 1, L]
767
+ mask = length_to_mask(lengths * L, max_len=L, device=x.device)
768
+ mask = mask.unsqueeze(1)
769
+
770
+ # Expand the temporal context of the pooling layer by allowing the
771
+ # self-attention to look at global properties of the utterance.
772
+ if self.global_context:
773
+ # torch.std is unstable for backward computation
774
+ # https://github.com/pytorch/pytorch/issues/4320
775
+ total = mask.sum(dim=2, keepdim=True).float()
776
+ mean, std = _compute_statistics(x, mask / total)
777
+ mean = mean.unsqueeze(2).repeat(1, 1, L)
778
+ std = std.unsqueeze(2).repeat(1, 1, L)
779
+ attn = torch.cat([x, mean, std], dim=1)
780
+ else:
781
+ attn = x
782
+
783
+ # Apply layers
784
+ attn = self.conv(self.tanh(self.tdnn(attn)))
785
+
786
+ # Filter out zero-paddings
787
+ attn = attn.masked_fill(mask == 0, float("-inf"))
788
+
789
+ attn = F.softmax(attn, dim=2)
790
+ mean, std = _compute_statistics(x, attn)
791
+ # Append mean and std of the batch
792
+ pooled_stats = torch.cat((mean, std), dim=1)
793
+ pooled_stats = pooled_stats.unsqueeze(2)
794
+
795
+ return pooled_stats
796
+
797
+
798
+ class SERes2NetBlock(nn.Module):
799
+ """An implementation of building block in ECAPA-TDNN, i.e.,
800
+ TDNN-Res2Net-TDNN-SEBlock.
801
+
802
+ Arguments
803
+ ----------
804
+ out_channels: int
805
+ The number of output channels.
806
+ res2net_scale: int
807
+ The scale of the Res2Net block.
808
+ kernel_size: int
809
+ The kernel size of the TDNN blocks.
810
+ dilation: int
811
+ The dilation of the Res2Net block.
812
+ activation : torch class
813
+ A class for constructing the activation layers.
814
+
815
+ Example
816
+ -------
817
+ >>> x = torch.rand(8, 120, 64).transpose(1, 2)
818
+ >>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
819
+ >>> out = conv(x).transpose(1, 2)
820
+ >>> out.shape
821
+ torch.Size([8, 120, 64])
822
+ """
823
+
824
+ def __init__(
825
+ self,
826
+ in_channels,
827
+ out_channels,
828
+ res2net_scale=8,
829
+ se_channels=128,
830
+ kernel_size=1,
831
+ dilation=1,
832
+ activation=torch.nn.ReLU,
833
+ batch_norm=True,
834
+ ):
835
+ super().__init__()
836
+ self.out_channels = out_channels
837
+ self.tdnn1 = TDNNBlock(
838
+ in_channels,
839
+ out_channels,
840
+ kernel_size=1,
841
+ dilation=1,
842
+ activation=activation,
843
+ batch_norm=batch_norm,
844
+ )
845
+ self.res2net_block = Res2NetBlock(
846
+ out_channels,
847
+ out_channels,
848
+ res2net_scale,
849
+ kernel_size,
850
+ dilation,
851
+ batch_norm=batch_norm,
852
+ )
853
+ self.tdnn2 = TDNNBlock(
854
+ out_channels,
855
+ out_channels,
856
+ kernel_size=1,
857
+ dilation=1,
858
+ activation=activation,
859
+ batch_norm=batch_norm,
860
+ )
861
+ self.se_block = SEBlock(out_channels, se_channels, out_channels)
862
+
863
+ self.shortcut = None
864
+ if in_channels != out_channels:
865
+ self.shortcut = Conv1d(
866
+ in_channels=in_channels,
867
+ out_channels=out_channels,
868
+ kernel_size=1,
869
+ )
870
+
871
+ def forward(self, x, lengths=None):
872
+ residual = x
873
+ if self.shortcut:
874
+ residual = self.shortcut(x)
875
+
876
+ x = self.tdnn1(x)
877
+ x = self.res2net_block(x)
878
+ x = self.tdnn2(x)
879
+ x = self.se_block(x, lengths)
880
+
881
+ return x + residual
882
+
883
+
884
+ class ECAPA_TDNN(torch.nn.Module):
885
+ """An implementation of the speaker embedding model in a paper.
886
+ "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
887
+ TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
888
+
889
+ Arguments
890
+ ---------
891
+ device : str
892
+ Device used, e.g., "cpu" or "cuda".
893
+ activation : torch class
894
+ A class for constructing the activation layers.
895
+ channels : list of ints
896
+ Output channels for TDNN/SERes2Net layer.
897
+ kernel_sizes : list of ints
898
+ List of kernel sizes for each layer.
899
+ dilations : list of ints
900
+ List of dilations for kernels in each layer.
901
+ lin_neurons : int
902
+ Number of neurons in linear layers.
903
+
904
+ Example
905
+ -------
906
+ >>> input_feats = torch.rand([5, 120, 80])
907
+ >>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192)
908
+ >>> outputs = compute_embedding(input_feats)
909
+ >>> outputs.shape
910
+ torch.Size([5, 1, 192])
911
+ """
912
+
913
+ def __init__(
914
+ self,
915
+ input_size,
916
+ lin_neurons=192,
917
+ activation=torch.nn.ReLU,
918
+ channels=[512, 512, 512, 512, 1536],
919
+ kernel_sizes=[5, 3, 3, 3, 1],
920
+ dilations=[1, 2, 3, 4, 1],
921
+ attention_channels=128,
922
+ res2net_scale=8,
923
+ se_channels=128,
924
+ global_context=True,
925
+ batch_norm=True,
926
+ ):
927
+
928
+ super().__init__()
929
+ assert len(channels) == len(kernel_sizes)
930
+ assert len(channels) == len(dilations)
931
+ self.channels = channels
932
+ self.blocks = nn.ModuleList()
933
+
934
+ # The initial TDNN layer
935
+ self.blocks.append(
936
+ TDNNBlock(
937
+ input_size,
938
+ channels[0],
939
+ kernel_sizes[0],
940
+ dilations[0],
941
+ activation,
942
+ batch_norm=batch_norm,
943
+ )
944
+ )
945
+
946
+ # SE-Res2Net layers
947
+ for i in range(1, len(channels) - 1):
948
+ self.blocks.append(
949
+ SERes2NetBlock(
950
+ channels[i - 1],
951
+ channels[i],
952
+ res2net_scale=res2net_scale,
953
+ se_channels=se_channels,
954
+ kernel_size=kernel_sizes[i],
955
+ dilation=dilations[i],
956
+ activation=activation,
957
+ batch_norm=batch_norm,
958
+ )
959
+ )
960
+
961
+ # Multi-layer feature aggregation
962
+ self.mfa = TDNNBlock(
963
+ channels[-1],
964
+ channels[-1],
965
+ kernel_sizes[-1],
966
+ dilations[-1],
967
+ activation,
968
+ batch_norm=batch_norm,
969
+ )
970
+
971
+ # Attentive Statistical Pooling
972
+ self.asp = AttentiveStatisticsPooling(
973
+ channels[-1],
974
+ attention_channels=attention_channels,
975
+ global_context=global_context,
976
+ batch_norm=batch_norm,
977
+ )
978
+ self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2, enabled=batch_norm)
979
+
980
+ # Final linear transformation
981
+ self.fc = Conv1d(
982
+ in_channels=channels[-1] * 2,
983
+ out_channels=lin_neurons,
984
+ kernel_size=1,
985
+ )
986
+
987
+ # @torch.cuda.amp.autocast(enabled=True, dtype=torch.float32)
988
+ def forward(self, x, lengths=None):
989
+ """Returns the embedding vector.
990
+
991
+ Arguments
992
+ ---------
993
+ x : torch.Tensor
994
+ Tensor of shape (batch, time, channel).
995
+ """
996
+ # Minimize transpose for efficiency
997
+ x = x.transpose(1, 2)
998
+
999
+ xl = []
1000
+ for layer in self.blocks:
1001
+ try:
1002
+ x = layer(x, lengths=lengths)
1003
+ except TypeError:
1004
+ x = layer(x)
1005
+ xl.append(x)
1006
+
1007
+ # Multi-layer feature aggregation
1008
+ x = torch.cat(xl[1:], dim=1)
1009
+ x = self.mfa(x)
1010
+
1011
+ # Attentive Statistical Pooling
1012
+ x = self.asp(x, lengths=lengths)
1013
+ x = self.asp_bn(x)
1014
+
1015
+ # Final linear transformation
1016
+ x = self.fc(x)
1017
+
1018
+ x = x.squeeze(-1)
1019
+ return x
1020
+
1021
+
1022
+ class SpeakerEmbedddingExtractor(object):
1023
+
1024
+ def __init__(self, ckpt_path, device="cuda"):
1025
+ # NOTE: The sampling rate is 16000
1026
+ self.mel_extractor = TorchMelSpectrogram()
1027
+ self.mel_extractor.to(device)
1028
+ model = ECAPA_TDNN(
1029
+ 80,
1030
+ 512,
1031
+ channels=[512, 512, 512, 512, 1536],
1032
+ kernel_sizes=[5, 3, 3, 3, 1],
1033
+ dilations=[1, 2, 3, 4, 1],
1034
+ attention_channels=128,
1035
+ res2net_scale=4,
1036
+ se_channels=128,
1037
+ global_context=True,
1038
+ batch_norm=True,
1039
+ )
1040
+ model.load_state_dict(torch.load(ckpt_path), strict=True)
1041
+ model.eval()
1042
+ self.model = model
1043
+ self.model.to(device)
1044
+
1045
+ def __call__(self, wav):
1046
+ # wav, sr = torchaudio.load(audio_path)
1047
+ # assert sr == 16000, f"The sampling rate is not 16000"
1048
+ # print(wav.shape)
1049
+ mel = self.mel_extractor(wav.unsqueeze(0))
1050
+ spk = self.model(mel)
1051
+ spk = spk[0]
1052
+ return spk
fireredtts/modules/flow/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fireredtts.modules.flow.codec_embedding import HHGCodecEmbedding
2
+ from fireredtts.modules.flow.conformer import ConformerDecoderV2
3
+ from fireredtts.modules.flow.mel_encoder import MelReduceEncoder
4
+ from fireredtts.modules.flow.decoder import ConditionalCFM, ConditionalDecoder
5
+ from fireredtts.modules.flow.flow_model import InterpolateRegulator, CrossAttnFlowMatching
6
+ from fireredtts.modules.flow.mel_spectrogram import MelSpectrogramExtractor
7
+
8
+
9
+ def get_flow_frontend(flow_config):
10
+ flow = CrossAttnFlowMatching(
11
+ output_size=flow_config["output_size"],
12
+ input_embedding=HHGCodecEmbedding(**flow_config["input_embedding"]),
13
+ encoder=ConformerDecoderV2(**flow_config["encoder"]),
14
+ length_regulator=InterpolateRegulator(**flow_config["length_regulator"]),
15
+ mel_encoder=MelReduceEncoder(**flow_config["mel_encoder"]),
16
+ decoder=ConditionalCFM(
17
+ estimator=ConditionalDecoder(**flow_config["decoder"]["estimator"]),
18
+ t_scheduler=flow_config["decoder"]["t_scheduler"],
19
+ inference_cfg_rate=flow_config["decoder"]["inference_cfg_rate"]
20
+ )
21
+ )
22
+ return flow
23
+
24
+
fireredtts/modules/flow/codebook.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f32d6280c561e4d0bf741e3ea07d651aee71da84c17fc55b686cb825bd677656
3
+ size 131200
fireredtts/modules/flow/codec_embedding.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class HHGCodecEmbedding(nn.Module):
7
+ def __init__(self, out_channels, codebook_path:str, freeze=True):
8
+ super().__init__()
9
+ # (2, 128, 128)
10
+ codebook = torch.from_numpy(np.load(codebook_path).copy())
11
+ assert codebook.shape[0] == 2 and codebook.shape[1] == 128
12
+ self.codebook_dim = codebook.shape[2]
13
+
14
+ self.codebook = torch.nn.ModuleList([
15
+ torch.nn.Embedding.from_pretrained(codebook[i], freeze=freeze)
16
+ for i in range(codebook.shape[0])]
17
+ )
18
+ if self.codebook_dim * 2 != out_channels:
19
+ self.proj = nn.Linear(self.codebook_dim * 2, out_channels)
20
+ else:
21
+ self.proj = nn.Identity()
22
+
23
+ def forward(self, tokens):
24
+ token_embs = torch.cat([
25
+ self.codebook[0](tokens % 128),
26
+ self.codebook[1](tokens // 128)
27
+ ], dim=-1)
28
+ token_embs = self.proj(token_embs)
29
+ return token_embs
30
+
fireredtts/modules/flow/conformer.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from fireredtts.modules.flow.utils import make_pad_mask
8
+
9
+
10
+ class MultiHeadedAttention(nn.Module):
11
+ """Multi-Head Attention layer.
12
+
13
+ Args:
14
+ n_head (int): The number of heads.
15
+ n_feat (int): The number of features.
16
+ dropout_rate (float): Dropout rate.
17
+
18
+ """
19
+
20
+ def __init__(self,
21
+ n_head: int,
22
+ n_feat: int,
23
+ dropout_rate: float,
24
+ key_bias: bool = True):
25
+ """Construct an MultiHeadedAttention object."""
26
+ super().__init__()
27
+ assert n_feat % n_head == 0
28
+ # We assume d_v always equals d_k
29
+ self.d_k = n_feat // n_head
30
+ self.h = n_head
31
+ self.linear_q = nn.Linear(n_feat, n_feat)
32
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
33
+ self.linear_v = nn.Linear(n_feat, n_feat)
34
+ self.linear_out = nn.Linear(n_feat, n_feat)
35
+ self.dropout = nn.Dropout(p=dropout_rate)
36
+
37
+ def forward_qkv(
38
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
39
+ ) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
40
+ """Transform query, key and value.
41
+
42
+ Args:
43
+ query (torch.Tensor): Query tensor (#batch, time1, size).
44
+ key (torch.Tensor): Key tensor (#batch, time2, size).
45
+ value (torch.Tensor): Value tensor (#batch, time2, size).
46
+
47
+ Returns:
48
+ torch.Tensor: Transformed query tensor, size
49
+ (#batch, n_head, time1, d_k).
50
+ torch.Tensor: Transformed key tensor, size
51
+ (#batch, n_head, time2, d_k).
52
+ torch.Tensor: Transformed value tensor, size
53
+ (#batch, n_head, time2, d_k).
54
+
55
+ """
56
+ n_batch = query.size(0)
57
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
58
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
59
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
60
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
61
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
62
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
63
+
64
+ return q, k, v
65
+
66
+ def forward_attention(
67
+ self,
68
+ value: torch.Tensor,
69
+ scores: torch.Tensor,
70
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
71
+ ) -> torch.Tensor:
72
+ """Compute attention context vector.
73
+
74
+ Args:
75
+ value (torch.Tensor): Transformed value, size
76
+ (#batch, n_head, time2, d_k).
77
+ scores (torch.Tensor): Attention score, size
78
+ (#batch, n_head, time1, time2).
79
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
80
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
81
+
82
+ Returns:
83
+ torch.Tensor: Transformed value (#batch, time1, d_model)
84
+ weighted by the attention score (#batch, time1, time2).
85
+
86
+ """
87
+ n_batch = value.size(0)
88
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
89
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
90
+ # 1st chunk to ease the onnx export.]
91
+ # 2. pytorch training
92
+ if mask.size(2) > 0: # time2 > 0
93
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
94
+ # For last chunk, time2 might be larger than scores.size(-1)
95
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
96
+ scores = scores.masked_fill(mask, -float('inf'))
97
+ attn = torch.softmax(scores, dim=-1).masked_fill(
98
+ mask, 0.0) # (batch, head, time1, time2)
99
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
100
+ # 1. onnx(16/-1, -1/-1, 16/0)
101
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
102
+ else:
103
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
104
+
105
+ p_attn = self.dropout(attn)
106
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
107
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
108
+ self.h * self.d_k)
109
+ ) # (batch, time1, d_model)
110
+
111
+ return self.linear_out(x) # (batch, time1, d_model)
112
+
113
+ def forward(
114
+ self,
115
+ query: torch.Tensor,
116
+ key: torch.Tensor,
117
+ value: torch.Tensor,
118
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
119
+ pos_emb: torch.Tensor = torch.empty(0),
120
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
121
+ ) -> tp.Tuple[torch.Tensor, torch.Tensor]:
122
+ """Compute scaled dot product attention.
123
+
124
+ Args:
125
+ query (torch.Tensor): Query tensor (#batch, time1, size).
126
+ key (torch.Tensor): Key tensor (#batch, time2, size).
127
+ value (torch.Tensor): Value tensor (#batch, time2, size).
128
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
129
+ (#batch, time1, time2).
130
+ 1.When applying cross attention between decoder and encoder,
131
+ the batch padding mask for input is in (#batch, 1, T) shape.
132
+ 2.When applying self attention of encoder,
133
+ the mask is in (#batch, T, T) shape.
134
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
135
+ where `cache_t == chunk_size * num_decoding_left_chunks`
136
+ and `head * d_k == size`
137
+
138
+
139
+ Returns:
140
+ torch.Tensor: Output tensor (#batch, time1, d_model).
141
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
142
+ where `cache_t == chunk_size * num_decoding_left_chunks`
143
+ and `head * d_k == size`
144
+
145
+ """
146
+ q, k, v = self.forward_qkv(query, key, value)
147
+
148
+ # NOTE(xcsong):
149
+ # when export onnx model, for 1st chunk, we feed
150
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
151
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
152
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
153
+ # and we will always do splitting and
154
+ # concatnation(this will simplify onnx export). Note that
155
+ # it's OK to concat & split zero-shaped tensors(see code below).
156
+ # when export jit model, for 1st chunk, we always feed
157
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
158
+ # >>> a = torch.ones((1, 2, 0, 4))
159
+ # >>> b = torch.ones((1, 2, 3, 4))
160
+ # >>> c = torch.cat((a, b), dim=2)
161
+ # >>> torch.equal(b, c) # True
162
+ # >>> d = torch.split(a, 2, dim=-1)
163
+ # >>> torch.equal(d[0], d[1]) # True
164
+ if cache.size(0) > 0:
165
+ key_cache, value_cache = torch.split(cache,
166
+ cache.size(-1) // 2,
167
+ dim=-1)
168
+ k = torch.cat([key_cache, k], dim=2)
169
+ v = torch.cat([value_cache, v], dim=2)
170
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
171
+ # non-trivial to calculate `next_cache_start` here.
172
+ new_cache = torch.cat((k, v), dim=-1)
173
+
174
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
175
+ return self.forward_attention(v, scores, mask), new_cache
176
+
177
+
178
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
179
+ """Multi-Head Attention layer with relative position encoding.
180
+ Paper: https://arxiv.org/abs/1901.02860
181
+ Args:
182
+ n_head (int): The number of heads.
183
+ n_feat (int): The number of features.
184
+ dropout_rate (float): Dropout rate.
185
+ """
186
+
187
+ def __init__(self,
188
+ n_head: int,
189
+ n_feat: int,
190
+ dropout_rate: float,
191
+ key_bias: bool = True):
192
+ """Construct an RelPositionMultiHeadedAttention object."""
193
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
194
+ # linear transformation for positional encoding
195
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
196
+ # these two learnable bias are used in matrix c and matrix d
197
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
198
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
199
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
200
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
201
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
202
+
203
+ def rel_shift(self, x):
204
+ """Compute relative positional encoding.
205
+
206
+ Args:
207
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
208
+ time1 means the length of query vector.
209
+
210
+ Returns:
211
+ torch.Tensor: Output tensor.
212
+
213
+ """
214
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
215
+ x_padded = torch.cat([zero_pad, x], dim=-1)
216
+
217
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
218
+ x = x_padded[:, :, 1:].view_as(x)[
219
+ :, :, :, : x.size(-1) // 2 + 1
220
+ ] # only keep the positions from 0 to time2
221
+ return x
222
+
223
+ def forward(
224
+ self,
225
+ query: torch.Tensor,
226
+ key: torch.Tensor,
227
+ value: torch.Tensor,
228
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
229
+ pos_emb: torch.Tensor = torch.empty(0),
230
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
231
+ ) -> tp.Tuple[torch.Tensor, torch.Tensor]:
232
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
233
+ Args:
234
+ query (torch.Tensor): Query tensor (#batch, time1, size).
235
+ key (torch.Tensor): Key tensor (#batch, time2, size).
236
+ value (torch.Tensor): Value tensor (#batch, time2, size).
237
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
238
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
239
+ pos_emb (torch.Tensor): Positional embedding tensor
240
+ (#batch, time2, size).
241
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
242
+ where `cache_t == chunk_size * num_decoding_left_chunks`
243
+ and `head * d_k == size`
244
+ Returns:
245
+ torch.Tensor: Output tensor (#batch, time1, d_model).
246
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
247
+ where `cache_t == chunk_size * num_decoding_left_chunks`
248
+ and `head * d_k == size`
249
+ """
250
+ q, k, v = self.forward_qkv(query, key, value)
251
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
252
+
253
+ # NOTE(xcsong):
254
+ # when export onnx model, for 1st chunk, we feed
255
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
256
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
257
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
258
+ # and we will always do splitting and
259
+ # concatnation(this will simplify onnx export). Note that
260
+ # it's OK to concat & split zero-shaped tensors(see code below).
261
+ # when export jit model, for 1st chunk, we always feed
262
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
263
+ # >>> a = torch.ones((1, 2, 0, 4))
264
+ # >>> b = torch.ones((1, 2, 3, 4))
265
+ # >>> c = torch.cat((a, b), dim=2)
266
+ # >>> torch.equal(b, c) # True
267
+ # >>> d = torch.split(a, 2, dim=-1)
268
+ # >>> torch.equal(d[0], d[1]) # True
269
+ if cache.size(0) > 0:
270
+ key_cache, value_cache = torch.split(cache,
271
+ cache.size(-1) // 2,
272
+ dim=-1)
273
+ k = torch.cat([key_cache, k], dim=2)
274
+ v = torch.cat([value_cache, v], dim=2)
275
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
276
+ # non-trivial to calculate `next_cache_start` here.
277
+ new_cache = torch.cat((k, v), dim=-1)
278
+
279
+ n_batch_pos = pos_emb.size(0)
280
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
281
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
282
+
283
+ # (batch, head, time1, d_k)
284
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
285
+ # (batch, head, time1, d_k)
286
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
287
+
288
+ # compute attention score
289
+ # first compute matrix a and matrix c
290
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
291
+ # (batch, head, time1, time2)
292
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
293
+
294
+ # compute matrix b and matrix d
295
+ # (batch, head, time1, time2)
296
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
297
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
298
+ if matrix_ac.shape != matrix_bd.shape:
299
+ matrix_bd = self.rel_shift(matrix_bd)
300
+
301
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
302
+ self.d_k) # (batch, head, time1, time2)
303
+
304
+ return self.forward_attention(v, scores, mask), new_cache
305
+
306
+
307
+ class PositionwiseFeedForward(torch.nn.Module):
308
+ """Positionwise feed forward layer.
309
+
310
+ FeedForward are appied on each position of the sequence.
311
+ The output dim is same with the input dim.
312
+
313
+ Args:
314
+ idim (int): Input dimenstion.
315
+ hidden_units (int): The number of hidden units.
316
+ dropout_rate (float): Dropout rate.
317
+ activation (torch.nn.Module): Activation function
318
+ """
319
+
320
+ def __init__(
321
+ self,
322
+ idim: int,
323
+ hidden_units: int,
324
+ dropout_rate: float,
325
+ activation: torch.nn.Module = torch.nn.ReLU(),
326
+ ):
327
+ """Construct a PositionwiseFeedForward object."""
328
+ super(PositionwiseFeedForward, self).__init__()
329
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
330
+ self.activation = activation
331
+ self.dropout = torch.nn.Dropout(dropout_rate)
332
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
333
+
334
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
335
+ """Forward function.
336
+
337
+ Args:
338
+ xs: input tensor (B, L, D)
339
+ Returns:
340
+ output tensor, (B, L, D)
341
+ """
342
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
343
+
344
+
345
+ class ConformerDecoderLayer(nn.Module):
346
+ """Encoder layer module.
347
+ Args:
348
+ size (int): Input dimension.
349
+ self_attn (torch.nn.Module): Self-attention module instance.
350
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
351
+ instance can be used as the argument.
352
+ src_attn (torch.nn.Module): Cross-attention module instance.
353
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
354
+ instance can be used as the argument.
355
+ feed_forward (torch.nn.Module): Feed-forward module instance.
356
+ `PositionwiseFeedForward` instance can be used as the argument.
357
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
358
+ instance.
359
+ `PositionwiseFeedForward` instance can be used as the argument.
360
+ conv_module (torch.nn.Module): Convolution module instance.
361
+ `ConvlutionModule` instance can be used as the argument.
362
+ dropout_rate (float): Dropout rate.
363
+ normalize_before (bool):
364
+ True: use layer_norm before each sub-block.
365
+ False: use layer_norm after each sub-block.
366
+ """
367
+
368
+ def __init__(
369
+ self,
370
+ size: int,
371
+ self_attn: torch.nn.Module,
372
+ src_attn: tp.Optional[torch.nn.Module] = None,
373
+ feed_forward: tp.Optional[nn.Module] = None,
374
+ feed_forward_macaron: tp.Optional[nn.Module] = None,
375
+ conv_module: tp.Optional[nn.Module] = None,
376
+ dropout_rate: float = 0.1,
377
+ normalize_before: bool = True,
378
+ ):
379
+ """Construct an EncoderLayer object."""
380
+ super().__init__()
381
+ self.self_attn = self_attn
382
+ self.src_attn = src_attn
383
+ self.feed_forward = feed_forward
384
+ self.feed_forward_macaron = feed_forward_macaron
385
+ self.conv_module = conv_module
386
+ self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
387
+ self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
388
+ if src_attn is not None:
389
+ self.norm_mha2 = nn.LayerNorm(size, eps=1e-5) # for the MHA module(src_attn)
390
+ if feed_forward_macaron is not None:
391
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
392
+ self.ff_scale = 0.5
393
+ else:
394
+ self.ff_scale = 1.0
395
+ if self.conv_module is not None:
396
+ self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
397
+ self.norm_final = nn.LayerNorm(
398
+ size, eps=1e-5) # for the final output of the block
399
+ self.dropout = nn.Dropout(dropout_rate)
400
+ self.size = size
401
+ self.normalize_before = normalize_before
402
+
403
+ def forward(
404
+ self,
405
+ x: torch.Tensor,
406
+ mask: torch.Tensor,
407
+ # src-attention
408
+ memory: torch.Tensor,
409
+ memory_mask: torch.Tensor,
410
+ pos_emb: torch.Tensor,
411
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
412
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
413
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
414
+ ) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
415
+ """Compute encoded features.
416
+
417
+ Args:
418
+ x (torch.Tensor): (#batch, time, size)
419
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
420
+ (0, 0, 0) means fake mask.
421
+ pos_emb (torch.Tensor): positional encoding, must not be None
422
+ for ConformerEncoderLayer.
423
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
424
+ (#batch, 1, time), (0, 0, 0) means fake mask.
425
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
426
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
427
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
428
+ (#batch=1, size, cache_t2)
429
+ Returns:
430
+ torch.Tensor: Output tensor (#batch, time, size).
431
+ torch.Tensor: Mask tensor (#batch, time, time).
432
+ torch.Tensor: att_cache tensor,
433
+ (#batch=1, head, cache_t1 + time, d_k * 2).
434
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
435
+ """
436
+
437
+ # whether to use macaron style
438
+ if self.feed_forward_macaron is not None:
439
+ residual = x
440
+ if self.normalize_before:
441
+ x = self.norm_ff_macaron(x)
442
+ x = residual + self.ff_scale * self.dropout(
443
+ self.feed_forward_macaron(x))
444
+ if not self.normalize_before:
445
+ x = self.norm_ff_macaron(x)
446
+
447
+ # multi-headed self-attention module
448
+ residual = x
449
+ if self.normalize_before:
450
+ x = self.norm_mha(x)
451
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
452
+ att_cache)
453
+ x = residual + self.dropout(x_att)
454
+ if not self.normalize_before:
455
+ x = self.norm_mha(x)
456
+
457
+ # multi-headed cross-attention module
458
+ if self.src_attn is not None:
459
+ residual = x
460
+ if self.normalize_before:
461
+ x = self.norm_mha2(x)
462
+ x_att, _ = self.src_attn(x, memory, memory, memory_mask)
463
+ x = residual + self.dropout(x_att)
464
+ if not self.normalize_before:
465
+ x = self.norm_mha2(x)
466
+
467
+ # convolution module
468
+ # Fake new cnn cache here, and then change it in conv_module
469
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
470
+ if self.conv_module is not None:
471
+ residual = x
472
+ if self.normalize_before:
473
+ x = self.norm_conv(x)
474
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
475
+ x = residual + self.dropout(x)
476
+
477
+ if not self.normalize_before:
478
+ x = self.norm_conv(x)
479
+
480
+ # feed forward module
481
+ residual = x
482
+ if self.normalize_before:
483
+ x = self.norm_ff(x)
484
+
485
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
486
+ if not self.normalize_before:
487
+ x = self.norm_ff(x)
488
+
489
+ if self.conv_module is not None:
490
+ x = self.norm_final(x)
491
+
492
+ return x, mask, new_att_cache, new_cnn_cache
493
+
494
+
495
+ class EspnetRelPositionalEncoding(torch.nn.Module):
496
+ """Relative positional encoding module (new implementation).
497
+
498
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
499
+
500
+ See : Appendix B in https://arxiv.org/abs/1901.02860
501
+
502
+ Args:
503
+ d_model (int): Embedding dimension.
504
+ dropout_rate (float): Dropout rate.
505
+ max_len (int): Maximum input length.
506
+
507
+ """
508
+
509
+ def __init__(self, d_model, dropout_rate, max_len=5000):
510
+ """Construct an PositionalEncoding object."""
511
+ super(EspnetRelPositionalEncoding, self).__init__()
512
+ self.d_model = d_model
513
+ self.xscale = math.sqrt(self.d_model)
514
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
515
+ self.pe = None
516
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
517
+
518
+ def extend_pe(self, x):
519
+ """Reset the positional encodings."""
520
+ if self.pe is not None:
521
+ # self.pe contains both positive and negative parts
522
+ # the length of self.pe is 2 * input_len - 1
523
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
524
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
525
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
526
+ return
527
+ # Suppose `i` means to the position of query vecotr and `j` means the
528
+ # position of key vector. We use position relative positions when keys
529
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
530
+ pe_positive = torch.zeros(x.size(1), self.d_model)
531
+ pe_negative = torch.zeros(x.size(1), self.d_model)
532
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
533
+ div_term = torch.exp(
534
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
535
+ * -(math.log(10000.0) / self.d_model)
536
+ )
537
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
538
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
539
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
540
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
541
+
542
+ # Reserve the order of positive indices and concat both positive and
543
+ # negative indices. This is used to support the shifting trick
544
+ # as in https://arxiv.org/abs/1901.02860
545
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
546
+ pe_negative = pe_negative[1:].unsqueeze(0)
547
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
548
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
549
+
550
+ def forward(self, x: torch.Tensor, offset: tp.Union[int, torch.Tensor] = 0):
551
+ """Add positional encoding.
552
+
553
+ Args:
554
+ x (torch.Tensor): Input tensor (batch, time, `*`).
555
+
556
+ Returns:
557
+ torch.Tensor: Encoded tensor (batch, time, `*`).
558
+
559
+ """
560
+ self.extend_pe(x)
561
+ x = x * self.xscale
562
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
563
+ return self.dropout(x), self.dropout(pos_emb)
564
+
565
+ def position_encoding(self,
566
+ offset: tp.Union[int, torch.Tensor],
567
+ size: int) -> torch.Tensor:
568
+ """ For getting encoding in a streaming fashion
569
+
570
+ Attention!!!!!
571
+ we apply dropout only once at the whole utterance level in a none
572
+ streaming way, but will call this function several times with
573
+ increasing input size in a streaming scenario, so the dropout will
574
+ be applied several times.
575
+
576
+ Args:
577
+ offset (int or torch.tensor): start offset
578
+ size (int): required size of position encoding
579
+
580
+ Returns:
581
+ torch.Tensor: Corresponding encoding
582
+ """
583
+ pos_emb = self.pe[
584
+ :,
585
+ self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size,
586
+ ]
587
+ return pos_emb
588
+
589
+
590
+ class LinearNoSubsampling(torch.nn.Module):
591
+ """Linear transform the input without subsampling
592
+
593
+ Args:
594
+ idim (int): Input dimension.
595
+ odim (int): Output dimension.
596
+ dropout_rate (float): Dropout rate.
597
+
598
+ """
599
+
600
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
601
+ pos_enc_class: torch.nn.Module):
602
+ """Construct an linear object."""
603
+ super().__init__()
604
+ self.out = torch.nn.Sequential(
605
+ torch.nn.Linear(idim, odim),
606
+ torch.nn.LayerNorm(odim, eps=1e-5),
607
+ torch.nn.Dropout(dropout_rate),
608
+ )
609
+ self.pos_enc = pos_enc_class
610
+ self.right_context = 0
611
+ self.subsampling_rate = 1
612
+
613
+ def forward(
614
+ self,
615
+ x: torch.Tensor,
616
+ x_mask: torch.Tensor,
617
+ offset: tp.Union[int, torch.Tensor] = 0
618
+ ) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
619
+ """Input x.
620
+
621
+ Args:
622
+ x (torch.Tensor): Input tensor (#batch, time, idim).
623
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
624
+
625
+ Returns:
626
+ torch.Tensor: linear input tensor (#batch, time', odim),
627
+ where time' = time .
628
+ torch.Tensor: linear input mask (#batch, 1, time'),
629
+ where time' = time .
630
+
631
+ """
632
+ x = self.out(x)
633
+ x, pos_emb = self.pos_enc(x, offset)
634
+ return x, pos_emb, x_mask
635
+
636
+
637
+ class ConformerDecoderV2(nn.Module):
638
+ def __init__(self,
639
+ input_size: int = 512,
640
+ output_size: int = 512,
641
+ attention_heads: int = 8,
642
+ linear_units: int = 2048,
643
+ num_blocks: int = 6,
644
+ dropout_rate: float = 0.01,
645
+ srcattention_start_index: int = 0,
646
+ srcattention_end_index: int = 2,
647
+ attention_dropout_rate: float = 0.01,
648
+ positional_dropout_rate: float = 0.01,
649
+ key_bias: bool = True,
650
+ normalize_before: bool = True,
651
+ ):
652
+ super().__init__()
653
+ self.num_blocks = num_blocks
654
+ self.normalize_before = normalize_before
655
+ self.output_size = output_size
656
+
657
+ self.embed = LinearNoSubsampling(
658
+ input_size,
659
+ output_size,
660
+ dropout_rate,
661
+ EspnetRelPositionalEncoding(output_size, positional_dropout_rate),
662
+ )
663
+
664
+ self.encoders = torch.nn.ModuleList()
665
+ for i in range(self.num_blocks):
666
+ # construct src attention
667
+ if srcattention_start_index <= i <= srcattention_end_index:
668
+ srcattention_layer = MultiHeadedAttention(
669
+ attention_heads,
670
+ output_size,
671
+ attention_dropout_rate,
672
+ key_bias
673
+ )
674
+ else:
675
+ srcattention_layer = None
676
+ # construct self attention
677
+ selfattention_layer = RelPositionMultiHeadedAttention(
678
+ attention_heads,
679
+ output_size,
680
+ attention_dropout_rate,
681
+ key_bias
682
+ )
683
+ # construct ffn
684
+ ffn_layer = PositionwiseFeedForward(
685
+ output_size,
686
+ linear_units,
687
+ dropout_rate,
688
+ torch.nn.SiLU()
689
+ )
690
+ self.encoders.append(
691
+ ConformerDecoderLayer(
692
+ output_size,
693
+ selfattention_layer,
694
+ srcattention_layer,
695
+ ffn_layer,
696
+ None,
697
+ None,
698
+ dropout_rate,
699
+ normalize_before=normalize_before
700
+ )
701
+ )
702
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
703
+
704
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
705
+ memory: torch.Tensor, memory_masks: torch.Tensor,
706
+ pos_emb: torch.Tensor, mask_pad: torch.Tensor) -> torch.Tensor:
707
+ for layer in self.encoders:
708
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, memory, memory_masks, pos_emb, mask_pad)
709
+ return xs
710
+
711
+ def forward(self,
712
+ xs:torch.Tensor,
713
+ xs_lens:torch.Tensor,
714
+ memory:torch.Tensor,
715
+ memory_lens: torch.Tensor,
716
+ ):
717
+ T = xs.size(1)
718
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
719
+ T2 = memory.size(1)
720
+ memory_masks = ~make_pad_mask(memory_lens, T2).unsqueeze(1) # (B, 1, T2)
721
+
722
+ xs, pos_emb, masks = self.embed(xs, masks)
723
+
724
+ xs = self.forward_layers(xs, masks, memory, memory_masks, pos_emb, masks)
725
+
726
+ if self.normalize_before:
727
+ xs = self.after_norm(xs)
728
+
729
+ return xs, masks
730
+
fireredtts/modules/flow/decoder.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import pack, rearrange, repeat
7
+ from diffusers.models.activations import get_activation
8
+
9
+ from fireredtts.modules.flow.transformer import BasicTransformerBlock
10
+
11
+
12
+ class SinusoidalPosEmb(torch.nn.Module):
13
+ def __init__(self, dim):
14
+ super().__init__()
15
+ self.dim = dim
16
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
17
+
18
+ def forward(self, x, scale=1000):
19
+ if x.ndim < 1:
20
+ x = x.unsqueeze(0)
21
+ device = x.device
22
+ half_dim = self.dim // 2
23
+ emb = math.log(10000) / (half_dim - 1)
24
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
25
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
26
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
27
+ return emb
28
+
29
+
30
+ class Block1D(torch.nn.Module):
31
+ def __init__(self, dim, dim_out, groups=8):
32
+ super().__init__()
33
+ self.block = torch.nn.Sequential(
34
+ torch.nn.Conv1d(dim, dim_out, 3, padding=1),
35
+ torch.nn.GroupNorm(groups, dim_out),
36
+ nn.Mish(),
37
+ )
38
+
39
+ def forward(self, x, mask):
40
+ output = self.block(x * mask)
41
+ return output * mask
42
+
43
+
44
+ class ResnetBlock1D(torch.nn.Module):
45
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
46
+ super().__init__()
47
+ self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
48
+
49
+ self.block1 = Block1D(dim, dim_out, groups=groups)
50
+ self.block2 = Block1D(dim_out, dim_out, groups=groups)
51
+
52
+ self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
53
+
54
+ def forward(self, x, mask, time_emb):
55
+ h = self.block1(x, mask)
56
+ h += self.mlp(time_emb).unsqueeze(-1)
57
+ h = self.block2(h, mask)
58
+ output = h + self.res_conv(x * mask)
59
+ return output
60
+
61
+
62
+ class Downsample1D(nn.Module):
63
+ def __init__(self, dim):
64
+ super().__init__()
65
+ self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
66
+
67
+ def forward(self, x):
68
+ return self.conv(x)
69
+
70
+
71
+ class TimestepEmbedding(nn.Module):
72
+ def __init__(
73
+ self,
74
+ in_channels: int,
75
+ time_embed_dim: int,
76
+ act_fn: str = "silu",
77
+ out_dim: int = None,
78
+ post_act_fn: Optional[str] = None,
79
+ cond_proj_dim=None,
80
+ ):
81
+ super().__init__()
82
+
83
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
84
+
85
+ if cond_proj_dim is not None:
86
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
87
+ else:
88
+ self.cond_proj = None
89
+
90
+ self.act = get_activation(act_fn)
91
+
92
+ if out_dim is not None:
93
+ time_embed_dim_out = out_dim
94
+ else:
95
+ time_embed_dim_out = time_embed_dim
96
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
97
+
98
+ if post_act_fn is None:
99
+ self.post_act = None
100
+ else:
101
+ self.post_act = get_activation(post_act_fn)
102
+
103
+ def forward(self, sample, condition=None):
104
+ if condition is not None:
105
+ sample = sample + self.cond_proj(condition)
106
+ sample = self.linear_1(sample)
107
+
108
+ if self.act is not None:
109
+ sample = self.act(sample)
110
+
111
+ sample = self.linear_2(sample)
112
+
113
+ if self.post_act is not None:
114
+ sample = self.post_act(sample)
115
+ return sample
116
+
117
+
118
+ class Upsample1D(nn.Module):
119
+ """A 1D upsampling layer with an optional convolution.
120
+
121
+ Parameters:
122
+ channels (`int`):
123
+ number of channels in the inputs and outputs.
124
+ use_conv (`bool`, default `False`):
125
+ option to use a convolution.
126
+ use_conv_transpose (`bool`, default `False`):
127
+ option to use a convolution transpose.
128
+ out_channels (`int`, optional):
129
+ number of output channels. Defaults to `channels`.
130
+ """
131
+
132
+ def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
133
+ super().__init__()
134
+ self.channels = channels
135
+ self.out_channels = out_channels or channels
136
+ self.use_conv = use_conv
137
+ self.use_conv_transpose = use_conv_transpose
138
+ self.name = name
139
+
140
+ self.conv = None
141
+ if use_conv_transpose:
142
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
143
+ elif use_conv:
144
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
145
+
146
+ def forward(self, inputs):
147
+ assert inputs.shape[1] == self.channels
148
+ if self.use_conv_transpose:
149
+ return self.conv(inputs)
150
+
151
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
152
+
153
+ if self.use_conv:
154
+ outputs = self.conv(outputs)
155
+
156
+ return outputs
157
+
158
+
159
+ class ConditionalDecoder(nn.Module):
160
+ def __init__(
161
+ self,
162
+ in_channels,
163
+ out_channels,
164
+ channels=(256, 256),
165
+ dropout=0.0,
166
+ attention_head_dim=64,
167
+ n_blocks=4,
168
+ num_mid_blocks=12,
169
+ num_heads=8,
170
+ act_fn="gelu",
171
+ ):
172
+ """
173
+ This decoder requires an input with the same shape of the target. So, if your text content
174
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
175
+ """
176
+ super().__init__()
177
+ channels = tuple(channels)
178
+ self.in_channels = in_channels
179
+ self.out_channels = out_channels
180
+
181
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
182
+ time_embed_dim = channels[0] * 4
183
+ self.time_mlp = TimestepEmbedding(
184
+ in_channels=in_channels,
185
+ time_embed_dim=time_embed_dim,
186
+ act_fn="silu",
187
+ )
188
+ self.down_blocks = nn.ModuleList([])
189
+ self.mid_blocks = nn.ModuleList([])
190
+ self.up_blocks = nn.ModuleList([])
191
+
192
+ output_channel = in_channels
193
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
194
+ input_channel = output_channel
195
+ output_channel = channels[i]
196
+ is_last = i == len(channels) - 1
197
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
198
+ transformer_blocks = nn.ModuleList(
199
+ [
200
+ BasicTransformerBlock(
201
+ dim=output_channel,
202
+ num_attention_heads=num_heads,
203
+ attention_head_dim=attention_head_dim,
204
+ dropout=dropout,
205
+ activation_fn=act_fn,
206
+ )
207
+ for _ in range(n_blocks)
208
+ ]
209
+ )
210
+ downsample = (
211
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
212
+ )
213
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
214
+
215
+ for i in range(num_mid_blocks):
216
+ input_channel = channels[-1]
217
+ out_channels = channels[-1]
218
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
219
+
220
+ transformer_blocks = nn.ModuleList(
221
+ [
222
+ BasicTransformerBlock(
223
+ dim=output_channel,
224
+ num_attention_heads=num_heads,
225
+ attention_head_dim=attention_head_dim,
226
+ dropout=dropout,
227
+ activation_fn=act_fn,
228
+ )
229
+ for _ in range(n_blocks)
230
+ ]
231
+ )
232
+
233
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
234
+
235
+ channels = channels[::-1] + (channels[0],)
236
+ for i in range(len(channels) - 1):
237
+ input_channel = channels[i] * 2
238
+ output_channel = channels[i + 1]
239
+ is_last = i == len(channels) - 2
240
+ resnet = ResnetBlock1D(
241
+ dim=input_channel,
242
+ dim_out=output_channel,
243
+ time_emb_dim=time_embed_dim,
244
+ )
245
+ transformer_blocks = nn.ModuleList(
246
+ [
247
+ BasicTransformerBlock(
248
+ dim=output_channel,
249
+ num_attention_heads=num_heads,
250
+ attention_head_dim=attention_head_dim,
251
+ dropout=dropout,
252
+ activation_fn=act_fn,
253
+ )
254
+ for _ in range(n_blocks)
255
+ ]
256
+ )
257
+ upsample = (
258
+ Upsample1D(output_channel, use_conv_transpose=True)
259
+ if not is_last
260
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
261
+ )
262
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
263
+ self.final_block = Block1D(channels[-1], channels[-1])
264
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
265
+ self.initialize_weights()
266
+
267
+
268
+ def initialize_weights(self):
269
+ for m in self.modules():
270
+ if isinstance(m, nn.Conv1d):
271
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
272
+ if m.bias is not None:
273
+ nn.init.constant_(m.bias, 0)
274
+ elif isinstance(m, nn.GroupNorm):
275
+ nn.init.constant_(m.weight, 1)
276
+ nn.init.constant_(m.bias, 0)
277
+ elif isinstance(m, nn.Linear):
278
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
279
+ if m.bias is not None:
280
+ nn.init.constant_(m.bias, 0)
281
+
282
+ def forward(self, x, mask, mu, t):
283
+ """Forward pass of the UNet1DConditional model.
284
+
285
+ Args:
286
+ x (torch.Tensor): shape (batch_size, in_channels, time)
287
+ mask (_type_): shape (batch_size, 1, time)
288
+ t (_type_): shape (batch_size)
289
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
290
+ cond (_type_, optional): placeholder for future use. Defaults to None.
291
+
292
+ Raises:
293
+ ValueError: _description_
294
+ ValueError: _description_
295
+
296
+ Returns:
297
+ _type_: _description_
298
+ """
299
+
300
+ t = self.time_embeddings(t)
301
+ t = self.time_mlp(t)
302
+
303
+ x = pack([x, mu], "b * t")[0]
304
+
305
+ hiddens = []
306
+ masks = [mask]
307
+ for resnet, transformer_blocks, downsample in self.down_blocks:
308
+ mask_down = masks[-1]
309
+ x = resnet(x, mask_down, t)
310
+ x = rearrange(x, "b c t -> b t c").contiguous()
311
+ attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
312
+ for transformer_block in transformer_blocks:
313
+ x = transformer_block(
314
+ hidden_states=x,
315
+ attention_mask=attn_mask,
316
+ timestep=t,
317
+ )
318
+ x = rearrange(x, "b t c -> b c t").contiguous()
319
+ hiddens.append(x) # Save hidden states for skip connections
320
+ x = downsample(x * mask_down)
321
+ masks.append(mask_down[:, :, ::2])
322
+ masks = masks[:-1]
323
+ mask_mid = masks[-1]
324
+
325
+ for resnet, transformer_blocks in self.mid_blocks:
326
+ x = resnet(x, mask_mid, t)
327
+ x = rearrange(x, "b c t -> b t c").contiguous()
328
+ attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
329
+ for transformer_block in transformer_blocks:
330
+ x = transformer_block(
331
+ hidden_states=x,
332
+ attention_mask=attn_mask,
333
+ timestep=t,
334
+ )
335
+ x = rearrange(x, "b t c -> b c t").contiguous()
336
+
337
+ for resnet, transformer_blocks, upsample in self.up_blocks:
338
+ mask_up = masks.pop()
339
+ skip = hiddens.pop()
340
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
341
+ x = resnet(x, mask_up, t)
342
+ x = rearrange(x, "b c t -> b t c").contiguous()
343
+ attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
344
+ for transformer_block in transformer_blocks:
345
+ x = transformer_block(
346
+ hidden_states=x,
347
+ attention_mask=attn_mask,
348
+ timestep=t,
349
+ )
350
+ x = rearrange(x, "b t c -> b c t").contiguous()
351
+ x = upsample(x * mask_up)
352
+ x = self.final_block(x, mask_up)
353
+ output = self.final_proj(x * mask_up)
354
+ return output * mask
355
+
356
+
357
+ class ConditionalCFM(nn.Module):
358
+ def __init__(self,
359
+ estimator: nn.Module,
360
+ t_scheduler: str = "cosine",
361
+ inference_cfg_rate: float = 0.7,
362
+ ):
363
+ super().__init__()
364
+ self.estimator = estimator
365
+ self.t_scheduler = t_scheduler
366
+ self.inference_cfg_rate = inference_cfg_rate
367
+
368
+ def solve_euler(self, x, t_span, mu, mask):
369
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
370
+
371
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
372
+ # Or in future might add like a return_all_steps flag
373
+ sol = []
374
+
375
+ for step in range(1, len(t_span)):
376
+ dphi_dt = self.estimator(x, mask, mu, t)
377
+ # Classifier-Free Guidance inference introduced in VoiceBox
378
+ if self.inference_cfg_rate > 0:
379
+ cfg_dphi_dt = self.estimator(x, mask, torch.zeros_like(mu), t)
380
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
381
+ self.inference_cfg_rate * cfg_dphi_dt)
382
+ x = x + dt * dphi_dt
383
+ t = t + dt
384
+ sol.append(x)
385
+ if step < len(t_span) - 1:
386
+ dt = t_span[step + 1] - t
387
+
388
+ return sol[-1]
389
+
390
+ def inference(self, mu, mask, n_timesteps, temperature: float=1.0):
391
+ z = torch.randn_like(mu) * temperature
392
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
393
+ if self.t_scheduler == 'cosine':
394
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
395
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask)
396
+
fireredtts/modules/flow/flow_model.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ from fireredtts.modules.flow.utils import make_pad_mask
7
+
8
+
9
+ class InterpolateRegulator(nn.Module):
10
+ def __init__(
11
+ self,
12
+ channels: int = 512,
13
+ num_blocks: int = 4,
14
+ groups: int = 1,
15
+ ):
16
+ super().__init__()
17
+ model = []
18
+ for _ in range(num_blocks):
19
+ model.extend([
20
+ nn.Conv1d(channels, channels, 3, 1, 1),
21
+ nn.GroupNorm(groups, channels),
22
+ nn.Mish(),
23
+ ])
24
+ model.append(
25
+ nn.Conv1d(channels, channels, 1, 1)
26
+ )
27
+ self.model = nn.Sequential(*model)
28
+
29
+ def forward(self, x, ylens=None):
30
+ # x in (B, T, D)
31
+ mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
32
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
33
+ out = self.model(x).transpose(1, 2).contiguous()
34
+ olens = ylens
35
+ return out * mask, olens
36
+
37
+
38
+ class CrossAttnFlowMatching(nn.Module):
39
+ def __init__(self,
40
+ output_size: int,
41
+ input_embedding: nn.Module,
42
+ encoder: nn.Module,
43
+ length_regulator: nn.Module,
44
+ mel_encoder: nn.Module,
45
+ decoder: nn.Module,
46
+ ):
47
+ super().__init__()
48
+ self.input_embedding = input_embedding
49
+ self.encoder = encoder
50
+ self.length_regulator = length_regulator
51
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size, output_size)
52
+ self.prompt_prenet = mel_encoder
53
+ self.decoder = decoder
54
+
55
+ def inference(self,
56
+ token: torch.Tensor,
57
+ token_len: torch.Tensor,
58
+ prompt_mel: torch.Tensor,
59
+ prompt_mel_len: torch.Tensor,
60
+ n_timesteps:int=10,
61
+ ):
62
+ # prompt projection
63
+ prompt_feat = self.prompt_prenet(prompt_mel)
64
+ prompt_feat_len = torch.ceil(prompt_mel_len/self.prompt_prenet.reduction_rate).long()
65
+
66
+ # concat text and prompt_text
67
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(token_len.device)
68
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
69
+
70
+ # 40ms shift to 10ms shift
71
+ feat_len = (token_len *4).int()
72
+
73
+ # first encoder
74
+ h, _ = self.encoder(token, token_len, prompt_feat, prompt_feat_len)
75
+ # length regulate
76
+ h, _ = self.length_regulator(h, feat_len)
77
+ # final projection
78
+ h = self.encoder_proj(h)
79
+
80
+ mask = (~make_pad_mask(feat_len)).to(h)
81
+
82
+ feat = self.decoder.inference(
83
+ mu=h.transpose(1, 2).contiguous(),
84
+ mask=mask.unsqueeze(1),
85
+ n_timesteps=n_timesteps,
86
+ )
87
+ return feat
88
+
89
+
fireredtts/modules/flow/mel_encoder.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class ConvLayer(nn.Module):
7
+ def __init__(self,
8
+ in_channels:int,
9
+ out_channels:int,
10
+ kernel_size:int,
11
+ stride:int,
12
+ activation:str="GELU",
13
+ dropout_rate:float=0.0,
14
+ ):
15
+ super().__init__()
16
+ self.conv = nn.Conv1d(
17
+ in_channels=in_channels,
18
+ out_channels=out_channels,
19
+ kernel_size=kernel_size,
20
+ stride=stride,
21
+ padding=(kernel_size-stride)//2,
22
+ )
23
+ self.drop = nn.Dropout(dropout_rate)
24
+ self.norm = nn.LayerNorm(out_channels)
25
+ self.activ = getattr(nn, activation)()
26
+
27
+ def forward(self, x:torch.Tensor):
28
+ """
29
+ Args:
30
+ x: (b, t, c)
31
+ Return:
32
+ x: (b, t, c)
33
+ """
34
+ x = x.transpose(2, 1)
35
+ x = self.conv(x)
36
+ x = x.transpose(2, 1)
37
+ x = self.drop(x)
38
+ x = self.norm(x)
39
+ x = self.activ(x)
40
+ return x
41
+
42
+
43
+ class ResidualConvLayer(nn.Module):
44
+ def __init__(self,
45
+ hidden_channels:int,
46
+ n_layers:int=2,
47
+ kernel_size:int=5,
48
+ activation:str="GELU",
49
+ dropout_rate:float=0.0,
50
+ ):
51
+ super().__init__()
52
+ layers = [
53
+ ConvLayer(hidden_channels, hidden_channels, kernel_size, 1, activation, dropout_rate)
54
+ for _ in range(n_layers)
55
+ ]
56
+ self.layers = nn.Sequential(*layers)
57
+
58
+ def forward(self, x:torch.Tensor):
59
+ """
60
+ Args:
61
+ x: (b, t, c)
62
+ Returns:
63
+ x: (b, t, c)
64
+ """
65
+ return x + self.layers(x)
66
+
67
+
68
+ class ResidualConvBlock(nn.Module):
69
+ def __init__(self,
70
+ in_channels:int,
71
+ hidden_channels:int,
72
+ out_channels:int,
73
+ n_layers:int=2,
74
+ n_blocks:int=5,
75
+ middle_layer:tp.Optional[nn.Module]=None,
76
+ kernel_size:int=5,
77
+ activation:str="GELU",
78
+ dropout_rate:float=0.0,
79
+ ):
80
+ super().__init__()
81
+ self.in_proj = nn.Conv1d(
82
+ in_channels,
83
+ hidden_channels,
84
+ kernel_size=kernel_size,
85
+ stride=1,
86
+ padding=(kernel_size-1)//2,
87
+ ) if in_channels != hidden_channels else nn.Identity()
88
+
89
+ self.conv1 = nn.Sequential(*[
90
+ ResidualConvLayer(hidden_channels, n_layers, kernel_size, activation, dropout_rate)
91
+ for _ in range(n_blocks)
92
+ ])
93
+
94
+ if middle_layer is None:
95
+ self.middle_layer = nn.Identity()
96
+ elif isinstance(middle_layer, nn.Module):
97
+ self.middle_layer = middle_layer
98
+ else:
99
+ raise TypeError("unknown middle layer type:{}".format(type(middle_layer)))
100
+
101
+ self.conv2 = nn.Sequential(*[
102
+ ResidualConvLayer(hidden_channels, n_layers, kernel_size, activation, dropout_rate)
103
+ for _ in range(n_blocks)
104
+ ])
105
+
106
+ self.out_proj = nn.Conv1d(
107
+ hidden_channels,
108
+ out_channels,
109
+ kernel_size=kernel_size,
110
+ stride=1,
111
+ padding=(kernel_size-1)//2,
112
+ ) if out_channels != hidden_channels else nn.Identity()
113
+
114
+ def forward(self, x:torch.Tensor, **middle_layer_kwargs):
115
+ """
116
+ Args:
117
+ x: (b, t1, c)
118
+ Return:
119
+ x: (b, t2, c)
120
+ """
121
+ x = self.in_proj(x.transpose(2, 1)).transpose(2, 1)
122
+ x = self.conv1(x)
123
+ if isinstance(self.middle_layer, nn.MaxPool1d) or isinstance(self.middle_layer, nn.Conv1d):
124
+ x = self.middle_layer(x.transpose(2, 1)).transpose(2, 1)
125
+ elif isinstance(self.middle_layer, nn.Identity):
126
+ x = self.middle_layer(x)
127
+ else:
128
+ # incase of phoneme-pooling layer
129
+ x = self.middle_layer(x, **middle_layer_kwargs)
130
+ x = self.conv2(x)
131
+ x = self.out_proj(x.transpose(2, 1)).transpose(2, 1)
132
+ return x
133
+
134
+
135
+ class MelReduceEncoder(nn.Module):
136
+ def __init__(self,
137
+ in_channels:int,
138
+ out_channels:int,
139
+ hidden_channels:int=384,
140
+ reduction_rate:int=4,
141
+ n_layers:int=2,
142
+ n_blocks:int=5,
143
+ kernel_size:int=3,
144
+ activation:str="GELU",
145
+ dropout:float=0.0,
146
+ ):
147
+ super().__init__()
148
+ self.reduction_rate = reduction_rate
149
+ middle_conv = nn.Conv1d(
150
+ in_channels=hidden_channels,
151
+ out_channels=hidden_channels,
152
+ kernel_size=reduction_rate,
153
+ stride=reduction_rate,
154
+ padding=0
155
+ )
156
+ self.encoder = ResidualConvBlock(
157
+ in_channels=in_channels,
158
+ hidden_channels=hidden_channels,
159
+ out_channels=out_channels,
160
+ n_layers=n_layers,
161
+ n_blocks=n_blocks,
162
+ middle_layer=middle_conv,
163
+ kernel_size=kernel_size,
164
+ activation=activation,
165
+ dropout_rate=dropout
166
+ )
167
+
168
+ def forward(self, x:torch.Tensor):
169
+ return self.encoder(x)
170
+
fireredtts/modules/flow/mel_spectrogram.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import torch
3
+ import numpy as np
4
+ import librosa
5
+ from librosa.filters import mel as librosa_mel_fn
6
+ from torchaudio.functional import resample as ta_resample_fn
7
+
8
+ MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
9
+
10
+
11
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
12
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
13
+
14
+
15
+ def dynamic_range_decompression(x, C=1):
16
+ return np.exp(x) / C
17
+
18
+
19
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
20
+ return torch.log(torch.clamp(x, min=clip_val) * C)
21
+
22
+
23
+ def dynamic_range_decompression_torch(x, C=1):
24
+ return torch.exp(x) / C
25
+
26
+
27
+ def spectral_normalize_torch(magnitudes):
28
+ output = dynamic_range_compression_torch(magnitudes)
29
+ return output
30
+
31
+
32
+ def spectral_de_normalize_torch(magnitudes):
33
+ output = dynamic_range_decompression_torch(magnitudes)
34
+ return output
35
+
36
+
37
+ mel_basis = {}
38
+ hann_window = {}
39
+
40
+
41
+ def mel_spectrogram(
42
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
43
+ ):
44
+ global mel_basis, hann_window
45
+ if fmax not in mel_basis:
46
+ mel = librosa_mel_fn(
47
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
48
+ )
49
+ str_key_mel_basis = str(fmax) + "_" + str(y.device)
50
+ mel_basis[str_key_mel_basis] = torch.from_numpy(mel).float().to(y.device)
51
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
52
+
53
+ y = torch.nn.functional.pad(
54
+ y.unsqueeze(1),
55
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
56
+ mode="reflect",
57
+ )
58
+ y = y.squeeze(1)
59
+
60
+ # complex tensor as default, then use view_as_real for future pytorch compatibility
61
+ spec = torch.stft(
62
+ y,
63
+ n_fft,
64
+ hop_length=hop_size,
65
+ win_length=win_size,
66
+ window=hann_window[str(y.device)],
67
+ center=center,
68
+ pad_mode="reflect",
69
+ normalized=False,
70
+ onesided=True,
71
+ return_complex=True,
72
+ )
73
+ spec = torch.view_as_real(spec)
74
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
75
+
76
+ spec = torch.matmul(mel_basis[str_key_mel_basis], spec)
77
+ spec = spectral_normalize_torch(spec)
78
+
79
+ return spec
80
+
81
+
82
+ kaiser_best_resampling_fn = partial(
83
+ ta_resample_fn,
84
+ resampling_method="sinc_interp_kaiser", # DO NOT CHANGE!
85
+ rolloff=0.917347, # DO NOT CHANGE!
86
+ beta=12.9846, # DO NOT CHANGE!
87
+ lowpass_filter_width=50, # DO NOT CHANGE!
88
+ )
89
+
90
+
91
+ class MelSpectrogramExtractor(object):
92
+ def __init__(
93
+ self,
94
+ n_fft=1024,
95
+ win_size=1024,
96
+ num_mels=100,
97
+ hop_size=160,
98
+ sampling_rate=16000,
99
+ fmin=0,
100
+ fmax=None,
101
+ ):
102
+ self.n_fft = n_fft
103
+ self.win_size = win_size
104
+ self.num_mels = num_mels
105
+ self.hop_size = hop_size
106
+ self.sampling_rate = sampling_rate
107
+ self.fmin = fmin
108
+ self.fmax = fmax
109
+
110
+ def __call__(self, wav_path) -> np.ndarray:
111
+ wav_data, wav_sr = librosa.load(wav_path, sr=None, mono=True)
112
+ wav_data = torch.from_numpy(wav_data.copy()).unsqueeze(0)
113
+ # for 16k wavs, up-downsample to reduce artifects
114
+ if wav_sr == self.sampling_rate:
115
+ wav_data = kaiser_best_resampling_fn(wav_data, orig_freq=wav_sr, new_freq=24000)
116
+ wav_data = kaiser_best_resampling_fn(wav_data, orig_freq=24000, new_freq=self.sampling_rate)
117
+ else:
118
+ wav_data = kaiser_best_resampling_fn(wav_data, orig_freq=wav_sr, new_freq=self.sampling_rate)
119
+
120
+ # (1, num_mels, t)
121
+ mel = mel_spectrogram(
122
+ wav_data,
123
+ self.n_fft,
124
+ self.num_mels,
125
+ self.sampling_rate,
126
+ self.hop_size,
127
+ self.win_size,
128
+ self.fmin,
129
+ self.fmax,
130
+ )
131
+ mel = mel.squeeze(0).transpose(1, 0)
132
+ return mel # (t, num_mels)
fireredtts/modules/flow/transformer.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from diffusers.models.attention import (
6
+ GEGLU,
7
+ GELU,
8
+ AdaLayerNorm,
9
+ AdaLayerNormZero,
10
+ ApproximateGELU,
11
+ )
12
+ from diffusers.models.attention_processor import Attention
13
+ # from diffusers.models.lora import LoRACompatibleLinear
14
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
15
+
16
+
17
+ class FeedForward(nn.Module):
18
+ r"""
19
+ A feed-forward layer.
20
+
21
+ Parameters:
22
+ dim (`int`): The number of channels in the input.
23
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
24
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
25
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
26
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
27
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ dim: int,
33
+ dim_out: Optional[int] = None,
34
+ mult: int = 4,
35
+ dropout: float = 0.0,
36
+ activation_fn: str = "geglu",
37
+ final_dropout: bool = False,
38
+ ):
39
+ super().__init__()
40
+ inner_dim = int(dim * mult)
41
+ dim_out = dim_out if dim_out is not None else dim
42
+
43
+ if activation_fn == "gelu":
44
+ act_fn = GELU(dim, inner_dim)
45
+ if activation_fn == "gelu-approximate":
46
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
47
+ elif activation_fn == "geglu":
48
+ act_fn = GEGLU(dim, inner_dim)
49
+ elif activation_fn == "geglu-approximate":
50
+ act_fn = ApproximateGELU(dim, inner_dim)
51
+
52
+ self.net = nn.ModuleList([])
53
+ # project in
54
+ self.net.append(act_fn)
55
+ # project dropout
56
+ self.net.append(nn.Dropout(dropout))
57
+ # project out
58
+ self.net.append(nn.Linear(inner_dim, dim_out))
59
+ # self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
60
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
61
+ if final_dropout:
62
+ self.net.append(nn.Dropout(dropout))
63
+
64
+ def forward(self, hidden_states):
65
+ for module in self.net:
66
+ hidden_states = module(hidden_states)
67
+ return hidden_states
68
+
69
+
70
+ @maybe_allow_in_graph
71
+ class BasicTransformerBlock(nn.Module):
72
+ r"""
73
+ A basic Transformer block.
74
+
75
+ Parameters:
76
+ dim (`int`): The number of channels in the input and output.
77
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
78
+ attention_head_dim (`int`): The number of channels in each head.
79
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
80
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
81
+ only_cross_attention (`bool`, *optional*):
82
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
83
+ double_self_attention (`bool`, *optional*):
84
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
85
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
86
+ num_embeds_ada_norm (:
87
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
88
+ attention_bias (:
89
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ dim: int,
95
+ num_attention_heads: int,
96
+ attention_head_dim: int,
97
+ dropout=0.0,
98
+ cross_attention_dim: Optional[int] = None,
99
+ activation_fn: str = "geglu",
100
+ num_embeds_ada_norm: Optional[int] = None,
101
+ attention_bias: bool = False,
102
+ only_cross_attention: bool = False,
103
+ double_self_attention: bool = False,
104
+ upcast_attention: bool = False,
105
+ norm_elementwise_affine: bool = True,
106
+ norm_type: str = "layer_norm",
107
+ final_dropout: bool = False,
108
+ ):
109
+ super().__init__()
110
+ self.only_cross_attention = only_cross_attention
111
+
112
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
113
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
114
+
115
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
116
+ raise ValueError(
117
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
118
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
119
+ )
120
+
121
+ # Define 3 blocks. Each block has its own normalization layer.
122
+ # 1. Self-Attn
123
+ if self.use_ada_layer_norm:
124
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
125
+ elif self.use_ada_layer_norm_zero:
126
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
127
+ else:
128
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
129
+ self.attn1 = Attention(
130
+ query_dim=dim,
131
+ heads=num_attention_heads,
132
+ dim_head=attention_head_dim,
133
+ dropout=dropout,
134
+ bias=attention_bias,
135
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
136
+ upcast_attention=upcast_attention,
137
+ )
138
+
139
+ # 2. Cross-Attn
140
+ if cross_attention_dim is not None or double_self_attention:
141
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
142
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
143
+ # the second cross attention block.
144
+ self.norm2 = (
145
+ AdaLayerNorm(dim, num_embeds_ada_norm)
146
+ if self.use_ada_layer_norm
147
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
148
+ )
149
+ self.attn2 = Attention(
150
+ query_dim=dim,
151
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
152
+ heads=num_attention_heads,
153
+ dim_head=attention_head_dim,
154
+ dropout=dropout,
155
+ bias=attention_bias,
156
+ upcast_attention=upcast_attention,
157
+ # scale_qk=False, # uncomment this to not to use flash attention
158
+ ) # is self-attn if encoder_hidden_states is none
159
+ else:
160
+ self.norm2 = None
161
+ self.attn2 = None
162
+
163
+ # 3. Feed-forward
164
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
165
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
166
+
167
+ # let chunk size default to None
168
+ self._chunk_size = None
169
+ self._chunk_dim = 0
170
+
171
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
172
+ # Sets chunk feed-forward
173
+ self._chunk_size = chunk_size
174
+ self._chunk_dim = dim
175
+
176
+ def forward(
177
+ self,
178
+ hidden_states: torch.FloatTensor,
179
+ attention_mask: Optional[torch.FloatTensor] = None,
180
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
181
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
182
+ timestep: Optional[torch.LongTensor] = None,
183
+ cross_attention_kwargs: Dict[str, Any] = None,
184
+ class_labels: Optional[torch.LongTensor] = None,
185
+ ):
186
+ # Notice that normalization is always applied before the real computation in the following blocks.
187
+ # 1. Self-Attention
188
+ if self.use_ada_layer_norm:
189
+ norm_hidden_states = self.norm1(hidden_states, timestep)
190
+ elif self.use_ada_layer_norm_zero:
191
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
192
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
193
+ )
194
+ else:
195
+ norm_hidden_states = self.norm1(hidden_states)
196
+
197
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
198
+
199
+ attn_output = self.attn1(
200
+ norm_hidden_states,
201
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
202
+ attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
203
+ **cross_attention_kwargs,
204
+ )
205
+ if self.use_ada_layer_norm_zero:
206
+ attn_output = gate_msa.unsqueeze(1) * attn_output
207
+ hidden_states = attn_output + hidden_states
208
+
209
+ # 2. Cross-Attention
210
+ if self.attn2 is not None:
211
+ norm_hidden_states = (
212
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
213
+ )
214
+
215
+ attn_output = self.attn2(
216
+ norm_hidden_states,
217
+ encoder_hidden_states=encoder_hidden_states,
218
+ attention_mask=encoder_attention_mask,
219
+ **cross_attention_kwargs,
220
+ )
221
+ hidden_states = attn_output + hidden_states
222
+
223
+ # 3. Feed-forward
224
+ norm_hidden_states = self.norm3(hidden_states)
225
+
226
+ if self.use_ada_layer_norm_zero:
227
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
228
+
229
+ if self._chunk_size is not None:
230
+ # "feed_forward_chunk_size" can be used to save memory
231
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
232
+ raise ValueError(
233
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
234
+ )
235
+
236
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
237
+ ff_output = torch.cat(
238
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
239
+ dim=self._chunk_dim,
240
+ )
241
+ else:
242
+ ff_output = self.ff(norm_hidden_states)
243
+
244
+ if self.use_ada_layer_norm_zero:
245
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
246
+
247
+ hidden_states = ff_output + hidden_states
248
+
249
+ return hidden_states
fireredtts/modules/flow/utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
5
+ """Make mask tensor containing indices of padded part.
6
+
7
+ See description of make_non_pad_mask.
8
+
9
+ Args:
10
+ lengths (torch.Tensor): Batch of lengths (B,).
11
+ Returns:
12
+ torch.Tensor: Mask tensor containing indices of padded part.
13
+
14
+ Examples:
15
+ >>> lengths = [5, 3, 2]
16
+ >>> make_pad_mask(lengths)
17
+ masks = [[0, 0, 0, 0 ,0],
18
+ [0, 0, 0, 1, 1],
19
+ [0, 0, 1, 1, 1]]
20
+ """
21
+ batch_size = lengths.size(0)
22
+ max_len = max_len if max_len > 0 else lengths.max().item()
23
+ seq_range = torch.arange(0,
24
+ max_len,
25
+ dtype=torch.int64,
26
+ device=lengths.device)
27
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
28
+ seq_length_expand = lengths.unsqueeze(-1)
29
+ mask = seq_range_expand >= seq_length_expand
30
+ return mask
fireredtts/modules/gpt/__init__.py ADDED
File without changes
fireredtts/modules/gpt/gpt.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ported from: https://github.com/neonbjb/tortoise-tts
2
+ # ported from: https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/layers/xtts/gpt.py
3
+
4
+ import functools
5
+ import math
6
+ import random
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from transformers import GPT2Config, GPT2Model, GPT2PreTrainedModel
12
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
13
+
14
+
15
+ class GPT2InferenceModel(GPT2PreTrainedModel):
16
+ """Override GPT2LMHeadModel to allow for prefix conditioning."""
17
+
18
+ def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
19
+ super().__init__(config)
20
+ self.transformer = gpt
21
+ self.pos_embedding = pos_emb
22
+ self.embeddings = embeddings
23
+ self.final_norm = norm
24
+ self.lm_head = nn.Sequential(norm, linear)
25
+ self.kv_cache = kv_cache
26
+
27
+ def store_prefix_emb(self, prefix_emb):
28
+ self.cached_prefix_emb = prefix_emb
29
+
30
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
31
+ token_type_ids = kwargs.get("token_type_ids", None) # usually None
32
+ if not self.kv_cache:
33
+ past_key_values = None
34
+
35
+ # only last token for inputs_ids if past is defined in kwargs
36
+ if past_key_values is not None:
37
+ input_ids = input_ids[:, -1].unsqueeze(-1)
38
+ if token_type_ids is not None:
39
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
40
+
41
+ attention_mask = kwargs.get("attention_mask", None)
42
+ position_ids = kwargs.get("position_ids", None)
43
+
44
+ if attention_mask is not None and position_ids is None:
45
+ # create position_ids on the fly for batch generation
46
+ position_ids = attention_mask.long().cumsum(-1) - 1
47
+ position_ids.masked_fill_(attention_mask == 0, 1)
48
+ if past_key_values is not None:
49
+ position_ids = position_ids[:, -1].unsqueeze(-1)
50
+ else:
51
+ position_ids = None
52
+ return {
53
+ "input_ids": input_ids,
54
+ "past_key_values": past_key_values,
55
+ "use_cache": kwargs.get("use_cache"),
56
+ "position_ids": position_ids,
57
+ "attention_mask": attention_mask,
58
+ "token_type_ids": token_type_ids,
59
+ }
60
+
61
+ def forward(
62
+ self,
63
+ input_ids=None,
64
+ past_key_values=None,
65
+ attention_mask=None,
66
+ token_type_ids=None,
67
+ position_ids=None,
68
+ head_mask=None,
69
+ inputs_embeds=None,
70
+ encoder_hidden_states=None,
71
+ encoder_attention_mask=None,
72
+ labels=None,
73
+ use_cache=None,
74
+ output_attentions=None,
75
+ output_hidden_states=None,
76
+ return_dict=None,
77
+ ):
78
+ assert self.cached_prefix_emb is not None
79
+ assert inputs_embeds is None # Not supported by this inference model.
80
+ assert labels is None # Training not supported by this inference model.
81
+ return_dict = (
82
+ return_dict if return_dict is not None else self.config.use_return_dict
83
+ )
84
+
85
+ # Create embedding
86
+ prefix_len = self.cached_prefix_emb.shape[1]
87
+ if input_ids.shape[1] != 1:
88
+ gen_inputs = input_ids[:, prefix_len:]
89
+ gen_emb = self.embeddings(gen_inputs)
90
+ gen_emb = gen_emb + self.pos_embedding(gen_emb)
91
+ if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
92
+ prefix_emb = self.cached_prefix_emb.repeat_interleave(
93
+ gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
94
+ )
95
+ else:
96
+ prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
97
+ emb = torch.cat([prefix_emb, gen_emb], dim=1)
98
+ else:
99
+ emb = self.embeddings(input_ids)
100
+ emb = emb + self.pos_embedding.get_fixed_embedding(
101
+ attention_mask.shape[1] - (prefix_len + 1), attention_mask.device
102
+ )
103
+ transformer_outputs = self.transformer(
104
+ inputs_embeds=emb,
105
+ past_key_values=past_key_values,
106
+ attention_mask=attention_mask,
107
+ token_type_ids=token_type_ids,
108
+ position_ids=position_ids,
109
+ head_mask=head_mask,
110
+ encoder_hidden_states=encoder_hidden_states,
111
+ encoder_attention_mask=encoder_attention_mask,
112
+ use_cache=use_cache,
113
+ output_attentions=output_attentions,
114
+ output_hidden_states=output_hidden_states,
115
+ return_dict=return_dict,
116
+ )
117
+ hidden_states = transformer_outputs[0]
118
+ lm_logits = self.lm_head(hidden_states)
119
+
120
+ if not return_dict:
121
+ return (lm_logits,) + transformer_outputs[1:]
122
+
123
+ return CausalLMOutputWithCrossAttentions(
124
+ loss=None,
125
+ logits=lm_logits,
126
+ past_key_values=transformer_outputs.past_key_values,
127
+ hidden_states=transformer_outputs.hidden_states,
128
+ attentions=transformer_outputs.attentions,
129
+ cross_attentions=transformer_outputs.cross_attentions,
130
+ )
131
+
132
+ @staticmethod
133
+ def _reorder_cache(past, beam_idx):
134
+ """
135
+ This function is used to re-order the :obj:`past_key_values` cache if
136
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
137
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
138
+ """
139
+ return tuple(
140
+ tuple(
141
+ past_state.index_select(0, beam_idx.to(past_state.device))
142
+ for past_state in layer_past
143
+ )
144
+ for layer_past in past
145
+ )
146
+
147
+
148
+ def null_position_embeddings(range, dim):
149
+ return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
150
+
151
+
152
+ class LearnedPositionEmbeddings(nn.Module):
153
+ def __init__(self, seq_len, model_dim, init=0.02):
154
+ super().__init__()
155
+ self.emb = torch.nn.Embedding(seq_len, model_dim)
156
+ # Initializing this way is standard for GPT-2
157
+ self.emb.weight.data.normal_(mean=0.0, std=init)
158
+
159
+ def forward(self, x):
160
+ sl = x.shape[1]
161
+ return self.emb(torch.arange(0, sl, device=x.device))
162
+
163
+ def get_fixed_embedding(self, ind, dev):
164
+ return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
165
+
166
+
167
+ def build_hf_gpt_transformer(
168
+ layers,
169
+ model_dim,
170
+ heads,
171
+ max_mel_seq_len,
172
+ max_text_seq_len,
173
+ max_prompt_len,
174
+ checkpointing,
175
+ ):
176
+ """
177
+ GPT-2 implemented by the HuggingFace library.
178
+ """
179
+ gpt_config = GPT2Config(
180
+ vocab_size=256, # Unused.
181
+ n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
182
+ n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
183
+ n_embd=model_dim,
184
+ n_layer=layers,
185
+ n_head=heads,
186
+ gradient_checkpointing=checkpointing,
187
+ use_cache=not checkpointing,
188
+ )
189
+ gpt = GPT2Model(gpt_config)
190
+ # Override the built in positional embeddings
191
+ del gpt.wpe
192
+ gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
193
+ # Built-in token embeddings are unused.
194
+ del gpt.wte
195
+
196
+ mel_pos_emb = (
197
+ LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
198
+ if max_mel_seq_len != -1
199
+ else functools.partial(null_position_embeddings, dim=model_dim)
200
+ )
201
+ text_pos_emb = (
202
+ LearnedPositionEmbeddings(max_text_seq_len, model_dim)
203
+ if max_mel_seq_len != -1
204
+ else functools.partial(null_position_embeddings, dim=model_dim)
205
+ )
206
+ return gpt, mel_pos_emb, text_pos_emb, None, None
207
+
208
+
209
+ class GPT(nn.Module):
210
+ def __init__(
211
+ self,
212
+ start_text_token=261,
213
+ stop_text_token=0,
214
+ layers=8,
215
+ model_dim=512,
216
+ heads=8,
217
+ max_text_tokens=120,
218
+ max_mel_tokens=250,
219
+ max_prompt_tokens=70,
220
+ max_conditioning_inputs=1,
221
+ code_stride_len=1024,
222
+ number_text_tokens=256,
223
+ num_audio_tokens=8194,
224
+ start_audio_token=8192,
225
+ stop_audio_token=8193,
226
+ checkpointing=False,
227
+ label_smoothing=0.0,
228
+ ):
229
+ """
230
+ Args:
231
+
232
+ """
233
+ super().__init__()
234
+
235
+ self.label_smoothing = label_smoothing
236
+ self.number_text_tokens = number_text_tokens
237
+ self.start_text_token = start_text_token
238
+ self.stop_text_token = stop_text_token
239
+ self.num_audio_tokens = num_audio_tokens
240
+ self.start_audio_token = start_audio_token
241
+ self.stop_audio_token = stop_audio_token
242
+ self.start_prompt_token = start_audio_token
243
+ self.stop_prompt_token = stop_audio_token
244
+ self.layers = layers
245
+ self.heads = heads
246
+ self.model_dim = model_dim
247
+ self.max_conditioning_inputs = max_conditioning_inputs
248
+ self.max_gen_mel_tokens = max_mel_tokens - self.max_conditioning_inputs - 2
249
+ self.max_mel_tokens = (
250
+ -1
251
+ if max_mel_tokens == -1
252
+ else max_mel_tokens + 2 + self.max_conditioning_inputs
253
+ )
254
+ self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
255
+ self.max_prompt_tokens = max_prompt_tokens
256
+ self.code_stride_len = code_stride_len
257
+ self.conditioning_dropout = nn.Dropout1d(0.1)
258
+ self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
259
+ self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
260
+
261
+ (
262
+ self.gpt,
263
+ self.mel_pos_embedding,
264
+ self.text_pos_embedding,
265
+ self.mel_layer_pos_embedding,
266
+ self.text_layer_pos_embedding,
267
+ ) = build_hf_gpt_transformer(
268
+ layers,
269
+ model_dim,
270
+ heads,
271
+ self.max_mel_tokens,
272
+ self.max_text_tokens,
273
+ self.max_prompt_tokens,
274
+ checkpointing,
275
+ )
276
+
277
+ self.final_norm = nn.LayerNorm(model_dim)
278
+ self.text_head = nn.Linear(model_dim, self.number_text_tokens)
279
+ self.mel_head = nn.Linear(model_dim, self.num_audio_tokens)
280
+
281
+ # reference_embedding
282
+ self.reference_embedding = nn.Sequential(
283
+ nn.Linear(512, 256),
284
+ nn.Tanh(),
285
+ nn.Linear(256, self.model_dim),
286
+ )
287
+
288
+ def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):
289
+ seq_length = (
290
+ self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1
291
+ )
292
+ gpt_config = GPT2Config(
293
+ vocab_size=self.max_mel_tokens,
294
+ n_positions=seq_length,
295
+ n_ctx=seq_length,
296
+ n_embd=self.model_dim,
297
+ n_layer=self.layers,
298
+ n_head=self.heads,
299
+ gradient_checkpointing=False,
300
+ use_cache=True,
301
+ )
302
+ self.gpt_inference = GPT2InferenceModel(
303
+ gpt_config,
304
+ self.gpt,
305
+ self.mel_pos_embedding,
306
+ self.mel_embedding,
307
+ self.final_norm,
308
+ self.mel_head,
309
+ kv_cache=kv_cache,
310
+ )
311
+ self.gpt.wte = self.mel_embedding
312
+
313
+ def inference(self, cond_latents, text_inputs, **hf_generate_kwargs):
314
+ self.compute_embeddings(cond_latents, text_inputs)
315
+ return self.generate(cond_latents, text_inputs, **hf_generate_kwargs)
316
+
317
+ def compute_embeddings(
318
+ self,
319
+ cond_latents,
320
+ text_inputs,
321
+ ):
322
+ text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
323
+ text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
324
+ emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
325
+ emb = torch.cat([cond_latents, emb], dim=1)
326
+ self.gpt_inference.store_prefix_emb(emb)
327
+ gpt_inputs = torch.full(
328
+ (
329
+ emb.shape[0],
330
+ emb.shape[1] + 1, # +1 for the start_audio_token
331
+ ),
332
+ fill_value=1,
333
+ dtype=torch.long,
334
+ device=text_inputs.device,
335
+ )
336
+ gpt_inputs[:, -1] = self.start_audio_token
337
+ return gpt_inputs
338
+
339
+ def generate(
340
+ self,
341
+ cond_latents,
342
+ text_inputs,
343
+ **hf_generate_kwargs,
344
+ ):
345
+ gpt_inputs = self.compute_embeddings(cond_latents, text_inputs)
346
+ gen = self.gpt_inference.generate(
347
+ gpt_inputs,
348
+ bos_token_id=self.start_audio_token,
349
+ pad_token_id=self.stop_audio_token,
350
+ eos_token_id=self.stop_audio_token,
351
+ max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1],
352
+ **hf_generate_kwargs,
353
+ )
354
+ if "return_dict_in_generate" in hf_generate_kwargs:
355
+ return gen.sequences[:, gpt_inputs.shape[1] :], gen
356
+ return gen[:, gpt_inputs.shape[1] :]
fireredtts/modules/text_normalizer/__init__.py ADDED
File without changes
fireredtts/modules/text_normalizer/normalize.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import regex
3
+ import inflect
4
+ import unicodedata
5
+ from lingua import Language, LanguageDetectorBuilder
6
+ from builtins import str as unicode
7
+
8
+ from tn.chinese.normalizer import Normalizer as ZhNormalizer
9
+ from tn.english.normalizer import Normalizer as EnNormalizer
10
+
11
+ from fireredtts.modules.text_normalizer.regex_common import *
12
+ from fireredtts.modules.text_normalizer.utils import *
13
+
14
+
15
+ def preprocess_text(sentence):
16
+ # preprocessing
17
+ sentence = bytes(sentence, "utf-8").decode("utf-8", "ignore")
18
+ sentence = regex.sub("[\p{Cf}--[\u200d]]", "", sentence, flags=regex.V1)
19
+ sentence = regex.sub("\p{Co}", "", sentence)
20
+ sentence = sentence.replace("\u00a0", " ")
21
+ sentence = sentence.replace("\ufffd", "")
22
+ sentence = regex.sub("\p{Zl}", "\n", sentence)
23
+ sentence = regex.sub("\p{Zp}", "\n", sentence)
24
+
25
+ sentence = unicode(sentence)
26
+ sentence = "".join(
27
+ char
28
+ for char in unicodedata.normalize("NFD", sentence)
29
+ if unicodedata.category(char) != "Mn"
30
+ ) # Strip accents
31
+
32
+ sentence = strip_kaomoji(sentence)
33
+ # full to half with exemption (to be converted after number TN): 。,:
34
+ sentence = f2b(sentence, exemption="。,:")
35
+
36
+ # clean spaces
37
+ sentence = sentence.replace("\n", ",")
38
+ sentence = sentence.replace("\t", ",")
39
+ sentence = sentence.replace("\r", ",")
40
+ sentence = re.sub(r"[。.]{3,}", "…", sentence)
41
+ sentence = re.sub(r"[…⋯]{1,}", "…", sentence)
42
+ sentence = re.sub(r"[ ]+", " ", sentence)
43
+ sentence = sentence.strip()
44
+
45
+ # punctuation reduction
46
+ result = ""
47
+ for idx, char in enumerate(sentence):
48
+ if char in symbol_reduction:
49
+ char = symbol_reduction[char]
50
+
51
+ if char == " ":
52
+ if idx == 0:
53
+ continue
54
+ if is_chinese(sentence[idx + 1]) and (
55
+ is_chinese(sentence[idx - 1]) or sentence[idx - 1] in '") '
56
+ ):
57
+ result += ","
58
+ else:
59
+ result += " "
60
+ continue
61
+
62
+ if is_valid_char(char):
63
+ result += char
64
+ result = re.sub(r"[ ]+", " ", result)
65
+ return result
66
+
67
+
68
+ def rettt(sentence):
69
+ # handle abbreviations for all languages
70
+ sentence = sentence.replace("&nd", "and")
71
+ sentence = sentence.replace("Jan.", "january")
72
+ sentence = sentence.replace("Feb.", "febrary")
73
+ sentence = sentence.replace("Mar.", "march")
74
+ sentence = sentence.replace("Apr.", "april")
75
+ sentence = sentence.replace("May.", "may")
76
+ sentence = sentence.replace("Jun.", "june")
77
+ sentence = sentence.replace("Jul.", "july")
78
+ sentence = sentence.replace("Aug.", "august")
79
+ sentence = sentence.replace("Sept.", "september")
80
+ sentence = sentence.replace("Sep.", "september")
81
+ sentence = sentence.replace("Oct.", "october")
82
+ sentence = sentence.replace("Nov.", "november")
83
+ sentence = sentence.replace("Dec.", "december")
84
+ sentence = sentence.replace("Mon.", "monday")
85
+ sentence = sentence.replace("Tues.", "tuesday")
86
+ sentence = sentence.replace("Wed.", "wednesday")
87
+ sentence = sentence.replace("Thur.", "thursday")
88
+ sentence = sentence.replace("Fri.", "friday")
89
+ sentence = sentence.replace("Sat.", "saturday")
90
+ if sentence != "Sun.":
91
+ sentence = sentence.replace("Sun.", "sunday")
92
+ sentence = re.sub(r" St\. ([A-Z])", r" saint \1", sentence)
93
+ sentence = re.sub(r" St\.", " street", sentence)
94
+ sentence = re.sub(r" Rd\.", " road", sentence)
95
+ sentence = re.sub(r"[Aa]\.[Mm]\.", "A_M", sentence)
96
+ sentence = re.sub(r"[Pp]\.[Mm]\.", "P_M", sentence)
97
+ sentence = re.sub(r"[Bb]\.[Cc]\.", "B_C", sentence)
98
+ sentence = re.sub(r"[Ad]\.[Dd]\.", "A_D", sentence)
99
+ sentence = sentence.replace("Mr.", "mister")
100
+ sentence = sentence.replace("Ms.", "miss")
101
+ sentence = sentence.replace("Mrs.", "misses")
102
+ sentence = sentence.replace("Ph.D", "P_H_D")
103
+ sentence = sentence.replace("i.e.", "that is")
104
+ sentence = sentence.replace("e.g.", "for example")
105
+ sentence = sentence.replace("btw.", "by the way")
106
+ sentence = sentence.replace("btw", "by the way")
107
+ sentence = sentence.replace("b.t.w.", "by the way")
108
+ sentence = sentence.replace("@", " at ")
109
+ return sentence
110
+
111
+
112
+ class TextNormalizer:
113
+ def __init__(self):
114
+ self.language_detector = LanguageDetectorBuilder.from_languages(
115
+ Language.ENGLISH, Language.CHINESE
116
+ ).build()
117
+ self.zh_normalizer = ZhNormalizer()
118
+ self.en_normalizer = EnNormalizer()
119
+ self.inflect_parser = inflect.engine()
120
+ self.lang2token = {Language.ENGLISH: "en", Language.CHINESE: "zh"}
121
+
122
+ def tn(self, text):
123
+ text = preprocess_text(text)
124
+ text = rettt(text) # regex replacements
125
+ # for non chinese languages
126
+ language = self.language_detector.detect_language_of(text)
127
+ # enforce chinese if text contains any chinese character
128
+ if contains_chinese(text):
129
+ language = Language.CHINESE
130
+ text_lang = self.lang2token.get(language, "zh")
131
+
132
+ if is_upper_eng_and_digit(text):
133
+ language = Language.CHINESE
134
+
135
+ if language == Language.CHINESE:
136
+ text = self.zh_normalizer.normalize(text)
137
+ text = text.replace("\n", "")
138
+ text = re.sub(r"[,,]+$", "。", text)
139
+ else:
140
+ text = re.sub(r"[^ 0-9A-Za-z\[\]'.,:?!_\-]", "", text)
141
+ text = self.en_normalizer.normalize(text)
142
+ # fallback number normalization
143
+ pieces = re.split(r"(\d+)", text)
144
+ text = "".join(
145
+ [
146
+ self.inflect_parser.number_to_words(p) if p.isnumeric() else p
147
+ for p in pieces
148
+ if len(p) > 0
149
+ ]
150
+ )
151
+
152
+ # cleanup
153
+ text = text.replace("_", " ")
154
+ text = re.sub(r"[ ]+", " ", text)
155
+
156
+ # spell caplital words
157
+ pieces = re.split(r"([A-Z]{2,4}|[ ])", text)
158
+ for idx, p in enumerate(pieces):
159
+ if re.match("[A-Z]{2,4}", p):
160
+ pieces[idx] = " ".join(p)
161
+ text = " ".join([p for p in pieces if p != " "])
162
+
163
+ # post TN full to half
164
+ text = text.replace("。", ".")
165
+ text = text.replace(",", ",")
166
+ text = text.replace(":", ":")
167
+
168
+ # model limitations
169
+ text = text.lower().strip()
170
+ text = text.replace('"', "")
171
+ text = text.replace("·", " ")
172
+ text = re.sub("[…~、!,?:;!?:;]+", ",", text)
173
+ text = re.sub("[,]+", ",", text)
174
+ text = re.sub(r"[,. ]+$", ".", text)
175
+ if len(text) > 0 and text[-1] != ".":
176
+ text = text + "."
177
+
178
+ return text, text_lang
fireredtts/modules/text_normalizer/regex_common.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ kaomoji_regex = re.compile(
4
+ r"[oヽwΣ┗╰O︿Ψ凸]?[(|≡*(].{0,4}[Д✿_▽→≧﹏`∩⊙∇☆≡๑〃′エ≦▔@﹁εヘ•́ω益‿≖ฺ皿•̀艹 ̄△|゚].{0,5}[|≡*))][┛ブ凸cdd︴oOΨ︿w╯ノ]?"
5
+ )
6
+ chinese_regex = re.compile(r"[\u4e00-\u9fa5]")
7
+ digit_regex = re.compile(r"(\\d+)(\\.\\d+)?", re.UNICODE)
8
+
9
+ chinese_char_regex = re.compile(r"^[\u4e00-\u9fa5]$", re.UNICODE)
10
+ eng_and_digit_char_regex = re.compile(r"^[0-9.,A-Za-z]+$", re.UNICODE)
11
+ upper_eng_and_digit_regex = re.compile(r"^[ 0-9A-Z\"'.,:?!\-]+$", re.UNICODE)
12
+ valid_char_regex = re.compile(
13
+ r"[\t\r\n ]|"
14
+ r"[\u4e00-\u9fa5]|"
15
+ r"\u0080|[\u20a0-\u20bf]|\u00a2|\u00a3|\u00a5|\uffe0|\uffe1|\uffe5|\uffe6|"
16
+ r"\u3000|\u3002|\u00b7|\u2014|\u2019|\u2026|\uff01|\uff1f|\uff0e|\uff1a|\uff1b|\uff0b|\uff0c|\uff0d|\uff0f|[\ufe10-\ufe16]|[\ufe50-\ufe51]|[\ufe55-\ufe57]|\ufe6a|"
17
+ r"[\u0030-\u0039]|"
18
+ r"[\u0391-\u03c9]|"
19
+ r"[\u00b0-\u00b3]|[\u2015-\u2018]|[\u3000-\u303f]|"
20
+ r"[\u0022-\u002f\u003a-\u003e\u0040\u005b-\u0060\u007b-\u007e]|"
21
+ r"[\uff21-\uff3a]|[\uff41-\uff5a]|[\u0041-\u005a]|[\u0061-\u007a]",
22
+ re.UNICODE,
23
+ )
fireredtts/modules/text_normalizer/utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fireredtts.modules.text_normalizer.regex_common import *
2
+
3
+
4
+ def contains_chinese(text):
5
+ return bool(chinese_regex.search(text))
6
+
7
+
8
+ def strip_kaomoji(text):
9
+ return kaomoji_regex.sub(" ", text)
10
+
11
+
12
+ def is_chinese(char):
13
+ return chinese_char_regex.match(char)
14
+
15
+
16
+ def is_eng_and_digit(char):
17
+ return eng_and_digit_char_regex.match(char)
18
+
19
+
20
+ def is_upper_eng_and_digit(text):
21
+ return upper_eng_and_digit_regex.match(text)
22
+
23
+
24
+ def is_valid_char(char):
25
+ return valid_char_regex.match(char)
26
+
27
+
28
+ def is_digit(text):
29
+ return digit_regex.match(text)
30
+
31
+
32
+ def contains_chinese(text):
33
+ return bool(chinese_regex.search(text))
34
+
35
+
36
+ def f2b(ustr, exemption="。,:"):
37
+ half = []
38
+ for u in ustr:
39
+ num = ord(u)
40
+ if num == 0x3000:
41
+ half.append(" ")
42
+ elif u in exemption: # exemption
43
+ half.append(u)
44
+ elif 0xFF01 <= num <= 0xFF5E:
45
+ num -= 0xFEE0
46
+ half.append(chr(num))
47
+ else:
48
+ half.append(u)
49
+ return "".join(half)
50
+
51
+
52
+ symbol_reduction = {
53
+ "「": '"',
54
+ "」": '"',
55
+ "`": '"',
56
+ "〝": '"',
57
+ "〞": '"',
58
+ "‟": '"',
59
+ "„": '"',
60
+ "{": "(",
61
+ "}": ")",
62
+ "【": "(",
63
+ "】": ")",
64
+ "〖": "(",
65
+ "〗": ")",
66
+ "〔": "(",
67
+ "〕": ")",
68
+ "〘": "(",
69
+ "〙": ")",
70
+ "《": "(",
71
+ "》": ")",
72
+ "⦅": "(",
73
+ "⦆": ")",
74
+ "〚": "(",
75
+ "〛": ")",
76
+ "『": '"',
77
+ "』": '"',
78
+ "「": '"',
79
+ "」": '"',
80
+ "{": "(",
81
+ "}": ")",
82
+ "〈": "(",
83
+ "〉": ")",
84
+ "•": "·",
85
+ "‧": "·",
86
+ "〰": "…",
87
+ "﹏": "…",
88
+ "〜": "~",
89
+ "~": "~",
90
+ "+": "+",
91
+ "、": "、",
92
+ "。": "。",
93
+ "︐": ",",
94
+ "﹐": ",",
95
+ "︑": "、",
96
+ "﹑": "、",
97
+ "︒": "。",
98
+ "︓": ":",
99
+ "﹕": ":",
100
+ "︔": ";",
101
+ "﹔": ";",
102
+ "︕": "!",
103
+ "﹗": "!",
104
+ "︖": "?",
105
+ "﹖": "?",
106
+ "﹙": "(",
107
+ "﹚": ")",
108
+ "﹪": "%",
109
+ "﹠": "&",
110
+ ">": ">",
111
+ "|": "、",
112
+ "=": "=",
113
+ "‐": "-",
114
+ "‑": "-",
115
+ "‒": "-",
116
+ "–": "-",
117
+ "—": "-",
118
+ "―": "-",
119
+ "%": "%",
120
+ "μ": "u",
121
+ }
fireredtts/modules/tokenizer/__init__.py ADDED
File without changes
fireredtts/modules/tokenizer/assets/multilingual.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
fireredtts/modules/tokenizer/tokenizer.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from fireredtts.modules.tokenizer.whisper_tokenizer import get_tokenizer
4
+ from fireredtts.modules.text_normalizer.normalize import TextNormalizer
5
+
6
+
7
+ DEFAULT_VOCAB_FILE = os.path.join(
8
+ os.path.dirname(os.path.realpath(__file__)), "../data/tokenizer.json"
9
+ )
10
+
11
+
12
+ class VoiceBpeTokenizer:
13
+ def __init__(self):
14
+ self.tokenizer = get_tokenizer(multilingual=True)
15
+ self.tn_engine = TextNormalizer()
16
+
17
+ def redtts_text_cleaner(self, text):
18
+ text = text.strip()
19
+ text, text_lang = self.tn_engine.tn(text)
20
+ # print("---text after tn:", text)
21
+ return text, text_lang
22
+
23
+ def encode(self, text, lang="auto"):
24
+ text, text_lang = self.redtts_text_cleaner(text=text)
25
+ if lang == "auto":
26
+ lang = text_lang
27
+ text = f"[{lang}]{text}"
28
+ return self.tokenizer.encode(text)
29
+
30
+ def decode(self, seq):
31
+ if isinstance(seq, torch.Tensor):
32
+ seq = seq.cpu().numpy()
33
+ text = self.tokenizer.decode(seq)
34
+ return text
35
+
36
+ def __len__(self):
37
+ return self.tokenizer.get_vocab_size()
38
+
39
+ def get_number_tokens(self):
40
+ return self.tokenizer.get_vocab_size()
41
+
42
+
43
+ if __name__ == "__main__":
44
+ tok = VoiceBpeTokenizer()
45
+ codes = tok.encode("我、真是hello USA啊?谢谢你world!")
46
+ print([tok.decode([c]) for c in codes])
fireredtts/modules/tokenizer/whisper_tokenizer.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted from https://github.com/openai/whisper/blob/main/whisper/tokenizer.py
2
+ # Copyright (c) 2022 OpenAI
3
+
4
+ # MIT License for this file
5
+
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+
24
+ import base64
25
+ import os
26
+ import string
27
+ from dataclasses import dataclass, field
28
+ from functools import cached_property, lru_cache
29
+ from typing import Dict, List, Optional, Tuple
30
+
31
+ import tiktoken
32
+
33
+ LANGUAGES = {
34
+ "en": "english",
35
+ "zh": "chinese",
36
+ "de": "german",
37
+ "es": "spanish",
38
+ "ru": "russian",
39
+ "ko": "korean",
40
+ "fr": "french",
41
+ "ja": "japanese",
42
+ "pt": "portuguese",
43
+ "tr": "turkish",
44
+ "pl": "polish",
45
+ "ca": "catalan",
46
+ "nl": "dutch",
47
+ "ar": "arabic",
48
+ "sv": "swedish",
49
+ "it": "italian",
50
+ "id": "indonesian",
51
+ "hi": "hindi",
52
+ "fi": "finnish",
53
+ "vi": "vietnamese",
54
+ "he": "hebrew",
55
+ "uk": "ukrainian",
56
+ "el": "greek",
57
+ "ms": "malay",
58
+ "cs": "czech",
59
+ "ro": "romanian",
60
+ "da": "danish",
61
+ "hu": "hungarian",
62
+ "ta": "tamil",
63
+ "no": "norwegian",
64
+ "th": "thai",
65
+ "ur": "urdu",
66
+ "hr": "croatian",
67
+ "bg": "bulgarian",
68
+ "lt": "lithuanian",
69
+ "la": "latin",
70
+ "mi": "maori",
71
+ "ml": "malayalam",
72
+ "cy": "welsh",
73
+ "sk": "slovak",
74
+ "te": "telugu",
75
+ "fa": "persian",
76
+ "lv": "latvian",
77
+ "bn": "bengali",
78
+ "sr": "serbian",
79
+ "az": "azerbaijani",
80
+ "sl": "slovenian",
81
+ "kn": "kannada",
82
+ "et": "estonian",
83
+ "mk": "macedonian",
84
+ "br": "breton",
85
+ "eu": "basque",
86
+ "is": "icelandic",
87
+ "hy": "armenian",
88
+ "ne": "nepali",
89
+ "mn": "mongolian",
90
+ "bs": "bosnian",
91
+ "kk": "kazakh",
92
+ "sq": "albanian",
93
+ "sw": "swahili",
94
+ "gl": "galician",
95
+ "mr": "marathi",
96
+ "pa": "punjabi",
97
+ "si": "sinhala",
98
+ "km": "khmer",
99
+ "sn": "shona",
100
+ "yo": "yoruba",
101
+ "so": "somali",
102
+ "af": "afrikaans",
103
+ "oc": "occitan",
104
+ "ka": "georgian",
105
+ "be": "belarusian",
106
+ "tg": "tajik",
107
+ "sd": "sindhi",
108
+ "gu": "gujarati",
109
+ "am": "amharic",
110
+ "yi": "yiddish",
111
+ "lo": "lao",
112
+ "uz": "uzbek",
113
+ "fo": "faroese",
114
+ "ht": "haitian creole",
115
+ "ps": "pashto",
116
+ "tk": "turkmen",
117
+ "nn": "nynorsk",
118
+ "mt": "maltese",
119
+ "sa": "sanskrit",
120
+ "lb": "luxembourgish",
121
+ "my": "myanmar",
122
+ "bo": "tibetan",
123
+ "tl": "tagalog",
124
+ "mg": "malagasy",
125
+ "as": "assamese",
126
+ "tt": "tatar",
127
+ "haw": "hawaiian",
128
+ "ln": "lingala",
129
+ "ha": "hausa",
130
+ "ba": "bashkir",
131
+ "jw": "javanese",
132
+ "su": "sundanese",
133
+ "yue": "cantonese",
134
+ }
135
+
136
+ # language code lookup by name, with a few language aliases
137
+ TO_LANGUAGE_CODE = {
138
+ **{language: code for code, language in LANGUAGES.items()},
139
+ "burmese": "my",
140
+ "valencian": "ca",
141
+ "flemish": "nl",
142
+ "haitian": "ht",
143
+ "letzeburgesch": "lb",
144
+ "pushto": "ps",
145
+ "panjabi": "pa",
146
+ "moldavian": "ro",
147
+ "moldovan": "ro",
148
+ "sinhalese": "si",
149
+ "castilian": "es",
150
+ "mandarin": "zh",
151
+ }
152
+
153
+
154
+ @dataclass
155
+ class Tokenizer:
156
+ """A thin wrapper around `tiktoken` providing quick access to special tokens"""
157
+
158
+ encoding: tiktoken.Encoding
159
+ num_languages: int
160
+ language: Optional[str] = None
161
+ task: Optional[str] = None
162
+ sot_sequence: Tuple[int] = ()
163
+ special_tokens: Dict[str, int] = field(default_factory=dict)
164
+
165
+ def __post_init__(self):
166
+ for special in self.encoding.special_tokens_set:
167
+ special_token = self.encoding.encode_single_token(special)
168
+ self.special_tokens[special] = special_token
169
+
170
+ sot: int = self.special_tokens["[startoftranscript]"]
171
+ translate: int = self.special_tokens["[translate]"]
172
+ transcribe: int = self.special_tokens["[transcribe]"]
173
+
174
+ langs = tuple(LANGUAGES.keys())[: self.num_languages]
175
+ sot_sequence = [sot]
176
+ if self.language is not None:
177
+ sot_sequence.append(sot + 1 + langs.index(self.language))
178
+ if self.task is not None:
179
+ task_token: int = transcribe if self.task == "transcribe" else translate
180
+ sot_sequence.append(task_token)
181
+
182
+ self.sot_sequence = tuple(sot_sequence)
183
+
184
+ def get_vocab_size(self):
185
+ return self.encoding.n_vocab
186
+
187
+ def encode(self, text):
188
+ return self.encoding.encode(text, allowed_special="all")
189
+
190
+ def decode(self, token_ids: List[int], **kwargs) -> str:
191
+ return self.encoding.decode(token_ids, **kwargs)
192
+
193
+ @cached_property
194
+ def eot(self) -> int:
195
+ return self.encoding.eot_token
196
+
197
+ @cached_property
198
+ def stop(self) -> int:
199
+ return self.special_tokens["[STOP]"]
200
+
201
+ @cached_property
202
+ def start(self) -> int:
203
+ return self.special_tokens["[START]"]
204
+
205
+ @cached_property
206
+ def transcribe(self) -> int:
207
+ return self.special_tokens["[transcribe]"]
208
+
209
+ @cached_property
210
+ def translate(self) -> int:
211
+ return self.special_tokens["[translate]"]
212
+
213
+ @cached_property
214
+ def sot(self) -> int:
215
+ return self.special_tokens["[startoftranscript]"]
216
+
217
+ @cached_property
218
+ def sot_lm(self) -> int:
219
+ return self.special_tokens["[startoflm]"]
220
+
221
+ @cached_property
222
+ def sot_prev(self) -> int:
223
+ return self.special_tokens["[startofprev]"]
224
+
225
+ @cached_property
226
+ def no_speech(self) -> int:
227
+ return self.special_tokens["[nospeech]"]
228
+
229
+ @cached_property
230
+ def language_token(self) -> int:
231
+ """Returns the token id corresponding to the value of the `language` field"""
232
+ if self.language is None:
233
+ raise ValueError("This tokenizer does not have language token configured")
234
+
235
+ return self.to_language_token(self.language)
236
+
237
+ def to_language_token(self, language):
238
+ if token := self.special_tokens.get(f"[{language}]", None):
239
+ return token
240
+
241
+ raise KeyError(f"Language {language} not found in tokenizer.")
242
+
243
+ @cached_property
244
+ def all_language_tokens(self) -> Tuple[int]:
245
+ result = []
246
+ for token, token_id in self.special_tokens.items():
247
+ if token.strip("[]") in LANGUAGES:
248
+ result.append(token_id)
249
+ return tuple(result)[: self.num_languages]
250
+
251
+ @cached_property
252
+ def all_language_codes(self) -> Tuple[str]:
253
+ return tuple(self.decode([_l]).strip("[]") for _l in self.all_language_tokens)
254
+
255
+ @cached_property
256
+ def non_speech_tokens(self) -> Tuple[int]:
257
+ """
258
+ Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
259
+ annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
260
+
261
+ - ♪♪♪
262
+ - ( SPEAKING FOREIGN LANGUAGE )
263
+ - [DAVID] Hey there,
264
+
265
+ keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
266
+ """
267
+ symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
268
+ symbols += (
269
+ "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
270
+ )
271
+
272
+ # symbols that may be a single token or multiple tokens depending on the tokenizer.
273
+ # In case they're multiple tokens, suppress the first token, which is safe because:
274
+ # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
275
+ # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
276
+ miscellaneous = set("♩♪♫♬♭♮♯")
277
+ assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
278
+
279
+ # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
280
+ result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
281
+ for symbol in symbols + list(miscellaneous):
282
+ for tokens in [
283
+ self.encoding.encode(symbol),
284
+ self.encoding.encode(" " + symbol),
285
+ ]:
286
+ if len(tokens) == 1 or symbol in miscellaneous:
287
+ result.add(tokens[0])
288
+
289
+ return tuple(sorted(result))
290
+
291
+ def split_to_word_tokens(self, tokens: List[int]):
292
+ if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
293
+ # These languages don't typically use spaces, so it is difficult to split words
294
+ # without morpheme analysis. Here, we instead split words at any
295
+ # position where the tokens are decoded as valid unicode points
296
+ return self.split_tokens_on_unicode(tokens)
297
+
298
+ return self.split_tokens_on_spaces(tokens)
299
+
300
+ def split_tokens_on_unicode(self, tokens: List[int]):
301
+ decoded_full = self.decode(tokens)
302
+ replacement_char = "\ufffd"
303
+
304
+ words = []
305
+ word_tokens = []
306
+ current_tokens = []
307
+ unicode_offset = 0
308
+
309
+ for token in tokens:
310
+ current_tokens.append(token)
311
+ decoded = self.decode(current_tokens)
312
+
313
+ if (
314
+ replacement_char not in decoded
315
+ or decoded_full[unicode_offset + decoded.index(replacement_char)]
316
+ == replacement_char
317
+ ):
318
+ words.append(decoded)
319
+ word_tokens.append(current_tokens)
320
+ current_tokens = []
321
+ unicode_offset += len(decoded)
322
+
323
+ return words, word_tokens
324
+
325
+ def split_tokens_on_spaces(self, tokens: List[int]):
326
+ subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
327
+ words = []
328
+ word_tokens = []
329
+
330
+ for subword, subword_tokens in zip(subwords, subword_tokens_list):
331
+ special = subword_tokens[0] >= self.eot
332
+ with_space = subword.startswith(" ")
333
+ punctuation = subword.strip() in string.punctuation
334
+ if special or with_space or punctuation or len(words) == 0:
335
+ words.append(subword)
336
+ word_tokens.append(subword_tokens)
337
+ else:
338
+ words[-1] = words[-1] + subword
339
+ word_tokens[-1].extend(subword_tokens)
340
+
341
+ return words, word_tokens
342
+
343
+
344
+ @lru_cache(maxsize=None)
345
+ def get_encoding(name: str = "multilingual", num_languages: int = 100):
346
+ vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
347
+ ranks = {
348
+ base64.b64decode(token): int(rank)
349
+ for token, rank in (line.split() for line in open(vocab_path) if line)
350
+ }
351
+ n_vocab = len(ranks)
352
+ special_tokens = {}
353
+
354
+ specials = [
355
+ "[STOP]",
356
+ "[UNK]",
357
+ "[SPACE]",
358
+ "[START]",
359
+ "[nospk]",
360
+ "[spkemb]",
361
+ "[emotionemb]",
362
+ "[contextemb]",
363
+ "[sbreak]",
364
+ "[pbreak]",
365
+ "[uvbreak]",
366
+ "[bsing]",
367
+ "[esing]",
368
+ "[sing]",
369
+ "[hum]",
370
+ "[laugh]",
371
+ "[break]",
372
+ "[breath]",
373
+ "[oralsii]",
374
+ "[oralze]",
375
+ "[prolong]",
376
+ "[stress]",
377
+ "[bstrong]",
378
+ "[estrong]",
379
+ "[hiccup]",
380
+ "[inhale]",
381
+ "[exhale]",
382
+ "[emounknown]",
383
+ "[happy]",
384
+ "[neutral]",
385
+ "[sad]",
386
+ "[surprise]",
387
+ "[angry]",
388
+ "[disgust]",
389
+ "[emo]",
390
+ "[laugha]",
391
+ "[laughb]",
392
+ "[laughc]",
393
+ "[orala]",
394
+ "[oralb]",
395
+ "[oralc]",
396
+ "[orald]",
397
+ "[orale]",
398
+ "[breaka]",
399
+ "[breakb]",
400
+ "[breakc]",
401
+ "[breakd]",
402
+ "[breake]",
403
+ "[breakf]",
404
+ "[endoftext]",
405
+ "[startoftranscript]",
406
+ *[f"[{lang}]" for lang in list(LANGUAGES.keys())[:num_languages]],
407
+ "[translate]",
408
+ "[transcribe]",
409
+ "[startoflm]",
410
+ "[startofprev]",
411
+ "[nospeech]",
412
+ ]
413
+
414
+ for token in specials:
415
+ special_tokens[token] = n_vocab
416
+ n_vocab += 1
417
+
418
+ return tiktoken.Encoding(
419
+ name=os.path.basename(vocab_path),
420
+ explicit_n_vocab=n_vocab,
421
+ pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d|\[[A-Z]+\]|\[[a-z]+\]|[\x{4e00}-\x{9df5}]| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
422
+ mergeable_ranks=ranks,
423
+ special_tokens=special_tokens,
424
+ )
425
+
426
+
427
+ @lru_cache(maxsize=None)
428
+ def get_tokenizer(
429
+ multilingual: bool,
430
+ *,
431
+ num_languages: int = 100,
432
+ language: Optional[str] = None,
433
+ task: Optional[str] = None, # Literal["transcribe", "translate", None]
434
+ ) -> Tokenizer:
435
+ if language is not None:
436
+ language = language.lower()
437
+ if language not in LANGUAGES:
438
+ if language in TO_LANGUAGE_CODE:
439
+ language = TO_LANGUAGE_CODE[language]
440
+ else:
441
+ raise ValueError(f"Unsupported language: {language}")
442
+
443
+ if multilingual:
444
+ encoding_name = "multilingual"
445
+ language = language or "en"
446
+ task = task or "transcribe"
447
+ else:
448
+ encoding_name = "gpt2"
449
+ language = None
450
+ task = None
451
+
452
+ encoding = get_encoding(name=encoding_name, num_languages=num_languages)
453
+
454
+ return Tokenizer(
455
+ encoding=encoding, num_languages=num_languages, language=language, task=task
456
+ )
fireredtts/utils/__init__.py ADDED
File without changes
fireredtts/utils/utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchaudio
7
+
8
+
9
+ def load_audio(audiopath, sampling_rate):
10
+ """_summary_
11
+
12
+ Args:
13
+ audiopath (_type_): audio_path
14
+ sampling_rate (_type_): sampling_rate
15
+
16
+ Returns:
17
+ _type_: _description_
18
+ """
19
+ audio, lsr = torchaudio.load(audiopath)
20
+
21
+ # stereo to mono if needed
22
+ if audio.size(0) != 1:
23
+ audio = torch.mean(audio, dim=0, keepdim=True)
24
+
25
+ # resample
26
+ audio_resampled = torchaudio.functional.resample(audio, lsr, sampling_rate)
27
+ if torch.any(audio > 10) or not torch.any(audio < 0):
28
+ print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
29
+
30
+ if torch.any(audio_resampled > 10) or not torch.any(audio_resampled < 0):
31
+ print(
32
+ f"Error with {audiopath}. Max={audio_resampled.max()} min={audio_resampled.min()}"
33
+ )
34
+ # clip audio invalid values
35
+ audio.clip_(-1, 1)
36
+ audio_resampled.clip_(-1, 1)
37
+ return audio, lsr, audio_resampled
pretrained_models/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ## Pretrained Models
2
+
3
+ Download the required model files and place them in the folder `pretrained_models`