dayyass's picture
Upload 3 files
ef7bae2 verified
raw
history blame
5.03 kB
import math
import torch
class PositionalEncoding(torch.nn.Module):
"""
https://pytorch.org/tutorials/beginner/transformer_tutorial.html
"""
def __init__(self, d_model: int, max_len: int = 512):
super().__init__()
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
)
pe = torch.zeros(max_len, d_model)
pe[:, : d_model // 2] = torch.sin(position * div_term)
pe[:, d_model // 2 :] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.pe[: x.size(0)]
return x
class MultiheadSelfAttention(torch.nn.Module):
def __init__(self, embed_dim: int, num_heads: int = 8):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.query = torch.nn.Linear(
in_features=embed_dim,
out_features=embed_dim,
)
self.key = torch.nn.Linear(
in_features=embed_dim,
out_features=embed_dim,
)
self.value = torch.nn.Linear(
in_features=embed_dim,
out_features=embed_dim,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
q = self.query(x).view(x.shape[0], self.num_heads, -1).transpose(0, 1)
k = self.key(x).view(x.shape[0], self.num_heads, -1).permute(1, 2, 0)
v = self.value(x).view(x.shape[0], self.num_heads, -1).transpose(0, 1)
qk = torch.softmax(
torch.matmul(q, k) / (self.embed_dim / self.num_heads) ** 0.5,
dim=-1,
)
qkv = torch.matmul(qk, v).transpose(0, 1).reshape(x.shape[0], -1)
return qkv
class Block(torch.nn.Module):
def __init__(self, d_model: int, num_heads: int = 8, eps: float = 1e-6):
super().__init__()
self.ln1 = torch.nn.LayerNorm(normalized_shape=d_model, eps=eps)
self.attn = MultiheadSelfAttention(embed_dim=d_model, num_heads=num_heads)
self.ln2 = torch.nn.LayerNorm(normalized_shape=d_model, eps=eps)
self.linear1 = torch.nn.Linear(in_features=d_model, out_features=d_model * 4)
self.linear2 = torch.nn.Linear(in_features=d_model * 4, out_features=d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
ln1 = self.ln1(x)
attn = self.attn(ln1)
ln2 = self.ln2(x + attn)
mlp = self.linear2(torch.relu(self.linear1(ln2)))
return mlp + x + attn
class Head(torch.nn.Module):
def __init__(
self,
d_model: int,
eps: float = 1e-6,
):
super().__init__()
self.d_model = d_model
self.eps = eps
self.ln = torch.nn.LayerNorm(normalized_shape=d_model, eps=eps)
self.linear1 = torch.nn.Linear(in_features=d_model, out_features=d_model)
self.linear2 = torch.nn.Linear(in_features=d_model, out_features=d_model)
self.tanh_layer = torch.nn.Linear(in_features=d_model * 2, out_features=d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
ln = self.ln(x)
mlp = torch.exp(self.linear2(torch.nn.functional.elu(self.linear1(ln))))
res = torch.cat(
[
ln.sum(dim=0) / ln.shape[0],
(mlp * ln).sum(dim=0) / mlp.sum(dim=0),
]
)
res = torch.tanh(self.tanh_layer(res))
res /= (res**2).sum() ** 0.5
res /= (res**2).sum() ** 0.5
return res
class MUSE(torch.nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
d_model: int,
num_heads: int,
eps: float = 1e-6,
):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.d_model = d_model
self.num_heads = num_heads
self.eps = eps
self.embedding = torch.nn.Embedding(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
)
self.linear = torch.nn.Linear(
in_features=embedding_dim,
out_features=d_model,
)
self.pe = PositionalEncoding(
d_model=d_model,
max_len=512, # TODO: remove hardcode
)
self.block0 = Block(d_model=d_model)
self.block1 = Block(d_model=d_model)
self.block2 = Block(d_model=d_model)
self.block3 = Block(d_model=d_model)
self.block4 = Block(d_model=d_model)
self.block5 = Block(d_model=d_model)
self.head = Head(d_model=d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.embedding(x)
x = self.linear(x)
x = self.pe(x)
x = self.block0(x)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
x = self.head(x)
return x