File size: 1,916 Bytes
98f685a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from modules.commons.common_layers import *
import random


class MixStyle(nn.Module):
    """MixStyle.
    Reference:
      Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
    """

    def __init__(self, p=0.5, alpha=0.1, eps=1e-6, hidden_size=256):
        """
        Args:
          p (float): probability of using MixStyle.
          alpha (float): parameter of the Beta distribution.
          eps (float): scaling parameter to avoid numerical issues.
          mix (str): how to mix.
        """
        super().__init__()
        self.p = p
        self.beta = torch.distributions.Beta(alpha, alpha)
        self.eps = eps
        self.alpha = alpha
        self._activated = True
        self.hidden_size = hidden_size
        self.affine_layer = LinearNorm(
            hidden_size,
            2 * hidden_size, # For both b (bias) g (gain)
        )

    def __repr__(self):
        return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps})'

    def set_activation_status(self, status=True):
        self._activated = status

    def forward(self, x, spk_embed):
        if not self.training or not self._activated:
            return x

        if random.random() > self.p:
            return x

        B = x.size(0)

        mu, sig = torch.mean(x, dim=-1, keepdim=True), torch.std(x, dim=-1, keepdim=True)
        x_normed = (x - mu) / (sig + 1e-6)  # [B, T, H_m]

        lmda = self.beta.sample((B, 1, 1))
        lmda = lmda.to(x.device)

        # Get Bias and Gain
        mu1, sig1 = torch.split(self.affine_layer(spk_embed), self.hidden_size, dim=-1)  # [B, 1, 2 * H_m] --> 2 * [B, 1, H_m]

        # MixStyle
        perm = torch.randperm(B)
        mu2, sig2 = mu1[perm], sig1[perm]

        mu_mix = mu1*lmda + mu2 * (1-lmda)
        sig_mix = sig1*lmda + sig2 * (1-lmda)

        # Perform Scailing and Shifting
        return sig_mix * x_normed + mu_mix # [B, T, H_m]