|
import math |
|
|
|
import torch |
|
|
|
|
|
class PositionalEncoding(torch.nn.Module): |
|
""" |
|
https://pytorch.org/tutorials/beginner/transformer_tutorial.html |
|
""" |
|
|
|
def __init__(self, d_model: int, max_len: int = 512): |
|
super().__init__() |
|
|
|
position = torch.arange(max_len).unsqueeze(1) |
|
div_term = torch.exp( |
|
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) |
|
) |
|
|
|
pe = torch.zeros(max_len, d_model) |
|
pe[:, : d_model // 2] = torch.sin(position * div_term) |
|
pe[:, d_model // 2 :] = torch.cos(position * div_term) |
|
|
|
self.register_buffer("pe", pe) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = x + self.pe[: x.size(0)] |
|
return x |
|
|
|
|
|
class MultiheadSelfAttention(torch.nn.Module): |
|
def __init__(self, embed_dim: int, num_heads: int = 8): |
|
super().__init__() |
|
|
|
self.embed_dim = embed_dim |
|
self.num_heads = num_heads |
|
|
|
self.query = torch.nn.Linear( |
|
in_features=embed_dim, |
|
out_features=embed_dim, |
|
) |
|
self.key = torch.nn.Linear( |
|
in_features=embed_dim, |
|
out_features=embed_dim, |
|
) |
|
self.value = torch.nn.Linear( |
|
in_features=embed_dim, |
|
out_features=embed_dim, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
q = self.query(x).view(x.shape[0], self.num_heads, -1).transpose(0, 1) |
|
k = self.key(x).view(x.shape[0], self.num_heads, -1).permute(1, 2, 0) |
|
v = self.value(x).view(x.shape[0], self.num_heads, -1).transpose(0, 1) |
|
qk = torch.softmax( |
|
torch.matmul(q, k) / (self.embed_dim / self.num_heads) ** 0.5, |
|
dim=-1, |
|
) |
|
qkv = torch.matmul(qk, v).transpose(0, 1).reshape(x.shape[0], -1) |
|
return qkv |
|
|
|
|
|
class Block(torch.nn.Module): |
|
def __init__(self, d_model: int, num_heads: int = 8, eps: float = 1e-6): |
|
super().__init__() |
|
|
|
self.ln1 = torch.nn.LayerNorm(normalized_shape=d_model, eps=eps) |
|
self.attn = MultiheadSelfAttention(embed_dim=d_model, num_heads=num_heads) |
|
self.ln2 = torch.nn.LayerNorm(normalized_shape=d_model, eps=eps) |
|
self.linear1 = torch.nn.Linear(in_features=d_model, out_features=d_model * 4) |
|
self.linear2 = torch.nn.Linear(in_features=d_model * 4, out_features=d_model) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
ln1 = self.ln1(x) |
|
attn = self.attn(ln1) |
|
ln2 = self.ln2(x + attn) |
|
mlp = self.linear2(torch.relu(self.linear1(ln2))) |
|
return mlp + x + attn |
|
|
|
|
|
class Head(torch.nn.Module): |
|
def __init__( |
|
self, |
|
d_model: int, |
|
eps: float = 1e-6, |
|
): |
|
super().__init__() |
|
|
|
self.d_model = d_model |
|
self.eps = eps |
|
|
|
self.ln = torch.nn.LayerNorm(normalized_shape=d_model, eps=eps) |
|
self.linear1 = torch.nn.Linear(in_features=d_model, out_features=d_model) |
|
self.linear2 = torch.nn.Linear(in_features=d_model, out_features=d_model) |
|
self.tanh_layer = torch.nn.Linear(in_features=d_model * 2, out_features=d_model) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
ln = self.ln(x) |
|
mlp = torch.exp(self.linear2(torch.nn.functional.elu(self.linear1(ln)))) |
|
res = torch.cat( |
|
[ |
|
ln.sum(dim=0) / ln.shape[0], |
|
(mlp * ln).sum(dim=0) / mlp.sum(dim=0), |
|
] |
|
) |
|
res = torch.tanh(self.tanh_layer(res)) |
|
res /= (res**2).sum() ** 0.5 |
|
res /= (res**2).sum() ** 0.5 |
|
return res |
|
|
|
|
|
class MUSE(torch.nn.Module): |
|
def __init__( |
|
self, |
|
num_embeddings: int, |
|
embedding_dim: int, |
|
d_model: int, |
|
num_heads: int, |
|
eps: float = 1e-6, |
|
): |
|
super().__init__() |
|
|
|
self.num_embeddings = num_embeddings |
|
self.embedding_dim = embedding_dim |
|
self.d_model = d_model |
|
self.num_heads = num_heads |
|
self.eps = eps |
|
|
|
self.embedding = torch.nn.Embedding( |
|
num_embeddings=num_embeddings, |
|
embedding_dim=embedding_dim, |
|
) |
|
self.linear = torch.nn.Linear( |
|
in_features=embedding_dim, |
|
out_features=d_model, |
|
) |
|
self.pe = PositionalEncoding( |
|
d_model=d_model, |
|
max_len=512, |
|
) |
|
self.block0 = Block(d_model=d_model) |
|
self.block1 = Block(d_model=d_model) |
|
self.block2 = Block(d_model=d_model) |
|
self.block3 = Block(d_model=d_model) |
|
self.block4 = Block(d_model=d_model) |
|
self.block5 = Block(d_model=d_model) |
|
self.head = Head(d_model=d_model) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.embedding(x) |
|
x = self.linear(x) |
|
x = self.pe(x) |
|
x = self.block0(x) |
|
x = self.block1(x) |
|
x = self.block2(x) |
|
x = self.block3(x) |
|
x = self.block4(x) |
|
x = self.block5(x) |
|
x = self.head(x) |
|
return x |
|
|