basit-text-to-spech / wavegru.py
basit123796's picture
Upload 25 files
901d6f7 verified
"""
WaveGRU model: melspectrogram => mu-law encoded waveform
"""
from typing import Tuple
import jax
import jax.numpy as jnp
import pax
from pax import GRUState
from tqdm.cli import tqdm
class ReLU(pax.Module):
def __call__(self, x):
return jax.nn.relu(x)
def dilated_residual_conv_block(dim, kernel, stride, dilation):
"""
Use dilated convs to enlarge the receptive field
"""
return pax.Sequential(
pax.Conv1D(dim, dim, kernel, stride, dilation, "VALID", with_bias=False),
pax.LayerNorm(dim, -1, True, True),
ReLU(),
pax.Conv1D(dim, dim, 1, 1, 1, "VALID", with_bias=False),
pax.LayerNorm(dim, -1, True, True),
ReLU(),
)
def tile_1d(x, factor):
"""
Tile tensor of shape N, L, D into N, L*factor, D
"""
N, L, D = x.shape
x = x[:, :, None, :]
x = jnp.tile(x, (1, 1, factor, 1))
x = jnp.reshape(x, (N, L * factor, D))
return x
def up_block(in_dim, out_dim, factor, relu=True):
"""
Tile >> Conv >> BatchNorm >> ReLU
"""
f = pax.Sequential(
lambda x: tile_1d(x, factor),
pax.Conv1D(
in_dim, out_dim, 2 * factor, stride=1, padding="VALID", with_bias=False
),
pax.LayerNorm(out_dim, -1, True, True),
)
if relu:
f >>= ReLU()
return f
class Upsample(pax.Module):
"""
Upsample melspectrogram to match raw audio sample rate.
"""
def __init__(
self, input_dim, hidden_dim, rnn_dim, upsample_factors, has_linear_output=False
):
super().__init__()
self.input_conv = pax.Sequential(
pax.Conv1D(input_dim, hidden_dim, 1, with_bias=False),
pax.LayerNorm(hidden_dim, -1, True, True),
)
self.upsample_factors = upsample_factors
self.dilated_convs = [
dilated_residual_conv_block(hidden_dim, 3, 1, 2**i) for i in range(5)
]
self.up_factors = upsample_factors[:-1]
self.up_blocks = [
up_block(hidden_dim, hidden_dim, x) for x in self.up_factors[:-1]
]
self.up_blocks.append(
up_block(
hidden_dim,
hidden_dim if has_linear_output else 3 * rnn_dim,
self.up_factors[-1],
relu=False,
)
)
if has_linear_output:
self.x2zrh_fc = pax.Linear(hidden_dim, rnn_dim * 3)
self.has_linear_output = has_linear_output
self.final_tile = upsample_factors[-1]
def __call__(self, x, no_repeat=False):
x = self.input_conv(x)
for residual in self.dilated_convs:
y = residual(x)
pad = (x.shape[1] - y.shape[1]) // 2
x = x[:, pad:-pad, :] + y
for f in self.up_blocks:
x = f(x)
if self.has_linear_output:
x = self.x2zrh_fc(x)
if no_repeat:
return x
x = tile_1d(x, self.final_tile)
return x
class GRU(pax.Module):
"""
A customized GRU module.
"""
input_dim: int
hidden_dim: int
def __init__(self, hidden_dim: int):
super().__init__()
self.hidden_dim = hidden_dim
self.h_zrh_fc = pax.Linear(
hidden_dim,
hidden_dim * 3,
w_init=jax.nn.initializers.variance_scaling(
1, "fan_out", "truncated_normal"
),
)
def initial_state(self, batch_size: int) -> GRUState:
"""Create an all zeros initial state."""
return GRUState(jnp.zeros((batch_size, self.hidden_dim), dtype=jnp.float32))
def __call__(self, state: GRUState, x) -> Tuple[GRUState, jnp.ndarray]:
hidden = state.hidden
x_zrh = x
h_zrh = self.h_zrh_fc(hidden)
x_zr, x_h = jnp.split(x_zrh, [2 * self.hidden_dim], axis=-1)
h_zr, h_h = jnp.split(h_zrh, [2 * self.hidden_dim], axis=-1)
zr = x_zr + h_zr
zr = jax.nn.sigmoid(zr)
z, r = jnp.split(zr, 2, axis=-1)
h_hat = x_h + r * h_h
h_hat = jnp.tanh(h_hat)
h = (1 - z) * hidden + z * h_hat
return GRUState(h), h
class Pruner(pax.Module):
"""
Base class for pruners
"""
def compute_sparsity(self, step):
t = jnp.power(1 - (step * 1.0 - 1_000) / 200_000, 3)
z = 0.95 * jnp.clip(1.0 - t, a_min=0, a_max=1)
return z
def prune(self, step, weights):
"""
Return a mask
"""
z = self.compute_sparsity(step)
x = weights
H, W = x.shape
x = x.reshape(H // 4, 4, W // 4, 4)
x = jnp.abs(x)
x = jnp.sum(x, axis=(1, 3), keepdims=True)
q = jnp.quantile(jnp.reshape(x, (-1,)), z)
x = x >= q
x = jnp.tile(x, (1, 4, 1, 4))
x = jnp.reshape(x, (H, W))
return x
class GRUPruner(Pruner):
def __init__(self, gru):
super().__init__()
self.h_zrh_fc_mask = jnp.ones_like(gru.h_zrh_fc.weight) == 1
def __call__(self, gru: pax.GRU):
"""
Apply mask after an optimization step
"""
zrh_masked_weights = jnp.where(self.h_zrh_fc_mask, gru.h_zrh_fc.weight, 0)
gru = gru.replace_node(gru.h_zrh_fc.weight, zrh_masked_weights)
return gru
def update_mask(self, step, gru: pax.GRU):
"""
Update internal masks
"""
z_weight, r_weight, h_weight = jnp.split(gru.h_zrh_fc.weight, 3, axis=1)
z_mask = self.prune(step, z_weight)
r_mask = self.prune(step, r_weight)
h_mask = self.prune(step, h_weight)
self.h_zrh_fc_mask *= jnp.concatenate((z_mask, r_mask, h_mask), axis=1)
class LinearPruner(Pruner):
def __init__(self, linear):
super().__init__()
self.mask = jnp.ones_like(linear.weight) == 1
def __call__(self, linear: pax.Linear):
"""
Apply mask after an optimization step
"""
return linear.replace(weight=jnp.where(self.mask, linear.weight, 0))
def update_mask(self, step, linear: pax.Linear):
"""
Update internal masks
"""
self.mask *= self.prune(step, linear.weight)
class WaveGRU(pax.Module):
"""
WaveGRU vocoder model.
"""
def __init__(
self,
mel_dim=80,
rnn_dim=1024,
upsample_factors=(5, 3, 20),
has_linear_output=False,
):
super().__init__()
self.embed = pax.Embed(256, 3 * rnn_dim)
self.upsample = Upsample(
input_dim=mel_dim,
hidden_dim=512,
rnn_dim=rnn_dim,
upsample_factors=upsample_factors,
has_linear_output=has_linear_output,
)
self.rnn = GRU(rnn_dim)
self.o1 = pax.Linear(rnn_dim, rnn_dim)
self.o2 = pax.Linear(rnn_dim, 256)
self.gru_pruner = GRUPruner(self.rnn)
self.o1_pruner = LinearPruner(self.o1)
self.o2_pruner = LinearPruner(self.o2)
def output(self, x):
x = self.o1(x)
x = jax.nn.relu(x)
x = self.o2(x)
return x
def inference(self, mel, no_gru=False, seed=42):
"""
generate waveform form melspectrogram
"""
@jax.jit
def step(rnn_state, mel, rng_key, x):
x = self.embed(x)
x = x + mel
rnn_state, x = self.rnn(rnn_state, x)
x = self.output(x)
rng_key, next_rng_key = jax.random.split(rng_key, 2)
x = jax.random.categorical(rng_key, x, axis=-1)
return rnn_state, next_rng_key, x
y = self.upsample(mel, no_repeat=no_gru)
if no_gru:
return y
x = jnp.array([127], dtype=jnp.int32)
rnn_state = self.rnn.initial_state(1)
output = []
rng_key = jax.random.PRNGKey(seed)
for i in tqdm(range(y.shape[1])):
rnn_state, rng_key, x = step(rnn_state, y[:, i], rng_key, x)
output.append(x)
x = jnp.concatenate(output, axis=0)
return x
def __call__(self, mel, x):
x = self.embed(x)
y = self.upsample(mel)
pad_left = (x.shape[1] - y.shape[1]) // 2
pad_right = x.shape[1] - y.shape[1] - pad_left
x = x[:, pad_left:-pad_right]
x = x + y
_, x = pax.scan(
self.rnn,
self.rnn.initial_state(x.shape[0]),
x,
time_major=False,
)
x = self.output(x)
return x