Spaces:
Runtime error
Runtime error
import os, sys | |
import tensorflow as tf | |
import numpy as np | |
from tensorflow.keras.models import Model | |
from tensorflow.keras import layers | |
sys.path.append(os.path.dirname(__file__)) | |
from utils import make_adjacency_matrix, make_feature_matrix | |
seed = 42 | |
tf.random.set_seed(seed) | |
np.random.seed(seed) | |
class Extractor(Model): | |
"""Base class for extractors""" | |
def __init__(self, name="Extractor", **kwargs): | |
super(Extractor, self).__init__(name=name) | |
def call(self, emb_s, emb_t, emb_s_len=None, emb_t_len=None): | |
""" | |
Args: | |
emb_s : speech embedding of shape `(batch, time, embedding)` | |
emb_t : text embedding of shape `(batch, phoneme, embedding)` | |
emb_s_len : length of speech embedding of shape `(B,)` | |
emb_t_len : length of text embedding of shape `(B,)` | |
""" | |
raise NotImplementedError | |
class BaseExtractor(Extractor): | |
"""Base class for pattern extractor""" | |
def __init__(self, name="BaseExtractor", **kwargs): | |
super(BaseExtractor, self).__init__(name=name) | |
self.embedding = kwargs['embedding'] | |
self.attn = layers.MultiHeadAttention(num_heads=1, key_dim=self.embedding) | |
def call(self, emb_s, emb_t): | |
""" | |
Args: | |
emb_s : speech embedding of shape `(batch, time, embedding)` | |
emb_t : text embedding of shape `(batch, phoneme, embedding)` | |
emb_s_len : length of speech embedding of shape `(B,)` | |
emb_t_len : length of text embedding of shape `(B,)` | |
* Query - text, Key,Value - speech * | |
""" | |
Q = emb_t | |
V = emb_s | |
# [B, Tt, m], [B, Tt, Ta] notation followed Learning Audio-Text Agreement for Open-vocabulary Keyword Spotting | |
if ('_keras_mask' in vars(Q)) and ('_keras_mask' in vars(V)): | |
if Q._keras_mask is None: | |
attn_mask = None | |
else: | |
attn_mask = tf.expand_dims(tf.cast(Q._keras_mask, tf.int32), -1) * tf.expand_dims(tf.cast(V._keras_mask, tf.int32), 1) | |
else: | |
attn_mask = None | |
attention_output, affinity_matrix = self.attn(Q, V, | |
return_attention_scores=True, | |
attention_mask = attn_mask | |
) | |
if self.attn._num_heads == 1: | |
affinity_matrix = affinity_matrix[:,0,:,:] | |
if attn_mask is not None: | |
affinity_matrix._keras_mask = attn_mask | |
if attn_mask is not None: | |
attention_output._keras_mask = Q._keras_mask | |
return attention_output, affinity_matrix | |
class StackExtractor(Extractor): | |
"""Self-attention based pattern extractor""" | |
def __init__(self, name="StackExtractor", **kwargs): | |
super(StackExtractor, self).__init__(name=name) | |
self.embedding = kwargs['embedding'] | |
self.attn = layers.MultiHeadAttention(num_heads=1, key_dim=self.embedding) | |
def call(self, emb_s, emb_t): | |
""" | |
Args: | |
emb_s : speech embedding of shape `(batch, time, embedding)` | |
emb_t : text embedding of shape `(batch, phoneme, embedding)` | |
* Query - text, Key,Value - speech * | |
""" | |
Q = make_feature_matrix(emb_s, emb_s._keras_mask, emb_t, emb_t._keras_mask) | |
V = Q | |
attn_mask = make_adjacency_matrix(emb_s._keras_mask, emb_t._keras_mask) | |
attention_output, affinity_matrix = self.attn(Q, V, | |
return_attention_scores=True, | |
attention_mask = attn_mask | |
) | |
if self.attn._num_heads == 1: | |
affinity_matrix = affinity_matrix[:,0,:,:] | |
if attn_mask is not None: | |
affinity_matrix._keras_mask = attn_mask | |
if attn_mask is not None: | |
attention_output._keras_mask = attn_mask[:,:,0] | |
return attention_output, affinity_matrix |