SunderAli17 commited on
Commit
8b964ac
·
verified ·
1 Parent(s): e3ac27a

Create encoders.py

Browse files
Files changed (1) hide show
  1. toonmage/encoders.py +64 -0
toonmage/encoders.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class IDEncoder(nn.Module):
6
+ def __init__(self, width=1280, context_dim=2048, num_token=5):
7
+ super().__init__()
8
+ self.num_token = num_token
9
+ self.context_dim = context_dim
10
+ h1 = min((context_dim * num_token) // 4, 1024)
11
+ h2 = min((context_dim * num_token) // 2, 1024)
12
+ self.body = nn.Sequential(
13
+ nn.Linear(width, h1),
14
+ nn.LayerNorm(h1),
15
+ nn.LeakyReLU(),
16
+ nn.Linear(h1, h2),
17
+ nn.LayerNorm(h2),
18
+ nn.LeakyReLU(),
19
+ nn.Linear(h2, context_dim * num_token),
20
+ )
21
+
22
+ for i in range(5):
23
+ setattr(
24
+ self,
25
+ f'mapping_{i}',
26
+ nn.Sequential(
27
+ nn.Linear(1024, 1024),
28
+ nn.LayerNorm(1024),
29
+ nn.LeakyReLU(),
30
+ nn.Linear(1024, 1024),
31
+ nn.LayerNorm(1024),
32
+ nn.LeakyReLU(),
33
+ nn.Linear(1024, context_dim),
34
+ ),
35
+ )
36
+
37
+ setattr(
38
+ self,
39
+ f'mapping_patch_{i}',
40
+ nn.Sequential(
41
+ nn.Linear(1024, 1024),
42
+ nn.LayerNorm(1024),
43
+ nn.LeakyReLU(),
44
+ nn.Linear(1024, 1024),
45
+ nn.LayerNorm(1024),
46
+ nn.LeakyReLU(),
47
+ nn.Linear(1024, context_dim),
48
+ ),
49
+ )
50
+
51
+ def forward(self, x, y):
52
+ # x shape [N, C]
53
+ x = self.body(x)
54
+ x = x.reshape(-1, self.num_token, self.context_dim)
55
+
56
+ hidden_states = ()
57
+ for i, emb in enumerate(y):
58
+ hidden_state = getattr(self, f'mapping_{i}')(emb[:, :1]) + getattr(self, f'mapping_patch_{i}')(
59
+ emb[:, 1:]
60
+ ).mean(dim=1, keepdim=True)
61
+ hidden_states += (hidden_state,)
62
+ hidden_states = torch.cat(hidden_states, dim=1)
63
+
64
+ return torch.cat([x, hidden_states], dim=1)