Spaces:
Running
on
Zero
Running
on
Zero
update
Browse files- fireredtts/fireredtts.py +163 -0
- fireredtts/modules/__init__.py +42 -0
- fireredtts/modules/bigvgan/__init__.py +6 -0
- fireredtts/modules/bigvgan/activations.py +126 -0
- fireredtts/modules/bigvgan/alias_free_cuda/__init__.py +0 -0
- fireredtts/modules/bigvgan/alias_free_cuda/activation1d.py +75 -0
- fireredtts/modules/bigvgan/alias_free_cuda/anti_alias_activation.cpp +48 -0
- fireredtts/modules/bigvgan/alias_free_cuda/anti_alias_activation_cuda.cu +314 -0
- fireredtts/modules/bigvgan/alias_free_cuda/compat.h +31 -0
- fireredtts/modules/bigvgan/alias_free_cuda/load.py +85 -0
- fireredtts/modules/bigvgan/alias_free_cuda/test_activation.py +64 -0
- fireredtts/modules/bigvgan/alias_free_cuda/test_activation_snake_beta.py +64 -0
- fireredtts/modules/bigvgan/alias_free_cuda/type_shim.h +97 -0
- fireredtts/modules/bigvgan/alias_free_torch/__init__.py +5 -0
- fireredtts/modules/bigvgan/alias_free_torch/act.py +29 -0
- fireredtts/modules/bigvgan/alias_free_torch/filter.py +98 -0
- fireredtts/modules/bigvgan/alias_free_torch/resample.py +57 -0
- fireredtts/modules/bigvgan/bigvgan.py +399 -0
- fireredtts/modules/codec/speaker.py +1052 -0
- fireredtts/modules/flow/__init__.py +24 -0
- fireredtts/modules/flow/codebook.npy +3 -0
- fireredtts/modules/flow/codec_embedding.py +30 -0
- fireredtts/modules/flow/conformer.py +730 -0
- fireredtts/modules/flow/decoder.py +396 -0
- fireredtts/modules/flow/flow_model.py +89 -0
- fireredtts/modules/flow/mel_encoder.py +170 -0
- fireredtts/modules/flow/mel_spectrogram.py +132 -0
- fireredtts/modules/flow/transformer.py +249 -0
- fireredtts/modules/flow/utils.py +30 -0
- fireredtts/modules/gpt/__init__.py +0 -0
- fireredtts/modules/gpt/gpt.py +356 -0
- fireredtts/modules/text_normalizer/__init__.py +0 -0
- fireredtts/modules/text_normalizer/normalize.py +178 -0
- fireredtts/modules/text_normalizer/regex_common.py +23 -0
- fireredtts/modules/text_normalizer/utils.py +121 -0
- fireredtts/modules/tokenizer/__init__.py +0 -0
- fireredtts/modules/tokenizer/assets/multilingual.tiktoken +0 -0
- fireredtts/modules/tokenizer/tokenizer.py +46 -0
- fireredtts/modules/tokenizer/whisper_tokenizer.py +456 -0
- fireredtts/utils/__init__.py +0 -0
- fireredtts/utils/utils.py +37 -0
- pretrained_models/README.md +3 -0
fireredtts/fireredtts.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
from fireredtts.modules.gpt.gpt import GPT
|
5 |
+
from fireredtts.modules import Token2Wav, MelSpectrogramExtractor
|
6 |
+
from fireredtts.modules.tokenizer.tokenizer import VoiceBpeTokenizer
|
7 |
+
from fireredtts.modules.codec.speaker import SpeakerEmbedddingExtractor
|
8 |
+
from fireredtts.utils.utils import load_audio
|
9 |
+
|
10 |
+
import time
|
11 |
+
|
12 |
+
|
13 |
+
class FireRedTTS:
|
14 |
+
def __init__(self, config_path, pretrained_path, device="cuda"):
|
15 |
+
self.device = device
|
16 |
+
self.config = json.load(open(config_path))
|
17 |
+
self.gpt_path = os.path.join(pretrained_path, "fireredtts_gpt.pt")
|
18 |
+
self.token2wav_path = os.path.join(pretrained_path, "fireredtts_token2wav.pt")
|
19 |
+
self.speaker_extractor_path = os.path.join(
|
20 |
+
pretrained_path, "fireredtts_speaker.bin"
|
21 |
+
)
|
22 |
+
assert os.path.exists(self.token2wav_path)
|
23 |
+
assert os.path.exists(self.gpt_path)
|
24 |
+
assert os.path.exists(self.speaker_extractor_path)
|
25 |
+
|
26 |
+
# tokenizer;
|
27 |
+
self.text_tokenizer = VoiceBpeTokenizer()
|
28 |
+
|
29 |
+
# speaker ectractor
|
30 |
+
self.speaker_extractor = SpeakerEmbedddingExtractor(
|
31 |
+
ckpt_path=self.speaker_extractor_path, device=device
|
32 |
+
)
|
33 |
+
|
34 |
+
# load gpt model
|
35 |
+
self.gpt = GPT(
|
36 |
+
start_text_token=self.config["gpt"]["gpt_start_text_token"],
|
37 |
+
stop_text_token=self.config["gpt"]["gpt_stop_text_token"],
|
38 |
+
layers=self.config["gpt"]["gpt_layers"],
|
39 |
+
model_dim=self.config["gpt"]["gpt_n_model_channels"],
|
40 |
+
heads=self.config["gpt"]["gpt_n_heads"],
|
41 |
+
max_text_tokens=self.config["gpt"]["gpt_max_text_tokens"],
|
42 |
+
max_mel_tokens=self.config["gpt"]["gpt_max_audio_tokens"],
|
43 |
+
max_prompt_tokens=self.config["gpt"]["gpt_max_prompt_tokens"],
|
44 |
+
code_stride_len=self.config["gpt"]["gpt_code_stride_len"],
|
45 |
+
number_text_tokens=self.config["gpt"]["gpt_number_text_tokens"],
|
46 |
+
num_audio_tokens=self.config["gpt"]["gpt_num_audio_tokens"],
|
47 |
+
start_audio_token=self.config["gpt"]["gpt_start_audio_token"],
|
48 |
+
stop_audio_token=self.config["gpt"]["gpt_stop_audio_token"],
|
49 |
+
)
|
50 |
+
|
51 |
+
sd = torch.load(self.gpt_path, map_location=device)["model"]
|
52 |
+
self.gpt.load_state_dict(sd, strict=True)
|
53 |
+
self.gpt = self.gpt.to(device=device)
|
54 |
+
self.gpt.eval()
|
55 |
+
self.gpt.init_gpt_for_inference(kv_cache=True)
|
56 |
+
|
57 |
+
# mel-spectrogram extractor
|
58 |
+
self.mel_extractor = MelSpectrogramExtractor()
|
59 |
+
|
60 |
+
# load token2wav model
|
61 |
+
self.token2wav = Token2Wav.init_from_config(self.config)
|
62 |
+
sd = torch.load(self.token2wav_path, map_location="cpu")
|
63 |
+
self.token2wav.load_state_dict(sd, strict=True)
|
64 |
+
self.token2wav.generator.remove_weight_norm()
|
65 |
+
self.token2wav.eval()
|
66 |
+
self.token2wav = self.token2wav.to(device)
|
67 |
+
|
68 |
+
def extract_spk_embeddings(self, prompt_wav):
|
69 |
+
_, _, audio_resampled = load_audio(audiopath=prompt_wav, sampling_rate=16000)
|
70 |
+
audio_len = torch.tensor(
|
71 |
+
data=[audio_resampled.shape[1]], dtype=torch.long, requires_grad=False
|
72 |
+
)
|
73 |
+
|
74 |
+
# speaker embeddings [1,512]
|
75 |
+
spk_embeddings = self.speaker_extractor(
|
76 |
+
audio_resampled.to(device="cuda")
|
77 |
+
).unsqueeze(0)
|
78 |
+
|
79 |
+
return spk_embeddings
|
80 |
+
|
81 |
+
def do_gpt_inference(self, spk_gpt, text_tokens):
|
82 |
+
"""_summary_
|
83 |
+
|
84 |
+
Args:
|
85 |
+
spk_gpt (_type_): speaker embeddidng in gpt
|
86 |
+
text_tokens (_type_): text tokens
|
87 |
+
"""
|
88 |
+
with torch.no_grad():
|
89 |
+
gpt_codes = self.gpt.generate(
|
90 |
+
cond_latents=spk_gpt,
|
91 |
+
text_inputs=text_tokens,
|
92 |
+
input_tokens=None,
|
93 |
+
do_sample=True,
|
94 |
+
top_p=0.85,
|
95 |
+
top_k=30,
|
96 |
+
temperature=0.75,
|
97 |
+
num_return_sequences=9,
|
98 |
+
num_beams=1,
|
99 |
+
length_penalty=1.0,
|
100 |
+
repetition_penalty=2.0,
|
101 |
+
output_attentions=False,
|
102 |
+
)
|
103 |
+
|
104 |
+
seqs = []
|
105 |
+
EOS_TOKEN = self.config["gpt"]["gpt_stop_audio_token"]
|
106 |
+
for seq in gpt_codes:
|
107 |
+
index = (seq == EOS_TOKEN).nonzero(as_tuple=True)[0][0]
|
108 |
+
seq = seq[:index]
|
109 |
+
seqs.append(seq)
|
110 |
+
|
111 |
+
sorted_seqs = sorted(seqs, key=lambda i: len(i), reverse=False)
|
112 |
+
gpt_codes = sorted_seqs[2].unsqueeze(0) # [1, len]
|
113 |
+
# sorted_len = [len(l) for l in sorted_seqs]
|
114 |
+
# print("---sorted_len:", sorted_len)
|
115 |
+
|
116 |
+
return gpt_codes
|
117 |
+
|
118 |
+
def synthesize(self, prompt_wav, text, lang="auto"):
|
119 |
+
"""_summary_
|
120 |
+
|
121 |
+
Args:
|
122 |
+
prompts_wav (_type_): prompts_wav path
|
123 |
+
text (_type_): text
|
124 |
+
lang (_type_): language of text
|
125 |
+
"""
|
126 |
+
# Currently only supports Chinese and English
|
127 |
+
assert lang in ["zh", "en", "auto"]
|
128 |
+
assert os.path.exists(prompt_wav)
|
129 |
+
|
130 |
+
# text to tokens
|
131 |
+
text_tokens = self.text_tokenizer.encode(text=text, lang=lang)
|
132 |
+
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0).to(self.device)
|
133 |
+
assert text_tokens.shape[-1] < 400
|
134 |
+
|
135 |
+
# extract speaker embedding
|
136 |
+
spk_embeddings = self.extract_spk_embeddings(prompt_wav=prompt_wav).unsqueeze(0)
|
137 |
+
with torch.no_grad():
|
138 |
+
spk_gpt = self.gpt.reference_embedding(spk_embeddings)
|
139 |
+
|
140 |
+
# gpt inference
|
141 |
+
gpt_start_time = time.time()
|
142 |
+
gpt_codes = self.do_gpt_inference(spk_gpt=spk_gpt, text_tokens=text_tokens)
|
143 |
+
gpt_end_time = time.time()
|
144 |
+
gpt_dur = gpt_end_time - gpt_start_time
|
145 |
+
|
146 |
+
# prompt mel-spectrogram compute
|
147 |
+
prompt_mel = (
|
148 |
+
self.mel_extractor(wav_path=prompt_wav).unsqueeze(0).to(self.device)
|
149 |
+
)
|
150 |
+
# convert token to waveform (b=1, t)
|
151 |
+
voc_start_time = time.time()
|
152 |
+
rec_wavs = self.token2wav.inference(gpt_codes, prompt_mel, n_timesteps=10)
|
153 |
+
voc_end_time = time.time()
|
154 |
+
voc_dur = voc_end_time - voc_start_time
|
155 |
+
all_dur = voc_end_time - gpt_start_time
|
156 |
+
|
157 |
+
# rtf compute
|
158 |
+
# audio_dur = rec_wavs.shape[-1] / 24000
|
159 |
+
# rtf_gpt = gpt_dur / audio_dur
|
160 |
+
# rtf_voc = voc_dur / audio_dur
|
161 |
+
# rtf_all = all_dur / audio_dur
|
162 |
+
|
163 |
+
return rec_wavs
|
fireredtts/modules/__init__.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from fireredtts.modules.bigvgan import get_bigvgan_backend
|
5 |
+
from fireredtts.modules.flow import get_flow_frontend, MelSpectrogramExtractor
|
6 |
+
|
7 |
+
|
8 |
+
class Token2Wav(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
flow: nn.Module,
|
12 |
+
generator: nn.Module,
|
13 |
+
):
|
14 |
+
super().__init__()
|
15 |
+
self.flow = flow
|
16 |
+
self.generator = generator
|
17 |
+
|
18 |
+
@torch.no_grad()
|
19 |
+
def inference(
|
20 |
+
self, tokens: torch.Tensor, prompt_mel: torch.Tensor, n_timesteps: int = 10
|
21 |
+
) -> torch.Tensor:
|
22 |
+
token_len = torch.tensor([tokens.shape[1]], dtype=torch.long).to(tokens.device)
|
23 |
+
prompt_mel_len = torch.tensor([prompt_mel.shape[1]], dtype=torch.long).to(
|
24 |
+
prompt_mel.device
|
25 |
+
)
|
26 |
+
# flow
|
27 |
+
mel = self.flow.inference(
|
28 |
+
token=tokens,
|
29 |
+
token_len=token_len,
|
30 |
+
prompt_mel=prompt_mel,
|
31 |
+
prompt_mel_len=prompt_mel_len,
|
32 |
+
n_timesteps=n_timesteps,
|
33 |
+
)
|
34 |
+
# bigvgan
|
35 |
+
audio = self.generator(mel) # (b=1, 1, t)
|
36 |
+
return audio.squeeze(1)
|
37 |
+
|
38 |
+
@classmethod
|
39 |
+
def init_from_config(cls, config) -> "Token2Wav":
|
40 |
+
flow = get_flow_frontend(config["flow"])
|
41 |
+
bigvgan = get_bigvgan_backend(config["bigvgan"])
|
42 |
+
return cls(flow, bigvgan)
|
fireredtts/modules/bigvgan/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fireredtts.modules.bigvgan.bigvgan import BigVGAN
|
2 |
+
|
3 |
+
|
4 |
+
def get_bigvgan_backend(bigvgan_config):
|
5 |
+
generator = BigVGAN(**bigvgan_config)
|
6 |
+
return generator
|
fireredtts/modules/bigvgan/activations.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
2 |
+
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn, sin, pow
|
6 |
+
from torch.nn import Parameter
|
7 |
+
|
8 |
+
|
9 |
+
class Snake(nn.Module):
|
10 |
+
"""
|
11 |
+
Implementation of a sine-based periodic activation function
|
12 |
+
Shape:
|
13 |
+
- Input: (B, C, T)
|
14 |
+
- Output: (B, C, T), same shape as the input
|
15 |
+
Parameters:
|
16 |
+
- alpha - trainable parameter
|
17 |
+
References:
|
18 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
19 |
+
https://arxiv.org/abs/2006.08195
|
20 |
+
Examples:
|
21 |
+
>>> a1 = snake(256)
|
22 |
+
>>> x = torch.randn(256)
|
23 |
+
>>> x = a1(x)
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
Initialization.
|
31 |
+
INPUT:
|
32 |
+
- in_features: shape of the input
|
33 |
+
- alpha: trainable parameter
|
34 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
35 |
+
alpha will be trained along with the rest of your model.
|
36 |
+
"""
|
37 |
+
super(Snake, self).__init__()
|
38 |
+
self.in_features = in_features
|
39 |
+
|
40 |
+
# initialize alpha
|
41 |
+
self.alpha_logscale = alpha_logscale
|
42 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
43 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
44 |
+
else: # linear scale alphas initialized to ones
|
45 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
46 |
+
|
47 |
+
self.alpha.requires_grad = alpha_trainable
|
48 |
+
|
49 |
+
self.no_div_by_zero = 0.000000001
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
"""
|
53 |
+
Forward pass of the function.
|
54 |
+
Applies the function to the input elementwise.
|
55 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
56 |
+
"""
|
57 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
58 |
+
if self.alpha_logscale:
|
59 |
+
alpha = torch.exp(alpha)
|
60 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
61 |
+
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class SnakeBeta(nn.Module):
|
66 |
+
"""
|
67 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
68 |
+
Shape:
|
69 |
+
- Input: (B, C, T)
|
70 |
+
- Output: (B, C, T), same shape as the input
|
71 |
+
Parameters:
|
72 |
+
- alpha - trainable parameter that controls frequency
|
73 |
+
- beta - trainable parameter that controls magnitude
|
74 |
+
References:
|
75 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
76 |
+
https://arxiv.org/abs/2006.08195
|
77 |
+
Examples:
|
78 |
+
>>> a1 = snakebeta(256)
|
79 |
+
>>> x = torch.randn(256)
|
80 |
+
>>> x = a1(x)
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(
|
84 |
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
85 |
+
):
|
86 |
+
"""
|
87 |
+
Initialization.
|
88 |
+
INPUT:
|
89 |
+
- in_features: shape of the input
|
90 |
+
- alpha - trainable parameter that controls frequency
|
91 |
+
- beta - trainable parameter that controls magnitude
|
92 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
93 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
94 |
+
alpha will be trained along with the rest of your model.
|
95 |
+
"""
|
96 |
+
super(SnakeBeta, self).__init__()
|
97 |
+
self.in_features = in_features
|
98 |
+
|
99 |
+
# initialize alpha
|
100 |
+
self.alpha_logscale = alpha_logscale
|
101 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
102 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
103 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
104 |
+
else: # linear scale alphas initialized to ones
|
105 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
106 |
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
107 |
+
|
108 |
+
self.alpha.requires_grad = alpha_trainable
|
109 |
+
self.beta.requires_grad = alpha_trainable
|
110 |
+
|
111 |
+
self.no_div_by_zero = 0.000000001
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
"""
|
115 |
+
Forward pass of the function.
|
116 |
+
Applies the function to the input elementwise.
|
117 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
118 |
+
"""
|
119 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
120 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
121 |
+
if self.alpha_logscale:
|
122 |
+
alpha = torch.exp(alpha)
|
123 |
+
beta = torch.exp(beta)
|
124 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
125 |
+
|
126 |
+
return x
|
fireredtts/modules/bigvgan/alias_free_cuda/__init__.py
ADDED
File without changes
|
fireredtts/modules/bigvgan/alias_free_cuda/activation1d.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from token2wav.alias_free_torch.resample import UpSample1d, DownSample1d
|
7 |
+
|
8 |
+
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
9 |
+
from token2wav.alias_free_cuda import load
|
10 |
+
|
11 |
+
load.load()
|
12 |
+
|
13 |
+
|
14 |
+
class FusedAntiAliasActivation(torch.autograd.Function):
|
15 |
+
"""
|
16 |
+
Assumes filter size 12, replication padding on upsampling, and logscale alpha/beta parameters as inputs
|
17 |
+
"""
|
18 |
+
|
19 |
+
@staticmethod
|
20 |
+
def forward(ctx, inputs, ftr, alpha, beta):
|
21 |
+
import anti_alias_activation_cuda
|
22 |
+
|
23 |
+
activation_results = anti_alias_activation_cuda.forward(
|
24 |
+
inputs, ftr, alpha, beta
|
25 |
+
)
|
26 |
+
return activation_results
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def backward(ctx, output_grads):
|
30 |
+
# TODO: implement bwd pass
|
31 |
+
raise NotImplementedError
|
32 |
+
return output_grads, None, None
|
33 |
+
|
34 |
+
|
35 |
+
class Activation1d(nn.Module):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
activation,
|
39 |
+
up_ratio: int = 2,
|
40 |
+
down_ratio: int = 2,
|
41 |
+
up_kernel_size: int = 12,
|
42 |
+
down_kernel_size: int = 12,
|
43 |
+
fused: bool = True,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
self.up_ratio = up_ratio
|
47 |
+
self.down_ratio = down_ratio
|
48 |
+
self.act = activation
|
49 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
50 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
51 |
+
|
52 |
+
self.fused = fused # whether to use fused CUDA kernel or not
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
if not self.fused:
|
56 |
+
x = self.upsample(x)
|
57 |
+
x = self.act(x)
|
58 |
+
x = self.downsample(x)
|
59 |
+
return x
|
60 |
+
else:
|
61 |
+
if self.act.__class__.__name__ == "Snake":
|
62 |
+
beta = self.act.alpha.data # snake uses same params for alpha and beta
|
63 |
+
else:
|
64 |
+
beta = (
|
65 |
+
self.act.beta.data
|
66 |
+
) # snakebeta uses different params for alpha and beta
|
67 |
+
alpha = self.act.alpha.data
|
68 |
+
if (
|
69 |
+
not self.act.alpha_logscale
|
70 |
+
): # exp baked into cuda kernel, cancel it out with a log
|
71 |
+
alpha = torch.log(alpha)
|
72 |
+
beta = torch.log(beta)
|
73 |
+
x = FusedAntiAliasActivation.apply(x, self.upsample.filter, alpha, beta)
|
74 |
+
x = self.downsample(x)
|
75 |
+
return x
|
fireredtts/modules/bigvgan/alias_free_cuda/anti_alias_activation.cpp
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#include <cuda_fp16.h>
|
18 |
+
#include <torch/extension.h>
|
19 |
+
#include <vector>
|
20 |
+
|
21 |
+
namespace anti_alias_activation {
|
22 |
+
|
23 |
+
torch::Tensor fwd_cuda(torch::Tensor const& input,
|
24 |
+
torch::Tensor const& filter,
|
25 |
+
torch::Tensor const& alpha,
|
26 |
+
torch::Tensor const& beta
|
27 |
+
);
|
28 |
+
|
29 |
+
torch::Tensor fwd(torch::Tensor const& input,
|
30 |
+
torch::Tensor const& filter,
|
31 |
+
torch::Tensor const& alpha,
|
32 |
+
torch::Tensor const& beta
|
33 |
+
) {
|
34 |
+
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
|
35 |
+
//AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
|
36 |
+
// (input.scalar_type() == at::ScalarType::BFloat16),
|
37 |
+
// "Only fp16 and bf16 are supported");
|
38 |
+
|
39 |
+
return fwd_cuda(input, filter, alpha, beta);
|
40 |
+
}
|
41 |
+
|
42 |
+
} // end namespace anti_alias_activation
|
43 |
+
|
44 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
45 |
+
m.def("forward",
|
46 |
+
&anti_alias_activation::fwd,
|
47 |
+
"Anti Alias Activation -- Forward.");
|
48 |
+
}
|
fireredtts/modules/bigvgan/alias_free_cuda/anti_alias_activation_cuda.cu
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#include <ATen/ATen.h>
|
18 |
+
#include <cuda.h>
|
19 |
+
#include <cuda_runtime.h>
|
20 |
+
#include <cuda_fp16.h>
|
21 |
+
#include <cuda_profiler_api.h>
|
22 |
+
#include <ATen/cuda/CUDAContext.h>
|
23 |
+
#include <torch/extension.h>
|
24 |
+
#include "type_shim.h"
|
25 |
+
#include <assert.h>
|
26 |
+
#include <cfloat>
|
27 |
+
#include <limits>
|
28 |
+
#include <stdint.h>
|
29 |
+
#include <c10/macros/Macros.h>
|
30 |
+
|
31 |
+
namespace {
|
32 |
+
|
33 |
+
/*
|
34 |
+
template <typename Datatype, int ELEMENTS_PER_LDG>
|
35 |
+
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
|
36 |
+
|
37 |
+
template <>
|
38 |
+
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
|
39 |
+
|
40 |
+
template <>
|
41 |
+
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
|
42 |
+
|
43 |
+
template <>
|
44 |
+
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
|
45 |
+
|
46 |
+
template <>
|
47 |
+
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
|
48 |
+
|
49 |
+
template <>
|
50 |
+
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
|
51 |
+
|
52 |
+
template <>
|
53 |
+
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
|
54 |
+
|
55 |
+
int log2_ceil(int value) {
|
56 |
+
int log2_value = 0;
|
57 |
+
while ((1 << log2_value) < value) ++log2_value;
|
58 |
+
return log2_value;
|
59 |
+
}
|
60 |
+
|
61 |
+
template<typename T>
|
62 |
+
struct Add {
|
63 |
+
__device__ __forceinline__ T operator()(T a, T b) const {
|
64 |
+
return a + b;
|
65 |
+
}
|
66 |
+
};
|
67 |
+
|
68 |
+
template<typename T>
|
69 |
+
struct Max {
|
70 |
+
__device__ __forceinline__ T operator()(T a, T b) const {
|
71 |
+
return a < b ? b : a;
|
72 |
+
}
|
73 |
+
};
|
74 |
+
|
75 |
+
template <typename T>
|
76 |
+
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
|
77 |
+
{
|
78 |
+
#if CUDA_VERSION >= 9000
|
79 |
+
return __shfl_xor_sync(mask, value, laneMask, width);
|
80 |
+
#else
|
81 |
+
return __shfl_xor(value, laneMask, width);
|
82 |
+
#endif
|
83 |
+
}
|
84 |
+
|
85 |
+
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
|
86 |
+
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
|
87 |
+
ReduceOp<acc_t> r;
|
88 |
+
#pragma unroll
|
89 |
+
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
|
90 |
+
#pragma unroll
|
91 |
+
for (int i = 0; i < WARP_BATCH; ++i) {
|
92 |
+
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
|
93 |
+
sum[i] = r(sum[i], b);
|
94 |
+
}
|
95 |
+
}
|
96 |
+
}
|
97 |
+
*/
|
98 |
+
|
99 |
+
template <typename input_t, typename output_t, typename acc_t>
|
100 |
+
__global__ void anti_alias_activation_forward(
|
101 |
+
output_t *dst,
|
102 |
+
const input_t *src,
|
103 |
+
const input_t *ftr,
|
104 |
+
const input_t *alpha,
|
105 |
+
const input_t *beta,
|
106 |
+
int batch_size,
|
107 |
+
int channels,
|
108 |
+
int seq_len)
|
109 |
+
{
|
110 |
+
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
111 |
+
constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
|
112 |
+
constexpr int BUFFER_SIZE = 32;
|
113 |
+
constexpr int FILTER_SIZE = 12;
|
114 |
+
constexpr int HALF_FILTER_SIZE = 6;
|
115 |
+
constexpr int REPLICATION_PAD = 5; // 5 on each side
|
116 |
+
|
117 |
+
// blockDim/threadIdx = (128, 1, 1)
|
118 |
+
// gridDim/blockIdx = (seq_blocks, channels, batches)
|
119 |
+
int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
120 |
+
int local_offset = threadIdx.x * BUFFER_SIZE;
|
121 |
+
int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
|
122 |
+
|
123 |
+
|
124 |
+
//int intermediate_seq_len = seq_len * 2 - 1 + 4 * REPLICATION_PAD;
|
125 |
+
//int intermediate_block_offset = (blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
126 |
+
//int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
|
127 |
+
|
128 |
+
int output_seq_len = seq_len * 2 ; //
|
129 |
+
int output_block_offset = (blockIdx.x * 128 * BUFFER_SIZE * 2 + output_seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
130 |
+
int output_local_offset = threadIdx.x * BUFFER_SIZE * 2;
|
131 |
+
int output_seq_offset = blockIdx.x * 128 * BUFFER_SIZE *2 + output_local_offset;
|
132 |
+
// get values needed for replication padding before moving pointer
|
133 |
+
const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
134 |
+
input_t seq_left_most_value = right_most_pntr[0];
|
135 |
+
input_t seq_right_most_value = right_most_pntr[seq_len - 1];
|
136 |
+
|
137 |
+
src += block_offset + local_offset;
|
138 |
+
dst += output_block_offset + output_local_offset ;
|
139 |
+
alpha = alpha + blockIdx.y;
|
140 |
+
input_t alpha_val = expf(alpha[0]);
|
141 |
+
beta = beta + blockIdx.y;
|
142 |
+
input_t beta_val = expf(beta[0]);
|
143 |
+
// load data from global memory
|
144 |
+
input_t elements[2*FILTER_SIZE+2*BUFFER_SIZE] = {0};
|
145 |
+
input_t intermediates[2*FILTER_SIZE+2*BUFFER_SIZE] = {0};
|
146 |
+
//output_t output[2*BUFFER_SIZE];
|
147 |
+
input_t filter[FILTER_SIZE];
|
148 |
+
//input_t temp_data[ELEMENTS_PER_LDG_STG];
|
149 |
+
//uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
|
150 |
+
|
151 |
+
#pragma unroll
|
152 |
+
for (int it = 0; it < FILTER_SIZE; it+=1) {
|
153 |
+
filter[it] = ftr[it];
|
154 |
+
}
|
155 |
+
|
156 |
+
|
157 |
+
#pragma unroll
|
158 |
+
for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE ; it+=1) {
|
159 |
+
int element_index = seq_offset + it;
|
160 |
+
if ((element_index < 0) && (element_index >= -REPLICATION_PAD)) {
|
161 |
+
elements[2*(HALF_FILTER_SIZE+it)] = 2*seq_left_most_value;
|
162 |
+
}
|
163 |
+
if ((element_index >= seq_len) && (element_index < seq_len + REPLICATION_PAD)) {
|
164 |
+
elements[2*(HALF_FILTER_SIZE+it)] = 2*seq_right_most_value;
|
165 |
+
}
|
166 |
+
if ((element_index >= 0) && (element_index < seq_len)) {
|
167 |
+
elements[2*(HALF_FILTER_SIZE+it)] = 2*src[it];
|
168 |
+
}
|
169 |
+
}
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
// apply filter
|
174 |
+
#pragma unroll
|
175 |
+
for (int it = 0; it < (2 * BUFFER_SIZE + 2*FILTER_SIZE); it+=1) {
|
176 |
+
input_t acc = 0.0;
|
177 |
+
|
178 |
+
int element_index = output_seq_offset + it; // index for output
|
179 |
+
#pragma unroll
|
180 |
+
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx+=1){
|
181 |
+
if ((element_index + f_idx) >= 0){
|
182 |
+
acc += filter[f_idx] * elements[it+f_idx];
|
183 |
+
}
|
184 |
+
}
|
185 |
+
intermediates[it] = acc;
|
186 |
+
}
|
187 |
+
|
188 |
+
double no_div_by_zero = 0.000000001;
|
189 |
+
#pragma unroll
|
190 |
+
for (int it = 0; it < 12 + 2 * BUFFER_SIZE; it++) {
|
191 |
+
intermediates[it] += (1.0/(beta_val + no_div_by_zero)) * sinf(intermediates[it] * alpha_val) * sinf(intermediates[it] * alpha_val);
|
192 |
+
}
|
193 |
+
|
194 |
+
|
195 |
+
// now copy to output
|
196 |
+
#pragma unroll
|
197 |
+
for (int it = 0; it < 2*BUFFER_SIZE; it+=1){
|
198 |
+
int element_index = output_seq_offset + it;
|
199 |
+
if (element_index < output_seq_len) {
|
200 |
+
dst[it] = intermediates[it+6];
|
201 |
+
}
|
202 |
+
}
|
203 |
+
|
204 |
+
|
205 |
+
|
206 |
+
// for (int it = 0; it < BUFFER_SIZE; it+=ELEMENTS_PER_LDG_STG) {
|
207 |
+
// int element_index = seq_offset + it;
|
208 |
+
// if (element_index < seq_len) {
|
209 |
+
// dst[it] = output[it];
|
210 |
+
// }
|
211 |
+
// }
|
212 |
+
|
213 |
+
|
214 |
+
// // Upsample convolution
|
215 |
+
// for (int it = 0; it < 2 * BUFFER_SIZE + 12; it+=1) {
|
216 |
+
// input_t acc = 0.0;
|
217 |
+
|
218 |
+
// for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx+=1){
|
219 |
+
// acc += filter[f_idx] * elements[it+f_idx];
|
220 |
+
// }
|
221 |
+
// intermediates[it] = acc;
|
222 |
+
// }
|
223 |
+
|
224 |
+
// // correct the corners of intermediates
|
225 |
+
// if (seq_offset == 0) {
|
226 |
+
// for (int it = 0; it < 6; it+=1)
|
227 |
+
// intermediates[it] = 0;
|
228 |
+
// }
|
229 |
+
|
230 |
+
// if (seq_offset + 32 >= seq_len) {
|
231 |
+
// int offset = seq_len % 32 == 0 ? 32 : seq_len % 32;
|
232 |
+
|
233 |
+
// for (int it = 0; it < 6; it++) {
|
234 |
+
// intermediates[6+2*offset+it] = 0;
|
235 |
+
// }
|
236 |
+
// }
|
237 |
+
|
238 |
+
|
239 |
+
|
240 |
+
|
241 |
+
// for (int it = 0; it < BUFFER_SIZE; it+=ELEMENTS_PER_LDG_STG) {
|
242 |
+
// int element_index = seq_offset + it;
|
243 |
+
// if (element_index < seq_len) {
|
244 |
+
// dst[it] = output[it];
|
245 |
+
// }
|
246 |
+
// }
|
247 |
+
}
|
248 |
+
|
249 |
+
template<typename input_t, typename output_t, typename acc_t>
|
250 |
+
void dispatch_anti_alias_activation_forward(
|
251 |
+
output_t *dst,
|
252 |
+
const input_t *src,
|
253 |
+
const input_t *ftr,
|
254 |
+
const input_t *alpha,
|
255 |
+
const input_t *beta,
|
256 |
+
int batch_size,
|
257 |
+
int channels,
|
258 |
+
int seq_len)
|
259 |
+
{
|
260 |
+
if (seq_len == 0) {
|
261 |
+
return;
|
262 |
+
} else {
|
263 |
+
// use 128 threads per block to maximimize gpu utilization
|
264 |
+
constexpr int threads_per_block = 128;
|
265 |
+
constexpr int seq_len_per_block = 4096;
|
266 |
+
int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
|
267 |
+
dim3 blocks(blocks_per_seq_len, channels, batch_size);
|
268 |
+
dim3 threads(threads_per_block, 1, 1);
|
269 |
+
|
270 |
+
anti_alias_activation_forward<input_t, output_t, acc_t>
|
271 |
+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, ftr, alpha, beta, batch_size, channels, seq_len);
|
272 |
+
}
|
273 |
+
}
|
274 |
+
}
|
275 |
+
|
276 |
+
namespace anti_alias_activation {
|
277 |
+
|
278 |
+
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& filter, torch::Tensor const& alpha, torch::Tensor const& beta)
|
279 |
+
{
|
280 |
+
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
|
281 |
+
const int batches = input.size(0);
|
282 |
+
const int channels = input.size(1);
|
283 |
+
const int seq_len = input.size(2);
|
284 |
+
|
285 |
+
// Output
|
286 |
+
auto act_options = input.options().requires_grad(false);
|
287 |
+
int output_seq_len = seq_len*2; // we'll be dilating between each element by interspersing with zeros
|
288 |
+
|
289 |
+
torch::Tensor anti_alias_activation_results =
|
290 |
+
torch::empty({batches, channels, output_seq_len}, act_options);
|
291 |
+
|
292 |
+
// Softmax Intermediate Result Ptr
|
293 |
+
void* input_ptr = static_cast<void*>(input.data_ptr());
|
294 |
+
void* filter_ptr = static_cast<void*>(filter.data_ptr());
|
295 |
+
void* alpha_ptr = static_cast<void*>(alpha.data_ptr());
|
296 |
+
void* beta_ptr = static_cast<void*>(beta.data_ptr());
|
297 |
+
void* anti_alias_activation_results_ptr = static_cast<void*>(anti_alias_activation_results.data_ptr());
|
298 |
+
|
299 |
+
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
300 |
+
input.scalar_type(),
|
301 |
+
"dispatch anti alias activation_forward",
|
302 |
+
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
|
303 |
+
reinterpret_cast<scalar_t*>(anti_alias_activation_results_ptr),
|
304 |
+
reinterpret_cast<const scalar_t*>(input_ptr),
|
305 |
+
reinterpret_cast<const scalar_t*>(filter_ptr),
|
306 |
+
reinterpret_cast<const scalar_t*>(alpha_ptr),
|
307 |
+
reinterpret_cast<const scalar_t*>(beta_ptr),
|
308 |
+
batches,
|
309 |
+
channels,
|
310 |
+
seq_len);
|
311 |
+
);
|
312 |
+
return anti_alias_activation_results;
|
313 |
+
}
|
314 |
+
}
|
fireredtts/modules/bigvgan/alias_free_cuda/compat.h
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
/*This code is copied fron NVIDIA apex:
|
18 |
+
* https://github.com/NVIDIA/apex
|
19 |
+
* with minor changes. */
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
#ifndef TORCH_CHECK
|
24 |
+
#define TORCH_CHECK AT_CHECK
|
25 |
+
#endif
|
26 |
+
|
27 |
+
#ifdef VERSION_GE_1_3
|
28 |
+
#define DATA_PTR data_ptr
|
29 |
+
#else
|
30 |
+
#define DATA_PTR data
|
31 |
+
#endif
|
fireredtts/modules/bigvgan/alias_free_cuda/load.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
import os
|
5 |
+
import pathlib
|
6 |
+
import subprocess
|
7 |
+
|
8 |
+
from torch.utils import cpp_extension
|
9 |
+
|
10 |
+
# Setting this param to a list has a problem of generating different
|
11 |
+
# compilation commands (with diferent order of architectures) and
|
12 |
+
# leading to recompilation of fused kernels. Set it to empty string
|
13 |
+
# to avoid recompilation and assign arch flags explicity in
|
14 |
+
# extra_cuda_cflags below
|
15 |
+
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
|
16 |
+
|
17 |
+
|
18 |
+
def load():
|
19 |
+
# Check if cuda 11 is installed for compute capability 8.0
|
20 |
+
cc_flag = []
|
21 |
+
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
|
22 |
+
if int(bare_metal_major) >= 11:
|
23 |
+
cc_flag.append("-gencode")
|
24 |
+
cc_flag.append("arch=compute_80,code=sm_80")
|
25 |
+
|
26 |
+
# Build path
|
27 |
+
srcpath = pathlib.Path(__file__).parent.absolute()
|
28 |
+
buildpath = srcpath / "build"
|
29 |
+
_create_build_dir(buildpath)
|
30 |
+
|
31 |
+
# Helper function to build the kernels.
|
32 |
+
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
|
33 |
+
return cpp_extension.load(
|
34 |
+
name=name,
|
35 |
+
sources=sources,
|
36 |
+
build_directory=buildpath,
|
37 |
+
extra_cflags=[
|
38 |
+
"-O3",
|
39 |
+
],
|
40 |
+
extra_cuda_cflags=[
|
41 |
+
"-O3",
|
42 |
+
"-gencode",
|
43 |
+
"arch=compute_70,code=sm_70",
|
44 |
+
"--use_fast_math",
|
45 |
+
]
|
46 |
+
+ extra_cuda_flags
|
47 |
+
+ cc_flag,
|
48 |
+
verbose=True,
|
49 |
+
)
|
50 |
+
|
51 |
+
extra_cuda_flags = [
|
52 |
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
53 |
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
54 |
+
"--expt-relaxed-constexpr",
|
55 |
+
"--expt-extended-lambda",
|
56 |
+
]
|
57 |
+
|
58 |
+
sources = [
|
59 |
+
srcpath / "anti_alias_activation.cpp",
|
60 |
+
srcpath / "anti_alias_activation_cuda.cu",
|
61 |
+
]
|
62 |
+
anti_alias_activation_cuda = _cpp_extention_load_helper(
|
63 |
+
"anti_alias_activation_cuda", sources, extra_cuda_flags
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
def _get_cuda_bare_metal_version(cuda_dir):
|
68 |
+
raw_output = subprocess.check_output(
|
69 |
+
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
70 |
+
)
|
71 |
+
output = raw_output.split()
|
72 |
+
release_idx = output.index("release") + 1
|
73 |
+
release = output[release_idx].split(".")
|
74 |
+
bare_metal_major = release[0]
|
75 |
+
bare_metal_minor = release[1][0]
|
76 |
+
|
77 |
+
return raw_output, bare_metal_major, bare_metal_minor
|
78 |
+
|
79 |
+
|
80 |
+
def _create_build_dir(buildpath):
|
81 |
+
try:
|
82 |
+
os.mkdir(buildpath)
|
83 |
+
except OSError:
|
84 |
+
if not os.path.isdir(buildpath):
|
85 |
+
print(f"Creation of the build directory {buildpath} failed")
|
fireredtts/modules/bigvgan/alias_free_cuda/test_activation.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
import math
|
5 |
+
import torch
|
6 |
+
import alias_free_cuda
|
7 |
+
from alias_free_cuda import activation1d
|
8 |
+
from activations import Snake, SnakeBeta
|
9 |
+
|
10 |
+
|
11 |
+
def test_load_fused_kernels():
|
12 |
+
try:
|
13 |
+
import alias_free_cuda
|
14 |
+
import torch
|
15 |
+
|
16 |
+
print("[Success] load_fused_kernels")
|
17 |
+
except ImportError as e:
|
18 |
+
print("[Fail] load_fused_kernels")
|
19 |
+
raise e
|
20 |
+
|
21 |
+
|
22 |
+
def test_anti_alias_activation():
|
23 |
+
data = torch.rand((10, 10, 50000), device="cuda")
|
24 |
+
|
25 |
+
# check activations.Snake cuda vs. torch
|
26 |
+
fused_anti_alias_activation = activation1d.Activation1d(
|
27 |
+
activation=Snake(10), fused=True
|
28 |
+
).cuda()
|
29 |
+
fused_activation_output = fused_anti_alias_activation(data)
|
30 |
+
|
31 |
+
torch_anti_alias_activation = activation1d.Activation1d(
|
32 |
+
activation=Snake(10), fused=False
|
33 |
+
).cuda()
|
34 |
+
torch_activation_output = torch_anti_alias_activation(data)
|
35 |
+
|
36 |
+
test_result = (fused_activation_output - torch_activation_output).abs()
|
37 |
+
|
38 |
+
while test_result.dim() != 1:
|
39 |
+
test_result = test_result.mean(dim=-1)
|
40 |
+
|
41 |
+
diff = test_result.mean(dim=-1)
|
42 |
+
|
43 |
+
if diff <= 1e-3:
|
44 |
+
print(
|
45 |
+
f"\n[Success] test_fused_anti_alias_activation"
|
46 |
+
f"\n > mean_difference={diff}"
|
47 |
+
f"\n > fused_values={fused_activation_output[-1][-1][-100:].tolist()}"
|
48 |
+
f"\n > torch_values={torch_activation_output[-1][-1][-100:].tolist()}"
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
print(
|
52 |
+
f"\n[Fail] test_fused_anti_alias_activation"
|
53 |
+
f"\n > mean_difference={diff}, "
|
54 |
+
f"\n > fused_values={fused_activation_output[-1][-1][-30:].tolist()}, "
|
55 |
+
f"\n > torch_values={torch_activation_output[-1][-1][-30:].tolist()}"
|
56 |
+
)
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
from alias_free_cuda import load
|
61 |
+
|
62 |
+
load.load()
|
63 |
+
test_load_fused_kernels()
|
64 |
+
test_anti_alias_activation()
|
fireredtts/modules/bigvgan/alias_free_cuda/test_activation_snake_beta.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
import math
|
5 |
+
import torch
|
6 |
+
import alias_free_cuda
|
7 |
+
from alias_free_cuda import activation1d
|
8 |
+
from activations import Snake, SnakeBeta
|
9 |
+
|
10 |
+
|
11 |
+
def test_load_fused_kernels():
|
12 |
+
try:
|
13 |
+
import alias_free_cuda
|
14 |
+
import torch
|
15 |
+
|
16 |
+
print("[Success] load_fused_kernels")
|
17 |
+
except ImportError as e:
|
18 |
+
print("[Fail] load_fused_kernels")
|
19 |
+
raise e
|
20 |
+
|
21 |
+
|
22 |
+
def test_anti_alias_activation():
|
23 |
+
data = torch.rand((10, 10, 50000), device="cuda")
|
24 |
+
|
25 |
+
# check activations.Snake cuda vs. torch
|
26 |
+
fused_anti_alias_activation = activation1d.Activation1d(
|
27 |
+
activation=SnakeBeta(10), fused=True
|
28 |
+
).cuda()
|
29 |
+
fused_activation_output = fused_anti_alias_activation(data)
|
30 |
+
|
31 |
+
torch_anti_alias_activation = activation1d.Activation1d(
|
32 |
+
activation=SnakeBeta(10), fused=False
|
33 |
+
).cuda()
|
34 |
+
torch_activation_output = torch_anti_alias_activation(data)
|
35 |
+
|
36 |
+
test_result = (fused_activation_output - torch_activation_output).abs()
|
37 |
+
|
38 |
+
while test_result.dim() != 1:
|
39 |
+
test_result = test_result.mean(dim=-1)
|
40 |
+
|
41 |
+
diff = test_result.mean(dim=-1)
|
42 |
+
|
43 |
+
if diff <= 1e-3:
|
44 |
+
print(
|
45 |
+
f"\n[Success] test_fused_anti_alias_activation"
|
46 |
+
f"\n > mean_difference={diff}"
|
47 |
+
f"\n > fused_values={fused_activation_output[-1][-1][-100:].tolist()}"
|
48 |
+
f"\n > torch_values={torch_activation_output[-1][-1][-100:].tolist()}"
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
print(
|
52 |
+
f"\n[Fail] test_fused_anti_alias_activation"
|
53 |
+
f"\n > mean_difference={diff}, "
|
54 |
+
f"\n > fused_values={fused_activation_output[-1][-1][-30:].tolist()}, "
|
55 |
+
f"\n > torch_values={torch_activation_output[-1][-1][-30:].tolist()}"
|
56 |
+
)
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
from alias_free_cuda import load
|
61 |
+
|
62 |
+
load.load()
|
63 |
+
test_load_fused_kernels()
|
64 |
+
test_anti_alias_activation()
|
fireredtts/modules/bigvgan/alias_free_cuda/type_shim.h
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
|
18 |
+
#include <ATen/ATen.h>
|
19 |
+
#include "compat.h"
|
20 |
+
|
21 |
+
|
22 |
+
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
23 |
+
switch(TYPE) \
|
24 |
+
{ \
|
25 |
+
case at::ScalarType::Float: \
|
26 |
+
{ \
|
27 |
+
using scalar_t = float; \
|
28 |
+
__VA_ARGS__; \
|
29 |
+
break; \
|
30 |
+
} \
|
31 |
+
case at::ScalarType::Half: \
|
32 |
+
{ \
|
33 |
+
using scalar_t = at::Half; \
|
34 |
+
__VA_ARGS__; \
|
35 |
+
break; \
|
36 |
+
} \
|
37 |
+
case at::ScalarType::BFloat16: \
|
38 |
+
{ \
|
39 |
+
using scalar_t = at::BFloat16; \
|
40 |
+
__VA_ARGS__; \
|
41 |
+
break; \
|
42 |
+
} \
|
43 |
+
default: \
|
44 |
+
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
45 |
+
}
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
50 |
+
switch(TYPEIN) \
|
51 |
+
{ \
|
52 |
+
case at::ScalarType::Float: \
|
53 |
+
{ \
|
54 |
+
using scalar_t_in = float; \
|
55 |
+
switch(TYPEOUT) \
|
56 |
+
{ \
|
57 |
+
case at::ScalarType::Float: \
|
58 |
+
{ \
|
59 |
+
using scalar_t_out = float; \
|
60 |
+
__VA_ARGS__; \
|
61 |
+
break; \
|
62 |
+
} \
|
63 |
+
case at::ScalarType::Half: \
|
64 |
+
{ \
|
65 |
+
using scalar_t_out = at::Half; \
|
66 |
+
__VA_ARGS__; \
|
67 |
+
break; \
|
68 |
+
} \
|
69 |
+
case at::ScalarType::BFloat16: \
|
70 |
+
{ \
|
71 |
+
using scalar_t_out = at::BFloat16; \
|
72 |
+
__VA_ARGS__; \
|
73 |
+
break; \
|
74 |
+
} \
|
75 |
+
default: \
|
76 |
+
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
|
77 |
+
} \
|
78 |
+
break; \
|
79 |
+
} \
|
80 |
+
case at::ScalarType::Half: \
|
81 |
+
{ \
|
82 |
+
using scalar_t_in = at::Half; \
|
83 |
+
using scalar_t_out = at::Half; \
|
84 |
+
__VA_ARGS__; \
|
85 |
+
break; \
|
86 |
+
} \
|
87 |
+
case at::ScalarType::BFloat16: \
|
88 |
+
{ \
|
89 |
+
using scalar_t_in = at::BFloat16; \
|
90 |
+
using scalar_t_out = at::BFloat16; \
|
91 |
+
__VA_ARGS__; \
|
92 |
+
break; \
|
93 |
+
} \
|
94 |
+
default: \
|
95 |
+
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
96 |
+
}
|
97 |
+
|
fireredtts/modules/bigvgan/alias_free_torch/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
|
3 |
+
from .filter import *
|
4 |
+
from .resample import *
|
5 |
+
from .act import *
|
fireredtts/modules/bigvgan/alias_free_torch/act.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
from .resample import UpSample1d, DownSample1d
|
5 |
+
|
6 |
+
|
7 |
+
class Activation1d(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
activation,
|
11 |
+
up_ratio: int = 2,
|
12 |
+
down_ratio: int = 2,
|
13 |
+
up_kernel_size: int = 12,
|
14 |
+
down_kernel_size: int = 12,
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
self.up_ratio = up_ratio
|
18 |
+
self.down_ratio = down_ratio
|
19 |
+
self.act = activation
|
20 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
21 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
22 |
+
|
23 |
+
# x: [B,C,T]
|
24 |
+
def forward(self, x):
|
25 |
+
x = self.upsample(x)
|
26 |
+
x = self.act(x)
|
27 |
+
x = self.downsample(x)
|
28 |
+
|
29 |
+
return x
|
fireredtts/modules/bigvgan/alias_free_torch/filter.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import math
|
7 |
+
|
8 |
+
if "sinc" in dir(torch):
|
9 |
+
sinc = torch.sinc
|
10 |
+
else:
|
11 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
12 |
+
# https://adefossez.github.io/julius/julius/core.html
|
13 |
+
# LICENSE is in incl_licenses directory.
|
14 |
+
def sinc(x: torch.Tensor):
|
15 |
+
"""
|
16 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
17 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
18 |
+
"""
|
19 |
+
return torch.where(
|
20 |
+
x == 0,
|
21 |
+
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
22 |
+
torch.sin(math.pi * x) / math.pi / x,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
27 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
28 |
+
# LICENSE is in incl_licenses directory.
|
29 |
+
def kaiser_sinc_filter1d(
|
30 |
+
cutoff, half_width, kernel_size
|
31 |
+
): # return filter [1,1,kernel_size]
|
32 |
+
even = kernel_size % 2 == 0
|
33 |
+
half_size = kernel_size // 2
|
34 |
+
|
35 |
+
# For kaiser window
|
36 |
+
delta_f = 4 * half_width
|
37 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
38 |
+
if A > 50.0:
|
39 |
+
beta = 0.1102 * (A - 8.7)
|
40 |
+
elif A >= 21.0:
|
41 |
+
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
42 |
+
else:
|
43 |
+
beta = 0.0
|
44 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
45 |
+
|
46 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
47 |
+
if even:
|
48 |
+
time = torch.arange(-half_size, half_size) + 0.5
|
49 |
+
else:
|
50 |
+
time = torch.arange(kernel_size) - half_size
|
51 |
+
if cutoff == 0:
|
52 |
+
filter_ = torch.zeros_like(time)
|
53 |
+
else:
|
54 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
55 |
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
56 |
+
# of the constant component in the input signal.
|
57 |
+
filter_ /= filter_.sum()
|
58 |
+
filter = filter_.view(1, 1, kernel_size)
|
59 |
+
|
60 |
+
return filter
|
61 |
+
|
62 |
+
|
63 |
+
class LowPassFilter1d(nn.Module):
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
cutoff=0.5,
|
67 |
+
half_width=0.6,
|
68 |
+
stride: int = 1,
|
69 |
+
padding: bool = True,
|
70 |
+
padding_mode: str = "replicate",
|
71 |
+
kernel_size: int = 12,
|
72 |
+
):
|
73 |
+
# kernel_size should be even number for stylegan3 setup,
|
74 |
+
# in this implementation, odd number is also possible.
|
75 |
+
super().__init__()
|
76 |
+
if cutoff < -0.0:
|
77 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
78 |
+
if cutoff > 0.5:
|
79 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
80 |
+
self.kernel_size = kernel_size
|
81 |
+
self.even = kernel_size % 2 == 0
|
82 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
83 |
+
self.pad_right = kernel_size // 2
|
84 |
+
self.stride = stride
|
85 |
+
self.padding = padding
|
86 |
+
self.padding_mode = padding_mode
|
87 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
88 |
+
self.register_buffer("filter", filter)
|
89 |
+
|
90 |
+
# input [B, C, T]
|
91 |
+
def forward(self, x):
|
92 |
+
_, C, _ = x.shape
|
93 |
+
|
94 |
+
if self.padding:
|
95 |
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
96 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
97 |
+
|
98 |
+
return out
|
fireredtts/modules/bigvgan/alias_free_torch/resample.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from .filter import LowPassFilter1d
|
6 |
+
from .filter import kaiser_sinc_filter1d
|
7 |
+
|
8 |
+
|
9 |
+
class UpSample1d(nn.Module):
|
10 |
+
def __init__(self, ratio=2, kernel_size=None):
|
11 |
+
super().__init__()
|
12 |
+
self.ratio = ratio
|
13 |
+
self.kernel_size = (
|
14 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
15 |
+
)
|
16 |
+
self.stride = ratio
|
17 |
+
self.pad = self.kernel_size // ratio - 1
|
18 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
19 |
+
self.pad_right = (
|
20 |
+
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
21 |
+
)
|
22 |
+
filter = kaiser_sinc_filter1d(
|
23 |
+
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
24 |
+
)
|
25 |
+
self.register_buffer("filter", filter)
|
26 |
+
|
27 |
+
# x: [B, C, T]
|
28 |
+
def forward(self, x):
|
29 |
+
_, C, _ = x.shape
|
30 |
+
|
31 |
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
32 |
+
x = self.ratio * F.conv_transpose1d(
|
33 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
34 |
+
)
|
35 |
+
x = x[..., self.pad_left : -self.pad_right]
|
36 |
+
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
class DownSample1d(nn.Module):
|
41 |
+
def __init__(self, ratio=2, kernel_size=None):
|
42 |
+
super().__init__()
|
43 |
+
self.ratio = ratio
|
44 |
+
self.kernel_size = (
|
45 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
46 |
+
)
|
47 |
+
self.lowpass = LowPassFilter1d(
|
48 |
+
cutoff=0.5 / ratio,
|
49 |
+
half_width=0.6 / ratio,
|
50 |
+
stride=ratio,
|
51 |
+
kernel_size=self.kernel_size,
|
52 |
+
)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
xx = self.lowpass(x)
|
56 |
+
|
57 |
+
return xx
|
fireredtts/modules/bigvgan/bigvgan.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing as tp
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
6 |
+
|
7 |
+
from fireredtts.modules.bigvgan.alias_free_torch import (
|
8 |
+
Activation1d as TorchActivation1d,
|
9 |
+
)
|
10 |
+
from fireredtts.modules.bigvgan.activations import Snake, SnakeBeta
|
11 |
+
|
12 |
+
|
13 |
+
def init_weights(m, mean=0.0, std=0.01):
|
14 |
+
classname = m.__class__.__name__
|
15 |
+
if classname.find("Conv") != -1:
|
16 |
+
m.weight.data.normal_(mean, std)
|
17 |
+
|
18 |
+
|
19 |
+
def get_padding(kernel_size, dilation=1):
|
20 |
+
return int((kernel_size * dilation - dilation) / 2)
|
21 |
+
|
22 |
+
|
23 |
+
class AMPBlock1(torch.nn.Module):
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
channels,
|
27 |
+
kernel_size=3,
|
28 |
+
dilation=(1, 3, 5),
|
29 |
+
activation=None,
|
30 |
+
snake_logscale=True,
|
31 |
+
use_cuda_kernel=False,
|
32 |
+
):
|
33 |
+
super(AMPBlock1, self).__init__()
|
34 |
+
|
35 |
+
self.convs1 = nn.ModuleList(
|
36 |
+
[
|
37 |
+
weight_norm(
|
38 |
+
Conv1d(
|
39 |
+
channels,
|
40 |
+
channels,
|
41 |
+
kernel_size,
|
42 |
+
1,
|
43 |
+
dilation=dilation[0],
|
44 |
+
padding=get_padding(kernel_size, dilation[0]),
|
45 |
+
)
|
46 |
+
),
|
47 |
+
weight_norm(
|
48 |
+
Conv1d(
|
49 |
+
channels,
|
50 |
+
channels,
|
51 |
+
kernel_size,
|
52 |
+
1,
|
53 |
+
dilation=dilation[1],
|
54 |
+
padding=get_padding(kernel_size, dilation[1]),
|
55 |
+
)
|
56 |
+
),
|
57 |
+
weight_norm(
|
58 |
+
Conv1d(
|
59 |
+
channels,
|
60 |
+
channels,
|
61 |
+
kernel_size,
|
62 |
+
1,
|
63 |
+
dilation=dilation[2],
|
64 |
+
padding=get_padding(kernel_size, dilation[2]),
|
65 |
+
)
|
66 |
+
),
|
67 |
+
]
|
68 |
+
)
|
69 |
+
self.convs1.apply(init_weights)
|
70 |
+
|
71 |
+
self.convs2 = nn.ModuleList(
|
72 |
+
[
|
73 |
+
weight_norm(
|
74 |
+
Conv1d(
|
75 |
+
channels,
|
76 |
+
channels,
|
77 |
+
kernel_size,
|
78 |
+
1,
|
79 |
+
dilation=1,
|
80 |
+
padding=get_padding(kernel_size, 1),
|
81 |
+
)
|
82 |
+
),
|
83 |
+
weight_norm(
|
84 |
+
Conv1d(
|
85 |
+
channels,
|
86 |
+
channels,
|
87 |
+
kernel_size,
|
88 |
+
1,
|
89 |
+
dilation=1,
|
90 |
+
padding=get_padding(kernel_size, 1),
|
91 |
+
)
|
92 |
+
),
|
93 |
+
weight_norm(
|
94 |
+
Conv1d(
|
95 |
+
channels,
|
96 |
+
channels,
|
97 |
+
kernel_size,
|
98 |
+
1,
|
99 |
+
dilation=1,
|
100 |
+
padding=get_padding(kernel_size, 1),
|
101 |
+
)
|
102 |
+
),
|
103 |
+
]
|
104 |
+
)
|
105 |
+
self.convs2.apply(init_weights)
|
106 |
+
|
107 |
+
self.num_layers = len(self.convs1) + len(
|
108 |
+
self.convs2
|
109 |
+
) # total number of conv layers
|
110 |
+
|
111 |
+
# select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
112 |
+
if use_cuda_kernel:
|
113 |
+
from modules.bigvgan.alias_free_cuda.activation1d import (
|
114 |
+
Activation1d as CudaActivation1d,
|
115 |
+
)
|
116 |
+
|
117 |
+
Activation1d = CudaActivation1d
|
118 |
+
else:
|
119 |
+
Activation1d = TorchActivation1d
|
120 |
+
|
121 |
+
if (
|
122 |
+
activation == "snake"
|
123 |
+
): # periodic nonlinearity with snake function and anti-aliasing
|
124 |
+
self.activations = nn.ModuleList(
|
125 |
+
[
|
126 |
+
Activation1d(
|
127 |
+
activation=Snake(channels, alpha_logscale=snake_logscale)
|
128 |
+
)
|
129 |
+
for _ in range(self.num_layers)
|
130 |
+
]
|
131 |
+
)
|
132 |
+
elif (
|
133 |
+
activation == "snakebeta"
|
134 |
+
): # periodic nonlinearity with snakebeta function and anti-aliasing
|
135 |
+
self.activations = nn.ModuleList(
|
136 |
+
[
|
137 |
+
Activation1d(
|
138 |
+
activation=SnakeBeta(channels, alpha_logscale=snake_logscale)
|
139 |
+
)
|
140 |
+
for _ in range(self.num_layers)
|
141 |
+
]
|
142 |
+
)
|
143 |
+
else:
|
144 |
+
raise NotImplementedError(
|
145 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
146 |
+
)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
150 |
+
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
151 |
+
xt = a1(x)
|
152 |
+
xt = c1(xt)
|
153 |
+
xt = a2(xt)
|
154 |
+
xt = c2(xt)
|
155 |
+
x = xt + x
|
156 |
+
|
157 |
+
return x
|
158 |
+
|
159 |
+
def remove_weight_norm(self):
|
160 |
+
for l in self.convs1:
|
161 |
+
remove_weight_norm(l)
|
162 |
+
for l in self.convs2:
|
163 |
+
remove_weight_norm(l)
|
164 |
+
|
165 |
+
|
166 |
+
class AMPBlock2(torch.nn.Module):
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
channels,
|
170 |
+
kernel_size=3,
|
171 |
+
dilation=(1, 3),
|
172 |
+
activation=None,
|
173 |
+
snake_logscale=True,
|
174 |
+
use_cuda_kernel=False,
|
175 |
+
):
|
176 |
+
super(AMPBlock2, self).__init__()
|
177 |
+
|
178 |
+
self.convs = nn.ModuleList(
|
179 |
+
[
|
180 |
+
weight_norm(
|
181 |
+
Conv1d(
|
182 |
+
channels,
|
183 |
+
channels,
|
184 |
+
kernel_size,
|
185 |
+
1,
|
186 |
+
dilation=dilation[0],
|
187 |
+
padding=get_padding(kernel_size, dilation[0]),
|
188 |
+
)
|
189 |
+
),
|
190 |
+
weight_norm(
|
191 |
+
Conv1d(
|
192 |
+
channels,
|
193 |
+
channels,
|
194 |
+
kernel_size,
|
195 |
+
1,
|
196 |
+
dilation=dilation[1],
|
197 |
+
padding=get_padding(kernel_size, dilation[1]),
|
198 |
+
)
|
199 |
+
),
|
200 |
+
]
|
201 |
+
)
|
202 |
+
self.convs.apply(init_weights)
|
203 |
+
|
204 |
+
self.num_layers = len(self.convs) # total number of conv layers
|
205 |
+
|
206 |
+
# select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
207 |
+
if use_cuda_kernel:
|
208 |
+
from modules.bigvgan.alias_free_cuda.activation1d import (
|
209 |
+
Activation1d as CudaActivation1d,
|
210 |
+
)
|
211 |
+
|
212 |
+
Activation1d = CudaActivation1d
|
213 |
+
else:
|
214 |
+
Activation1d = TorchActivation1d
|
215 |
+
|
216 |
+
if (
|
217 |
+
activation == "snake"
|
218 |
+
): # periodic nonlinearity with snake function and anti-aliasing
|
219 |
+
self.activations = nn.ModuleList(
|
220 |
+
[
|
221 |
+
Activation1d(
|
222 |
+
activation=Snake(channels, alpha_logscale=snake_logscale)
|
223 |
+
)
|
224 |
+
for _ in range(self.num_layers)
|
225 |
+
]
|
226 |
+
)
|
227 |
+
elif (
|
228 |
+
activation == "snakebeta"
|
229 |
+
): # periodic nonlinearity with snakebeta function and anti-aliasing
|
230 |
+
self.activations = nn.ModuleList(
|
231 |
+
[
|
232 |
+
Activation1d(
|
233 |
+
activation=SnakeBeta(channels, alpha_logscale=snake_logscale)
|
234 |
+
)
|
235 |
+
for _ in range(self.num_layers)
|
236 |
+
]
|
237 |
+
)
|
238 |
+
else:
|
239 |
+
raise NotImplementedError(
|
240 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
241 |
+
)
|
242 |
+
|
243 |
+
def forward(self, x):
|
244 |
+
for c, a in zip(self.convs, self.activations):
|
245 |
+
xt = a(x)
|
246 |
+
xt = c(xt)
|
247 |
+
x = xt + x
|
248 |
+
|
249 |
+
return x
|
250 |
+
|
251 |
+
def remove_weight_norm(self):
|
252 |
+
for l in self.convs:
|
253 |
+
remove_weight_norm(l)
|
254 |
+
|
255 |
+
|
256 |
+
class BigVGAN(torch.nn.Module):
|
257 |
+
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
|
258 |
+
def __init__(
|
259 |
+
self,
|
260 |
+
num_mels: int,
|
261 |
+
upsample_initial_channel: int,
|
262 |
+
resblock_kernel_sizes: tp.List[int],
|
263 |
+
resblock_dilation_sizes: tp.List[tp.List[int]],
|
264 |
+
upsample_rates: tp.List[int],
|
265 |
+
upsample_kernel_sizes: tp.List[int],
|
266 |
+
resblock_type: str = "1",
|
267 |
+
snake_logscale: bool = True,
|
268 |
+
activation: str = "snakebeta",
|
269 |
+
use_tanh_at_final: bool = False,
|
270 |
+
use_bias_at_final: bool = False,
|
271 |
+
use_cuda_kernel: bool = False,
|
272 |
+
):
|
273 |
+
super(BigVGAN, self).__init__()
|
274 |
+
|
275 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
276 |
+
self.num_upsamples = len(upsample_rates)
|
277 |
+
|
278 |
+
# pre conv
|
279 |
+
self.conv_pre = weight_norm(
|
280 |
+
Conv1d(num_mels, upsample_initial_channel, 7, 1, padding=3)
|
281 |
+
)
|
282 |
+
|
283 |
+
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
284 |
+
resblock = AMPBlock1 if resblock_type == "1" else AMPBlock2
|
285 |
+
|
286 |
+
# transposed conv-based upsamplers. does not apply anti-aliasing
|
287 |
+
self.ups = nn.ModuleList()
|
288 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
289 |
+
self.ups.append(
|
290 |
+
nn.ModuleList(
|
291 |
+
[
|
292 |
+
weight_norm(
|
293 |
+
ConvTranspose1d(
|
294 |
+
upsample_initial_channel // (2**i),
|
295 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
296 |
+
k,
|
297 |
+
u,
|
298 |
+
padding=(k - u) // 2,
|
299 |
+
)
|
300 |
+
)
|
301 |
+
]
|
302 |
+
)
|
303 |
+
)
|
304 |
+
|
305 |
+
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
306 |
+
self.resblocks = nn.ModuleList()
|
307 |
+
for i in range(len(self.ups)):
|
308 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
309 |
+
for j, (k, d) in enumerate(
|
310 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
311 |
+
):
|
312 |
+
self.resblocks.append(
|
313 |
+
resblock(
|
314 |
+
ch,
|
315 |
+
k,
|
316 |
+
d,
|
317 |
+
activation=activation,
|
318 |
+
snake_logscale=snake_logscale,
|
319 |
+
use_cuda_kernel=use_cuda_kernel,
|
320 |
+
)
|
321 |
+
)
|
322 |
+
|
323 |
+
# select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
324 |
+
if use_cuda_kernel:
|
325 |
+
from modules.bigvgan.alias_free_cuda.activation1d import (
|
326 |
+
Activation1d as CudaActivation1d,
|
327 |
+
)
|
328 |
+
|
329 |
+
Activation1d = CudaActivation1d
|
330 |
+
else:
|
331 |
+
Activation1d = TorchActivation1d
|
332 |
+
|
333 |
+
# post conv
|
334 |
+
if (
|
335 |
+
activation == "snake"
|
336 |
+
): # periodic nonlinearity with snake function and anti-aliasing
|
337 |
+
activation_post = Snake(ch, alpha_logscale=snake_logscale)
|
338 |
+
self.activation_post = Activation1d(activation=activation_post)
|
339 |
+
elif (
|
340 |
+
activation == "snakebeta"
|
341 |
+
): # periodic nonlinearity with snakebeta function and anti-aliasing
|
342 |
+
activation_post = SnakeBeta(ch, alpha_logscale=snake_logscale)
|
343 |
+
self.activation_post = Activation1d(activation=activation_post)
|
344 |
+
else:
|
345 |
+
raise NotImplementedError(
|
346 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
347 |
+
)
|
348 |
+
|
349 |
+
# whether to use bias for the final conv_post. Defaults to True for backward compatibility
|
350 |
+
self.use_bias_at_final = use_bias_at_final
|
351 |
+
self.conv_post = weight_norm(
|
352 |
+
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
|
353 |
+
)
|
354 |
+
|
355 |
+
# weight initialization
|
356 |
+
for i in range(len(self.ups)):
|
357 |
+
self.ups[i].apply(init_weights)
|
358 |
+
self.conv_post.apply(init_weights)
|
359 |
+
|
360 |
+
# final tanh activation. Defaults to True for backward compatibility
|
361 |
+
self.use_tanh_at_final = use_tanh_at_final
|
362 |
+
|
363 |
+
def forward(self, x):
|
364 |
+
# pre conv
|
365 |
+
x = self.conv_pre(x)
|
366 |
+
|
367 |
+
for i in range(self.num_upsamples):
|
368 |
+
# upsampling
|
369 |
+
for i_up in range(len(self.ups[i])):
|
370 |
+
x = self.ups[i][i_up](x)
|
371 |
+
# AMP blocks
|
372 |
+
xs = None
|
373 |
+
for j in range(self.num_kernels):
|
374 |
+
if xs is None:
|
375 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
376 |
+
else:
|
377 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
378 |
+
x = xs / self.num_kernels
|
379 |
+
|
380 |
+
# post conv
|
381 |
+
x = self.activation_post(x)
|
382 |
+
x = self.conv_post(x)
|
383 |
+
# final tanh activation
|
384 |
+
if self.use_tanh_at_final:
|
385 |
+
x = torch.tanh(x)
|
386 |
+
else:
|
387 |
+
x = torch.clamp(x, min=-1.0, max=1.0) # bound the output to [-1, 1]
|
388 |
+
|
389 |
+
return x
|
390 |
+
|
391 |
+
def remove_weight_norm(self):
|
392 |
+
print("Removing weight norm...")
|
393 |
+
for l in self.ups:
|
394 |
+
for l_i in l:
|
395 |
+
remove_weight_norm(l_i)
|
396 |
+
for l in self.resblocks:
|
397 |
+
l.remove_weight_norm()
|
398 |
+
remove_weight_norm(self.conv_pre)
|
399 |
+
remove_weight_norm(self.conv_post)
|
fireredtts/modules/codec/speaker.py
ADDED
@@ -0,0 +1,1052 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from librosa.filters import mel as librosa_mel_fn
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
import math
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
|
10 |
+
|
11 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
12 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
13 |
+
|
14 |
+
|
15 |
+
def spectral_normalize_torch(magnitudes):
|
16 |
+
output = dynamic_range_compression_torch(magnitudes)
|
17 |
+
return output
|
18 |
+
|
19 |
+
|
20 |
+
class TorchMelSpectrogram(nn.Module):
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
filter_length=1024,
|
25 |
+
hop_length=160,
|
26 |
+
win_length=640,
|
27 |
+
n_mel_channels=80,
|
28 |
+
mel_fmin=0,
|
29 |
+
mel_fmax=8000,
|
30 |
+
sampling_rate=16000,
|
31 |
+
):
|
32 |
+
|
33 |
+
super().__init__()
|
34 |
+
self.filter_length = filter_length
|
35 |
+
self.hop_length = hop_length
|
36 |
+
self.win_length = win_length
|
37 |
+
self.n_mel_channels = n_mel_channels
|
38 |
+
self.mel_fmin = mel_fmin
|
39 |
+
self.mel_fmax = mel_fmax
|
40 |
+
self.sampling_rate = sampling_rate
|
41 |
+
|
42 |
+
self.mel_basis = {}
|
43 |
+
self.hann_window = {}
|
44 |
+
|
45 |
+
def forward(self, inp, length=None):
|
46 |
+
if len(inp.shape) == 3:
|
47 |
+
inp = inp.squeeze(1) if inp.shape[1] == 1 else inp.squeeze(2)
|
48 |
+
assert len(inp.shape) == 2
|
49 |
+
|
50 |
+
y = inp
|
51 |
+
if len(list(self.mel_basis.keys())) == 0:
|
52 |
+
mel = librosa_mel_fn(
|
53 |
+
sr=self.sampling_rate,
|
54 |
+
n_fft=self.filter_length,
|
55 |
+
n_mels=self.n_mel_channels,
|
56 |
+
fmin=self.mel_fmin,
|
57 |
+
fmax=self.mel_fmax,
|
58 |
+
)
|
59 |
+
self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)] = (
|
60 |
+
torch.from_numpy(mel).float().to(y.device)
|
61 |
+
)
|
62 |
+
self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to(
|
63 |
+
y.device
|
64 |
+
)
|
65 |
+
|
66 |
+
y = torch.nn.functional.pad(
|
67 |
+
y.unsqueeze(1),
|
68 |
+
(
|
69 |
+
int((self.filter_length - self.hop_length) / 2),
|
70 |
+
int((self.filter_length - self.hop_length) / 2),
|
71 |
+
),
|
72 |
+
mode="reflect",
|
73 |
+
)
|
74 |
+
y = y.squeeze(1)
|
75 |
+
|
76 |
+
# complex tensor as default, then use view_as_real for future pytorch compatibility
|
77 |
+
spec = torch.stft(
|
78 |
+
y,
|
79 |
+
self.filter_length,
|
80 |
+
hop_length=self.hop_length,
|
81 |
+
win_length=self.win_length,
|
82 |
+
window=self.hann_window[str(y.device)],
|
83 |
+
center=False,
|
84 |
+
pad_mode="reflect",
|
85 |
+
normalized=False,
|
86 |
+
onesided=True,
|
87 |
+
return_complex=True,
|
88 |
+
)
|
89 |
+
spec = torch.view_as_real(spec)
|
90 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
91 |
+
|
92 |
+
spec = torch.matmul(
|
93 |
+
self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)], spec
|
94 |
+
)
|
95 |
+
spec = spectral_normalize_torch(spec)
|
96 |
+
|
97 |
+
max_mel_length = math.ceil(y.shape[-1] / self.hop_length)
|
98 |
+
spec = spec[..., :max_mel_length].transpose(1, 2)
|
99 |
+
|
100 |
+
if length is None:
|
101 |
+
return spec
|
102 |
+
else:
|
103 |
+
spec_len = torch.ceil(length / self.hop_length).clamp(max=spec.shape[1])
|
104 |
+
return spec, spec_len
|
105 |
+
|
106 |
+
|
107 |
+
def length_to_mask(length, max_len=None, dtype=None, device=None):
|
108 |
+
"""Creates a binary mask for each sequence.
|
109 |
+
|
110 |
+
Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
|
111 |
+
|
112 |
+
Arguments
|
113 |
+
---------
|
114 |
+
length : torch.LongTensor
|
115 |
+
Containing the length of each sequence in the batch. Must be 1D.
|
116 |
+
max_len : int
|
117 |
+
Max length for the mask, also the size of the second dimension.
|
118 |
+
dtype : torch.dtype, default: None
|
119 |
+
The dtype of the generated mask.
|
120 |
+
device: torch.device, default: None
|
121 |
+
The device to put the mask variable.
|
122 |
+
|
123 |
+
Returns
|
124 |
+
-------
|
125 |
+
mask : tensor
|
126 |
+
The binary mask.
|
127 |
+
|
128 |
+
Example
|
129 |
+
-------
|
130 |
+
>>> length=torch.Tensor([1,2,3])
|
131 |
+
>>> mask=length_to_mask(length)
|
132 |
+
>>> mask
|
133 |
+
tensor([[1., 0., 0.],
|
134 |
+
[1., 1., 0.],
|
135 |
+
[1., 1., 1.]])
|
136 |
+
"""
|
137 |
+
assert len(length.shape) == 1
|
138 |
+
|
139 |
+
if max_len is None:
|
140 |
+
max_len = length.max().long().item() # using arange to generate mask
|
141 |
+
mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
|
142 |
+
len(length), max_len
|
143 |
+
) < length.unsqueeze(1)
|
144 |
+
|
145 |
+
if dtype is None:
|
146 |
+
dtype = length.dtype
|
147 |
+
|
148 |
+
if device is None:
|
149 |
+
device = length.device
|
150 |
+
|
151 |
+
mask = torch.as_tensor(mask, dtype=dtype, device=device)
|
152 |
+
return mask
|
153 |
+
|
154 |
+
|
155 |
+
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
|
156 |
+
"""This function computes the number of elements to add for zero-padding.
|
157 |
+
|
158 |
+
Arguments
|
159 |
+
---------
|
160 |
+
L_in : int
|
161 |
+
stride: int
|
162 |
+
kernel_size : int
|
163 |
+
dilation : int
|
164 |
+
"""
|
165 |
+
if stride > 1:
|
166 |
+
n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
|
167 |
+
L_out = stride * (n_steps - 1) + kernel_size * dilation
|
168 |
+
padding = [kernel_size // 2, kernel_size // 2]
|
169 |
+
|
170 |
+
else:
|
171 |
+
L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
|
172 |
+
|
173 |
+
padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
|
174 |
+
return padding
|
175 |
+
|
176 |
+
|
177 |
+
class Conv1d(nn.Module):
|
178 |
+
"""This function implements 1d convolution.
|
179 |
+
|
180 |
+
Arguments
|
181 |
+
---------
|
182 |
+
out_channels : int
|
183 |
+
It is the number of output channels.
|
184 |
+
kernel_size : int
|
185 |
+
Kernel size of the convolutional filters.
|
186 |
+
input_shape : tuple
|
187 |
+
The shape of the input. Alternatively use ``in_channels``.
|
188 |
+
in_channels : int
|
189 |
+
The number of input channels. Alternatively use ``input_shape``.
|
190 |
+
stride : int
|
191 |
+
Stride factor of the convolutional filters. When the stride factor > 1,
|
192 |
+
a decimation in time is performed.
|
193 |
+
dilation : int
|
194 |
+
Dilation factor of the convolutional filters.
|
195 |
+
padding : str
|
196 |
+
(same, valid, causal). If "valid", no padding is performed.
|
197 |
+
If "same" and stride is 1, output shape is the same as the input shape.
|
198 |
+
"causal" results in causal (dilated) convolutions.
|
199 |
+
padding_mode : str
|
200 |
+
This flag specifies the type of padding. See torch.nn documentation
|
201 |
+
for more information.
|
202 |
+
skip_transpose : bool
|
203 |
+
If False, uses batch x time x channel convention of speechbrain.
|
204 |
+
If True, uses batch x channel x time convention.
|
205 |
+
|
206 |
+
Example
|
207 |
+
-------
|
208 |
+
>>> inp_tensor = torch.rand([10, 40, 16])
|
209 |
+
>>> cnn_1d = Conv1d(
|
210 |
+
... input_shape=inp_tensor.shape, out_channels=8, kernel_size=5
|
211 |
+
... )
|
212 |
+
>>> out_tensor = cnn_1d(inp_tensor)
|
213 |
+
>>> out_tensor.shape
|
214 |
+
torch.Size([10, 40, 8])
|
215 |
+
"""
|
216 |
+
|
217 |
+
def __init__(
|
218 |
+
self,
|
219 |
+
out_channels,
|
220 |
+
kernel_size,
|
221 |
+
input_shape=None,
|
222 |
+
in_channels=None,
|
223 |
+
stride=1,
|
224 |
+
dilation=1,
|
225 |
+
padding="same",
|
226 |
+
groups=1,
|
227 |
+
bias=True,
|
228 |
+
padding_mode="reflect",
|
229 |
+
skip_transpose=True,
|
230 |
+
):
|
231 |
+
super().__init__()
|
232 |
+
self.kernel_size = kernel_size
|
233 |
+
self.stride = stride
|
234 |
+
self.dilation = dilation
|
235 |
+
self.padding = padding
|
236 |
+
self.padding_mode = padding_mode
|
237 |
+
self.unsqueeze = False
|
238 |
+
self.skip_transpose = skip_transpose
|
239 |
+
|
240 |
+
if input_shape is None and in_channels is None:
|
241 |
+
raise ValueError("Must provide one of input_shape or in_channels")
|
242 |
+
|
243 |
+
if in_channels is None:
|
244 |
+
in_channels = self._check_input_shape(input_shape)
|
245 |
+
|
246 |
+
self.conv = nn.Conv1d(
|
247 |
+
in_channels,
|
248 |
+
out_channels,
|
249 |
+
self.kernel_size,
|
250 |
+
stride=self.stride,
|
251 |
+
dilation=self.dilation,
|
252 |
+
padding=0,
|
253 |
+
groups=groups,
|
254 |
+
bias=bias,
|
255 |
+
)
|
256 |
+
|
257 |
+
def forward(self, x):
|
258 |
+
"""Returns the output of the convolution.
|
259 |
+
|
260 |
+
Arguments
|
261 |
+
---------
|
262 |
+
x : torch.Tensor (batch, time, channel)
|
263 |
+
input to convolve. 2d or 4d tensors are expected.
|
264 |
+
"""
|
265 |
+
|
266 |
+
if not self.skip_transpose:
|
267 |
+
x = x.transpose(1, -1)
|
268 |
+
|
269 |
+
if self.unsqueeze:
|
270 |
+
x = x.unsqueeze(1)
|
271 |
+
|
272 |
+
if self.padding == "same":
|
273 |
+
x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
|
274 |
+
|
275 |
+
elif self.padding == "causal":
|
276 |
+
num_pad = (self.kernel_size - 1) * self.dilation
|
277 |
+
x = F.pad(x, (num_pad, 0))
|
278 |
+
|
279 |
+
elif self.padding == "valid":
|
280 |
+
pass
|
281 |
+
|
282 |
+
else:
|
283 |
+
raise ValueError(
|
284 |
+
"Padding must be 'same', 'valid' or 'causal'. Got " + self.padding
|
285 |
+
)
|
286 |
+
|
287 |
+
wx = self.conv(x)
|
288 |
+
|
289 |
+
if self.unsqueeze:
|
290 |
+
wx = wx.squeeze(1)
|
291 |
+
|
292 |
+
if not self.skip_transpose:
|
293 |
+
wx = wx.transpose(1, -1)
|
294 |
+
|
295 |
+
return wx
|
296 |
+
|
297 |
+
def _manage_padding(
|
298 |
+
self,
|
299 |
+
x,
|
300 |
+
kernel_size: int,
|
301 |
+
dilation: int,
|
302 |
+
stride: int,
|
303 |
+
):
|
304 |
+
"""This function performs zero-padding on the time axis
|
305 |
+
such that their lengths is unchanged after the convolution.
|
306 |
+
|
307 |
+
Arguments
|
308 |
+
---------
|
309 |
+
x : torch.Tensor
|
310 |
+
Input tensor.
|
311 |
+
kernel_size : int
|
312 |
+
Size of kernel.
|
313 |
+
dilation : int
|
314 |
+
Dilation used.
|
315 |
+
stride : int
|
316 |
+
Stride.
|
317 |
+
"""
|
318 |
+
|
319 |
+
# Detecting input shape
|
320 |
+
L_in = x.shape[-1]
|
321 |
+
|
322 |
+
# Time padding
|
323 |
+
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
|
324 |
+
|
325 |
+
# Applying padding
|
326 |
+
x = F.pad(x, padding, mode=self.padding_mode)
|
327 |
+
|
328 |
+
return x
|
329 |
+
|
330 |
+
def _check_input_shape(self, shape):
|
331 |
+
"""Checks the input shape and returns the number of input channels."""
|
332 |
+
|
333 |
+
if len(shape) == 2:
|
334 |
+
self.unsqueeze = True
|
335 |
+
in_channels = 1
|
336 |
+
elif self.skip_transpose:
|
337 |
+
in_channels = shape[1]
|
338 |
+
elif len(shape) == 3:
|
339 |
+
in_channels = shape[2]
|
340 |
+
else:
|
341 |
+
raise ValueError("conv1d expects 2d, 3d inputs. Got " + str(len(shape)))
|
342 |
+
|
343 |
+
# Kernel size must be odd
|
344 |
+
if self.kernel_size % 2 == 0:
|
345 |
+
raise ValueError(
|
346 |
+
"The field kernel size must be an odd number. Got %s."
|
347 |
+
% (self.kernel_size)
|
348 |
+
)
|
349 |
+
return in_channels
|
350 |
+
|
351 |
+
|
352 |
+
class Fp32BatchNorm(nn.Module):
|
353 |
+
def __init__(self, sync=True, *args, **kwargs):
|
354 |
+
super().__init__()
|
355 |
+
|
356 |
+
if (
|
357 |
+
not torch.distributed.is_initialized()
|
358 |
+
or torch.distributed.get_world_size() == 1
|
359 |
+
):
|
360 |
+
sync = False
|
361 |
+
|
362 |
+
if sync:
|
363 |
+
self.bn = nn.SyncBatchNorm(*args, **kwargs)
|
364 |
+
else:
|
365 |
+
self.bn = nn.BatchNorm1d(*args, **kwargs)
|
366 |
+
|
367 |
+
self.sync = sync
|
368 |
+
|
369 |
+
def forward(self, input):
|
370 |
+
if self.bn.running_mean.dtype != torch.float:
|
371 |
+
if self.sync:
|
372 |
+
self.bn.running_mean = self.bn.running_mean.float()
|
373 |
+
self.bn.running_var = self.bn.running_var.float()
|
374 |
+
if self.bn.affine:
|
375 |
+
try:
|
376 |
+
self.bn.weight = self.bn.weight.float()
|
377 |
+
self.bn.bias = self.bn.bias.float()
|
378 |
+
except:
|
379 |
+
self.bn.float()
|
380 |
+
else:
|
381 |
+
self.bn.float()
|
382 |
+
|
383 |
+
output = self.bn(input.float())
|
384 |
+
return output.type_as(input)
|
385 |
+
|
386 |
+
|
387 |
+
class BatchNorm1d(nn.Module):
|
388 |
+
"""Applies 1d batch normalization to the input tensor.
|
389 |
+
|
390 |
+
Arguments
|
391 |
+
---------
|
392 |
+
input_shape : tuple
|
393 |
+
The expected shape of the input. Alternatively, use ``input_size``.
|
394 |
+
input_size : int
|
395 |
+
The expected size of the input. Alternatively, use ``input_shape``.
|
396 |
+
eps : float
|
397 |
+
This value is added to std deviation estimation to improve the numerical
|
398 |
+
stability.
|
399 |
+
momentum : float
|
400 |
+
It is a value used for the running_mean and running_var computation.
|
401 |
+
affine : bool
|
402 |
+
When set to True, the affine parameters are learned.
|
403 |
+
track_running_stats : bool
|
404 |
+
When set to True, this module tracks the running mean and variance,
|
405 |
+
and when set to False, this module does not track such statistics.
|
406 |
+
combine_batch_time : bool
|
407 |
+
When true, it combines batch an time axis.
|
408 |
+
|
409 |
+
|
410 |
+
Example
|
411 |
+
-------
|
412 |
+
>>> input = torch.randn(100, 10)
|
413 |
+
>>> norm = BatchNorm1d(input_shape=input.shape)
|
414 |
+
>>> output = norm(input)
|
415 |
+
>>> output.shape
|
416 |
+
torch.Size([100, 10])
|
417 |
+
"""
|
418 |
+
|
419 |
+
def __init__(
|
420 |
+
self,
|
421 |
+
input_shape=None,
|
422 |
+
input_size=None,
|
423 |
+
eps=1e-05,
|
424 |
+
momentum=0.1,
|
425 |
+
affine=True,
|
426 |
+
track_running_stats=True,
|
427 |
+
combine_batch_time=False,
|
428 |
+
skip_transpose=True,
|
429 |
+
enabled=True,
|
430 |
+
):
|
431 |
+
super().__init__()
|
432 |
+
self.combine_batch_time = combine_batch_time
|
433 |
+
self.skip_transpose = skip_transpose
|
434 |
+
|
435 |
+
if input_size is None and skip_transpose:
|
436 |
+
input_size = input_shape[1]
|
437 |
+
elif input_size is None:
|
438 |
+
input_size = input_shape[-1]
|
439 |
+
|
440 |
+
if enabled:
|
441 |
+
self.norm = Fp32BatchNorm(
|
442 |
+
num_features=input_size,
|
443 |
+
eps=eps,
|
444 |
+
momentum=momentum,
|
445 |
+
affine=affine,
|
446 |
+
track_running_stats=track_running_stats,
|
447 |
+
)
|
448 |
+
else:
|
449 |
+
self.norm = nn.Identity()
|
450 |
+
|
451 |
+
def forward(self, x):
|
452 |
+
"""Returns the normalized input tensor.
|
453 |
+
|
454 |
+
Arguments
|
455 |
+
---------
|
456 |
+
x : torch.Tensor (batch, time, [channels])
|
457 |
+
input to normalize. 2d or 3d tensors are expected in input
|
458 |
+
4d tensors can be used when combine_dims=True.
|
459 |
+
"""
|
460 |
+
shape_or = x.shape
|
461 |
+
if self.combine_batch_time:
|
462 |
+
if x.ndim == 3:
|
463 |
+
x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
|
464 |
+
else:
|
465 |
+
x = x.reshape(shape_or[0] * shape_or[1], shape_or[3], shape_or[2])
|
466 |
+
|
467 |
+
elif not self.skip_transpose:
|
468 |
+
x = x.transpose(-1, 1)
|
469 |
+
|
470 |
+
x_n = self.norm(x)
|
471 |
+
|
472 |
+
if self.combine_batch_time:
|
473 |
+
x_n = x_n.reshape(shape_or)
|
474 |
+
elif not self.skip_transpose:
|
475 |
+
x_n = x_n.transpose(1, -1)
|
476 |
+
|
477 |
+
return x_n
|
478 |
+
|
479 |
+
|
480 |
+
class Linear(torch.nn.Module):
|
481 |
+
"""Computes a linear transformation y = wx + b.
|
482 |
+
|
483 |
+
Arguments
|
484 |
+
---------
|
485 |
+
n_neurons : int
|
486 |
+
It is the number of output neurons (i.e, the dimensionality of the
|
487 |
+
output).
|
488 |
+
bias : bool
|
489 |
+
If True, the additive bias b is adopted.
|
490 |
+
combine_dims : bool
|
491 |
+
If True and the input is 4D, combine 3rd and 4th dimensions of input.
|
492 |
+
|
493 |
+
Example
|
494 |
+
-------
|
495 |
+
>>> inputs = torch.rand(10, 50, 40)
|
496 |
+
>>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100)
|
497 |
+
>>> output = lin_t(inputs)
|
498 |
+
>>> output.shape
|
499 |
+
torch.Size([10, 50, 100])
|
500 |
+
"""
|
501 |
+
|
502 |
+
def __init__(
|
503 |
+
self,
|
504 |
+
n_neurons,
|
505 |
+
input_shape=None,
|
506 |
+
input_size=None,
|
507 |
+
bias=True,
|
508 |
+
combine_dims=False,
|
509 |
+
):
|
510 |
+
super().__init__()
|
511 |
+
self.combine_dims = combine_dims
|
512 |
+
|
513 |
+
if input_shape is None and input_size is None:
|
514 |
+
raise ValueError("Expected one of input_shape or input_size")
|
515 |
+
|
516 |
+
if input_size is None:
|
517 |
+
input_size = input_shape[-1]
|
518 |
+
if len(input_shape) == 4 and self.combine_dims:
|
519 |
+
input_size = input_shape[2] * input_shape[3]
|
520 |
+
|
521 |
+
# Weights are initialized following pytorch approach
|
522 |
+
self.w = nn.Linear(input_size, n_neurons, bias=bias)
|
523 |
+
|
524 |
+
def forward(self, x):
|
525 |
+
"""Returns the linear transformation of input tensor.
|
526 |
+
|
527 |
+
Arguments
|
528 |
+
---------
|
529 |
+
x : torch.Tensor
|
530 |
+
Input to transform linearly.
|
531 |
+
"""
|
532 |
+
if x.ndim == 4 and self.combine_dims:
|
533 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
|
534 |
+
|
535 |
+
wx = self.w(x)
|
536 |
+
|
537 |
+
return wx
|
538 |
+
|
539 |
+
|
540 |
+
class TDNNBlock(nn.Module):
|
541 |
+
"""An implementation of TDNN.
|
542 |
+
|
543 |
+
Arguments
|
544 |
+
----------
|
545 |
+
in_channels : int
|
546 |
+
Number of input channels.
|
547 |
+
out_channels : int
|
548 |
+
The number of output channels.
|
549 |
+
kernel_size : int
|
550 |
+
The kernel size of the TDNN blocks.
|
551 |
+
dilation : int
|
552 |
+
The dilation of the Res2Net block.
|
553 |
+
activation : torch class
|
554 |
+
A class for constructing the activation layers.
|
555 |
+
|
556 |
+
Example
|
557 |
+
-------
|
558 |
+
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
559 |
+
>>> layer = TDNNBlock(64, 64, kernel_size=3, dilation=1)
|
560 |
+
>>> out_tensor = layer(inp_tensor).transpose(1, 2)
|
561 |
+
>>> out_tensor.shape
|
562 |
+
torch.Size([8, 120, 64])
|
563 |
+
"""
|
564 |
+
|
565 |
+
def __init__(
|
566 |
+
self,
|
567 |
+
in_channels,
|
568 |
+
out_channels,
|
569 |
+
kernel_size,
|
570 |
+
dilation,
|
571 |
+
activation=nn.ReLU,
|
572 |
+
batch_norm=True,
|
573 |
+
):
|
574 |
+
super(TDNNBlock, self).__init__()
|
575 |
+
self.conv = Conv1d(
|
576 |
+
in_channels=in_channels,
|
577 |
+
out_channels=out_channels,
|
578 |
+
kernel_size=kernel_size,
|
579 |
+
dilation=dilation,
|
580 |
+
)
|
581 |
+
self.activation = activation()
|
582 |
+
self.norm = BatchNorm1d(input_size=out_channels, enabled=batch_norm)
|
583 |
+
|
584 |
+
def forward(self, x):
|
585 |
+
return self.norm(self.activation(self.conv(x)))
|
586 |
+
|
587 |
+
|
588 |
+
class Res2NetBlock(torch.nn.Module):
|
589 |
+
"""An implementation of Res2NetBlock w/ dilation.
|
590 |
+
|
591 |
+
Arguments
|
592 |
+
---------
|
593 |
+
in_channels : int
|
594 |
+
The number of channels expected in the input.
|
595 |
+
out_channels : int
|
596 |
+
The number of output channels.
|
597 |
+
scale : int
|
598 |
+
The scale of the Res2Net block.
|
599 |
+
kernel_size: int
|
600 |
+
The kernel size of the Res2Net block.
|
601 |
+
dilation : int
|
602 |
+
The dilation of the Res2Net block.
|
603 |
+
|
604 |
+
Example
|
605 |
+
-------
|
606 |
+
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
607 |
+
>>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
|
608 |
+
>>> out_tensor = layer(inp_tensor).transpose(1, 2)
|
609 |
+
>>> out_tensor.shape
|
610 |
+
torch.Size([8, 120, 64])
|
611 |
+
"""
|
612 |
+
|
613 |
+
def __init__(
|
614 |
+
self,
|
615 |
+
in_channels,
|
616 |
+
out_channels,
|
617 |
+
scale=8,
|
618 |
+
kernel_size=3,
|
619 |
+
dilation=1,
|
620 |
+
batch_norm=True,
|
621 |
+
):
|
622 |
+
super(Res2NetBlock, self).__init__()
|
623 |
+
assert in_channels % scale == 0
|
624 |
+
assert out_channels % scale == 0
|
625 |
+
|
626 |
+
in_channel = in_channels // scale
|
627 |
+
hidden_channel = out_channels // scale
|
628 |
+
|
629 |
+
self.blocks = nn.ModuleList(
|
630 |
+
[
|
631 |
+
TDNNBlock(
|
632 |
+
in_channel,
|
633 |
+
hidden_channel,
|
634 |
+
kernel_size=kernel_size,
|
635 |
+
dilation=dilation,
|
636 |
+
batch_norm=batch_norm,
|
637 |
+
)
|
638 |
+
for i in range(scale - 1)
|
639 |
+
]
|
640 |
+
)
|
641 |
+
self.scale = scale
|
642 |
+
|
643 |
+
def forward(self, x):
|
644 |
+
y = []
|
645 |
+
for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
|
646 |
+
if i == 0:
|
647 |
+
y_i = x_i
|
648 |
+
elif i == 1:
|
649 |
+
y_i = self.blocks[i - 1](x_i)
|
650 |
+
else:
|
651 |
+
y_i = self.blocks[i - 1](x_i + y_i)
|
652 |
+
y.append(y_i)
|
653 |
+
y = torch.cat(y, dim=1)
|
654 |
+
return y
|
655 |
+
|
656 |
+
|
657 |
+
class SEBlock(nn.Module):
|
658 |
+
"""An implementation of squeeze-and-excitation block.
|
659 |
+
|
660 |
+
Arguments
|
661 |
+
---------
|
662 |
+
in_channels : int
|
663 |
+
The number of input channels.
|
664 |
+
se_channels : int
|
665 |
+
The number of output channels after squeeze.
|
666 |
+
out_channels : int
|
667 |
+
The number of output channels.
|
668 |
+
|
669 |
+
Example
|
670 |
+
-------
|
671 |
+
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
672 |
+
>>> se_layer = SEBlock(64, 16, 64)
|
673 |
+
>>> lengths = torch.rand((8,))
|
674 |
+
>>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2)
|
675 |
+
>>> out_tensor.shape
|
676 |
+
torch.Size([8, 120, 64])
|
677 |
+
"""
|
678 |
+
|
679 |
+
def __init__(self, in_channels, se_channels, out_channels):
|
680 |
+
super(SEBlock, self).__init__()
|
681 |
+
|
682 |
+
self.conv1 = Conv1d(
|
683 |
+
in_channels=in_channels, out_channels=se_channels, kernel_size=1
|
684 |
+
)
|
685 |
+
self.relu = torch.nn.ReLU(inplace=True)
|
686 |
+
self.conv2 = Conv1d(
|
687 |
+
in_channels=se_channels, out_channels=out_channels, kernel_size=1
|
688 |
+
)
|
689 |
+
self.sigmoid = torch.nn.Sigmoid()
|
690 |
+
|
691 |
+
def forward(self, x, lengths=None):
|
692 |
+
L = x.shape[-1]
|
693 |
+
if lengths is not None:
|
694 |
+
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
|
695 |
+
mask = mask.unsqueeze(1)
|
696 |
+
total = mask.sum(dim=2, keepdim=True)
|
697 |
+
s = (x * mask).sum(dim=2, keepdim=True) / total
|
698 |
+
else:
|
699 |
+
s = x.mean(dim=2, keepdim=True)
|
700 |
+
|
701 |
+
s = self.relu(self.conv1(s))
|
702 |
+
s = self.sigmoid(self.conv2(s))
|
703 |
+
|
704 |
+
return s * x
|
705 |
+
|
706 |
+
|
707 |
+
class AttentiveStatisticsPooling(nn.Module):
|
708 |
+
"""This class implements an attentive statistic pooling layer for each channel.
|
709 |
+
It returns the concatenated mean and std of the input tensor.
|
710 |
+
|
711 |
+
Arguments
|
712 |
+
---------
|
713 |
+
channels: int
|
714 |
+
The number of input channels.
|
715 |
+
attention_channels: int
|
716 |
+
The number of attention channels.
|
717 |
+
|
718 |
+
Example
|
719 |
+
-------
|
720 |
+
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
721 |
+
>>> asp_layer = AttentiveStatisticsPooling(64)
|
722 |
+
>>> lengths = torch.rand((8,))
|
723 |
+
>>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2)
|
724 |
+
>>> out_tensor.shape
|
725 |
+
torch.Size([8, 1, 128])
|
726 |
+
"""
|
727 |
+
|
728 |
+
def __init__(
|
729 |
+
self, channels, attention_channels=128, global_context=True, batch_norm=True
|
730 |
+
):
|
731 |
+
super().__init__()
|
732 |
+
|
733 |
+
self.eps = 1e-12
|
734 |
+
self.global_context = global_context
|
735 |
+
if global_context:
|
736 |
+
self.tdnn = TDNNBlock(
|
737 |
+
channels * 3, attention_channels, 1, 1, batch_norm=batch_norm
|
738 |
+
)
|
739 |
+
else:
|
740 |
+
self.tdnn = TDNNBlock(
|
741 |
+
channels, attention_channels, 1, 1, batch_norm, batch_norm
|
742 |
+
)
|
743 |
+
self.tanh = nn.Tanh()
|
744 |
+
self.conv = Conv1d(
|
745 |
+
in_channels=attention_channels, out_channels=channels, kernel_size=1
|
746 |
+
)
|
747 |
+
|
748 |
+
def forward(self, x, lengths=None):
|
749 |
+
"""Calculates mean and std for a batch (input tensor).
|
750 |
+
|
751 |
+
Arguments
|
752 |
+
---------
|
753 |
+
x : torch.Tensor
|
754 |
+
Tensor of shape [N, C, L].
|
755 |
+
"""
|
756 |
+
L = x.shape[-1]
|
757 |
+
|
758 |
+
def _compute_statistics(x, m, dim=2, eps=self.eps):
|
759 |
+
mean = (m * x).sum(dim)
|
760 |
+
std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
|
761 |
+
return mean, std
|
762 |
+
|
763 |
+
if lengths is None:
|
764 |
+
lengths = torch.ones(x.shape[0], device=x.device)
|
765 |
+
|
766 |
+
# Make binary mask of shape [N, 1, L]
|
767 |
+
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
|
768 |
+
mask = mask.unsqueeze(1)
|
769 |
+
|
770 |
+
# Expand the temporal context of the pooling layer by allowing the
|
771 |
+
# self-attention to look at global properties of the utterance.
|
772 |
+
if self.global_context:
|
773 |
+
# torch.std is unstable for backward computation
|
774 |
+
# https://github.com/pytorch/pytorch/issues/4320
|
775 |
+
total = mask.sum(dim=2, keepdim=True).float()
|
776 |
+
mean, std = _compute_statistics(x, mask / total)
|
777 |
+
mean = mean.unsqueeze(2).repeat(1, 1, L)
|
778 |
+
std = std.unsqueeze(2).repeat(1, 1, L)
|
779 |
+
attn = torch.cat([x, mean, std], dim=1)
|
780 |
+
else:
|
781 |
+
attn = x
|
782 |
+
|
783 |
+
# Apply layers
|
784 |
+
attn = self.conv(self.tanh(self.tdnn(attn)))
|
785 |
+
|
786 |
+
# Filter out zero-paddings
|
787 |
+
attn = attn.masked_fill(mask == 0, float("-inf"))
|
788 |
+
|
789 |
+
attn = F.softmax(attn, dim=2)
|
790 |
+
mean, std = _compute_statistics(x, attn)
|
791 |
+
# Append mean and std of the batch
|
792 |
+
pooled_stats = torch.cat((mean, std), dim=1)
|
793 |
+
pooled_stats = pooled_stats.unsqueeze(2)
|
794 |
+
|
795 |
+
return pooled_stats
|
796 |
+
|
797 |
+
|
798 |
+
class SERes2NetBlock(nn.Module):
|
799 |
+
"""An implementation of building block in ECAPA-TDNN, i.e.,
|
800 |
+
TDNN-Res2Net-TDNN-SEBlock.
|
801 |
+
|
802 |
+
Arguments
|
803 |
+
----------
|
804 |
+
out_channels: int
|
805 |
+
The number of output channels.
|
806 |
+
res2net_scale: int
|
807 |
+
The scale of the Res2Net block.
|
808 |
+
kernel_size: int
|
809 |
+
The kernel size of the TDNN blocks.
|
810 |
+
dilation: int
|
811 |
+
The dilation of the Res2Net block.
|
812 |
+
activation : torch class
|
813 |
+
A class for constructing the activation layers.
|
814 |
+
|
815 |
+
Example
|
816 |
+
-------
|
817 |
+
>>> x = torch.rand(8, 120, 64).transpose(1, 2)
|
818 |
+
>>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
|
819 |
+
>>> out = conv(x).transpose(1, 2)
|
820 |
+
>>> out.shape
|
821 |
+
torch.Size([8, 120, 64])
|
822 |
+
"""
|
823 |
+
|
824 |
+
def __init__(
|
825 |
+
self,
|
826 |
+
in_channels,
|
827 |
+
out_channels,
|
828 |
+
res2net_scale=8,
|
829 |
+
se_channels=128,
|
830 |
+
kernel_size=1,
|
831 |
+
dilation=1,
|
832 |
+
activation=torch.nn.ReLU,
|
833 |
+
batch_norm=True,
|
834 |
+
):
|
835 |
+
super().__init__()
|
836 |
+
self.out_channels = out_channels
|
837 |
+
self.tdnn1 = TDNNBlock(
|
838 |
+
in_channels,
|
839 |
+
out_channels,
|
840 |
+
kernel_size=1,
|
841 |
+
dilation=1,
|
842 |
+
activation=activation,
|
843 |
+
batch_norm=batch_norm,
|
844 |
+
)
|
845 |
+
self.res2net_block = Res2NetBlock(
|
846 |
+
out_channels,
|
847 |
+
out_channels,
|
848 |
+
res2net_scale,
|
849 |
+
kernel_size,
|
850 |
+
dilation,
|
851 |
+
batch_norm=batch_norm,
|
852 |
+
)
|
853 |
+
self.tdnn2 = TDNNBlock(
|
854 |
+
out_channels,
|
855 |
+
out_channels,
|
856 |
+
kernel_size=1,
|
857 |
+
dilation=1,
|
858 |
+
activation=activation,
|
859 |
+
batch_norm=batch_norm,
|
860 |
+
)
|
861 |
+
self.se_block = SEBlock(out_channels, se_channels, out_channels)
|
862 |
+
|
863 |
+
self.shortcut = None
|
864 |
+
if in_channels != out_channels:
|
865 |
+
self.shortcut = Conv1d(
|
866 |
+
in_channels=in_channels,
|
867 |
+
out_channels=out_channels,
|
868 |
+
kernel_size=1,
|
869 |
+
)
|
870 |
+
|
871 |
+
def forward(self, x, lengths=None):
|
872 |
+
residual = x
|
873 |
+
if self.shortcut:
|
874 |
+
residual = self.shortcut(x)
|
875 |
+
|
876 |
+
x = self.tdnn1(x)
|
877 |
+
x = self.res2net_block(x)
|
878 |
+
x = self.tdnn2(x)
|
879 |
+
x = self.se_block(x, lengths)
|
880 |
+
|
881 |
+
return x + residual
|
882 |
+
|
883 |
+
|
884 |
+
class ECAPA_TDNN(torch.nn.Module):
|
885 |
+
"""An implementation of the speaker embedding model in a paper.
|
886 |
+
"ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
|
887 |
+
TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
|
888 |
+
|
889 |
+
Arguments
|
890 |
+
---------
|
891 |
+
device : str
|
892 |
+
Device used, e.g., "cpu" or "cuda".
|
893 |
+
activation : torch class
|
894 |
+
A class for constructing the activation layers.
|
895 |
+
channels : list of ints
|
896 |
+
Output channels for TDNN/SERes2Net layer.
|
897 |
+
kernel_sizes : list of ints
|
898 |
+
List of kernel sizes for each layer.
|
899 |
+
dilations : list of ints
|
900 |
+
List of dilations for kernels in each layer.
|
901 |
+
lin_neurons : int
|
902 |
+
Number of neurons in linear layers.
|
903 |
+
|
904 |
+
Example
|
905 |
+
-------
|
906 |
+
>>> input_feats = torch.rand([5, 120, 80])
|
907 |
+
>>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192)
|
908 |
+
>>> outputs = compute_embedding(input_feats)
|
909 |
+
>>> outputs.shape
|
910 |
+
torch.Size([5, 1, 192])
|
911 |
+
"""
|
912 |
+
|
913 |
+
def __init__(
|
914 |
+
self,
|
915 |
+
input_size,
|
916 |
+
lin_neurons=192,
|
917 |
+
activation=torch.nn.ReLU,
|
918 |
+
channels=[512, 512, 512, 512, 1536],
|
919 |
+
kernel_sizes=[5, 3, 3, 3, 1],
|
920 |
+
dilations=[1, 2, 3, 4, 1],
|
921 |
+
attention_channels=128,
|
922 |
+
res2net_scale=8,
|
923 |
+
se_channels=128,
|
924 |
+
global_context=True,
|
925 |
+
batch_norm=True,
|
926 |
+
):
|
927 |
+
|
928 |
+
super().__init__()
|
929 |
+
assert len(channels) == len(kernel_sizes)
|
930 |
+
assert len(channels) == len(dilations)
|
931 |
+
self.channels = channels
|
932 |
+
self.blocks = nn.ModuleList()
|
933 |
+
|
934 |
+
# The initial TDNN layer
|
935 |
+
self.blocks.append(
|
936 |
+
TDNNBlock(
|
937 |
+
input_size,
|
938 |
+
channels[0],
|
939 |
+
kernel_sizes[0],
|
940 |
+
dilations[0],
|
941 |
+
activation,
|
942 |
+
batch_norm=batch_norm,
|
943 |
+
)
|
944 |
+
)
|
945 |
+
|
946 |
+
# SE-Res2Net layers
|
947 |
+
for i in range(1, len(channels) - 1):
|
948 |
+
self.blocks.append(
|
949 |
+
SERes2NetBlock(
|
950 |
+
channels[i - 1],
|
951 |
+
channels[i],
|
952 |
+
res2net_scale=res2net_scale,
|
953 |
+
se_channels=se_channels,
|
954 |
+
kernel_size=kernel_sizes[i],
|
955 |
+
dilation=dilations[i],
|
956 |
+
activation=activation,
|
957 |
+
batch_norm=batch_norm,
|
958 |
+
)
|
959 |
+
)
|
960 |
+
|
961 |
+
# Multi-layer feature aggregation
|
962 |
+
self.mfa = TDNNBlock(
|
963 |
+
channels[-1],
|
964 |
+
channels[-1],
|
965 |
+
kernel_sizes[-1],
|
966 |
+
dilations[-1],
|
967 |
+
activation,
|
968 |
+
batch_norm=batch_norm,
|
969 |
+
)
|
970 |
+
|
971 |
+
# Attentive Statistical Pooling
|
972 |
+
self.asp = AttentiveStatisticsPooling(
|
973 |
+
channels[-1],
|
974 |
+
attention_channels=attention_channels,
|
975 |
+
global_context=global_context,
|
976 |
+
batch_norm=batch_norm,
|
977 |
+
)
|
978 |
+
self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2, enabled=batch_norm)
|
979 |
+
|
980 |
+
# Final linear transformation
|
981 |
+
self.fc = Conv1d(
|
982 |
+
in_channels=channels[-1] * 2,
|
983 |
+
out_channels=lin_neurons,
|
984 |
+
kernel_size=1,
|
985 |
+
)
|
986 |
+
|
987 |
+
# @torch.cuda.amp.autocast(enabled=True, dtype=torch.float32)
|
988 |
+
def forward(self, x, lengths=None):
|
989 |
+
"""Returns the embedding vector.
|
990 |
+
|
991 |
+
Arguments
|
992 |
+
---------
|
993 |
+
x : torch.Tensor
|
994 |
+
Tensor of shape (batch, time, channel).
|
995 |
+
"""
|
996 |
+
# Minimize transpose for efficiency
|
997 |
+
x = x.transpose(1, 2)
|
998 |
+
|
999 |
+
xl = []
|
1000 |
+
for layer in self.blocks:
|
1001 |
+
try:
|
1002 |
+
x = layer(x, lengths=lengths)
|
1003 |
+
except TypeError:
|
1004 |
+
x = layer(x)
|
1005 |
+
xl.append(x)
|
1006 |
+
|
1007 |
+
# Multi-layer feature aggregation
|
1008 |
+
x = torch.cat(xl[1:], dim=1)
|
1009 |
+
x = self.mfa(x)
|
1010 |
+
|
1011 |
+
# Attentive Statistical Pooling
|
1012 |
+
x = self.asp(x, lengths=lengths)
|
1013 |
+
x = self.asp_bn(x)
|
1014 |
+
|
1015 |
+
# Final linear transformation
|
1016 |
+
x = self.fc(x)
|
1017 |
+
|
1018 |
+
x = x.squeeze(-1)
|
1019 |
+
return x
|
1020 |
+
|
1021 |
+
|
1022 |
+
class SpeakerEmbedddingExtractor(object):
|
1023 |
+
|
1024 |
+
def __init__(self, ckpt_path, device="cuda"):
|
1025 |
+
# NOTE: The sampling rate is 16000
|
1026 |
+
self.mel_extractor = TorchMelSpectrogram()
|
1027 |
+
self.mel_extractor.to(device)
|
1028 |
+
model = ECAPA_TDNN(
|
1029 |
+
80,
|
1030 |
+
512,
|
1031 |
+
channels=[512, 512, 512, 512, 1536],
|
1032 |
+
kernel_sizes=[5, 3, 3, 3, 1],
|
1033 |
+
dilations=[1, 2, 3, 4, 1],
|
1034 |
+
attention_channels=128,
|
1035 |
+
res2net_scale=4,
|
1036 |
+
se_channels=128,
|
1037 |
+
global_context=True,
|
1038 |
+
batch_norm=True,
|
1039 |
+
)
|
1040 |
+
model.load_state_dict(torch.load(ckpt_path), strict=True)
|
1041 |
+
model.eval()
|
1042 |
+
self.model = model
|
1043 |
+
self.model.to(device)
|
1044 |
+
|
1045 |
+
def __call__(self, wav):
|
1046 |
+
# wav, sr = torchaudio.load(audio_path)
|
1047 |
+
# assert sr == 16000, f"The sampling rate is not 16000"
|
1048 |
+
# print(wav.shape)
|
1049 |
+
mel = self.mel_extractor(wav.unsqueeze(0))
|
1050 |
+
spk = self.model(mel)
|
1051 |
+
spk = spk[0]
|
1052 |
+
return spk
|
fireredtts/modules/flow/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fireredtts.modules.flow.codec_embedding import HHGCodecEmbedding
|
2 |
+
from fireredtts.modules.flow.conformer import ConformerDecoderV2
|
3 |
+
from fireredtts.modules.flow.mel_encoder import MelReduceEncoder
|
4 |
+
from fireredtts.modules.flow.decoder import ConditionalCFM, ConditionalDecoder
|
5 |
+
from fireredtts.modules.flow.flow_model import InterpolateRegulator, CrossAttnFlowMatching
|
6 |
+
from fireredtts.modules.flow.mel_spectrogram import MelSpectrogramExtractor
|
7 |
+
|
8 |
+
|
9 |
+
def get_flow_frontend(flow_config):
|
10 |
+
flow = CrossAttnFlowMatching(
|
11 |
+
output_size=flow_config["output_size"],
|
12 |
+
input_embedding=HHGCodecEmbedding(**flow_config["input_embedding"]),
|
13 |
+
encoder=ConformerDecoderV2(**flow_config["encoder"]),
|
14 |
+
length_regulator=InterpolateRegulator(**flow_config["length_regulator"]),
|
15 |
+
mel_encoder=MelReduceEncoder(**flow_config["mel_encoder"]),
|
16 |
+
decoder=ConditionalCFM(
|
17 |
+
estimator=ConditionalDecoder(**flow_config["decoder"]["estimator"]),
|
18 |
+
t_scheduler=flow_config["decoder"]["t_scheduler"],
|
19 |
+
inference_cfg_rate=flow_config["decoder"]["inference_cfg_rate"]
|
20 |
+
)
|
21 |
+
)
|
22 |
+
return flow
|
23 |
+
|
24 |
+
|
fireredtts/modules/flow/codebook.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f32d6280c561e4d0bf741e3ea07d651aee71da84c17fc55b686cb825bd677656
|
3 |
+
size 131200
|
fireredtts/modules/flow/codec_embedding.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class HHGCodecEmbedding(nn.Module):
|
7 |
+
def __init__(self, out_channels, codebook_path:str, freeze=True):
|
8 |
+
super().__init__()
|
9 |
+
# (2, 128, 128)
|
10 |
+
codebook = torch.from_numpy(np.load(codebook_path).copy())
|
11 |
+
assert codebook.shape[0] == 2 and codebook.shape[1] == 128
|
12 |
+
self.codebook_dim = codebook.shape[2]
|
13 |
+
|
14 |
+
self.codebook = torch.nn.ModuleList([
|
15 |
+
torch.nn.Embedding.from_pretrained(codebook[i], freeze=freeze)
|
16 |
+
for i in range(codebook.shape[0])]
|
17 |
+
)
|
18 |
+
if self.codebook_dim * 2 != out_channels:
|
19 |
+
self.proj = nn.Linear(self.codebook_dim * 2, out_channels)
|
20 |
+
else:
|
21 |
+
self.proj = nn.Identity()
|
22 |
+
|
23 |
+
def forward(self, tokens):
|
24 |
+
token_embs = torch.cat([
|
25 |
+
self.codebook[0](tokens % 128),
|
26 |
+
self.codebook[1](tokens // 128)
|
27 |
+
], dim=-1)
|
28 |
+
token_embs = self.proj(token_embs)
|
29 |
+
return token_embs
|
30 |
+
|
fireredtts/modules/flow/conformer.py
ADDED
@@ -0,0 +1,730 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing as tp
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from fireredtts.modules.flow.utils import make_pad_mask
|
8 |
+
|
9 |
+
|
10 |
+
class MultiHeadedAttention(nn.Module):
|
11 |
+
"""Multi-Head Attention layer.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
n_head (int): The number of heads.
|
15 |
+
n_feat (int): The number of features.
|
16 |
+
dropout_rate (float): Dropout rate.
|
17 |
+
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
n_head: int,
|
22 |
+
n_feat: int,
|
23 |
+
dropout_rate: float,
|
24 |
+
key_bias: bool = True):
|
25 |
+
"""Construct an MultiHeadedAttention object."""
|
26 |
+
super().__init__()
|
27 |
+
assert n_feat % n_head == 0
|
28 |
+
# We assume d_v always equals d_k
|
29 |
+
self.d_k = n_feat // n_head
|
30 |
+
self.h = n_head
|
31 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
32 |
+
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
|
33 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
34 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
35 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
36 |
+
|
37 |
+
def forward_qkv(
|
38 |
+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
39 |
+
) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
40 |
+
"""Transform query, key and value.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
44 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
45 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
torch.Tensor: Transformed query tensor, size
|
49 |
+
(#batch, n_head, time1, d_k).
|
50 |
+
torch.Tensor: Transformed key tensor, size
|
51 |
+
(#batch, n_head, time2, d_k).
|
52 |
+
torch.Tensor: Transformed value tensor, size
|
53 |
+
(#batch, n_head, time2, d_k).
|
54 |
+
|
55 |
+
"""
|
56 |
+
n_batch = query.size(0)
|
57 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
58 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
59 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
60 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
61 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
62 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
63 |
+
|
64 |
+
return q, k, v
|
65 |
+
|
66 |
+
def forward_attention(
|
67 |
+
self,
|
68 |
+
value: torch.Tensor,
|
69 |
+
scores: torch.Tensor,
|
70 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
71 |
+
) -> torch.Tensor:
|
72 |
+
"""Compute attention context vector.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
value (torch.Tensor): Transformed value, size
|
76 |
+
(#batch, n_head, time2, d_k).
|
77 |
+
scores (torch.Tensor): Attention score, size
|
78 |
+
(#batch, n_head, time1, time2).
|
79 |
+
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
|
80 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
84 |
+
weighted by the attention score (#batch, time1, time2).
|
85 |
+
|
86 |
+
"""
|
87 |
+
n_batch = value.size(0)
|
88 |
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
|
89 |
+
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
90 |
+
# 1st chunk to ease the onnx export.]
|
91 |
+
# 2. pytorch training
|
92 |
+
if mask.size(2) > 0: # time2 > 0
|
93 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
94 |
+
# For last chunk, time2 might be larger than scores.size(-1)
|
95 |
+
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
96 |
+
scores = scores.masked_fill(mask, -float('inf'))
|
97 |
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
98 |
+
mask, 0.0) # (batch, head, time1, time2)
|
99 |
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
|
100 |
+
# 1. onnx(16/-1, -1/-1, 16/0)
|
101 |
+
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
102 |
+
else:
|
103 |
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
104 |
+
|
105 |
+
p_attn = self.dropout(attn)
|
106 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
107 |
+
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
|
108 |
+
self.h * self.d_k)
|
109 |
+
) # (batch, time1, d_model)
|
110 |
+
|
111 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
112 |
+
|
113 |
+
def forward(
|
114 |
+
self,
|
115 |
+
query: torch.Tensor,
|
116 |
+
key: torch.Tensor,
|
117 |
+
value: torch.Tensor,
|
118 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
119 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
120 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
121 |
+
) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
122 |
+
"""Compute scaled dot product attention.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
126 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
127 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
128 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
129 |
+
(#batch, time1, time2).
|
130 |
+
1.When applying cross attention between decoder and encoder,
|
131 |
+
the batch padding mask for input is in (#batch, 1, T) shape.
|
132 |
+
2.When applying self attention of encoder,
|
133 |
+
the mask is in (#batch, T, T) shape.
|
134 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
135 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
136 |
+
and `head * d_k == size`
|
137 |
+
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
141 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
142 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
143 |
+
and `head * d_k == size`
|
144 |
+
|
145 |
+
"""
|
146 |
+
q, k, v = self.forward_qkv(query, key, value)
|
147 |
+
|
148 |
+
# NOTE(xcsong):
|
149 |
+
# when export onnx model, for 1st chunk, we feed
|
150 |
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
151 |
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
152 |
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
153 |
+
# and we will always do splitting and
|
154 |
+
# concatnation(this will simplify onnx export). Note that
|
155 |
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
156 |
+
# when export jit model, for 1st chunk, we always feed
|
157 |
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
158 |
+
# >>> a = torch.ones((1, 2, 0, 4))
|
159 |
+
# >>> b = torch.ones((1, 2, 3, 4))
|
160 |
+
# >>> c = torch.cat((a, b), dim=2)
|
161 |
+
# >>> torch.equal(b, c) # True
|
162 |
+
# >>> d = torch.split(a, 2, dim=-1)
|
163 |
+
# >>> torch.equal(d[0], d[1]) # True
|
164 |
+
if cache.size(0) > 0:
|
165 |
+
key_cache, value_cache = torch.split(cache,
|
166 |
+
cache.size(-1) // 2,
|
167 |
+
dim=-1)
|
168 |
+
k = torch.cat([key_cache, k], dim=2)
|
169 |
+
v = torch.cat([value_cache, v], dim=2)
|
170 |
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
171 |
+
# non-trivial to calculate `next_cache_start` here.
|
172 |
+
new_cache = torch.cat((k, v), dim=-1)
|
173 |
+
|
174 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
175 |
+
return self.forward_attention(v, scores, mask), new_cache
|
176 |
+
|
177 |
+
|
178 |
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
179 |
+
"""Multi-Head Attention layer with relative position encoding.
|
180 |
+
Paper: https://arxiv.org/abs/1901.02860
|
181 |
+
Args:
|
182 |
+
n_head (int): The number of heads.
|
183 |
+
n_feat (int): The number of features.
|
184 |
+
dropout_rate (float): Dropout rate.
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(self,
|
188 |
+
n_head: int,
|
189 |
+
n_feat: int,
|
190 |
+
dropout_rate: float,
|
191 |
+
key_bias: bool = True):
|
192 |
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
193 |
+
super().__init__(n_head, n_feat, dropout_rate, key_bias)
|
194 |
+
# linear transformation for positional encoding
|
195 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
196 |
+
# these two learnable bias are used in matrix c and matrix d
|
197 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
198 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
199 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
200 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
201 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
202 |
+
|
203 |
+
def rel_shift(self, x):
|
204 |
+
"""Compute relative positional encoding.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
208 |
+
time1 means the length of query vector.
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
torch.Tensor: Output tensor.
|
212 |
+
|
213 |
+
"""
|
214 |
+
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
215 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
216 |
+
|
217 |
+
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
218 |
+
x = x_padded[:, :, 1:].view_as(x)[
|
219 |
+
:, :, :, : x.size(-1) // 2 + 1
|
220 |
+
] # only keep the positions from 0 to time2
|
221 |
+
return x
|
222 |
+
|
223 |
+
def forward(
|
224 |
+
self,
|
225 |
+
query: torch.Tensor,
|
226 |
+
key: torch.Tensor,
|
227 |
+
value: torch.Tensor,
|
228 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
229 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
230 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
231 |
+
) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
232 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
233 |
+
Args:
|
234 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
235 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
236 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
237 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
238 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
239 |
+
pos_emb (torch.Tensor): Positional embedding tensor
|
240 |
+
(#batch, time2, size).
|
241 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
242 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
243 |
+
and `head * d_k == size`
|
244 |
+
Returns:
|
245 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
246 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
247 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
248 |
+
and `head * d_k == size`
|
249 |
+
"""
|
250 |
+
q, k, v = self.forward_qkv(query, key, value)
|
251 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
252 |
+
|
253 |
+
# NOTE(xcsong):
|
254 |
+
# when export onnx model, for 1st chunk, we feed
|
255 |
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
256 |
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
257 |
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
258 |
+
# and we will always do splitting and
|
259 |
+
# concatnation(this will simplify onnx export). Note that
|
260 |
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
261 |
+
# when export jit model, for 1st chunk, we always feed
|
262 |
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
263 |
+
# >>> a = torch.ones((1, 2, 0, 4))
|
264 |
+
# >>> b = torch.ones((1, 2, 3, 4))
|
265 |
+
# >>> c = torch.cat((a, b), dim=2)
|
266 |
+
# >>> torch.equal(b, c) # True
|
267 |
+
# >>> d = torch.split(a, 2, dim=-1)
|
268 |
+
# >>> torch.equal(d[0], d[1]) # True
|
269 |
+
if cache.size(0) > 0:
|
270 |
+
key_cache, value_cache = torch.split(cache,
|
271 |
+
cache.size(-1) // 2,
|
272 |
+
dim=-1)
|
273 |
+
k = torch.cat([key_cache, k], dim=2)
|
274 |
+
v = torch.cat([value_cache, v], dim=2)
|
275 |
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
276 |
+
# non-trivial to calculate `next_cache_start` here.
|
277 |
+
new_cache = torch.cat((k, v), dim=-1)
|
278 |
+
|
279 |
+
n_batch_pos = pos_emb.size(0)
|
280 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
281 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
282 |
+
|
283 |
+
# (batch, head, time1, d_k)
|
284 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
285 |
+
# (batch, head, time1, d_k)
|
286 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
287 |
+
|
288 |
+
# compute attention score
|
289 |
+
# first compute matrix a and matrix c
|
290 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
291 |
+
# (batch, head, time1, time2)
|
292 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
293 |
+
|
294 |
+
# compute matrix b and matrix d
|
295 |
+
# (batch, head, time1, time2)
|
296 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
297 |
+
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
|
298 |
+
if matrix_ac.shape != matrix_bd.shape:
|
299 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
300 |
+
|
301 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
302 |
+
self.d_k) # (batch, head, time1, time2)
|
303 |
+
|
304 |
+
return self.forward_attention(v, scores, mask), new_cache
|
305 |
+
|
306 |
+
|
307 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
308 |
+
"""Positionwise feed forward layer.
|
309 |
+
|
310 |
+
FeedForward are appied on each position of the sequence.
|
311 |
+
The output dim is same with the input dim.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
idim (int): Input dimenstion.
|
315 |
+
hidden_units (int): The number of hidden units.
|
316 |
+
dropout_rate (float): Dropout rate.
|
317 |
+
activation (torch.nn.Module): Activation function
|
318 |
+
"""
|
319 |
+
|
320 |
+
def __init__(
|
321 |
+
self,
|
322 |
+
idim: int,
|
323 |
+
hidden_units: int,
|
324 |
+
dropout_rate: float,
|
325 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
326 |
+
):
|
327 |
+
"""Construct a PositionwiseFeedForward object."""
|
328 |
+
super(PositionwiseFeedForward, self).__init__()
|
329 |
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
330 |
+
self.activation = activation
|
331 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
332 |
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
333 |
+
|
334 |
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
335 |
+
"""Forward function.
|
336 |
+
|
337 |
+
Args:
|
338 |
+
xs: input tensor (B, L, D)
|
339 |
+
Returns:
|
340 |
+
output tensor, (B, L, D)
|
341 |
+
"""
|
342 |
+
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
343 |
+
|
344 |
+
|
345 |
+
class ConformerDecoderLayer(nn.Module):
|
346 |
+
"""Encoder layer module.
|
347 |
+
Args:
|
348 |
+
size (int): Input dimension.
|
349 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
350 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
351 |
+
instance can be used as the argument.
|
352 |
+
src_attn (torch.nn.Module): Cross-attention module instance.
|
353 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
354 |
+
instance can be used as the argument.
|
355 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
356 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
357 |
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
|
358 |
+
instance.
|
359 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
360 |
+
conv_module (torch.nn.Module): Convolution module instance.
|
361 |
+
`ConvlutionModule` instance can be used as the argument.
|
362 |
+
dropout_rate (float): Dropout rate.
|
363 |
+
normalize_before (bool):
|
364 |
+
True: use layer_norm before each sub-block.
|
365 |
+
False: use layer_norm after each sub-block.
|
366 |
+
"""
|
367 |
+
|
368 |
+
def __init__(
|
369 |
+
self,
|
370 |
+
size: int,
|
371 |
+
self_attn: torch.nn.Module,
|
372 |
+
src_attn: tp.Optional[torch.nn.Module] = None,
|
373 |
+
feed_forward: tp.Optional[nn.Module] = None,
|
374 |
+
feed_forward_macaron: tp.Optional[nn.Module] = None,
|
375 |
+
conv_module: tp.Optional[nn.Module] = None,
|
376 |
+
dropout_rate: float = 0.1,
|
377 |
+
normalize_before: bool = True,
|
378 |
+
):
|
379 |
+
"""Construct an EncoderLayer object."""
|
380 |
+
super().__init__()
|
381 |
+
self.self_attn = self_attn
|
382 |
+
self.src_attn = src_attn
|
383 |
+
self.feed_forward = feed_forward
|
384 |
+
self.feed_forward_macaron = feed_forward_macaron
|
385 |
+
self.conv_module = conv_module
|
386 |
+
self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
|
387 |
+
self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
|
388 |
+
if src_attn is not None:
|
389 |
+
self.norm_mha2 = nn.LayerNorm(size, eps=1e-5) # for the MHA module(src_attn)
|
390 |
+
if feed_forward_macaron is not None:
|
391 |
+
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
|
392 |
+
self.ff_scale = 0.5
|
393 |
+
else:
|
394 |
+
self.ff_scale = 1.0
|
395 |
+
if self.conv_module is not None:
|
396 |
+
self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
|
397 |
+
self.norm_final = nn.LayerNorm(
|
398 |
+
size, eps=1e-5) # for the final output of the block
|
399 |
+
self.dropout = nn.Dropout(dropout_rate)
|
400 |
+
self.size = size
|
401 |
+
self.normalize_before = normalize_before
|
402 |
+
|
403 |
+
def forward(
|
404 |
+
self,
|
405 |
+
x: torch.Tensor,
|
406 |
+
mask: torch.Tensor,
|
407 |
+
# src-attention
|
408 |
+
memory: torch.Tensor,
|
409 |
+
memory_mask: torch.Tensor,
|
410 |
+
pos_emb: torch.Tensor,
|
411 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
412 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
413 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
414 |
+
) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
415 |
+
"""Compute encoded features.
|
416 |
+
|
417 |
+
Args:
|
418 |
+
x (torch.Tensor): (#batch, time, size)
|
419 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
420 |
+
(0, 0, 0) means fake mask.
|
421 |
+
pos_emb (torch.Tensor): positional encoding, must not be None
|
422 |
+
for ConformerEncoderLayer.
|
423 |
+
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
424 |
+
(#batch, 1, time), (0, 0, 0) means fake mask.
|
425 |
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
426 |
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
427 |
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
428 |
+
(#batch=1, size, cache_t2)
|
429 |
+
Returns:
|
430 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
431 |
+
torch.Tensor: Mask tensor (#batch, time, time).
|
432 |
+
torch.Tensor: att_cache tensor,
|
433 |
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
434 |
+
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
435 |
+
"""
|
436 |
+
|
437 |
+
# whether to use macaron style
|
438 |
+
if self.feed_forward_macaron is not None:
|
439 |
+
residual = x
|
440 |
+
if self.normalize_before:
|
441 |
+
x = self.norm_ff_macaron(x)
|
442 |
+
x = residual + self.ff_scale * self.dropout(
|
443 |
+
self.feed_forward_macaron(x))
|
444 |
+
if not self.normalize_before:
|
445 |
+
x = self.norm_ff_macaron(x)
|
446 |
+
|
447 |
+
# multi-headed self-attention module
|
448 |
+
residual = x
|
449 |
+
if self.normalize_before:
|
450 |
+
x = self.norm_mha(x)
|
451 |
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
|
452 |
+
att_cache)
|
453 |
+
x = residual + self.dropout(x_att)
|
454 |
+
if not self.normalize_before:
|
455 |
+
x = self.norm_mha(x)
|
456 |
+
|
457 |
+
# multi-headed cross-attention module
|
458 |
+
if self.src_attn is not None:
|
459 |
+
residual = x
|
460 |
+
if self.normalize_before:
|
461 |
+
x = self.norm_mha2(x)
|
462 |
+
x_att, _ = self.src_attn(x, memory, memory, memory_mask)
|
463 |
+
x = residual + self.dropout(x_att)
|
464 |
+
if not self.normalize_before:
|
465 |
+
x = self.norm_mha2(x)
|
466 |
+
|
467 |
+
# convolution module
|
468 |
+
# Fake new cnn cache here, and then change it in conv_module
|
469 |
+
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
470 |
+
if self.conv_module is not None:
|
471 |
+
residual = x
|
472 |
+
if self.normalize_before:
|
473 |
+
x = self.norm_conv(x)
|
474 |
+
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
|
475 |
+
x = residual + self.dropout(x)
|
476 |
+
|
477 |
+
if not self.normalize_before:
|
478 |
+
x = self.norm_conv(x)
|
479 |
+
|
480 |
+
# feed forward module
|
481 |
+
residual = x
|
482 |
+
if self.normalize_before:
|
483 |
+
x = self.norm_ff(x)
|
484 |
+
|
485 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
486 |
+
if not self.normalize_before:
|
487 |
+
x = self.norm_ff(x)
|
488 |
+
|
489 |
+
if self.conv_module is not None:
|
490 |
+
x = self.norm_final(x)
|
491 |
+
|
492 |
+
return x, mask, new_att_cache, new_cnn_cache
|
493 |
+
|
494 |
+
|
495 |
+
class EspnetRelPositionalEncoding(torch.nn.Module):
|
496 |
+
"""Relative positional encoding module (new implementation).
|
497 |
+
|
498 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
499 |
+
|
500 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
501 |
+
|
502 |
+
Args:
|
503 |
+
d_model (int): Embedding dimension.
|
504 |
+
dropout_rate (float): Dropout rate.
|
505 |
+
max_len (int): Maximum input length.
|
506 |
+
|
507 |
+
"""
|
508 |
+
|
509 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
510 |
+
"""Construct an PositionalEncoding object."""
|
511 |
+
super(EspnetRelPositionalEncoding, self).__init__()
|
512 |
+
self.d_model = d_model
|
513 |
+
self.xscale = math.sqrt(self.d_model)
|
514 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
515 |
+
self.pe = None
|
516 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
517 |
+
|
518 |
+
def extend_pe(self, x):
|
519 |
+
"""Reset the positional encodings."""
|
520 |
+
if self.pe is not None:
|
521 |
+
# self.pe contains both positive and negative parts
|
522 |
+
# the length of self.pe is 2 * input_len - 1
|
523 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
524 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
525 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
526 |
+
return
|
527 |
+
# Suppose `i` means to the position of query vecotr and `j` means the
|
528 |
+
# position of key vector. We use position relative positions when keys
|
529 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
530 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
531 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
532 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
533 |
+
div_term = torch.exp(
|
534 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
535 |
+
* -(math.log(10000.0) / self.d_model)
|
536 |
+
)
|
537 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
538 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
539 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
540 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
541 |
+
|
542 |
+
# Reserve the order of positive indices and concat both positive and
|
543 |
+
# negative indices. This is used to support the shifting trick
|
544 |
+
# as in https://arxiv.org/abs/1901.02860
|
545 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
546 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
547 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
548 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
549 |
+
|
550 |
+
def forward(self, x: torch.Tensor, offset: tp.Union[int, torch.Tensor] = 0):
|
551 |
+
"""Add positional encoding.
|
552 |
+
|
553 |
+
Args:
|
554 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
555 |
+
|
556 |
+
Returns:
|
557 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
558 |
+
|
559 |
+
"""
|
560 |
+
self.extend_pe(x)
|
561 |
+
x = x * self.xscale
|
562 |
+
pos_emb = self.position_encoding(size=x.size(1), offset=offset)
|
563 |
+
return self.dropout(x), self.dropout(pos_emb)
|
564 |
+
|
565 |
+
def position_encoding(self,
|
566 |
+
offset: tp.Union[int, torch.Tensor],
|
567 |
+
size: int) -> torch.Tensor:
|
568 |
+
""" For getting encoding in a streaming fashion
|
569 |
+
|
570 |
+
Attention!!!!!
|
571 |
+
we apply dropout only once at the whole utterance level in a none
|
572 |
+
streaming way, but will call this function several times with
|
573 |
+
increasing input size in a streaming scenario, so the dropout will
|
574 |
+
be applied several times.
|
575 |
+
|
576 |
+
Args:
|
577 |
+
offset (int or torch.tensor): start offset
|
578 |
+
size (int): required size of position encoding
|
579 |
+
|
580 |
+
Returns:
|
581 |
+
torch.Tensor: Corresponding encoding
|
582 |
+
"""
|
583 |
+
pos_emb = self.pe[
|
584 |
+
:,
|
585 |
+
self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size,
|
586 |
+
]
|
587 |
+
return pos_emb
|
588 |
+
|
589 |
+
|
590 |
+
class LinearNoSubsampling(torch.nn.Module):
|
591 |
+
"""Linear transform the input without subsampling
|
592 |
+
|
593 |
+
Args:
|
594 |
+
idim (int): Input dimension.
|
595 |
+
odim (int): Output dimension.
|
596 |
+
dropout_rate (float): Dropout rate.
|
597 |
+
|
598 |
+
"""
|
599 |
+
|
600 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
601 |
+
pos_enc_class: torch.nn.Module):
|
602 |
+
"""Construct an linear object."""
|
603 |
+
super().__init__()
|
604 |
+
self.out = torch.nn.Sequential(
|
605 |
+
torch.nn.Linear(idim, odim),
|
606 |
+
torch.nn.LayerNorm(odim, eps=1e-5),
|
607 |
+
torch.nn.Dropout(dropout_rate),
|
608 |
+
)
|
609 |
+
self.pos_enc = pos_enc_class
|
610 |
+
self.right_context = 0
|
611 |
+
self.subsampling_rate = 1
|
612 |
+
|
613 |
+
def forward(
|
614 |
+
self,
|
615 |
+
x: torch.Tensor,
|
616 |
+
x_mask: torch.Tensor,
|
617 |
+
offset: tp.Union[int, torch.Tensor] = 0
|
618 |
+
) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
619 |
+
"""Input x.
|
620 |
+
|
621 |
+
Args:
|
622 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
623 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
624 |
+
|
625 |
+
Returns:
|
626 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
627 |
+
where time' = time .
|
628 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
629 |
+
where time' = time .
|
630 |
+
|
631 |
+
"""
|
632 |
+
x = self.out(x)
|
633 |
+
x, pos_emb = self.pos_enc(x, offset)
|
634 |
+
return x, pos_emb, x_mask
|
635 |
+
|
636 |
+
|
637 |
+
class ConformerDecoderV2(nn.Module):
|
638 |
+
def __init__(self,
|
639 |
+
input_size: int = 512,
|
640 |
+
output_size: int = 512,
|
641 |
+
attention_heads: int = 8,
|
642 |
+
linear_units: int = 2048,
|
643 |
+
num_blocks: int = 6,
|
644 |
+
dropout_rate: float = 0.01,
|
645 |
+
srcattention_start_index: int = 0,
|
646 |
+
srcattention_end_index: int = 2,
|
647 |
+
attention_dropout_rate: float = 0.01,
|
648 |
+
positional_dropout_rate: float = 0.01,
|
649 |
+
key_bias: bool = True,
|
650 |
+
normalize_before: bool = True,
|
651 |
+
):
|
652 |
+
super().__init__()
|
653 |
+
self.num_blocks = num_blocks
|
654 |
+
self.normalize_before = normalize_before
|
655 |
+
self.output_size = output_size
|
656 |
+
|
657 |
+
self.embed = LinearNoSubsampling(
|
658 |
+
input_size,
|
659 |
+
output_size,
|
660 |
+
dropout_rate,
|
661 |
+
EspnetRelPositionalEncoding(output_size, positional_dropout_rate),
|
662 |
+
)
|
663 |
+
|
664 |
+
self.encoders = torch.nn.ModuleList()
|
665 |
+
for i in range(self.num_blocks):
|
666 |
+
# construct src attention
|
667 |
+
if srcattention_start_index <= i <= srcattention_end_index:
|
668 |
+
srcattention_layer = MultiHeadedAttention(
|
669 |
+
attention_heads,
|
670 |
+
output_size,
|
671 |
+
attention_dropout_rate,
|
672 |
+
key_bias
|
673 |
+
)
|
674 |
+
else:
|
675 |
+
srcattention_layer = None
|
676 |
+
# construct self attention
|
677 |
+
selfattention_layer = RelPositionMultiHeadedAttention(
|
678 |
+
attention_heads,
|
679 |
+
output_size,
|
680 |
+
attention_dropout_rate,
|
681 |
+
key_bias
|
682 |
+
)
|
683 |
+
# construct ffn
|
684 |
+
ffn_layer = PositionwiseFeedForward(
|
685 |
+
output_size,
|
686 |
+
linear_units,
|
687 |
+
dropout_rate,
|
688 |
+
torch.nn.SiLU()
|
689 |
+
)
|
690 |
+
self.encoders.append(
|
691 |
+
ConformerDecoderLayer(
|
692 |
+
output_size,
|
693 |
+
selfattention_layer,
|
694 |
+
srcattention_layer,
|
695 |
+
ffn_layer,
|
696 |
+
None,
|
697 |
+
None,
|
698 |
+
dropout_rate,
|
699 |
+
normalize_before=normalize_before
|
700 |
+
)
|
701 |
+
)
|
702 |
+
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
703 |
+
|
704 |
+
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
705 |
+
memory: torch.Tensor, memory_masks: torch.Tensor,
|
706 |
+
pos_emb: torch.Tensor, mask_pad: torch.Tensor) -> torch.Tensor:
|
707 |
+
for layer in self.encoders:
|
708 |
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, memory, memory_masks, pos_emb, mask_pad)
|
709 |
+
return xs
|
710 |
+
|
711 |
+
def forward(self,
|
712 |
+
xs:torch.Tensor,
|
713 |
+
xs_lens:torch.Tensor,
|
714 |
+
memory:torch.Tensor,
|
715 |
+
memory_lens: torch.Tensor,
|
716 |
+
):
|
717 |
+
T = xs.size(1)
|
718 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
719 |
+
T2 = memory.size(1)
|
720 |
+
memory_masks = ~make_pad_mask(memory_lens, T2).unsqueeze(1) # (B, 1, T2)
|
721 |
+
|
722 |
+
xs, pos_emb, masks = self.embed(xs, masks)
|
723 |
+
|
724 |
+
xs = self.forward_layers(xs, masks, memory, memory_masks, pos_emb, masks)
|
725 |
+
|
726 |
+
if self.normalize_before:
|
727 |
+
xs = self.after_norm(xs)
|
728 |
+
|
729 |
+
return xs, masks
|
730 |
+
|
fireredtts/modules/flow/decoder.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import pack, rearrange, repeat
|
7 |
+
from diffusers.models.activations import get_activation
|
8 |
+
|
9 |
+
from fireredtts.modules.flow.transformer import BasicTransformerBlock
|
10 |
+
|
11 |
+
|
12 |
+
class SinusoidalPosEmb(torch.nn.Module):
|
13 |
+
def __init__(self, dim):
|
14 |
+
super().__init__()
|
15 |
+
self.dim = dim
|
16 |
+
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
17 |
+
|
18 |
+
def forward(self, x, scale=1000):
|
19 |
+
if x.ndim < 1:
|
20 |
+
x = x.unsqueeze(0)
|
21 |
+
device = x.device
|
22 |
+
half_dim = self.dim // 2
|
23 |
+
emb = math.log(10000) / (half_dim - 1)
|
24 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
25 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
26 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
27 |
+
return emb
|
28 |
+
|
29 |
+
|
30 |
+
class Block1D(torch.nn.Module):
|
31 |
+
def __init__(self, dim, dim_out, groups=8):
|
32 |
+
super().__init__()
|
33 |
+
self.block = torch.nn.Sequential(
|
34 |
+
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
|
35 |
+
torch.nn.GroupNorm(groups, dim_out),
|
36 |
+
nn.Mish(),
|
37 |
+
)
|
38 |
+
|
39 |
+
def forward(self, x, mask):
|
40 |
+
output = self.block(x * mask)
|
41 |
+
return output * mask
|
42 |
+
|
43 |
+
|
44 |
+
class ResnetBlock1D(torch.nn.Module):
|
45 |
+
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
46 |
+
super().__init__()
|
47 |
+
self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
|
48 |
+
|
49 |
+
self.block1 = Block1D(dim, dim_out, groups=groups)
|
50 |
+
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
51 |
+
|
52 |
+
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
|
53 |
+
|
54 |
+
def forward(self, x, mask, time_emb):
|
55 |
+
h = self.block1(x, mask)
|
56 |
+
h += self.mlp(time_emb).unsqueeze(-1)
|
57 |
+
h = self.block2(h, mask)
|
58 |
+
output = h + self.res_conv(x * mask)
|
59 |
+
return output
|
60 |
+
|
61 |
+
|
62 |
+
class Downsample1D(nn.Module):
|
63 |
+
def __init__(self, dim):
|
64 |
+
super().__init__()
|
65 |
+
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
return self.conv(x)
|
69 |
+
|
70 |
+
|
71 |
+
class TimestepEmbedding(nn.Module):
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
in_channels: int,
|
75 |
+
time_embed_dim: int,
|
76 |
+
act_fn: str = "silu",
|
77 |
+
out_dim: int = None,
|
78 |
+
post_act_fn: Optional[str] = None,
|
79 |
+
cond_proj_dim=None,
|
80 |
+
):
|
81 |
+
super().__init__()
|
82 |
+
|
83 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
84 |
+
|
85 |
+
if cond_proj_dim is not None:
|
86 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
87 |
+
else:
|
88 |
+
self.cond_proj = None
|
89 |
+
|
90 |
+
self.act = get_activation(act_fn)
|
91 |
+
|
92 |
+
if out_dim is not None:
|
93 |
+
time_embed_dim_out = out_dim
|
94 |
+
else:
|
95 |
+
time_embed_dim_out = time_embed_dim
|
96 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
97 |
+
|
98 |
+
if post_act_fn is None:
|
99 |
+
self.post_act = None
|
100 |
+
else:
|
101 |
+
self.post_act = get_activation(post_act_fn)
|
102 |
+
|
103 |
+
def forward(self, sample, condition=None):
|
104 |
+
if condition is not None:
|
105 |
+
sample = sample + self.cond_proj(condition)
|
106 |
+
sample = self.linear_1(sample)
|
107 |
+
|
108 |
+
if self.act is not None:
|
109 |
+
sample = self.act(sample)
|
110 |
+
|
111 |
+
sample = self.linear_2(sample)
|
112 |
+
|
113 |
+
if self.post_act is not None:
|
114 |
+
sample = self.post_act(sample)
|
115 |
+
return sample
|
116 |
+
|
117 |
+
|
118 |
+
class Upsample1D(nn.Module):
|
119 |
+
"""A 1D upsampling layer with an optional convolution.
|
120 |
+
|
121 |
+
Parameters:
|
122 |
+
channels (`int`):
|
123 |
+
number of channels in the inputs and outputs.
|
124 |
+
use_conv (`bool`, default `False`):
|
125 |
+
option to use a convolution.
|
126 |
+
use_conv_transpose (`bool`, default `False`):
|
127 |
+
option to use a convolution transpose.
|
128 |
+
out_channels (`int`, optional):
|
129 |
+
number of output channels. Defaults to `channels`.
|
130 |
+
"""
|
131 |
+
|
132 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
|
133 |
+
super().__init__()
|
134 |
+
self.channels = channels
|
135 |
+
self.out_channels = out_channels or channels
|
136 |
+
self.use_conv = use_conv
|
137 |
+
self.use_conv_transpose = use_conv_transpose
|
138 |
+
self.name = name
|
139 |
+
|
140 |
+
self.conv = None
|
141 |
+
if use_conv_transpose:
|
142 |
+
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
143 |
+
elif use_conv:
|
144 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
145 |
+
|
146 |
+
def forward(self, inputs):
|
147 |
+
assert inputs.shape[1] == self.channels
|
148 |
+
if self.use_conv_transpose:
|
149 |
+
return self.conv(inputs)
|
150 |
+
|
151 |
+
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
152 |
+
|
153 |
+
if self.use_conv:
|
154 |
+
outputs = self.conv(outputs)
|
155 |
+
|
156 |
+
return outputs
|
157 |
+
|
158 |
+
|
159 |
+
class ConditionalDecoder(nn.Module):
|
160 |
+
def __init__(
|
161 |
+
self,
|
162 |
+
in_channels,
|
163 |
+
out_channels,
|
164 |
+
channels=(256, 256),
|
165 |
+
dropout=0.0,
|
166 |
+
attention_head_dim=64,
|
167 |
+
n_blocks=4,
|
168 |
+
num_mid_blocks=12,
|
169 |
+
num_heads=8,
|
170 |
+
act_fn="gelu",
|
171 |
+
):
|
172 |
+
"""
|
173 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
174 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
175 |
+
"""
|
176 |
+
super().__init__()
|
177 |
+
channels = tuple(channels)
|
178 |
+
self.in_channels = in_channels
|
179 |
+
self.out_channels = out_channels
|
180 |
+
|
181 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
182 |
+
time_embed_dim = channels[0] * 4
|
183 |
+
self.time_mlp = TimestepEmbedding(
|
184 |
+
in_channels=in_channels,
|
185 |
+
time_embed_dim=time_embed_dim,
|
186 |
+
act_fn="silu",
|
187 |
+
)
|
188 |
+
self.down_blocks = nn.ModuleList([])
|
189 |
+
self.mid_blocks = nn.ModuleList([])
|
190 |
+
self.up_blocks = nn.ModuleList([])
|
191 |
+
|
192 |
+
output_channel = in_channels
|
193 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
194 |
+
input_channel = output_channel
|
195 |
+
output_channel = channels[i]
|
196 |
+
is_last = i == len(channels) - 1
|
197 |
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
198 |
+
transformer_blocks = nn.ModuleList(
|
199 |
+
[
|
200 |
+
BasicTransformerBlock(
|
201 |
+
dim=output_channel,
|
202 |
+
num_attention_heads=num_heads,
|
203 |
+
attention_head_dim=attention_head_dim,
|
204 |
+
dropout=dropout,
|
205 |
+
activation_fn=act_fn,
|
206 |
+
)
|
207 |
+
for _ in range(n_blocks)
|
208 |
+
]
|
209 |
+
)
|
210 |
+
downsample = (
|
211 |
+
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
212 |
+
)
|
213 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
214 |
+
|
215 |
+
for i in range(num_mid_blocks):
|
216 |
+
input_channel = channels[-1]
|
217 |
+
out_channels = channels[-1]
|
218 |
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
219 |
+
|
220 |
+
transformer_blocks = nn.ModuleList(
|
221 |
+
[
|
222 |
+
BasicTransformerBlock(
|
223 |
+
dim=output_channel,
|
224 |
+
num_attention_heads=num_heads,
|
225 |
+
attention_head_dim=attention_head_dim,
|
226 |
+
dropout=dropout,
|
227 |
+
activation_fn=act_fn,
|
228 |
+
)
|
229 |
+
for _ in range(n_blocks)
|
230 |
+
]
|
231 |
+
)
|
232 |
+
|
233 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
234 |
+
|
235 |
+
channels = channels[::-1] + (channels[0],)
|
236 |
+
for i in range(len(channels) - 1):
|
237 |
+
input_channel = channels[i] * 2
|
238 |
+
output_channel = channels[i + 1]
|
239 |
+
is_last = i == len(channels) - 2
|
240 |
+
resnet = ResnetBlock1D(
|
241 |
+
dim=input_channel,
|
242 |
+
dim_out=output_channel,
|
243 |
+
time_emb_dim=time_embed_dim,
|
244 |
+
)
|
245 |
+
transformer_blocks = nn.ModuleList(
|
246 |
+
[
|
247 |
+
BasicTransformerBlock(
|
248 |
+
dim=output_channel,
|
249 |
+
num_attention_heads=num_heads,
|
250 |
+
attention_head_dim=attention_head_dim,
|
251 |
+
dropout=dropout,
|
252 |
+
activation_fn=act_fn,
|
253 |
+
)
|
254 |
+
for _ in range(n_blocks)
|
255 |
+
]
|
256 |
+
)
|
257 |
+
upsample = (
|
258 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
259 |
+
if not is_last
|
260 |
+
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
261 |
+
)
|
262 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
263 |
+
self.final_block = Block1D(channels[-1], channels[-1])
|
264 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
265 |
+
self.initialize_weights()
|
266 |
+
|
267 |
+
|
268 |
+
def initialize_weights(self):
|
269 |
+
for m in self.modules():
|
270 |
+
if isinstance(m, nn.Conv1d):
|
271 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
272 |
+
if m.bias is not None:
|
273 |
+
nn.init.constant_(m.bias, 0)
|
274 |
+
elif isinstance(m, nn.GroupNorm):
|
275 |
+
nn.init.constant_(m.weight, 1)
|
276 |
+
nn.init.constant_(m.bias, 0)
|
277 |
+
elif isinstance(m, nn.Linear):
|
278 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
279 |
+
if m.bias is not None:
|
280 |
+
nn.init.constant_(m.bias, 0)
|
281 |
+
|
282 |
+
def forward(self, x, mask, mu, t):
|
283 |
+
"""Forward pass of the UNet1DConditional model.
|
284 |
+
|
285 |
+
Args:
|
286 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
287 |
+
mask (_type_): shape (batch_size, 1, time)
|
288 |
+
t (_type_): shape (batch_size)
|
289 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
290 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
291 |
+
|
292 |
+
Raises:
|
293 |
+
ValueError: _description_
|
294 |
+
ValueError: _description_
|
295 |
+
|
296 |
+
Returns:
|
297 |
+
_type_: _description_
|
298 |
+
"""
|
299 |
+
|
300 |
+
t = self.time_embeddings(t)
|
301 |
+
t = self.time_mlp(t)
|
302 |
+
|
303 |
+
x = pack([x, mu], "b * t")[0]
|
304 |
+
|
305 |
+
hiddens = []
|
306 |
+
masks = [mask]
|
307 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
308 |
+
mask_down = masks[-1]
|
309 |
+
x = resnet(x, mask_down, t)
|
310 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
311 |
+
attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
312 |
+
for transformer_block in transformer_blocks:
|
313 |
+
x = transformer_block(
|
314 |
+
hidden_states=x,
|
315 |
+
attention_mask=attn_mask,
|
316 |
+
timestep=t,
|
317 |
+
)
|
318 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
319 |
+
hiddens.append(x) # Save hidden states for skip connections
|
320 |
+
x = downsample(x * mask_down)
|
321 |
+
masks.append(mask_down[:, :, ::2])
|
322 |
+
masks = masks[:-1]
|
323 |
+
mask_mid = masks[-1]
|
324 |
+
|
325 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
326 |
+
x = resnet(x, mask_mid, t)
|
327 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
328 |
+
attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
329 |
+
for transformer_block in transformer_blocks:
|
330 |
+
x = transformer_block(
|
331 |
+
hidden_states=x,
|
332 |
+
attention_mask=attn_mask,
|
333 |
+
timestep=t,
|
334 |
+
)
|
335 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
336 |
+
|
337 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
338 |
+
mask_up = masks.pop()
|
339 |
+
skip = hiddens.pop()
|
340 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
341 |
+
x = resnet(x, mask_up, t)
|
342 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
343 |
+
attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
344 |
+
for transformer_block in transformer_blocks:
|
345 |
+
x = transformer_block(
|
346 |
+
hidden_states=x,
|
347 |
+
attention_mask=attn_mask,
|
348 |
+
timestep=t,
|
349 |
+
)
|
350 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
351 |
+
x = upsample(x * mask_up)
|
352 |
+
x = self.final_block(x, mask_up)
|
353 |
+
output = self.final_proj(x * mask_up)
|
354 |
+
return output * mask
|
355 |
+
|
356 |
+
|
357 |
+
class ConditionalCFM(nn.Module):
|
358 |
+
def __init__(self,
|
359 |
+
estimator: nn.Module,
|
360 |
+
t_scheduler: str = "cosine",
|
361 |
+
inference_cfg_rate: float = 0.7,
|
362 |
+
):
|
363 |
+
super().__init__()
|
364 |
+
self.estimator = estimator
|
365 |
+
self.t_scheduler = t_scheduler
|
366 |
+
self.inference_cfg_rate = inference_cfg_rate
|
367 |
+
|
368 |
+
def solve_euler(self, x, t_span, mu, mask):
|
369 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
370 |
+
|
371 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
372 |
+
# Or in future might add like a return_all_steps flag
|
373 |
+
sol = []
|
374 |
+
|
375 |
+
for step in range(1, len(t_span)):
|
376 |
+
dphi_dt = self.estimator(x, mask, mu, t)
|
377 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
378 |
+
if self.inference_cfg_rate > 0:
|
379 |
+
cfg_dphi_dt = self.estimator(x, mask, torch.zeros_like(mu), t)
|
380 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
|
381 |
+
self.inference_cfg_rate * cfg_dphi_dt)
|
382 |
+
x = x + dt * dphi_dt
|
383 |
+
t = t + dt
|
384 |
+
sol.append(x)
|
385 |
+
if step < len(t_span) - 1:
|
386 |
+
dt = t_span[step + 1] - t
|
387 |
+
|
388 |
+
return sol[-1]
|
389 |
+
|
390 |
+
def inference(self, mu, mask, n_timesteps, temperature: float=1.0):
|
391 |
+
z = torch.randn_like(mu) * temperature
|
392 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
393 |
+
if self.t_scheduler == 'cosine':
|
394 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
395 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask)
|
396 |
+
|
fireredtts/modules/flow/flow_model.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from fireredtts.modules.flow.utils import make_pad_mask
|
7 |
+
|
8 |
+
|
9 |
+
class InterpolateRegulator(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
channels: int = 512,
|
13 |
+
num_blocks: int = 4,
|
14 |
+
groups: int = 1,
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
model = []
|
18 |
+
for _ in range(num_blocks):
|
19 |
+
model.extend([
|
20 |
+
nn.Conv1d(channels, channels, 3, 1, 1),
|
21 |
+
nn.GroupNorm(groups, channels),
|
22 |
+
nn.Mish(),
|
23 |
+
])
|
24 |
+
model.append(
|
25 |
+
nn.Conv1d(channels, channels, 1, 1)
|
26 |
+
)
|
27 |
+
self.model = nn.Sequential(*model)
|
28 |
+
|
29 |
+
def forward(self, x, ylens=None):
|
30 |
+
# x in (B, T, D)
|
31 |
+
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
32 |
+
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
|
33 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
34 |
+
olens = ylens
|
35 |
+
return out * mask, olens
|
36 |
+
|
37 |
+
|
38 |
+
class CrossAttnFlowMatching(nn.Module):
|
39 |
+
def __init__(self,
|
40 |
+
output_size: int,
|
41 |
+
input_embedding: nn.Module,
|
42 |
+
encoder: nn.Module,
|
43 |
+
length_regulator: nn.Module,
|
44 |
+
mel_encoder: nn.Module,
|
45 |
+
decoder: nn.Module,
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
self.input_embedding = input_embedding
|
49 |
+
self.encoder = encoder
|
50 |
+
self.length_regulator = length_regulator
|
51 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size, output_size)
|
52 |
+
self.prompt_prenet = mel_encoder
|
53 |
+
self.decoder = decoder
|
54 |
+
|
55 |
+
def inference(self,
|
56 |
+
token: torch.Tensor,
|
57 |
+
token_len: torch.Tensor,
|
58 |
+
prompt_mel: torch.Tensor,
|
59 |
+
prompt_mel_len: torch.Tensor,
|
60 |
+
n_timesteps:int=10,
|
61 |
+
):
|
62 |
+
# prompt projection
|
63 |
+
prompt_feat = self.prompt_prenet(prompt_mel)
|
64 |
+
prompt_feat_len = torch.ceil(prompt_mel_len/self.prompt_prenet.reduction_rate).long()
|
65 |
+
|
66 |
+
# concat text and prompt_text
|
67 |
+
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(token_len.device)
|
68 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
69 |
+
|
70 |
+
# 40ms shift to 10ms shift
|
71 |
+
feat_len = (token_len *4).int()
|
72 |
+
|
73 |
+
# first encoder
|
74 |
+
h, _ = self.encoder(token, token_len, prompt_feat, prompt_feat_len)
|
75 |
+
# length regulate
|
76 |
+
h, _ = self.length_regulator(h, feat_len)
|
77 |
+
# final projection
|
78 |
+
h = self.encoder_proj(h)
|
79 |
+
|
80 |
+
mask = (~make_pad_mask(feat_len)).to(h)
|
81 |
+
|
82 |
+
feat = self.decoder.inference(
|
83 |
+
mu=h.transpose(1, 2).contiguous(),
|
84 |
+
mask=mask.unsqueeze(1),
|
85 |
+
n_timesteps=n_timesteps,
|
86 |
+
)
|
87 |
+
return feat
|
88 |
+
|
89 |
+
|
fireredtts/modules/flow/mel_encoder.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing as tp
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class ConvLayer(nn.Module):
|
7 |
+
def __init__(self,
|
8 |
+
in_channels:int,
|
9 |
+
out_channels:int,
|
10 |
+
kernel_size:int,
|
11 |
+
stride:int,
|
12 |
+
activation:str="GELU",
|
13 |
+
dropout_rate:float=0.0,
|
14 |
+
):
|
15 |
+
super().__init__()
|
16 |
+
self.conv = nn.Conv1d(
|
17 |
+
in_channels=in_channels,
|
18 |
+
out_channels=out_channels,
|
19 |
+
kernel_size=kernel_size,
|
20 |
+
stride=stride,
|
21 |
+
padding=(kernel_size-stride)//2,
|
22 |
+
)
|
23 |
+
self.drop = nn.Dropout(dropout_rate)
|
24 |
+
self.norm = nn.LayerNorm(out_channels)
|
25 |
+
self.activ = getattr(nn, activation)()
|
26 |
+
|
27 |
+
def forward(self, x:torch.Tensor):
|
28 |
+
"""
|
29 |
+
Args:
|
30 |
+
x: (b, t, c)
|
31 |
+
Return:
|
32 |
+
x: (b, t, c)
|
33 |
+
"""
|
34 |
+
x = x.transpose(2, 1)
|
35 |
+
x = self.conv(x)
|
36 |
+
x = x.transpose(2, 1)
|
37 |
+
x = self.drop(x)
|
38 |
+
x = self.norm(x)
|
39 |
+
x = self.activ(x)
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
class ResidualConvLayer(nn.Module):
|
44 |
+
def __init__(self,
|
45 |
+
hidden_channels:int,
|
46 |
+
n_layers:int=2,
|
47 |
+
kernel_size:int=5,
|
48 |
+
activation:str="GELU",
|
49 |
+
dropout_rate:float=0.0,
|
50 |
+
):
|
51 |
+
super().__init__()
|
52 |
+
layers = [
|
53 |
+
ConvLayer(hidden_channels, hidden_channels, kernel_size, 1, activation, dropout_rate)
|
54 |
+
for _ in range(n_layers)
|
55 |
+
]
|
56 |
+
self.layers = nn.Sequential(*layers)
|
57 |
+
|
58 |
+
def forward(self, x:torch.Tensor):
|
59 |
+
"""
|
60 |
+
Args:
|
61 |
+
x: (b, t, c)
|
62 |
+
Returns:
|
63 |
+
x: (b, t, c)
|
64 |
+
"""
|
65 |
+
return x + self.layers(x)
|
66 |
+
|
67 |
+
|
68 |
+
class ResidualConvBlock(nn.Module):
|
69 |
+
def __init__(self,
|
70 |
+
in_channels:int,
|
71 |
+
hidden_channels:int,
|
72 |
+
out_channels:int,
|
73 |
+
n_layers:int=2,
|
74 |
+
n_blocks:int=5,
|
75 |
+
middle_layer:tp.Optional[nn.Module]=None,
|
76 |
+
kernel_size:int=5,
|
77 |
+
activation:str="GELU",
|
78 |
+
dropout_rate:float=0.0,
|
79 |
+
):
|
80 |
+
super().__init__()
|
81 |
+
self.in_proj = nn.Conv1d(
|
82 |
+
in_channels,
|
83 |
+
hidden_channels,
|
84 |
+
kernel_size=kernel_size,
|
85 |
+
stride=1,
|
86 |
+
padding=(kernel_size-1)//2,
|
87 |
+
) if in_channels != hidden_channels else nn.Identity()
|
88 |
+
|
89 |
+
self.conv1 = nn.Sequential(*[
|
90 |
+
ResidualConvLayer(hidden_channels, n_layers, kernel_size, activation, dropout_rate)
|
91 |
+
for _ in range(n_blocks)
|
92 |
+
])
|
93 |
+
|
94 |
+
if middle_layer is None:
|
95 |
+
self.middle_layer = nn.Identity()
|
96 |
+
elif isinstance(middle_layer, nn.Module):
|
97 |
+
self.middle_layer = middle_layer
|
98 |
+
else:
|
99 |
+
raise TypeError("unknown middle layer type:{}".format(type(middle_layer)))
|
100 |
+
|
101 |
+
self.conv2 = nn.Sequential(*[
|
102 |
+
ResidualConvLayer(hidden_channels, n_layers, kernel_size, activation, dropout_rate)
|
103 |
+
for _ in range(n_blocks)
|
104 |
+
])
|
105 |
+
|
106 |
+
self.out_proj = nn.Conv1d(
|
107 |
+
hidden_channels,
|
108 |
+
out_channels,
|
109 |
+
kernel_size=kernel_size,
|
110 |
+
stride=1,
|
111 |
+
padding=(kernel_size-1)//2,
|
112 |
+
) if out_channels != hidden_channels else nn.Identity()
|
113 |
+
|
114 |
+
def forward(self, x:torch.Tensor, **middle_layer_kwargs):
|
115 |
+
"""
|
116 |
+
Args:
|
117 |
+
x: (b, t1, c)
|
118 |
+
Return:
|
119 |
+
x: (b, t2, c)
|
120 |
+
"""
|
121 |
+
x = self.in_proj(x.transpose(2, 1)).transpose(2, 1)
|
122 |
+
x = self.conv1(x)
|
123 |
+
if isinstance(self.middle_layer, nn.MaxPool1d) or isinstance(self.middle_layer, nn.Conv1d):
|
124 |
+
x = self.middle_layer(x.transpose(2, 1)).transpose(2, 1)
|
125 |
+
elif isinstance(self.middle_layer, nn.Identity):
|
126 |
+
x = self.middle_layer(x)
|
127 |
+
else:
|
128 |
+
# incase of phoneme-pooling layer
|
129 |
+
x = self.middle_layer(x, **middle_layer_kwargs)
|
130 |
+
x = self.conv2(x)
|
131 |
+
x = self.out_proj(x.transpose(2, 1)).transpose(2, 1)
|
132 |
+
return x
|
133 |
+
|
134 |
+
|
135 |
+
class MelReduceEncoder(nn.Module):
|
136 |
+
def __init__(self,
|
137 |
+
in_channels:int,
|
138 |
+
out_channels:int,
|
139 |
+
hidden_channels:int=384,
|
140 |
+
reduction_rate:int=4,
|
141 |
+
n_layers:int=2,
|
142 |
+
n_blocks:int=5,
|
143 |
+
kernel_size:int=3,
|
144 |
+
activation:str="GELU",
|
145 |
+
dropout:float=0.0,
|
146 |
+
):
|
147 |
+
super().__init__()
|
148 |
+
self.reduction_rate = reduction_rate
|
149 |
+
middle_conv = nn.Conv1d(
|
150 |
+
in_channels=hidden_channels,
|
151 |
+
out_channels=hidden_channels,
|
152 |
+
kernel_size=reduction_rate,
|
153 |
+
stride=reduction_rate,
|
154 |
+
padding=0
|
155 |
+
)
|
156 |
+
self.encoder = ResidualConvBlock(
|
157 |
+
in_channels=in_channels,
|
158 |
+
hidden_channels=hidden_channels,
|
159 |
+
out_channels=out_channels,
|
160 |
+
n_layers=n_layers,
|
161 |
+
n_blocks=n_blocks,
|
162 |
+
middle_layer=middle_conv,
|
163 |
+
kernel_size=kernel_size,
|
164 |
+
activation=activation,
|
165 |
+
dropout_rate=dropout
|
166 |
+
)
|
167 |
+
|
168 |
+
def forward(self, x:torch.Tensor):
|
169 |
+
return self.encoder(x)
|
170 |
+
|
fireredtts/modules/flow/mel_spectrogram.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import librosa
|
5 |
+
from librosa.filters import mel as librosa_mel_fn
|
6 |
+
from torchaudio.functional import resample as ta_resample_fn
|
7 |
+
|
8 |
+
MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
|
9 |
+
|
10 |
+
|
11 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
12 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
13 |
+
|
14 |
+
|
15 |
+
def dynamic_range_decompression(x, C=1):
|
16 |
+
return np.exp(x) / C
|
17 |
+
|
18 |
+
|
19 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
20 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
21 |
+
|
22 |
+
|
23 |
+
def dynamic_range_decompression_torch(x, C=1):
|
24 |
+
return torch.exp(x) / C
|
25 |
+
|
26 |
+
|
27 |
+
def spectral_normalize_torch(magnitudes):
|
28 |
+
output = dynamic_range_compression_torch(magnitudes)
|
29 |
+
return output
|
30 |
+
|
31 |
+
|
32 |
+
def spectral_de_normalize_torch(magnitudes):
|
33 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
34 |
+
return output
|
35 |
+
|
36 |
+
|
37 |
+
mel_basis = {}
|
38 |
+
hann_window = {}
|
39 |
+
|
40 |
+
|
41 |
+
def mel_spectrogram(
|
42 |
+
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
43 |
+
):
|
44 |
+
global mel_basis, hann_window
|
45 |
+
if fmax not in mel_basis:
|
46 |
+
mel = librosa_mel_fn(
|
47 |
+
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
48 |
+
)
|
49 |
+
str_key_mel_basis = str(fmax) + "_" + str(y.device)
|
50 |
+
mel_basis[str_key_mel_basis] = torch.from_numpy(mel).float().to(y.device)
|
51 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
52 |
+
|
53 |
+
y = torch.nn.functional.pad(
|
54 |
+
y.unsqueeze(1),
|
55 |
+
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
56 |
+
mode="reflect",
|
57 |
+
)
|
58 |
+
y = y.squeeze(1)
|
59 |
+
|
60 |
+
# complex tensor as default, then use view_as_real for future pytorch compatibility
|
61 |
+
spec = torch.stft(
|
62 |
+
y,
|
63 |
+
n_fft,
|
64 |
+
hop_length=hop_size,
|
65 |
+
win_length=win_size,
|
66 |
+
window=hann_window[str(y.device)],
|
67 |
+
center=center,
|
68 |
+
pad_mode="reflect",
|
69 |
+
normalized=False,
|
70 |
+
onesided=True,
|
71 |
+
return_complex=True,
|
72 |
+
)
|
73 |
+
spec = torch.view_as_real(spec)
|
74 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
75 |
+
|
76 |
+
spec = torch.matmul(mel_basis[str_key_mel_basis], spec)
|
77 |
+
spec = spectral_normalize_torch(spec)
|
78 |
+
|
79 |
+
return spec
|
80 |
+
|
81 |
+
|
82 |
+
kaiser_best_resampling_fn = partial(
|
83 |
+
ta_resample_fn,
|
84 |
+
resampling_method="sinc_interp_kaiser", # DO NOT CHANGE!
|
85 |
+
rolloff=0.917347, # DO NOT CHANGE!
|
86 |
+
beta=12.9846, # DO NOT CHANGE!
|
87 |
+
lowpass_filter_width=50, # DO NOT CHANGE!
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
class MelSpectrogramExtractor(object):
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
n_fft=1024,
|
95 |
+
win_size=1024,
|
96 |
+
num_mels=100,
|
97 |
+
hop_size=160,
|
98 |
+
sampling_rate=16000,
|
99 |
+
fmin=0,
|
100 |
+
fmax=None,
|
101 |
+
):
|
102 |
+
self.n_fft = n_fft
|
103 |
+
self.win_size = win_size
|
104 |
+
self.num_mels = num_mels
|
105 |
+
self.hop_size = hop_size
|
106 |
+
self.sampling_rate = sampling_rate
|
107 |
+
self.fmin = fmin
|
108 |
+
self.fmax = fmax
|
109 |
+
|
110 |
+
def __call__(self, wav_path) -> np.ndarray:
|
111 |
+
wav_data, wav_sr = librosa.load(wav_path, sr=None, mono=True)
|
112 |
+
wav_data = torch.from_numpy(wav_data.copy()).unsqueeze(0)
|
113 |
+
# for 16k wavs, up-downsample to reduce artifects
|
114 |
+
if wav_sr == self.sampling_rate:
|
115 |
+
wav_data = kaiser_best_resampling_fn(wav_data, orig_freq=wav_sr, new_freq=24000)
|
116 |
+
wav_data = kaiser_best_resampling_fn(wav_data, orig_freq=24000, new_freq=self.sampling_rate)
|
117 |
+
else:
|
118 |
+
wav_data = kaiser_best_resampling_fn(wav_data, orig_freq=wav_sr, new_freq=self.sampling_rate)
|
119 |
+
|
120 |
+
# (1, num_mels, t)
|
121 |
+
mel = mel_spectrogram(
|
122 |
+
wav_data,
|
123 |
+
self.n_fft,
|
124 |
+
self.num_mels,
|
125 |
+
self.sampling_rate,
|
126 |
+
self.hop_size,
|
127 |
+
self.win_size,
|
128 |
+
self.fmin,
|
129 |
+
self.fmax,
|
130 |
+
)
|
131 |
+
mel = mel.squeeze(0).transpose(1, 0)
|
132 |
+
return mel # (t, num_mels)
|
fireredtts/modules/flow/transformer.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from diffusers.models.attention import (
|
6 |
+
GEGLU,
|
7 |
+
GELU,
|
8 |
+
AdaLayerNorm,
|
9 |
+
AdaLayerNormZero,
|
10 |
+
ApproximateGELU,
|
11 |
+
)
|
12 |
+
from diffusers.models.attention_processor import Attention
|
13 |
+
# from diffusers.models.lora import LoRACompatibleLinear
|
14 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
15 |
+
|
16 |
+
|
17 |
+
class FeedForward(nn.Module):
|
18 |
+
r"""
|
19 |
+
A feed-forward layer.
|
20 |
+
|
21 |
+
Parameters:
|
22 |
+
dim (`int`): The number of channels in the input.
|
23 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
24 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
25 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
26 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
27 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
dim: int,
|
33 |
+
dim_out: Optional[int] = None,
|
34 |
+
mult: int = 4,
|
35 |
+
dropout: float = 0.0,
|
36 |
+
activation_fn: str = "geglu",
|
37 |
+
final_dropout: bool = False,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
inner_dim = int(dim * mult)
|
41 |
+
dim_out = dim_out if dim_out is not None else dim
|
42 |
+
|
43 |
+
if activation_fn == "gelu":
|
44 |
+
act_fn = GELU(dim, inner_dim)
|
45 |
+
if activation_fn == "gelu-approximate":
|
46 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
47 |
+
elif activation_fn == "geglu":
|
48 |
+
act_fn = GEGLU(dim, inner_dim)
|
49 |
+
elif activation_fn == "geglu-approximate":
|
50 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
51 |
+
|
52 |
+
self.net = nn.ModuleList([])
|
53 |
+
# project in
|
54 |
+
self.net.append(act_fn)
|
55 |
+
# project dropout
|
56 |
+
self.net.append(nn.Dropout(dropout))
|
57 |
+
# project out
|
58 |
+
self.net.append(nn.Linear(inner_dim, dim_out))
|
59 |
+
# self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
|
60 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
61 |
+
if final_dropout:
|
62 |
+
self.net.append(nn.Dropout(dropout))
|
63 |
+
|
64 |
+
def forward(self, hidden_states):
|
65 |
+
for module in self.net:
|
66 |
+
hidden_states = module(hidden_states)
|
67 |
+
return hidden_states
|
68 |
+
|
69 |
+
|
70 |
+
@maybe_allow_in_graph
|
71 |
+
class BasicTransformerBlock(nn.Module):
|
72 |
+
r"""
|
73 |
+
A basic Transformer block.
|
74 |
+
|
75 |
+
Parameters:
|
76 |
+
dim (`int`): The number of channels in the input and output.
|
77 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
78 |
+
attention_head_dim (`int`): The number of channels in each head.
|
79 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
80 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
81 |
+
only_cross_attention (`bool`, *optional*):
|
82 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
83 |
+
double_self_attention (`bool`, *optional*):
|
84 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
85 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
86 |
+
num_embeds_ada_norm (:
|
87 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
88 |
+
attention_bias (:
|
89 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
90 |
+
"""
|
91 |
+
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
dim: int,
|
95 |
+
num_attention_heads: int,
|
96 |
+
attention_head_dim: int,
|
97 |
+
dropout=0.0,
|
98 |
+
cross_attention_dim: Optional[int] = None,
|
99 |
+
activation_fn: str = "geglu",
|
100 |
+
num_embeds_ada_norm: Optional[int] = None,
|
101 |
+
attention_bias: bool = False,
|
102 |
+
only_cross_attention: bool = False,
|
103 |
+
double_self_attention: bool = False,
|
104 |
+
upcast_attention: bool = False,
|
105 |
+
norm_elementwise_affine: bool = True,
|
106 |
+
norm_type: str = "layer_norm",
|
107 |
+
final_dropout: bool = False,
|
108 |
+
):
|
109 |
+
super().__init__()
|
110 |
+
self.only_cross_attention = only_cross_attention
|
111 |
+
|
112 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
113 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
114 |
+
|
115 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
116 |
+
raise ValueError(
|
117 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
118 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
119 |
+
)
|
120 |
+
|
121 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
122 |
+
# 1. Self-Attn
|
123 |
+
if self.use_ada_layer_norm:
|
124 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
125 |
+
elif self.use_ada_layer_norm_zero:
|
126 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
127 |
+
else:
|
128 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
129 |
+
self.attn1 = Attention(
|
130 |
+
query_dim=dim,
|
131 |
+
heads=num_attention_heads,
|
132 |
+
dim_head=attention_head_dim,
|
133 |
+
dropout=dropout,
|
134 |
+
bias=attention_bias,
|
135 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
136 |
+
upcast_attention=upcast_attention,
|
137 |
+
)
|
138 |
+
|
139 |
+
# 2. Cross-Attn
|
140 |
+
if cross_attention_dim is not None or double_self_attention:
|
141 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
142 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
143 |
+
# the second cross attention block.
|
144 |
+
self.norm2 = (
|
145 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
146 |
+
if self.use_ada_layer_norm
|
147 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
148 |
+
)
|
149 |
+
self.attn2 = Attention(
|
150 |
+
query_dim=dim,
|
151 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
152 |
+
heads=num_attention_heads,
|
153 |
+
dim_head=attention_head_dim,
|
154 |
+
dropout=dropout,
|
155 |
+
bias=attention_bias,
|
156 |
+
upcast_attention=upcast_attention,
|
157 |
+
# scale_qk=False, # uncomment this to not to use flash attention
|
158 |
+
) # is self-attn if encoder_hidden_states is none
|
159 |
+
else:
|
160 |
+
self.norm2 = None
|
161 |
+
self.attn2 = None
|
162 |
+
|
163 |
+
# 3. Feed-forward
|
164 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
165 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
166 |
+
|
167 |
+
# let chunk size default to None
|
168 |
+
self._chunk_size = None
|
169 |
+
self._chunk_dim = 0
|
170 |
+
|
171 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
172 |
+
# Sets chunk feed-forward
|
173 |
+
self._chunk_size = chunk_size
|
174 |
+
self._chunk_dim = dim
|
175 |
+
|
176 |
+
def forward(
|
177 |
+
self,
|
178 |
+
hidden_states: torch.FloatTensor,
|
179 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
180 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
181 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
182 |
+
timestep: Optional[torch.LongTensor] = None,
|
183 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
184 |
+
class_labels: Optional[torch.LongTensor] = None,
|
185 |
+
):
|
186 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
187 |
+
# 1. Self-Attention
|
188 |
+
if self.use_ada_layer_norm:
|
189 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
190 |
+
elif self.use_ada_layer_norm_zero:
|
191 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
192 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
norm_hidden_states = self.norm1(hidden_states)
|
196 |
+
|
197 |
+
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
198 |
+
|
199 |
+
attn_output = self.attn1(
|
200 |
+
norm_hidden_states,
|
201 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
202 |
+
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
|
203 |
+
**cross_attention_kwargs,
|
204 |
+
)
|
205 |
+
if self.use_ada_layer_norm_zero:
|
206 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
207 |
+
hidden_states = attn_output + hidden_states
|
208 |
+
|
209 |
+
# 2. Cross-Attention
|
210 |
+
if self.attn2 is not None:
|
211 |
+
norm_hidden_states = (
|
212 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
213 |
+
)
|
214 |
+
|
215 |
+
attn_output = self.attn2(
|
216 |
+
norm_hidden_states,
|
217 |
+
encoder_hidden_states=encoder_hidden_states,
|
218 |
+
attention_mask=encoder_attention_mask,
|
219 |
+
**cross_attention_kwargs,
|
220 |
+
)
|
221 |
+
hidden_states = attn_output + hidden_states
|
222 |
+
|
223 |
+
# 3. Feed-forward
|
224 |
+
norm_hidden_states = self.norm3(hidden_states)
|
225 |
+
|
226 |
+
if self.use_ada_layer_norm_zero:
|
227 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
228 |
+
|
229 |
+
if self._chunk_size is not None:
|
230 |
+
# "feed_forward_chunk_size" can be used to save memory
|
231 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
232 |
+
raise ValueError(
|
233 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
234 |
+
)
|
235 |
+
|
236 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
237 |
+
ff_output = torch.cat(
|
238 |
+
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
239 |
+
dim=self._chunk_dim,
|
240 |
+
)
|
241 |
+
else:
|
242 |
+
ff_output = self.ff(norm_hidden_states)
|
243 |
+
|
244 |
+
if self.use_ada_layer_norm_zero:
|
245 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
246 |
+
|
247 |
+
hidden_states = ff_output + hidden_states
|
248 |
+
|
249 |
+
return hidden_states
|
fireredtts/modules/flow/utils.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
5 |
+
"""Make mask tensor containing indices of padded part.
|
6 |
+
|
7 |
+
See description of make_non_pad_mask.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
lengths (torch.Tensor): Batch of lengths (B,).
|
11 |
+
Returns:
|
12 |
+
torch.Tensor: Mask tensor containing indices of padded part.
|
13 |
+
|
14 |
+
Examples:
|
15 |
+
>>> lengths = [5, 3, 2]
|
16 |
+
>>> make_pad_mask(lengths)
|
17 |
+
masks = [[0, 0, 0, 0 ,0],
|
18 |
+
[0, 0, 0, 1, 1],
|
19 |
+
[0, 0, 1, 1, 1]]
|
20 |
+
"""
|
21 |
+
batch_size = lengths.size(0)
|
22 |
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
23 |
+
seq_range = torch.arange(0,
|
24 |
+
max_len,
|
25 |
+
dtype=torch.int64,
|
26 |
+
device=lengths.device)
|
27 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
28 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
29 |
+
mask = seq_range_expand >= seq_length_expand
|
30 |
+
return mask
|
fireredtts/modules/gpt/__init__.py
ADDED
File without changes
|
fireredtts/modules/gpt/gpt.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ported from: https://github.com/neonbjb/tortoise-tts
|
2 |
+
# ported from: https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/layers/xtts/gpt.py
|
3 |
+
|
4 |
+
import functools
|
5 |
+
import math
|
6 |
+
import random
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from transformers import GPT2Config, GPT2Model, GPT2PreTrainedModel
|
12 |
+
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
13 |
+
|
14 |
+
|
15 |
+
class GPT2InferenceModel(GPT2PreTrainedModel):
|
16 |
+
"""Override GPT2LMHeadModel to allow for prefix conditioning."""
|
17 |
+
|
18 |
+
def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
|
19 |
+
super().__init__(config)
|
20 |
+
self.transformer = gpt
|
21 |
+
self.pos_embedding = pos_emb
|
22 |
+
self.embeddings = embeddings
|
23 |
+
self.final_norm = norm
|
24 |
+
self.lm_head = nn.Sequential(norm, linear)
|
25 |
+
self.kv_cache = kv_cache
|
26 |
+
|
27 |
+
def store_prefix_emb(self, prefix_emb):
|
28 |
+
self.cached_prefix_emb = prefix_emb
|
29 |
+
|
30 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
31 |
+
token_type_ids = kwargs.get("token_type_ids", None) # usually None
|
32 |
+
if not self.kv_cache:
|
33 |
+
past_key_values = None
|
34 |
+
|
35 |
+
# only last token for inputs_ids if past is defined in kwargs
|
36 |
+
if past_key_values is not None:
|
37 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
38 |
+
if token_type_ids is not None:
|
39 |
+
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
40 |
+
|
41 |
+
attention_mask = kwargs.get("attention_mask", None)
|
42 |
+
position_ids = kwargs.get("position_ids", None)
|
43 |
+
|
44 |
+
if attention_mask is not None and position_ids is None:
|
45 |
+
# create position_ids on the fly for batch generation
|
46 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
47 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
48 |
+
if past_key_values is not None:
|
49 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
50 |
+
else:
|
51 |
+
position_ids = None
|
52 |
+
return {
|
53 |
+
"input_ids": input_ids,
|
54 |
+
"past_key_values": past_key_values,
|
55 |
+
"use_cache": kwargs.get("use_cache"),
|
56 |
+
"position_ids": position_ids,
|
57 |
+
"attention_mask": attention_mask,
|
58 |
+
"token_type_ids": token_type_ids,
|
59 |
+
}
|
60 |
+
|
61 |
+
def forward(
|
62 |
+
self,
|
63 |
+
input_ids=None,
|
64 |
+
past_key_values=None,
|
65 |
+
attention_mask=None,
|
66 |
+
token_type_ids=None,
|
67 |
+
position_ids=None,
|
68 |
+
head_mask=None,
|
69 |
+
inputs_embeds=None,
|
70 |
+
encoder_hidden_states=None,
|
71 |
+
encoder_attention_mask=None,
|
72 |
+
labels=None,
|
73 |
+
use_cache=None,
|
74 |
+
output_attentions=None,
|
75 |
+
output_hidden_states=None,
|
76 |
+
return_dict=None,
|
77 |
+
):
|
78 |
+
assert self.cached_prefix_emb is not None
|
79 |
+
assert inputs_embeds is None # Not supported by this inference model.
|
80 |
+
assert labels is None # Training not supported by this inference model.
|
81 |
+
return_dict = (
|
82 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
83 |
+
)
|
84 |
+
|
85 |
+
# Create embedding
|
86 |
+
prefix_len = self.cached_prefix_emb.shape[1]
|
87 |
+
if input_ids.shape[1] != 1:
|
88 |
+
gen_inputs = input_ids[:, prefix_len:]
|
89 |
+
gen_emb = self.embeddings(gen_inputs)
|
90 |
+
gen_emb = gen_emb + self.pos_embedding(gen_emb)
|
91 |
+
if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
|
92 |
+
prefix_emb = self.cached_prefix_emb.repeat_interleave(
|
93 |
+
gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
|
94 |
+
)
|
95 |
+
else:
|
96 |
+
prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
|
97 |
+
emb = torch.cat([prefix_emb, gen_emb], dim=1)
|
98 |
+
else:
|
99 |
+
emb = self.embeddings(input_ids)
|
100 |
+
emb = emb + self.pos_embedding.get_fixed_embedding(
|
101 |
+
attention_mask.shape[1] - (prefix_len + 1), attention_mask.device
|
102 |
+
)
|
103 |
+
transformer_outputs = self.transformer(
|
104 |
+
inputs_embeds=emb,
|
105 |
+
past_key_values=past_key_values,
|
106 |
+
attention_mask=attention_mask,
|
107 |
+
token_type_ids=token_type_ids,
|
108 |
+
position_ids=position_ids,
|
109 |
+
head_mask=head_mask,
|
110 |
+
encoder_hidden_states=encoder_hidden_states,
|
111 |
+
encoder_attention_mask=encoder_attention_mask,
|
112 |
+
use_cache=use_cache,
|
113 |
+
output_attentions=output_attentions,
|
114 |
+
output_hidden_states=output_hidden_states,
|
115 |
+
return_dict=return_dict,
|
116 |
+
)
|
117 |
+
hidden_states = transformer_outputs[0]
|
118 |
+
lm_logits = self.lm_head(hidden_states)
|
119 |
+
|
120 |
+
if not return_dict:
|
121 |
+
return (lm_logits,) + transformer_outputs[1:]
|
122 |
+
|
123 |
+
return CausalLMOutputWithCrossAttentions(
|
124 |
+
loss=None,
|
125 |
+
logits=lm_logits,
|
126 |
+
past_key_values=transformer_outputs.past_key_values,
|
127 |
+
hidden_states=transformer_outputs.hidden_states,
|
128 |
+
attentions=transformer_outputs.attentions,
|
129 |
+
cross_attentions=transformer_outputs.cross_attentions,
|
130 |
+
)
|
131 |
+
|
132 |
+
@staticmethod
|
133 |
+
def _reorder_cache(past, beam_idx):
|
134 |
+
"""
|
135 |
+
This function is used to re-order the :obj:`past_key_values` cache if
|
136 |
+
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
|
137 |
+
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
138 |
+
"""
|
139 |
+
return tuple(
|
140 |
+
tuple(
|
141 |
+
past_state.index_select(0, beam_idx.to(past_state.device))
|
142 |
+
for past_state in layer_past
|
143 |
+
)
|
144 |
+
for layer_past in past
|
145 |
+
)
|
146 |
+
|
147 |
+
|
148 |
+
def null_position_embeddings(range, dim):
|
149 |
+
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
150 |
+
|
151 |
+
|
152 |
+
class LearnedPositionEmbeddings(nn.Module):
|
153 |
+
def __init__(self, seq_len, model_dim, init=0.02):
|
154 |
+
super().__init__()
|
155 |
+
self.emb = torch.nn.Embedding(seq_len, model_dim)
|
156 |
+
# Initializing this way is standard for GPT-2
|
157 |
+
self.emb.weight.data.normal_(mean=0.0, std=init)
|
158 |
+
|
159 |
+
def forward(self, x):
|
160 |
+
sl = x.shape[1]
|
161 |
+
return self.emb(torch.arange(0, sl, device=x.device))
|
162 |
+
|
163 |
+
def get_fixed_embedding(self, ind, dev):
|
164 |
+
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
165 |
+
|
166 |
+
|
167 |
+
def build_hf_gpt_transformer(
|
168 |
+
layers,
|
169 |
+
model_dim,
|
170 |
+
heads,
|
171 |
+
max_mel_seq_len,
|
172 |
+
max_text_seq_len,
|
173 |
+
max_prompt_len,
|
174 |
+
checkpointing,
|
175 |
+
):
|
176 |
+
"""
|
177 |
+
GPT-2 implemented by the HuggingFace library.
|
178 |
+
"""
|
179 |
+
gpt_config = GPT2Config(
|
180 |
+
vocab_size=256, # Unused.
|
181 |
+
n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
182 |
+
n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
183 |
+
n_embd=model_dim,
|
184 |
+
n_layer=layers,
|
185 |
+
n_head=heads,
|
186 |
+
gradient_checkpointing=checkpointing,
|
187 |
+
use_cache=not checkpointing,
|
188 |
+
)
|
189 |
+
gpt = GPT2Model(gpt_config)
|
190 |
+
# Override the built in positional embeddings
|
191 |
+
del gpt.wpe
|
192 |
+
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
193 |
+
# Built-in token embeddings are unused.
|
194 |
+
del gpt.wte
|
195 |
+
|
196 |
+
mel_pos_emb = (
|
197 |
+
LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
|
198 |
+
if max_mel_seq_len != -1
|
199 |
+
else functools.partial(null_position_embeddings, dim=model_dim)
|
200 |
+
)
|
201 |
+
text_pos_emb = (
|
202 |
+
LearnedPositionEmbeddings(max_text_seq_len, model_dim)
|
203 |
+
if max_mel_seq_len != -1
|
204 |
+
else functools.partial(null_position_embeddings, dim=model_dim)
|
205 |
+
)
|
206 |
+
return gpt, mel_pos_emb, text_pos_emb, None, None
|
207 |
+
|
208 |
+
|
209 |
+
class GPT(nn.Module):
|
210 |
+
def __init__(
|
211 |
+
self,
|
212 |
+
start_text_token=261,
|
213 |
+
stop_text_token=0,
|
214 |
+
layers=8,
|
215 |
+
model_dim=512,
|
216 |
+
heads=8,
|
217 |
+
max_text_tokens=120,
|
218 |
+
max_mel_tokens=250,
|
219 |
+
max_prompt_tokens=70,
|
220 |
+
max_conditioning_inputs=1,
|
221 |
+
code_stride_len=1024,
|
222 |
+
number_text_tokens=256,
|
223 |
+
num_audio_tokens=8194,
|
224 |
+
start_audio_token=8192,
|
225 |
+
stop_audio_token=8193,
|
226 |
+
checkpointing=False,
|
227 |
+
label_smoothing=0.0,
|
228 |
+
):
|
229 |
+
"""
|
230 |
+
Args:
|
231 |
+
|
232 |
+
"""
|
233 |
+
super().__init__()
|
234 |
+
|
235 |
+
self.label_smoothing = label_smoothing
|
236 |
+
self.number_text_tokens = number_text_tokens
|
237 |
+
self.start_text_token = start_text_token
|
238 |
+
self.stop_text_token = stop_text_token
|
239 |
+
self.num_audio_tokens = num_audio_tokens
|
240 |
+
self.start_audio_token = start_audio_token
|
241 |
+
self.stop_audio_token = stop_audio_token
|
242 |
+
self.start_prompt_token = start_audio_token
|
243 |
+
self.stop_prompt_token = stop_audio_token
|
244 |
+
self.layers = layers
|
245 |
+
self.heads = heads
|
246 |
+
self.model_dim = model_dim
|
247 |
+
self.max_conditioning_inputs = max_conditioning_inputs
|
248 |
+
self.max_gen_mel_tokens = max_mel_tokens - self.max_conditioning_inputs - 2
|
249 |
+
self.max_mel_tokens = (
|
250 |
+
-1
|
251 |
+
if max_mel_tokens == -1
|
252 |
+
else max_mel_tokens + 2 + self.max_conditioning_inputs
|
253 |
+
)
|
254 |
+
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
|
255 |
+
self.max_prompt_tokens = max_prompt_tokens
|
256 |
+
self.code_stride_len = code_stride_len
|
257 |
+
self.conditioning_dropout = nn.Dropout1d(0.1)
|
258 |
+
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
259 |
+
self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
260 |
+
|
261 |
+
(
|
262 |
+
self.gpt,
|
263 |
+
self.mel_pos_embedding,
|
264 |
+
self.text_pos_embedding,
|
265 |
+
self.mel_layer_pos_embedding,
|
266 |
+
self.text_layer_pos_embedding,
|
267 |
+
) = build_hf_gpt_transformer(
|
268 |
+
layers,
|
269 |
+
model_dim,
|
270 |
+
heads,
|
271 |
+
self.max_mel_tokens,
|
272 |
+
self.max_text_tokens,
|
273 |
+
self.max_prompt_tokens,
|
274 |
+
checkpointing,
|
275 |
+
)
|
276 |
+
|
277 |
+
self.final_norm = nn.LayerNorm(model_dim)
|
278 |
+
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
279 |
+
self.mel_head = nn.Linear(model_dim, self.num_audio_tokens)
|
280 |
+
|
281 |
+
# reference_embedding
|
282 |
+
self.reference_embedding = nn.Sequential(
|
283 |
+
nn.Linear(512, 256),
|
284 |
+
nn.Tanh(),
|
285 |
+
nn.Linear(256, self.model_dim),
|
286 |
+
)
|
287 |
+
|
288 |
+
def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):
|
289 |
+
seq_length = (
|
290 |
+
self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1
|
291 |
+
)
|
292 |
+
gpt_config = GPT2Config(
|
293 |
+
vocab_size=self.max_mel_tokens,
|
294 |
+
n_positions=seq_length,
|
295 |
+
n_ctx=seq_length,
|
296 |
+
n_embd=self.model_dim,
|
297 |
+
n_layer=self.layers,
|
298 |
+
n_head=self.heads,
|
299 |
+
gradient_checkpointing=False,
|
300 |
+
use_cache=True,
|
301 |
+
)
|
302 |
+
self.gpt_inference = GPT2InferenceModel(
|
303 |
+
gpt_config,
|
304 |
+
self.gpt,
|
305 |
+
self.mel_pos_embedding,
|
306 |
+
self.mel_embedding,
|
307 |
+
self.final_norm,
|
308 |
+
self.mel_head,
|
309 |
+
kv_cache=kv_cache,
|
310 |
+
)
|
311 |
+
self.gpt.wte = self.mel_embedding
|
312 |
+
|
313 |
+
def inference(self, cond_latents, text_inputs, **hf_generate_kwargs):
|
314 |
+
self.compute_embeddings(cond_latents, text_inputs)
|
315 |
+
return self.generate(cond_latents, text_inputs, **hf_generate_kwargs)
|
316 |
+
|
317 |
+
def compute_embeddings(
|
318 |
+
self,
|
319 |
+
cond_latents,
|
320 |
+
text_inputs,
|
321 |
+
):
|
322 |
+
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
323 |
+
text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
|
324 |
+
emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
325 |
+
emb = torch.cat([cond_latents, emb], dim=1)
|
326 |
+
self.gpt_inference.store_prefix_emb(emb)
|
327 |
+
gpt_inputs = torch.full(
|
328 |
+
(
|
329 |
+
emb.shape[0],
|
330 |
+
emb.shape[1] + 1, # +1 for the start_audio_token
|
331 |
+
),
|
332 |
+
fill_value=1,
|
333 |
+
dtype=torch.long,
|
334 |
+
device=text_inputs.device,
|
335 |
+
)
|
336 |
+
gpt_inputs[:, -1] = self.start_audio_token
|
337 |
+
return gpt_inputs
|
338 |
+
|
339 |
+
def generate(
|
340 |
+
self,
|
341 |
+
cond_latents,
|
342 |
+
text_inputs,
|
343 |
+
**hf_generate_kwargs,
|
344 |
+
):
|
345 |
+
gpt_inputs = self.compute_embeddings(cond_latents, text_inputs)
|
346 |
+
gen = self.gpt_inference.generate(
|
347 |
+
gpt_inputs,
|
348 |
+
bos_token_id=self.start_audio_token,
|
349 |
+
pad_token_id=self.stop_audio_token,
|
350 |
+
eos_token_id=self.stop_audio_token,
|
351 |
+
max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1],
|
352 |
+
**hf_generate_kwargs,
|
353 |
+
)
|
354 |
+
if "return_dict_in_generate" in hf_generate_kwargs:
|
355 |
+
return gen.sequences[:, gpt_inputs.shape[1] :], gen
|
356 |
+
return gen[:, gpt_inputs.shape[1] :]
|
fireredtts/modules/text_normalizer/__init__.py
ADDED
File without changes
|
fireredtts/modules/text_normalizer/normalize.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import regex
|
3 |
+
import inflect
|
4 |
+
import unicodedata
|
5 |
+
from lingua import Language, LanguageDetectorBuilder
|
6 |
+
from builtins import str as unicode
|
7 |
+
|
8 |
+
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
9 |
+
from tn.english.normalizer import Normalizer as EnNormalizer
|
10 |
+
|
11 |
+
from fireredtts.modules.text_normalizer.regex_common import *
|
12 |
+
from fireredtts.modules.text_normalizer.utils import *
|
13 |
+
|
14 |
+
|
15 |
+
def preprocess_text(sentence):
|
16 |
+
# preprocessing
|
17 |
+
sentence = bytes(sentence, "utf-8").decode("utf-8", "ignore")
|
18 |
+
sentence = regex.sub("[\p{Cf}--[\u200d]]", "", sentence, flags=regex.V1)
|
19 |
+
sentence = regex.sub("\p{Co}", "", sentence)
|
20 |
+
sentence = sentence.replace("\u00a0", " ")
|
21 |
+
sentence = sentence.replace("\ufffd", "")
|
22 |
+
sentence = regex.sub("\p{Zl}", "\n", sentence)
|
23 |
+
sentence = regex.sub("\p{Zp}", "\n", sentence)
|
24 |
+
|
25 |
+
sentence = unicode(sentence)
|
26 |
+
sentence = "".join(
|
27 |
+
char
|
28 |
+
for char in unicodedata.normalize("NFD", sentence)
|
29 |
+
if unicodedata.category(char) != "Mn"
|
30 |
+
) # Strip accents
|
31 |
+
|
32 |
+
sentence = strip_kaomoji(sentence)
|
33 |
+
# full to half with exemption (to be converted after number TN): 。,:
|
34 |
+
sentence = f2b(sentence, exemption="。,:")
|
35 |
+
|
36 |
+
# clean spaces
|
37 |
+
sentence = sentence.replace("\n", ",")
|
38 |
+
sentence = sentence.replace("\t", ",")
|
39 |
+
sentence = sentence.replace("\r", ",")
|
40 |
+
sentence = re.sub(r"[。.]{3,}", "…", sentence)
|
41 |
+
sentence = re.sub(r"[…⋯]{1,}", "…", sentence)
|
42 |
+
sentence = re.sub(r"[ ]+", " ", sentence)
|
43 |
+
sentence = sentence.strip()
|
44 |
+
|
45 |
+
# punctuation reduction
|
46 |
+
result = ""
|
47 |
+
for idx, char in enumerate(sentence):
|
48 |
+
if char in symbol_reduction:
|
49 |
+
char = symbol_reduction[char]
|
50 |
+
|
51 |
+
if char == " ":
|
52 |
+
if idx == 0:
|
53 |
+
continue
|
54 |
+
if is_chinese(sentence[idx + 1]) and (
|
55 |
+
is_chinese(sentence[idx - 1]) or sentence[idx - 1] in '") '
|
56 |
+
):
|
57 |
+
result += ","
|
58 |
+
else:
|
59 |
+
result += " "
|
60 |
+
continue
|
61 |
+
|
62 |
+
if is_valid_char(char):
|
63 |
+
result += char
|
64 |
+
result = re.sub(r"[ ]+", " ", result)
|
65 |
+
return result
|
66 |
+
|
67 |
+
|
68 |
+
def rettt(sentence):
|
69 |
+
# handle abbreviations for all languages
|
70 |
+
sentence = sentence.replace("&nd", "and")
|
71 |
+
sentence = sentence.replace("Jan.", "january")
|
72 |
+
sentence = sentence.replace("Feb.", "febrary")
|
73 |
+
sentence = sentence.replace("Mar.", "march")
|
74 |
+
sentence = sentence.replace("Apr.", "april")
|
75 |
+
sentence = sentence.replace("May.", "may")
|
76 |
+
sentence = sentence.replace("Jun.", "june")
|
77 |
+
sentence = sentence.replace("Jul.", "july")
|
78 |
+
sentence = sentence.replace("Aug.", "august")
|
79 |
+
sentence = sentence.replace("Sept.", "september")
|
80 |
+
sentence = sentence.replace("Sep.", "september")
|
81 |
+
sentence = sentence.replace("Oct.", "october")
|
82 |
+
sentence = sentence.replace("Nov.", "november")
|
83 |
+
sentence = sentence.replace("Dec.", "december")
|
84 |
+
sentence = sentence.replace("Mon.", "monday")
|
85 |
+
sentence = sentence.replace("Tues.", "tuesday")
|
86 |
+
sentence = sentence.replace("Wed.", "wednesday")
|
87 |
+
sentence = sentence.replace("Thur.", "thursday")
|
88 |
+
sentence = sentence.replace("Fri.", "friday")
|
89 |
+
sentence = sentence.replace("Sat.", "saturday")
|
90 |
+
if sentence != "Sun.":
|
91 |
+
sentence = sentence.replace("Sun.", "sunday")
|
92 |
+
sentence = re.sub(r" St\. ([A-Z])", r" saint \1", sentence)
|
93 |
+
sentence = re.sub(r" St\.", " street", sentence)
|
94 |
+
sentence = re.sub(r" Rd\.", " road", sentence)
|
95 |
+
sentence = re.sub(r"[Aa]\.[Mm]\.", "A_M", sentence)
|
96 |
+
sentence = re.sub(r"[Pp]\.[Mm]\.", "P_M", sentence)
|
97 |
+
sentence = re.sub(r"[Bb]\.[Cc]\.", "B_C", sentence)
|
98 |
+
sentence = re.sub(r"[Ad]\.[Dd]\.", "A_D", sentence)
|
99 |
+
sentence = sentence.replace("Mr.", "mister")
|
100 |
+
sentence = sentence.replace("Ms.", "miss")
|
101 |
+
sentence = sentence.replace("Mrs.", "misses")
|
102 |
+
sentence = sentence.replace("Ph.D", "P_H_D")
|
103 |
+
sentence = sentence.replace("i.e.", "that is")
|
104 |
+
sentence = sentence.replace("e.g.", "for example")
|
105 |
+
sentence = sentence.replace("btw.", "by the way")
|
106 |
+
sentence = sentence.replace("btw", "by the way")
|
107 |
+
sentence = sentence.replace("b.t.w.", "by the way")
|
108 |
+
sentence = sentence.replace("@", " at ")
|
109 |
+
return sentence
|
110 |
+
|
111 |
+
|
112 |
+
class TextNormalizer:
|
113 |
+
def __init__(self):
|
114 |
+
self.language_detector = LanguageDetectorBuilder.from_languages(
|
115 |
+
Language.ENGLISH, Language.CHINESE
|
116 |
+
).build()
|
117 |
+
self.zh_normalizer = ZhNormalizer()
|
118 |
+
self.en_normalizer = EnNormalizer()
|
119 |
+
self.inflect_parser = inflect.engine()
|
120 |
+
self.lang2token = {Language.ENGLISH: "en", Language.CHINESE: "zh"}
|
121 |
+
|
122 |
+
def tn(self, text):
|
123 |
+
text = preprocess_text(text)
|
124 |
+
text = rettt(text) # regex replacements
|
125 |
+
# for non chinese languages
|
126 |
+
language = self.language_detector.detect_language_of(text)
|
127 |
+
# enforce chinese if text contains any chinese character
|
128 |
+
if contains_chinese(text):
|
129 |
+
language = Language.CHINESE
|
130 |
+
text_lang = self.lang2token.get(language, "zh")
|
131 |
+
|
132 |
+
if is_upper_eng_and_digit(text):
|
133 |
+
language = Language.CHINESE
|
134 |
+
|
135 |
+
if language == Language.CHINESE:
|
136 |
+
text = self.zh_normalizer.normalize(text)
|
137 |
+
text = text.replace("\n", "")
|
138 |
+
text = re.sub(r"[,,]+$", "。", text)
|
139 |
+
else:
|
140 |
+
text = re.sub(r"[^ 0-9A-Za-z\[\]'.,:?!_\-]", "", text)
|
141 |
+
text = self.en_normalizer.normalize(text)
|
142 |
+
# fallback number normalization
|
143 |
+
pieces = re.split(r"(\d+)", text)
|
144 |
+
text = "".join(
|
145 |
+
[
|
146 |
+
self.inflect_parser.number_to_words(p) if p.isnumeric() else p
|
147 |
+
for p in pieces
|
148 |
+
if len(p) > 0
|
149 |
+
]
|
150 |
+
)
|
151 |
+
|
152 |
+
# cleanup
|
153 |
+
text = text.replace("_", " ")
|
154 |
+
text = re.sub(r"[ ]+", " ", text)
|
155 |
+
|
156 |
+
# spell caplital words
|
157 |
+
pieces = re.split(r"([A-Z]{2,4}|[ ])", text)
|
158 |
+
for idx, p in enumerate(pieces):
|
159 |
+
if re.match("[A-Z]{2,4}", p):
|
160 |
+
pieces[idx] = " ".join(p)
|
161 |
+
text = " ".join([p for p in pieces if p != " "])
|
162 |
+
|
163 |
+
# post TN full to half
|
164 |
+
text = text.replace("。", ".")
|
165 |
+
text = text.replace(",", ",")
|
166 |
+
text = text.replace(":", ":")
|
167 |
+
|
168 |
+
# model limitations
|
169 |
+
text = text.lower().strip()
|
170 |
+
text = text.replace('"', "")
|
171 |
+
text = text.replace("·", " ")
|
172 |
+
text = re.sub("[…~、!,?:;!?:;]+", ",", text)
|
173 |
+
text = re.sub("[,]+", ",", text)
|
174 |
+
text = re.sub(r"[,. ]+$", ".", text)
|
175 |
+
if len(text) > 0 and text[-1] != ".":
|
176 |
+
text = text + "."
|
177 |
+
|
178 |
+
return text, text_lang
|
fireredtts/modules/text_normalizer/regex_common.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
kaomoji_regex = re.compile(
|
4 |
+
r"[oヽwΣ┗╰O︿Ψ凸]?[(|≡*(].{0,4}[Д✿_▽→≧﹏`∩⊙∇☆≡๑〃′エ≦▔@﹁εヘ•́ω益‿≖ฺ皿•̀艹 ̄△|゚].{0,5}[|≡*))][┛ブ凸cdd︴oOΨ︿w╯ノ]?"
|
5 |
+
)
|
6 |
+
chinese_regex = re.compile(r"[\u4e00-\u9fa5]")
|
7 |
+
digit_regex = re.compile(r"(\\d+)(\\.\\d+)?", re.UNICODE)
|
8 |
+
|
9 |
+
chinese_char_regex = re.compile(r"^[\u4e00-\u9fa5]$", re.UNICODE)
|
10 |
+
eng_and_digit_char_regex = re.compile(r"^[0-9.,A-Za-z]+$", re.UNICODE)
|
11 |
+
upper_eng_and_digit_regex = re.compile(r"^[ 0-9A-Z\"'.,:?!\-]+$", re.UNICODE)
|
12 |
+
valid_char_regex = re.compile(
|
13 |
+
r"[\t\r\n ]|"
|
14 |
+
r"[\u4e00-\u9fa5]|"
|
15 |
+
r"\u0080|[\u20a0-\u20bf]|\u00a2|\u00a3|\u00a5|\uffe0|\uffe1|\uffe5|\uffe6|"
|
16 |
+
r"\u3000|\u3002|\u00b7|\u2014|\u2019|\u2026|\uff01|\uff1f|\uff0e|\uff1a|\uff1b|\uff0b|\uff0c|\uff0d|\uff0f|[\ufe10-\ufe16]|[\ufe50-\ufe51]|[\ufe55-\ufe57]|\ufe6a|"
|
17 |
+
r"[\u0030-\u0039]|"
|
18 |
+
r"[\u0391-\u03c9]|"
|
19 |
+
r"[\u00b0-\u00b3]|[\u2015-\u2018]|[\u3000-\u303f]|"
|
20 |
+
r"[\u0022-\u002f\u003a-\u003e\u0040\u005b-\u0060\u007b-\u007e]|"
|
21 |
+
r"[\uff21-\uff3a]|[\uff41-\uff5a]|[\u0041-\u005a]|[\u0061-\u007a]",
|
22 |
+
re.UNICODE,
|
23 |
+
)
|
fireredtts/modules/text_normalizer/utils.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fireredtts.modules.text_normalizer.regex_common import *
|
2 |
+
|
3 |
+
|
4 |
+
def contains_chinese(text):
|
5 |
+
return bool(chinese_regex.search(text))
|
6 |
+
|
7 |
+
|
8 |
+
def strip_kaomoji(text):
|
9 |
+
return kaomoji_regex.sub(" ", text)
|
10 |
+
|
11 |
+
|
12 |
+
def is_chinese(char):
|
13 |
+
return chinese_char_regex.match(char)
|
14 |
+
|
15 |
+
|
16 |
+
def is_eng_and_digit(char):
|
17 |
+
return eng_and_digit_char_regex.match(char)
|
18 |
+
|
19 |
+
|
20 |
+
def is_upper_eng_and_digit(text):
|
21 |
+
return upper_eng_and_digit_regex.match(text)
|
22 |
+
|
23 |
+
|
24 |
+
def is_valid_char(char):
|
25 |
+
return valid_char_regex.match(char)
|
26 |
+
|
27 |
+
|
28 |
+
def is_digit(text):
|
29 |
+
return digit_regex.match(text)
|
30 |
+
|
31 |
+
|
32 |
+
def contains_chinese(text):
|
33 |
+
return bool(chinese_regex.search(text))
|
34 |
+
|
35 |
+
|
36 |
+
def f2b(ustr, exemption="。,:"):
|
37 |
+
half = []
|
38 |
+
for u in ustr:
|
39 |
+
num = ord(u)
|
40 |
+
if num == 0x3000:
|
41 |
+
half.append(" ")
|
42 |
+
elif u in exemption: # exemption
|
43 |
+
half.append(u)
|
44 |
+
elif 0xFF01 <= num <= 0xFF5E:
|
45 |
+
num -= 0xFEE0
|
46 |
+
half.append(chr(num))
|
47 |
+
else:
|
48 |
+
half.append(u)
|
49 |
+
return "".join(half)
|
50 |
+
|
51 |
+
|
52 |
+
symbol_reduction = {
|
53 |
+
"「": '"',
|
54 |
+
"」": '"',
|
55 |
+
"`": '"',
|
56 |
+
"〝": '"',
|
57 |
+
"〞": '"',
|
58 |
+
"‟": '"',
|
59 |
+
"„": '"',
|
60 |
+
"{": "(",
|
61 |
+
"}": ")",
|
62 |
+
"【": "(",
|
63 |
+
"】": ")",
|
64 |
+
"〖": "(",
|
65 |
+
"〗": ")",
|
66 |
+
"〔": "(",
|
67 |
+
"〕": ")",
|
68 |
+
"〘": "(",
|
69 |
+
"〙": ")",
|
70 |
+
"《": "(",
|
71 |
+
"》": ")",
|
72 |
+
"⦅": "(",
|
73 |
+
"⦆": ")",
|
74 |
+
"〚": "(",
|
75 |
+
"〛": ")",
|
76 |
+
"『": '"',
|
77 |
+
"』": '"',
|
78 |
+
"「": '"',
|
79 |
+
"」": '"',
|
80 |
+
"{": "(",
|
81 |
+
"}": ")",
|
82 |
+
"〈": "(",
|
83 |
+
"〉": ")",
|
84 |
+
"•": "·",
|
85 |
+
"‧": "·",
|
86 |
+
"〰": "…",
|
87 |
+
"﹏": "…",
|
88 |
+
"〜": "~",
|
89 |
+
"~": "~",
|
90 |
+
"+": "+",
|
91 |
+
"、": "、",
|
92 |
+
"。": "。",
|
93 |
+
"︐": ",",
|
94 |
+
"﹐": ",",
|
95 |
+
"︑": "、",
|
96 |
+
"﹑": "、",
|
97 |
+
"︒": "。",
|
98 |
+
"︓": ":",
|
99 |
+
"﹕": ":",
|
100 |
+
"︔": ";",
|
101 |
+
"﹔": ";",
|
102 |
+
"︕": "!",
|
103 |
+
"﹗": "!",
|
104 |
+
"︖": "?",
|
105 |
+
"﹖": "?",
|
106 |
+
"﹙": "(",
|
107 |
+
"﹚": ")",
|
108 |
+
"﹪": "%",
|
109 |
+
"﹠": "&",
|
110 |
+
">": ">",
|
111 |
+
"|": "、",
|
112 |
+
"=": "=",
|
113 |
+
"‐": "-",
|
114 |
+
"‑": "-",
|
115 |
+
"‒": "-",
|
116 |
+
"–": "-",
|
117 |
+
"—": "-",
|
118 |
+
"―": "-",
|
119 |
+
"%": "%",
|
120 |
+
"μ": "u",
|
121 |
+
}
|
fireredtts/modules/tokenizer/__init__.py
ADDED
File without changes
|
fireredtts/modules/tokenizer/assets/multilingual.tiktoken
ADDED
The diff for this file is too large to render.
See raw diff
|
|
fireredtts/modules/tokenizer/tokenizer.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from fireredtts.modules.tokenizer.whisper_tokenizer import get_tokenizer
|
4 |
+
from fireredtts.modules.text_normalizer.normalize import TextNormalizer
|
5 |
+
|
6 |
+
|
7 |
+
DEFAULT_VOCAB_FILE = os.path.join(
|
8 |
+
os.path.dirname(os.path.realpath(__file__)), "../data/tokenizer.json"
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
class VoiceBpeTokenizer:
|
13 |
+
def __init__(self):
|
14 |
+
self.tokenizer = get_tokenizer(multilingual=True)
|
15 |
+
self.tn_engine = TextNormalizer()
|
16 |
+
|
17 |
+
def redtts_text_cleaner(self, text):
|
18 |
+
text = text.strip()
|
19 |
+
text, text_lang = self.tn_engine.tn(text)
|
20 |
+
# print("---text after tn:", text)
|
21 |
+
return text, text_lang
|
22 |
+
|
23 |
+
def encode(self, text, lang="auto"):
|
24 |
+
text, text_lang = self.redtts_text_cleaner(text=text)
|
25 |
+
if lang == "auto":
|
26 |
+
lang = text_lang
|
27 |
+
text = f"[{lang}]{text}"
|
28 |
+
return self.tokenizer.encode(text)
|
29 |
+
|
30 |
+
def decode(self, seq):
|
31 |
+
if isinstance(seq, torch.Tensor):
|
32 |
+
seq = seq.cpu().numpy()
|
33 |
+
text = self.tokenizer.decode(seq)
|
34 |
+
return text
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return self.tokenizer.get_vocab_size()
|
38 |
+
|
39 |
+
def get_number_tokens(self):
|
40 |
+
return self.tokenizer.get_vocab_size()
|
41 |
+
|
42 |
+
|
43 |
+
if __name__ == "__main__":
|
44 |
+
tok = VoiceBpeTokenizer()
|
45 |
+
codes = tok.encode("我、真是hello USA啊?谢谢你world!")
|
46 |
+
print([tok.decode([c]) for c in codes])
|
fireredtts/modules/tokenizer/whisper_tokenizer.py
ADDED
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# adapted from https://github.com/openai/whisper/blob/main/whisper/tokenizer.py
|
2 |
+
# Copyright (c) 2022 OpenAI
|
3 |
+
|
4 |
+
# MIT License for this file
|
5 |
+
|
6 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
7 |
+
# of this software and associated documentation files (the "Software"), to deal
|
8 |
+
# in the Software without restriction, including without limitation the rights
|
9 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10 |
+
# copies of the Software, and to permit persons to whom the Software is
|
11 |
+
# furnished to do so, subject to the following conditions:
|
12 |
+
|
13 |
+
# The above copyright notice and this permission notice shall be included in all
|
14 |
+
# copies or substantial portions of the Software.
|
15 |
+
|
16 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
+
# SOFTWARE.
|
23 |
+
|
24 |
+
import base64
|
25 |
+
import os
|
26 |
+
import string
|
27 |
+
from dataclasses import dataclass, field
|
28 |
+
from functools import cached_property, lru_cache
|
29 |
+
from typing import Dict, List, Optional, Tuple
|
30 |
+
|
31 |
+
import tiktoken
|
32 |
+
|
33 |
+
LANGUAGES = {
|
34 |
+
"en": "english",
|
35 |
+
"zh": "chinese",
|
36 |
+
"de": "german",
|
37 |
+
"es": "spanish",
|
38 |
+
"ru": "russian",
|
39 |
+
"ko": "korean",
|
40 |
+
"fr": "french",
|
41 |
+
"ja": "japanese",
|
42 |
+
"pt": "portuguese",
|
43 |
+
"tr": "turkish",
|
44 |
+
"pl": "polish",
|
45 |
+
"ca": "catalan",
|
46 |
+
"nl": "dutch",
|
47 |
+
"ar": "arabic",
|
48 |
+
"sv": "swedish",
|
49 |
+
"it": "italian",
|
50 |
+
"id": "indonesian",
|
51 |
+
"hi": "hindi",
|
52 |
+
"fi": "finnish",
|
53 |
+
"vi": "vietnamese",
|
54 |
+
"he": "hebrew",
|
55 |
+
"uk": "ukrainian",
|
56 |
+
"el": "greek",
|
57 |
+
"ms": "malay",
|
58 |
+
"cs": "czech",
|
59 |
+
"ro": "romanian",
|
60 |
+
"da": "danish",
|
61 |
+
"hu": "hungarian",
|
62 |
+
"ta": "tamil",
|
63 |
+
"no": "norwegian",
|
64 |
+
"th": "thai",
|
65 |
+
"ur": "urdu",
|
66 |
+
"hr": "croatian",
|
67 |
+
"bg": "bulgarian",
|
68 |
+
"lt": "lithuanian",
|
69 |
+
"la": "latin",
|
70 |
+
"mi": "maori",
|
71 |
+
"ml": "malayalam",
|
72 |
+
"cy": "welsh",
|
73 |
+
"sk": "slovak",
|
74 |
+
"te": "telugu",
|
75 |
+
"fa": "persian",
|
76 |
+
"lv": "latvian",
|
77 |
+
"bn": "bengali",
|
78 |
+
"sr": "serbian",
|
79 |
+
"az": "azerbaijani",
|
80 |
+
"sl": "slovenian",
|
81 |
+
"kn": "kannada",
|
82 |
+
"et": "estonian",
|
83 |
+
"mk": "macedonian",
|
84 |
+
"br": "breton",
|
85 |
+
"eu": "basque",
|
86 |
+
"is": "icelandic",
|
87 |
+
"hy": "armenian",
|
88 |
+
"ne": "nepali",
|
89 |
+
"mn": "mongolian",
|
90 |
+
"bs": "bosnian",
|
91 |
+
"kk": "kazakh",
|
92 |
+
"sq": "albanian",
|
93 |
+
"sw": "swahili",
|
94 |
+
"gl": "galician",
|
95 |
+
"mr": "marathi",
|
96 |
+
"pa": "punjabi",
|
97 |
+
"si": "sinhala",
|
98 |
+
"km": "khmer",
|
99 |
+
"sn": "shona",
|
100 |
+
"yo": "yoruba",
|
101 |
+
"so": "somali",
|
102 |
+
"af": "afrikaans",
|
103 |
+
"oc": "occitan",
|
104 |
+
"ka": "georgian",
|
105 |
+
"be": "belarusian",
|
106 |
+
"tg": "tajik",
|
107 |
+
"sd": "sindhi",
|
108 |
+
"gu": "gujarati",
|
109 |
+
"am": "amharic",
|
110 |
+
"yi": "yiddish",
|
111 |
+
"lo": "lao",
|
112 |
+
"uz": "uzbek",
|
113 |
+
"fo": "faroese",
|
114 |
+
"ht": "haitian creole",
|
115 |
+
"ps": "pashto",
|
116 |
+
"tk": "turkmen",
|
117 |
+
"nn": "nynorsk",
|
118 |
+
"mt": "maltese",
|
119 |
+
"sa": "sanskrit",
|
120 |
+
"lb": "luxembourgish",
|
121 |
+
"my": "myanmar",
|
122 |
+
"bo": "tibetan",
|
123 |
+
"tl": "tagalog",
|
124 |
+
"mg": "malagasy",
|
125 |
+
"as": "assamese",
|
126 |
+
"tt": "tatar",
|
127 |
+
"haw": "hawaiian",
|
128 |
+
"ln": "lingala",
|
129 |
+
"ha": "hausa",
|
130 |
+
"ba": "bashkir",
|
131 |
+
"jw": "javanese",
|
132 |
+
"su": "sundanese",
|
133 |
+
"yue": "cantonese",
|
134 |
+
}
|
135 |
+
|
136 |
+
# language code lookup by name, with a few language aliases
|
137 |
+
TO_LANGUAGE_CODE = {
|
138 |
+
**{language: code for code, language in LANGUAGES.items()},
|
139 |
+
"burmese": "my",
|
140 |
+
"valencian": "ca",
|
141 |
+
"flemish": "nl",
|
142 |
+
"haitian": "ht",
|
143 |
+
"letzeburgesch": "lb",
|
144 |
+
"pushto": "ps",
|
145 |
+
"panjabi": "pa",
|
146 |
+
"moldavian": "ro",
|
147 |
+
"moldovan": "ro",
|
148 |
+
"sinhalese": "si",
|
149 |
+
"castilian": "es",
|
150 |
+
"mandarin": "zh",
|
151 |
+
}
|
152 |
+
|
153 |
+
|
154 |
+
@dataclass
|
155 |
+
class Tokenizer:
|
156 |
+
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
157 |
+
|
158 |
+
encoding: tiktoken.Encoding
|
159 |
+
num_languages: int
|
160 |
+
language: Optional[str] = None
|
161 |
+
task: Optional[str] = None
|
162 |
+
sot_sequence: Tuple[int] = ()
|
163 |
+
special_tokens: Dict[str, int] = field(default_factory=dict)
|
164 |
+
|
165 |
+
def __post_init__(self):
|
166 |
+
for special in self.encoding.special_tokens_set:
|
167 |
+
special_token = self.encoding.encode_single_token(special)
|
168 |
+
self.special_tokens[special] = special_token
|
169 |
+
|
170 |
+
sot: int = self.special_tokens["[startoftranscript]"]
|
171 |
+
translate: int = self.special_tokens["[translate]"]
|
172 |
+
transcribe: int = self.special_tokens["[transcribe]"]
|
173 |
+
|
174 |
+
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
175 |
+
sot_sequence = [sot]
|
176 |
+
if self.language is not None:
|
177 |
+
sot_sequence.append(sot + 1 + langs.index(self.language))
|
178 |
+
if self.task is not None:
|
179 |
+
task_token: int = transcribe if self.task == "transcribe" else translate
|
180 |
+
sot_sequence.append(task_token)
|
181 |
+
|
182 |
+
self.sot_sequence = tuple(sot_sequence)
|
183 |
+
|
184 |
+
def get_vocab_size(self):
|
185 |
+
return self.encoding.n_vocab
|
186 |
+
|
187 |
+
def encode(self, text):
|
188 |
+
return self.encoding.encode(text, allowed_special="all")
|
189 |
+
|
190 |
+
def decode(self, token_ids: List[int], **kwargs) -> str:
|
191 |
+
return self.encoding.decode(token_ids, **kwargs)
|
192 |
+
|
193 |
+
@cached_property
|
194 |
+
def eot(self) -> int:
|
195 |
+
return self.encoding.eot_token
|
196 |
+
|
197 |
+
@cached_property
|
198 |
+
def stop(self) -> int:
|
199 |
+
return self.special_tokens["[STOP]"]
|
200 |
+
|
201 |
+
@cached_property
|
202 |
+
def start(self) -> int:
|
203 |
+
return self.special_tokens["[START]"]
|
204 |
+
|
205 |
+
@cached_property
|
206 |
+
def transcribe(self) -> int:
|
207 |
+
return self.special_tokens["[transcribe]"]
|
208 |
+
|
209 |
+
@cached_property
|
210 |
+
def translate(self) -> int:
|
211 |
+
return self.special_tokens["[translate]"]
|
212 |
+
|
213 |
+
@cached_property
|
214 |
+
def sot(self) -> int:
|
215 |
+
return self.special_tokens["[startoftranscript]"]
|
216 |
+
|
217 |
+
@cached_property
|
218 |
+
def sot_lm(self) -> int:
|
219 |
+
return self.special_tokens["[startoflm]"]
|
220 |
+
|
221 |
+
@cached_property
|
222 |
+
def sot_prev(self) -> int:
|
223 |
+
return self.special_tokens["[startofprev]"]
|
224 |
+
|
225 |
+
@cached_property
|
226 |
+
def no_speech(self) -> int:
|
227 |
+
return self.special_tokens["[nospeech]"]
|
228 |
+
|
229 |
+
@cached_property
|
230 |
+
def language_token(self) -> int:
|
231 |
+
"""Returns the token id corresponding to the value of the `language` field"""
|
232 |
+
if self.language is None:
|
233 |
+
raise ValueError("This tokenizer does not have language token configured")
|
234 |
+
|
235 |
+
return self.to_language_token(self.language)
|
236 |
+
|
237 |
+
def to_language_token(self, language):
|
238 |
+
if token := self.special_tokens.get(f"[{language}]", None):
|
239 |
+
return token
|
240 |
+
|
241 |
+
raise KeyError(f"Language {language} not found in tokenizer.")
|
242 |
+
|
243 |
+
@cached_property
|
244 |
+
def all_language_tokens(self) -> Tuple[int]:
|
245 |
+
result = []
|
246 |
+
for token, token_id in self.special_tokens.items():
|
247 |
+
if token.strip("[]") in LANGUAGES:
|
248 |
+
result.append(token_id)
|
249 |
+
return tuple(result)[: self.num_languages]
|
250 |
+
|
251 |
+
@cached_property
|
252 |
+
def all_language_codes(self) -> Tuple[str]:
|
253 |
+
return tuple(self.decode([_l]).strip("[]") for _l in self.all_language_tokens)
|
254 |
+
|
255 |
+
@cached_property
|
256 |
+
def non_speech_tokens(self) -> Tuple[int]:
|
257 |
+
"""
|
258 |
+
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
259 |
+
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
260 |
+
|
261 |
+
- ♪♪♪
|
262 |
+
- ( SPEAKING FOREIGN LANGUAGE )
|
263 |
+
- [DAVID] Hey there,
|
264 |
+
|
265 |
+
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
266 |
+
"""
|
267 |
+
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
268 |
+
symbols += (
|
269 |
+
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
270 |
+
)
|
271 |
+
|
272 |
+
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
273 |
+
# In case they're multiple tokens, suppress the first token, which is safe because:
|
274 |
+
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
275 |
+
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
276 |
+
miscellaneous = set("♩♪♫♬♭♮♯")
|
277 |
+
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
278 |
+
|
279 |
+
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
280 |
+
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
281 |
+
for symbol in symbols + list(miscellaneous):
|
282 |
+
for tokens in [
|
283 |
+
self.encoding.encode(symbol),
|
284 |
+
self.encoding.encode(" " + symbol),
|
285 |
+
]:
|
286 |
+
if len(tokens) == 1 or symbol in miscellaneous:
|
287 |
+
result.add(tokens[0])
|
288 |
+
|
289 |
+
return tuple(sorted(result))
|
290 |
+
|
291 |
+
def split_to_word_tokens(self, tokens: List[int]):
|
292 |
+
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
|
293 |
+
# These languages don't typically use spaces, so it is difficult to split words
|
294 |
+
# without morpheme analysis. Here, we instead split words at any
|
295 |
+
# position where the tokens are decoded as valid unicode points
|
296 |
+
return self.split_tokens_on_unicode(tokens)
|
297 |
+
|
298 |
+
return self.split_tokens_on_spaces(tokens)
|
299 |
+
|
300 |
+
def split_tokens_on_unicode(self, tokens: List[int]):
|
301 |
+
decoded_full = self.decode(tokens)
|
302 |
+
replacement_char = "\ufffd"
|
303 |
+
|
304 |
+
words = []
|
305 |
+
word_tokens = []
|
306 |
+
current_tokens = []
|
307 |
+
unicode_offset = 0
|
308 |
+
|
309 |
+
for token in tokens:
|
310 |
+
current_tokens.append(token)
|
311 |
+
decoded = self.decode(current_tokens)
|
312 |
+
|
313 |
+
if (
|
314 |
+
replacement_char not in decoded
|
315 |
+
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
316 |
+
== replacement_char
|
317 |
+
):
|
318 |
+
words.append(decoded)
|
319 |
+
word_tokens.append(current_tokens)
|
320 |
+
current_tokens = []
|
321 |
+
unicode_offset += len(decoded)
|
322 |
+
|
323 |
+
return words, word_tokens
|
324 |
+
|
325 |
+
def split_tokens_on_spaces(self, tokens: List[int]):
|
326 |
+
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
327 |
+
words = []
|
328 |
+
word_tokens = []
|
329 |
+
|
330 |
+
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
331 |
+
special = subword_tokens[0] >= self.eot
|
332 |
+
with_space = subword.startswith(" ")
|
333 |
+
punctuation = subword.strip() in string.punctuation
|
334 |
+
if special or with_space or punctuation or len(words) == 0:
|
335 |
+
words.append(subword)
|
336 |
+
word_tokens.append(subword_tokens)
|
337 |
+
else:
|
338 |
+
words[-1] = words[-1] + subword
|
339 |
+
word_tokens[-1].extend(subword_tokens)
|
340 |
+
|
341 |
+
return words, word_tokens
|
342 |
+
|
343 |
+
|
344 |
+
@lru_cache(maxsize=None)
|
345 |
+
def get_encoding(name: str = "multilingual", num_languages: int = 100):
|
346 |
+
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
347 |
+
ranks = {
|
348 |
+
base64.b64decode(token): int(rank)
|
349 |
+
for token, rank in (line.split() for line in open(vocab_path) if line)
|
350 |
+
}
|
351 |
+
n_vocab = len(ranks)
|
352 |
+
special_tokens = {}
|
353 |
+
|
354 |
+
specials = [
|
355 |
+
"[STOP]",
|
356 |
+
"[UNK]",
|
357 |
+
"[SPACE]",
|
358 |
+
"[START]",
|
359 |
+
"[nospk]",
|
360 |
+
"[spkemb]",
|
361 |
+
"[emotionemb]",
|
362 |
+
"[contextemb]",
|
363 |
+
"[sbreak]",
|
364 |
+
"[pbreak]",
|
365 |
+
"[uvbreak]",
|
366 |
+
"[bsing]",
|
367 |
+
"[esing]",
|
368 |
+
"[sing]",
|
369 |
+
"[hum]",
|
370 |
+
"[laugh]",
|
371 |
+
"[break]",
|
372 |
+
"[breath]",
|
373 |
+
"[oralsii]",
|
374 |
+
"[oralze]",
|
375 |
+
"[prolong]",
|
376 |
+
"[stress]",
|
377 |
+
"[bstrong]",
|
378 |
+
"[estrong]",
|
379 |
+
"[hiccup]",
|
380 |
+
"[inhale]",
|
381 |
+
"[exhale]",
|
382 |
+
"[emounknown]",
|
383 |
+
"[happy]",
|
384 |
+
"[neutral]",
|
385 |
+
"[sad]",
|
386 |
+
"[surprise]",
|
387 |
+
"[angry]",
|
388 |
+
"[disgust]",
|
389 |
+
"[emo]",
|
390 |
+
"[laugha]",
|
391 |
+
"[laughb]",
|
392 |
+
"[laughc]",
|
393 |
+
"[orala]",
|
394 |
+
"[oralb]",
|
395 |
+
"[oralc]",
|
396 |
+
"[orald]",
|
397 |
+
"[orale]",
|
398 |
+
"[breaka]",
|
399 |
+
"[breakb]",
|
400 |
+
"[breakc]",
|
401 |
+
"[breakd]",
|
402 |
+
"[breake]",
|
403 |
+
"[breakf]",
|
404 |
+
"[endoftext]",
|
405 |
+
"[startoftranscript]",
|
406 |
+
*[f"[{lang}]" for lang in list(LANGUAGES.keys())[:num_languages]],
|
407 |
+
"[translate]",
|
408 |
+
"[transcribe]",
|
409 |
+
"[startoflm]",
|
410 |
+
"[startofprev]",
|
411 |
+
"[nospeech]",
|
412 |
+
]
|
413 |
+
|
414 |
+
for token in specials:
|
415 |
+
special_tokens[token] = n_vocab
|
416 |
+
n_vocab += 1
|
417 |
+
|
418 |
+
return tiktoken.Encoding(
|
419 |
+
name=os.path.basename(vocab_path),
|
420 |
+
explicit_n_vocab=n_vocab,
|
421 |
+
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d|\[[A-Z]+\]|\[[a-z]+\]|[\x{4e00}-\x{9df5}]| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
422 |
+
mergeable_ranks=ranks,
|
423 |
+
special_tokens=special_tokens,
|
424 |
+
)
|
425 |
+
|
426 |
+
|
427 |
+
@lru_cache(maxsize=None)
|
428 |
+
def get_tokenizer(
|
429 |
+
multilingual: bool,
|
430 |
+
*,
|
431 |
+
num_languages: int = 100,
|
432 |
+
language: Optional[str] = None,
|
433 |
+
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
434 |
+
) -> Tokenizer:
|
435 |
+
if language is not None:
|
436 |
+
language = language.lower()
|
437 |
+
if language not in LANGUAGES:
|
438 |
+
if language in TO_LANGUAGE_CODE:
|
439 |
+
language = TO_LANGUAGE_CODE[language]
|
440 |
+
else:
|
441 |
+
raise ValueError(f"Unsupported language: {language}")
|
442 |
+
|
443 |
+
if multilingual:
|
444 |
+
encoding_name = "multilingual"
|
445 |
+
language = language or "en"
|
446 |
+
task = task or "transcribe"
|
447 |
+
else:
|
448 |
+
encoding_name = "gpt2"
|
449 |
+
language = None
|
450 |
+
task = None
|
451 |
+
|
452 |
+
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
453 |
+
|
454 |
+
return Tokenizer(
|
455 |
+
encoding=encoding, num_languages=num_languages, language=language, task=task
|
456 |
+
)
|
fireredtts/utils/__init__.py
ADDED
File without changes
|
fireredtts/utils/utils.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torchaudio
|
7 |
+
|
8 |
+
|
9 |
+
def load_audio(audiopath, sampling_rate):
|
10 |
+
"""_summary_
|
11 |
+
|
12 |
+
Args:
|
13 |
+
audiopath (_type_): audio_path
|
14 |
+
sampling_rate (_type_): sampling_rate
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
_type_: _description_
|
18 |
+
"""
|
19 |
+
audio, lsr = torchaudio.load(audiopath)
|
20 |
+
|
21 |
+
# stereo to mono if needed
|
22 |
+
if audio.size(0) != 1:
|
23 |
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
24 |
+
|
25 |
+
# resample
|
26 |
+
audio_resampled = torchaudio.functional.resample(audio, lsr, sampling_rate)
|
27 |
+
if torch.any(audio > 10) or not torch.any(audio < 0):
|
28 |
+
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
29 |
+
|
30 |
+
if torch.any(audio_resampled > 10) or not torch.any(audio_resampled < 0):
|
31 |
+
print(
|
32 |
+
f"Error with {audiopath}. Max={audio_resampled.max()} min={audio_resampled.min()}"
|
33 |
+
)
|
34 |
+
# clip audio invalid values
|
35 |
+
audio.clip_(-1, 1)
|
36 |
+
audio_resampled.clip_(-1, 1)
|
37 |
+
return audio, lsr, audio_resampled
|
pretrained_models/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
## Pretrained Models
|
2 |
+
|
3 |
+
Download the required model files and place them in the folder `pretrained_models`
|