File size: 940 Bytes
d945eeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import torch
import torch.nn as nn
class Modulation(nn.Module):
def __init__(
self,
embedding_dim: int,
condition_dim: int,
zero_init: bool = False,
single_layer: bool = False,
):
super().__init__()
self.silu = nn.SiLU()
if single_layer:
self.linear1 = nn.Identity()
else:
self.linear1 = nn.Linear(condition_dim, condition_dim)
self.linear2 = nn.Linear(condition_dim, embedding_dim * 2)
# Only zero init the last linear layer
if zero_init:
nn.init.zeros_(self.linear2.weight)
nn.init.zeros_(self.linear2.bias)
def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
emb = self.linear2(self.silu(self.linear1(condition)))
scale, shift = torch.chunk(emb, 2, dim=1)
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return x
|