File size: 1,777 Bytes
2045faa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import tensorflow as tf
import numpy as np
from tensorflow.keras.models import Model
from tensorflow.keras import layers

seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)

class Discriminator(Model):
    """Base class for discriminators"""
    
    def __init__(self, name="Discriminator", **kwargs):
        super(Discriminator, self).__init__(name=name)

    def call(self, src, src_len=None):
        """
        Args:
            src     : source of shape `(batch, src_len)`
            src_len : lengths of each source of shape `(batch)`
        """
        raise NotImplementedError

class BaseDiscriminator(Discriminator):
    """Base class for discriminators"""
    
    def __init__(self, name="BaseDiscriminator", **kwargs):
        super(BaseDiscriminator, self).__init__(name=name)
        self.gru = []
        for i, l in enumerate(kwargs['gru']):
            unit = l
            if i == len(kwargs['gru']) - 1:
                self.gru.append(layers.GRU(unit[0], return_sequences=False))
            else:
                self.gru.append(layers.GRU(unit[0], return_sequences=True))
        self.dense = layers.Dense(1)
        self.act = layers.Lambda(lambda x: tf.keras.activations.sigmoid(x), name='sigmoid')

        
    def call(self, src, src_len=None):
        """
        Args:
            src         : source of shape `(batch, time, feature)`
            src_len     : lengths of each source of shape `(batch)`
        """
        x = src
        for layer in self.gru:
            # [B, Tt, m] -> [B, embedding]
            if '_keras_mask' in vars(src):
                x = layer(x, mask=tf.cast(src._keras_mask, tf.bool))
            else:
                x = layer(x)
        # [B, 1]
        x = self.dense(x)
        return self.act(x), x