JMalott commited on
Commit
1a2253c
1 Parent(s): 893e878

Upload dalle_bart_decoder.py

Browse files
min_dalle/models/dalle_bart_decoder.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List
2
+ import torch
3
+ from torch import nn, LongTensor, FloatTensor, BoolTensor
4
+ from .dalle_bart_encoder import GLU, AttentionBase
5
+
6
+ IMAGE_TOKEN_COUNT = 256
7
+
8
+
9
+ class DecoderCrossAttention(AttentionBase):
10
+ def forward(
11
+ self,
12
+ decoder_state: FloatTensor,
13
+ encoder_state: FloatTensor,
14
+ attention_mask: BoolTensor
15
+ ) -> FloatTensor:
16
+ keys = self.k_proj.forward(encoder_state)
17
+ values = self.v_proj.forward(encoder_state)
18
+ queries = self.q_proj.forward(decoder_state)
19
+ return super().forward(keys, values, queries, attention_mask)
20
+
21
+
22
+ class DecoderSelfAttention(AttentionBase):
23
+ def __init__(self, head_count: int, embed_count: int):
24
+ super().__init__(head_count, embed_count)
25
+
26
+
27
+ def forward(
28
+ self,
29
+ decoder_state: FloatTensor,
30
+ attention_state: FloatTensor,
31
+ attn_mask: BoolTensor,
32
+ token_index: LongTensor
33
+ ) -> Tuple[FloatTensor, FloatTensor]:
34
+ keys = self.k_proj.forward(decoder_state)
35
+ values = self.v_proj.forward(decoder_state)
36
+ queries = self.q_proj.forward(decoder_state)
37
+ attn_state_new = torch.cat([keys, values]).to(attention_state.dtype)
38
+ attention_state[:, token_index] = attn_state_new
39
+ batch_count = decoder_state.shape[0]
40
+ keys = attention_state[:batch_count]
41
+ values = attention_state[batch_count:]
42
+ decoder_state = super().forward(keys, values, queries, attn_mask)
43
+ return decoder_state, attention_state
44
+
45
+
46
+ class DecoderLayer(nn.Module):
47
+ def __init__(
48
+ self,
49
+ head_count: int,
50
+ embed_count: int,
51
+ glu_embed_count: int,
52
+ device: str
53
+ ):
54
+ super().__init__()
55
+ self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
56
+ self.self_attn = DecoderSelfAttention(head_count, embed_count)
57
+ self.self_attn_layer_norm = nn.LayerNorm(embed_count)
58
+ self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count)
59
+ self.encoder_attn = DecoderCrossAttention(head_count, embed_count)
60
+ self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
61
+ self.glu = GLU(embed_count, glu_embed_count)
62
+ self.token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=device)
63
+
64
+
65
+ def forward(
66
+ self,
67
+ decoder_state: FloatTensor,
68
+ encoder_state: FloatTensor,
69
+ attention_state: FloatTensor,
70
+ attention_mask: BoolTensor,
71
+ token_index: LongTensor
72
+ ) -> Tuple[FloatTensor, FloatTensor]:
73
+ # Self Attention
74
+ self_attn_mask = self.token_indices < token_index + 1
75
+ self_attn_mask = self_attn_mask[None][[0] * decoder_state.shape[0]]
76
+ residual = decoder_state
77
+ decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
78
+ decoder_state, attention_state = self.self_attn.forward(
79
+ decoder_state=decoder_state,
80
+ attention_state=attention_state,
81
+ attn_mask=self_attn_mask,
82
+ token_index=token_index
83
+ )
84
+ decoder_state = self.self_attn_layer_norm.forward(decoder_state)
85
+ decoder_state = residual + decoder_state
86
+
87
+ # Cross Attention
88
+ residual = decoder_state
89
+ decoder_state = self.pre_encoder_attn_layer_norm.forward(decoder_state)
90
+ decoder_state = self.encoder_attn.forward(
91
+ decoder_state=decoder_state,
92
+ encoder_state=encoder_state,
93
+ attention_mask=attention_mask
94
+ )
95
+ decoder_state = self.encoder_attn_layer_norm.forward(decoder_state)
96
+ decoder_state = residual + decoder_state
97
+
98
+ # Feed forward
99
+ residual = decoder_state
100
+ decoder_state = self.glu.forward(decoder_state)
101
+ decoder_state = residual + decoder_state
102
+
103
+ return decoder_state, attention_state
104
+
105
+
106
+ class DalleBartDecoder(nn.Module):
107
+ def __init__(
108
+ self,
109
+ image_vocab_count: int,
110
+ embed_count: int,
111
+ attention_head_count: int,
112
+ glu_embed_count: int,
113
+ layer_count: int,
114
+ device: str
115
+ ):
116
+ super().__init__()
117
+ self.layer_count = layer_count
118
+ self.embed_count = embed_count
119
+ self.image_vocab_count = image_vocab_count
120
+ self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
121
+ self.embed_positions = nn.Embedding(IMAGE_TOKEN_COUNT, embed_count)
122
+ self.layers: List[DecoderLayer] = nn.ModuleList([
123
+ DecoderLayer(
124
+ head_count=attention_head_count,
125
+ embed_count=embed_count,
126
+ glu_embed_count=glu_embed_count,
127
+ device=device
128
+ )
129
+ for _ in range(layer_count)
130
+ ])
131
+ self.layernorm_embedding = nn.LayerNorm(embed_count)
132
+ self.final_ln = nn.LayerNorm(embed_count)
133
+ self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False)
134
+ self.token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=device)
135
+
136
+
137
+ def forward(
138
+ self,
139
+ settings: FloatTensor,
140
+ attention_mask: BoolTensor,
141
+ encoder_state: FloatTensor,
142
+ attention_state: FloatTensor,
143
+ prev_tokens: LongTensor,
144
+ token_index: LongTensor
145
+ ) -> Tuple[LongTensor, FloatTensor]:
146
+ image_count = encoder_state.shape[0] // 2
147
+ token_index_batched = token_index[[0] * image_count * 2]
148
+ prev_tokens = prev_tokens[list(range(image_count)) * 2]
149
+ prev_tokens.clamp_(0, self.image_vocab_count)
150
+ decoder_state = self.embed_tokens.forward(prev_tokens)
151
+ decoder_state += self.embed_positions.forward(token_index_batched)
152
+ decoder_state = self.layernorm_embedding.forward(decoder_state)
153
+ decoder_state = decoder_state[:, None]
154
+ for i in range(self.layer_count):
155
+ decoder_state, attention_state[i] = self.layers[i].forward(
156
+ decoder_state,
157
+ encoder_state,
158
+ attention_state[i],
159
+ attention_mask,
160
+ token_index
161
+ )
162
+ decoder_state = self.final_ln(decoder_state)
163
+ logits = self.lm_head(decoder_state)
164
+ temperature = settings[[0]]
165
+ top_k = settings[[1]].to(torch.long)
166
+ supercondition_factor = settings[[2]]
167
+ logits = logits[:, -1, : 2 ** 14]
168
+ logits: FloatTensor = (
169
+ logits[:image_count] * (1 - supercondition_factor) +
170
+ logits[image_count:] * supercondition_factor
171
+ )
172
+ logits_sorted, _ = logits.sort(descending=True)
173
+ is_kept = logits >= logits_sorted[:, top_k - 1]
174
+ logits -= logits_sorted[:, [0]]
175
+ logits /= temperature
176
+ logits.exp_()
177
+ logits *= is_kept.to(torch.float32)
178
+ image_tokens = torch.multinomial(logits, 1)[:, 0]
179
+ return image_tokens, attention_state