File size: 702 Bytes
3cdcba2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch
import torch.nn as nn
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
class NoisyGate(nn.Module):
def __init__(self, hidden_dim, num_experts, noise_mult=1.0, bias=False):
super().__init__()
self.hidden_dim = hidden_dim
self.num_experts = num_experts
self.noise_mult = noise_mult
self.bias = bias
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=self.bias)
def forward(self, x):
x = self.gate(x)
noise = gumbel_noise(x)
out = x + noise * self.noise_mult
return out
|