|
|
|
|
|
|
|
|
|
|
|
from .modules.seanet import SEANetEncoder, SEANetDecoder |
|
from .modules.quantization import ResidualVectorQuantizer |
|
import torch.nn as nn |
|
from einops import rearrange |
|
import torch |
|
import numpy as np |
|
|
|
|
|
class SpeechTokenizer(nn.Module): |
|
def __init__(self, config): |
|
""" |
|
|
|
Parameters |
|
---------- |
|
config : json |
|
Model Config. |
|
|
|
""" |
|
super().__init__() |
|
self.encoder = SEANetEncoder( |
|
n_filters=config.get("n_filters"), |
|
dimension=config.get("dimension"), |
|
ratios=config.get("strides"), |
|
lstm=config.get("lstm_layers"), |
|
bidirectional=config.get("bidirectional"), |
|
dilation_base=config.get("dilation_base"), |
|
residual_kernel_size=config.get("residual_kernel_size"), |
|
n_residual_layers=config.get("n_residual_layers"), |
|
activation=config.get("activation"), |
|
) |
|
self.sample_rate = config.get("sample_rate") |
|
self.n_q = config.get("n_q") |
|
self.downsample_rate = np.prod(config.get("strides")) |
|
if config.get("dimension") != config.get("semantic_dimension"): |
|
self.transform = nn.Linear( |
|
config.get("dimension"), config.get("semantic_dimension") |
|
) |
|
else: |
|
self.transform = nn.Identity() |
|
self.quantizer = ResidualVectorQuantizer( |
|
dimension=config.get("dimension"), |
|
n_q=config.get("n_q"), |
|
bins=config.get("codebook_size"), |
|
) |
|
self.decoder = SEANetDecoder( |
|
n_filters=config.get("n_filters"), |
|
dimension=config.get("dimension"), |
|
ratios=config.get("strides"), |
|
lstm=config.get("lstm_layers"), |
|
bidirectional=False, |
|
dilation_base=config.get("dilation_base"), |
|
residual_kernel_size=config.get("residual_kernel_size"), |
|
n_residual_layers=config.get("n_residual_layers"), |
|
activation=config.get("activation"), |
|
) |
|
|
|
@classmethod |
|
def load_from_checkpoint(cls, config_path: str, ckpt_path: str): |
|
""" |
|
|
|
Parameters |
|
---------- |
|
config_path : str |
|
Path of model configuration file. |
|
ckpt_path : str |
|
Path of model checkpoint. |
|
|
|
Returns |
|
------- |
|
model : SpeechTokenizer |
|
SpeechTokenizer model. |
|
|
|
""" |
|
import json |
|
|
|
with open(config_path) as f: |
|
cfg = json.load(f) |
|
model = cls(cfg) |
|
params = torch.load(ckpt_path, map_location="cpu") |
|
model.load_state_dict(params) |
|
return model |
|
|
|
def forward(self, x: torch.tensor, n_q: int = None, layers: list = [0]): |
|
""" |
|
|
|
Parameters |
|
---------- |
|
x : torch.tensor |
|
Input wavs. Shape: (batch, channels, timesteps). |
|
n_q : int, optional |
|
Number of quantizers in RVQ used to encode. The default is all layers. |
|
layers : list[int], optional |
|
Layers of RVQ should return quantized result. The default is the first layer. |
|
|
|
Returns |
|
------- |
|
o : torch.tensor |
|
Output wavs. Shape: (batch, channels, timesteps). |
|
commit_loss : torch.tensor |
|
Commitment loss from residual vector quantizers. |
|
feature : torch.tensor |
|
Output of RVQ's first layer. Shape: (batch, timesteps, dimension) |
|
|
|
""" |
|
n_q = n_q if n_q else self.n_q |
|
e = self.encoder(x) |
|
quantized, codes, commit_loss, quantized_list = self.quantizer( |
|
e, n_q=n_q, layers=layers |
|
) |
|
feature = rearrange(quantized_list[0], "b d t -> b t d") |
|
feature = self.transform(feature) |
|
o = self.decoder(quantized) |
|
return o, commit_loss, feature |
|
|
|
def forward_feature(self, x: torch.tensor, layers: list = None): |
|
""" |
|
|
|
Parameters |
|
---------- |
|
x : torch.tensor |
|
Input wavs. Shape should be (batch, channels, timesteps). |
|
layers : list[int], optional |
|
Layers of RVQ should return quantized result. The default is all layers. |
|
|
|
Returns |
|
------- |
|
quantized_list : list[torch.tensor] |
|
Quantized of required layers. |
|
|
|
""" |
|
e = self.encoder(x) |
|
layers = layers if layers else list(range(self.n_q)) |
|
quantized, codes, commit_loss, quantized_list = self.quantizer(e, layers=layers) |
|
return quantized_list |
|
|
|
def encode(self, x: torch.tensor, n_q: int = None, st: int = None): |
|
""" |
|
|
|
Parameters |
|
---------- |
|
x : torch.tensor |
|
Input wavs. Shape: (batch, channels, timesteps). |
|
n_q : int, optional |
|
Number of quantizers in RVQ used to encode. The default is all layers. |
|
st : int, optional |
|
Start quantizer index in RVQ. The default is 0. |
|
|
|
Returns |
|
------- |
|
codes : torch.tensor |
|
Output indices for each quantizer. Shape: (n_q, batch, timesteps) |
|
|
|
""" |
|
e = self.encoder(x) |
|
if st is None: |
|
st = 0 |
|
n_q = n_q if n_q else self.n_q |
|
codes = self.quantizer.encode(e, n_q=n_q, st=st) |
|
return codes |
|
|
|
def decode(self, codes: torch.tensor, st: int = 0): |
|
""" |
|
|
|
Parameters |
|
---------- |
|
codes : torch.tensor |
|
Indices for each quantizer. Shape: (n_q, batch, timesteps). |
|
st : int, optional |
|
Start quantizer index in RVQ. The default is 0. |
|
|
|
Returns |
|
------- |
|
o : torch.tensor |
|
Reconstruct wavs from codes. Shape: (batch, channels, timesteps) |
|
|
|
""" |
|
quantized = self.quantizer.decode(codes, st=st) |
|
o = self.decoder(quantized) |
|
return o |
|
|