|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from attention import SelfAttention |
|
|
|
class CLIPEmbedding(nn.Module): |
|
def __init__(self, n_vocab: int, n_embd: int, n_token: int): |
|
super().__init__() |
|
|
|
self.token_embedding = nn.Embedding(n_vocab, n_embd) |
|
|
|
self.position_embedding = nn.Parameter(torch.zeros((n_token, n_embd))) |
|
|
|
def forward(self, tokens): |
|
|
|
x = self.token_embedding(tokens) |
|
|
|
x += self.position_embedding |
|
|
|
return x |
|
|
|
class CLIPLayer(nn.Module): |
|
def __init__(self, n_head: int, n_embd: int): |
|
super().__init__() |
|
|
|
|
|
self.layernorm_1 = nn.LayerNorm(n_embd) |
|
|
|
self.attention = SelfAttention(n_head, n_embd) |
|
|
|
self.layernorm_2 = nn.LayerNorm(n_embd) |
|
|
|
self.linear_1 = nn.Linear(n_embd, 4 * n_embd) |
|
self.linear_2 = nn.Linear(4 * n_embd, n_embd) |
|
|
|
def forward(self, x): |
|
|
|
residue = x |
|
|
|
|
|
|
|
|
|
x = self.layernorm_1(x) |
|
|
|
|
|
x = self.attention(x, causal_mask=True) |
|
|
|
|
|
x += residue |
|
|
|
|
|
|
|
|
|
residue = x |
|
|
|
x = self.layernorm_2(x) |
|
|
|
|
|
x = self.linear_1(x) |
|
|
|
|
|
x = x * torch.sigmoid(1.702 * x) |
|
|
|
|
|
x = self.linear_2(x) |
|
|
|
|
|
x += residue |
|
|
|
return x |
|
|
|
class CLIP(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.embedding = CLIPEmbedding(49408, 768, 77) |
|
|
|
self.layers = nn.ModuleList([ |
|
CLIPLayer(12, 768) for i in range(12) |
|
]) |
|
|
|
self.layernorm = nn.LayerNorm(768) |
|
|
|
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor: |
|
tokens = tokens.type(torch.long) |
|
|
|
|
|
state = self.embedding(tokens) |
|
|
|
|
|
for layer in self.layers: |
|
|
|
state = layer(state) |
|
|
|
output = self.layernorm(state) |
|
|
|
return output |