Spaces:
Runtime error
Runtime error
File size: 1,777 Bytes
2045faa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import tensorflow as tf
import numpy as np
from tensorflow.keras.models import Model
from tensorflow.keras import layers
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)
class Discriminator(Model):
"""Base class for discriminators"""
def __init__(self, name="Discriminator", **kwargs):
super(Discriminator, self).__init__(name=name)
def call(self, src, src_len=None):
"""
Args:
src : source of shape `(batch, src_len)`
src_len : lengths of each source of shape `(batch)`
"""
raise NotImplementedError
class BaseDiscriminator(Discriminator):
"""Base class for discriminators"""
def __init__(self, name="BaseDiscriminator", **kwargs):
super(BaseDiscriminator, self).__init__(name=name)
self.gru = []
for i, l in enumerate(kwargs['gru']):
unit = l
if i == len(kwargs['gru']) - 1:
self.gru.append(layers.GRU(unit[0], return_sequences=False))
else:
self.gru.append(layers.GRU(unit[0], return_sequences=True))
self.dense = layers.Dense(1)
self.act = layers.Lambda(lambda x: tf.keras.activations.sigmoid(x), name='sigmoid')
def call(self, src, src_len=None):
"""
Args:
src : source of shape `(batch, time, feature)`
src_len : lengths of each source of shape `(batch)`
"""
x = src
for layer in self.gru:
# [B, Tt, m] -> [B, embedding]
if '_keras_mask' in vars(src):
x = layer(x, mask=tf.cast(src._keras_mask, tf.bool))
else:
x = layer(x)
# [B, 1]
x = self.dense(x)
return self.act(x), x |