File size: 5,954 Bytes
c968fc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
# Copyright (c) 2023 Amphion.
#
# This code is modified from https://github.com/ZhangXInFD/SpeechTokenizer/blob/main/speechtokenizer/model.py
# Licensed under Apache License 2.0
from .modules.seanet import SEANetEncoder, SEANetDecoder
from .modules.quantization import ResidualVectorQuantizer
import torch.nn as nn
from einops import rearrange
import torch
import numpy as np
class SpeechTokenizer(nn.Module):
def __init__(self, config):
"""
Parameters
----------
config : json
Model Config.
"""
super().__init__()
self.encoder = SEANetEncoder(
n_filters=config.get("n_filters"),
dimension=config.get("dimension"),
ratios=config.get("strides"),
lstm=config.get("lstm_layers"),
bidirectional=config.get("bidirectional"),
dilation_base=config.get("dilation_base"),
residual_kernel_size=config.get("residual_kernel_size"),
n_residual_layers=config.get("n_residual_layers"),
activation=config.get("activation"),
)
self.sample_rate = config.get("sample_rate")
self.n_q = config.get("n_q")
self.downsample_rate = np.prod(config.get("strides"))
if config.get("dimension") != config.get("semantic_dimension"):
self.transform = nn.Linear(
config.get("dimension"), config.get("semantic_dimension")
)
else:
self.transform = nn.Identity()
self.quantizer = ResidualVectorQuantizer(
dimension=config.get("dimension"),
n_q=config.get("n_q"),
bins=config.get("codebook_size"),
)
self.decoder = SEANetDecoder(
n_filters=config.get("n_filters"),
dimension=config.get("dimension"),
ratios=config.get("strides"),
lstm=config.get("lstm_layers"),
bidirectional=False,
dilation_base=config.get("dilation_base"),
residual_kernel_size=config.get("residual_kernel_size"),
n_residual_layers=config.get("n_residual_layers"),
activation=config.get("activation"),
)
@classmethod
def load_from_checkpoint(cls, config_path: str, ckpt_path: str):
"""
Parameters
----------
config_path : str
Path of model configuration file.
ckpt_path : str
Path of model checkpoint.
Returns
-------
model : SpeechTokenizer
SpeechTokenizer model.
"""
import json
with open(config_path) as f:
cfg = json.load(f)
model = cls(cfg)
params = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(params)
return model
def forward(self, x: torch.tensor, n_q: int = None, layers: list = [0]):
"""
Parameters
----------
x : torch.tensor
Input wavs. Shape: (batch, channels, timesteps).
n_q : int, optional
Number of quantizers in RVQ used to encode. The default is all layers.
layers : list[int], optional
Layers of RVQ should return quantized result. The default is the first layer.
Returns
-------
o : torch.tensor
Output wavs. Shape: (batch, channels, timesteps).
commit_loss : torch.tensor
Commitment loss from residual vector quantizers.
feature : torch.tensor
Output of RVQ's first layer. Shape: (batch, timesteps, dimension)
"""
n_q = n_q if n_q else self.n_q
e = self.encoder(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(
e, n_q=n_q, layers=layers
)
feature = rearrange(quantized_list[0], "b d t -> b t d")
feature = self.transform(feature)
o = self.decoder(quantized)
return o, commit_loss, feature
def forward_feature(self, x: torch.tensor, layers: list = None):
"""
Parameters
----------
x : torch.tensor
Input wavs. Shape should be (batch, channels, timesteps).
layers : list[int], optional
Layers of RVQ should return quantized result. The default is all layers.
Returns
-------
quantized_list : list[torch.tensor]
Quantized of required layers.
"""
e = self.encoder(x)
layers = layers if layers else list(range(self.n_q))
quantized, codes, commit_loss, quantized_list = self.quantizer(e, layers=layers)
return quantized_list
def encode(self, x: torch.tensor, n_q: int = None, st: int = None):
"""
Parameters
----------
x : torch.tensor
Input wavs. Shape: (batch, channels, timesteps).
n_q : int, optional
Number of quantizers in RVQ used to encode. The default is all layers.
st : int, optional
Start quantizer index in RVQ. The default is 0.
Returns
-------
codes : torch.tensor
Output indices for each quantizer. Shape: (n_q, batch, timesteps)
"""
e = self.encoder(x)
if st is None:
st = 0
n_q = n_q if n_q else self.n_q
codes = self.quantizer.encode(e, n_q=n_q, st=st)
return codes
def decode(self, codes: torch.tensor, st: int = 0):
"""
Parameters
----------
codes : torch.tensor
Indices for each quantizer. Shape: (n_q, batch, timesteps).
st : int, optional
Start quantizer index in RVQ. The default is 0.
Returns
-------
o : torch.tensor
Reconstruct wavs from codes. Shape: (batch, channels, timesteps)
"""
quantized = self.quantizer.decode(codes, st=st)
o = self.decoder(quantized)
return o
|