CL-KWS_202408_v1 / model /encoder.py
Francis0917's picture
Upload folder using huggingface_hub
2045faa verified
import tensorflow as tf
import numpy as np
from tensorflow.keras.models import Model
from tensorflow.keras import layers
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)
class Encoder(Model):
"""Base class for encoders"""
def __init__(self, name="Encoder", **kwargs):
super(Encoder, self).__init__(name=name)
def call(self, src, src_len=None):
"""
Args:
src : source of shape `(batch, src_len)`
src_len : lengths of each source of shape `(batch)`
"""
raise NotImplementedError
class AudioEncoder(Encoder):
"""Base class for audio encoders"""
def __init__(self, name="BaseAudioEncoder", **kwargs):
super(AudioEncoder, self).__init__(name=name)
self.crnn = []
self.stride = 1
if kwargs['audio_input'] == 'raw':
for l in kwargs['conv']:
f, k, s = l
self.crnn.append(layers.Conv1D(f, k, s, padding='same'))
self.crnn.append(layers.BatchNormalization())
self.crnn.append(layers.ReLU())
self.stride *= s
for l in kwargs['gru']:
unit = l
self.crnn.append(layers.GRU(unit[0], return_sequences=True))
self.dense = layers.Dense(kwargs['fc'])
self.act = layers.LeakyReLU()
def call(self, src):
"""
Args:
src : source of shape `(batch, time, feature)`
src_len : lengths of each source of shape `(batch)`
"""
# keep the batch mask
mask_flag = 'mask' in vars(src)
if mask_flag:
mask = src.mask[:,::self.stride]
x = src
for layer in self.crnn:
# [B, T, F] -> [B, T/2, Conv1D] -> [B, T/2, GRU]
if isinstance(layer, layers.GRU):
if mask_flag:
x = layer(x, mask=mask)
else:
x = layer(x)
else:
x = layer(x)
# [B, T/2, Dense]
x = self.dense(x)
LD = x
x = self.act(x)
if mask_flag:
x._keras_mask = mask
return x, LD
class EfficientAudioEncoder(Encoder):
"""Efficient encoder class for audio encoders"""
def __init__(self, name="EfficientAudioEncoder", downsample=True, **kwargs):
super(EfficientAudioEncoder, self).__init__(name=name)
self.downsample = downsample
self.layer = []
if self.downsample:
for _ in range(2):
self.layer.append(layers.Conv1D(kwargs['fc'], 5, 2, padding='same'))
self.layer.append(layers.BatchNormalization())
self.layer.append(layers.ReLU())
self.layer.append(layers.MaxPool1D(pool_size=2, strides=2, padding='valid'))
self.layer = self.layer[:-1]
else:
self.layer.append(layers.Conv1D(kwargs['fc'], 3, 2, padding='same'))
self.layer.append(layers.BatchNormalization())
self.layer.append(layers.ReLU())
self.layer.append(layers.Conv1D(kwargs['fc'], 3, 1, padding='same'))
self.layer.append(layers.BatchNormalization())
self.layer.append(layers.ReLU())
self.deConv = layers.Conv1DTranspose(kwargs['fc'], 5, 4)
self.dense = layers.Dense(kwargs['fc'])
self.act = layers.LeakyReLU()
def call(self, specgram, embed):
"""
Args:
embed : google speech embedding of shape `(batch, time / 8, 96)`
specgram : log mel-spectrogram of shape `(batch, time, mel)`
"""
# keep the batch mask
mask_flag = 'mask' in vars(embed)
if mask_flag:
if self.downsample:
mask = specgram.mask[:,::8]
else:
mask = specgram.mask[:,::2]
x = specgram
for l in self.layer:
# [B, T, F] -> [B, T/8, dense] or [B, T/2, dense]
x = l(x)
LD = x
# [B, T/8, dense] or [B, T/2, dense]
if self.downsample:
y = self.act(self.dense(embed))
# Summation two embedding
x += tf.pad(y, [[0, 0],[0, tf.shape(x)[1] - tf.shape(y)[1]],[0, 0]], 'CONSTANT', constant_values=0.0)
else:
y = self.act(self.deConv(embed))
if tf.shape(x)[1] > tf.shape(y)[1]:
x += tf.pad(y, [[0, 0],[0, tf.shape(x)[1] - tf.shape(y)[1]],[0, 0]], 'CONSTANT', constant_values=0.0)
elif tf.shape(x)[1] < tf.shape(y)[1]:
x += y[:, :x.shape[1], :]
if mask_flag:
x._keras_mask = mask
LD._keras_mask = mask
return x, LD
class TextEncoder(Encoder):
"""Base class for text encoders"""
def __init__(self, name="BaseTextEncoder", **kwargs):
super(TextEncoder, self).__init__(name=name)
self.features = kwargs['text_input']
if self.features == 'phoneme':
self.mask = tf.keras.layers.Masking(mask_value=0, input_shape=(None,))
vocab = tf.convert_to_tensor(kwargs['vocab'], dtype=tf.int32)
self.one_hot = layers.Lambda(lambda x: tf.one_hot(x, vocab), dtype=tf.int32, name='ont_hot')
elif self.features == 'g2p_embed':
self.mask = tf.keras.layers.Masking(mask_value=0, input_shape=(None, 256))
self.dense = layers.Dense(kwargs['fc'])
self.act = layers.LeakyReLU()
def call(self, src):
"""
Args:
src : phoneme token of shape `(batch, phoneme)`
src_len : lengths of each source of shape `(batch)`
"""
# [B, phoneme] -> [B, phoneme, embedding]
mask = self.mask(src)
if self.features == 'phoneme':
x = self.one_hot(src)
elif self.features == 'g2p_embed':
x = src
mask = mask[:,:,0]
x = self.act(self.dense(x))
x._keras_mask = tf.cast(mask, tf.bool)
return x