Esmail-AGumaan commited on
Commit
8182f5b
1 Parent(s): ca5cb45

Update clip.py

Browse files
Files changed (1) hide show
  1. clip.py +63 -63
clip.py CHANGED
@@ -1,64 +1,64 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
- from nanograd.models.stable_diffusion.attention import SelfAttention
5
-
6
- class CLIPEmbedding(nn.Module):
7
- def __init__(self, n_vocab: int, n_embd: int, n_token: int):
8
- super().__init__()
9
-
10
- self.token_embedding = nn.Embedding(n_vocab, n_embd)
11
- self.position_embedding = nn.Parameter(torch.zeros((n_token, n_embd)))
12
-
13
- def forward(self, tokens):
14
- x = self.token_embedding(tokens)
15
- x += self.position_embedding
16
-
17
- return x
18
-
19
- class CLIPLayer(nn.Module):
20
- def __init__(self, n_head: int, n_embd: int):
21
- super().__init__()
22
- self.layernorm_1 = nn.LayerNorm(n_embd)
23
- self.attention = SelfAttention(n_head, n_embd)
24
- self.layernorm_2 = nn.LayerNorm(n_embd)
25
- self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
26
- self.linear_2 = nn.Linear(4 * n_embd, n_embd)
27
-
28
- def forward(self, x):
29
- residue = x
30
- x = self.layernorm_1(x)
31
- x = self.attention(x, causal_mask=True)
32
- x += residue
33
-
34
- residue = x
35
- x = self.layernorm_2(x)
36
- x = self.linear_1(x)
37
-
38
- x = x * torch.sigmoid(1.702 * x)
39
- x = self.linear_2(x)
40
- x += residue
41
-
42
- return x
43
-
44
- class CLIP(nn.Module):
45
- def __init__(self):
46
- super().__init__()
47
- self.embedding = CLIPEmbedding(49408, 768, 77)
48
-
49
- self.layers = nn.ModuleList([
50
- CLIPLayer(12, 768) for i in range(12)
51
- ])
52
-
53
- self.layernorm = nn.LayerNorm(768)
54
-
55
- def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
56
- tokens = tokens.type(torch.long)
57
-
58
- state = self.embedding(tokens)
59
-
60
- for layer in self.layers:
61
- state = layer(state)
62
- output = self.layernorm(state)
63
-
64
  return output
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from attention import SelfAttention
5
+
6
+ class CLIPEmbedding(nn.Module):
7
+ def __init__(self, n_vocab: int, n_embd: int, n_token: int):
8
+ super().__init__()
9
+
10
+ self.token_embedding = nn.Embedding(n_vocab, n_embd)
11
+ self.position_embedding = nn.Parameter(torch.zeros((n_token, n_embd)))
12
+
13
+ def forward(self, tokens):
14
+ x = self.token_embedding(tokens)
15
+ x += self.position_embedding
16
+
17
+ return x
18
+
19
+ class CLIPLayer(nn.Module):
20
+ def __init__(self, n_head: int, n_embd: int):
21
+ super().__init__()
22
+ self.layernorm_1 = nn.LayerNorm(n_embd)
23
+ self.attention = SelfAttention(n_head, n_embd)
24
+ self.layernorm_2 = nn.LayerNorm(n_embd)
25
+ self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
26
+ self.linear_2 = nn.Linear(4 * n_embd, n_embd)
27
+
28
+ def forward(self, x):
29
+ residue = x
30
+ x = self.layernorm_1(x)
31
+ x = self.attention(x, causal_mask=True)
32
+ x += residue
33
+
34
+ residue = x
35
+ x = self.layernorm_2(x)
36
+ x = self.linear_1(x)
37
+
38
+ x = x * torch.sigmoid(1.702 * x)
39
+ x = self.linear_2(x)
40
+ x += residue
41
+
42
+ return x
43
+
44
+ class CLIP(nn.Module):
45
+ def __init__(self):
46
+ super().__init__()
47
+ self.embedding = CLIPEmbedding(49408, 768, 77)
48
+
49
+ self.layers = nn.ModuleList([
50
+ CLIPLayer(12, 768) for i in range(12)
51
+ ])
52
+
53
+ self.layernorm = nn.LayerNorm(768)
54
+
55
+ def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
56
+ tokens = tokens.type(torch.long)
57
+
58
+ state = self.embedding(tokens)
59
+
60
+ for layer in self.layers:
61
+ state = layer(state)
62
+ output = self.layernorm(state)
63
+
64
  return output