Chengxu Zhuang commited on
Commit
6ec18a1
1 Parent(s): b71988c

model upload

Browse files
config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "facebook/opt-125m",
3
+ "_remove_final_layer_norm": false,
4
+ "activation_dropout": 0.0,
5
+ "activation_function": "relu",
6
+ "architectures": [
7
+ "FlamingoForCausalLM"
8
+ ],
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_flamingo.FlamingoConfig",
11
+ "AutoModelForCausalLM": "modeling_flamingo.FlamingoForCausalLM"
12
+ },
13
+ "attention_dropout": 0.0,
14
+ "bos_token_id": 2,
15
+ "cross_attn_every": 2,
16
+ "do_layer_norm_before": true,
17
+ "dropout": 0.1,
18
+ "enable_bias": true,
19
+ "eos_token_id": 2,
20
+ "ffn_dim": 3072,
21
+ "finetune_LM": true,
22
+ "hidden_size": 768,
23
+ "id_perceiver": false,
24
+ "init_std": 0.02,
25
+ "inp_dim": 768,
26
+ "layer_norm_elementwise_affine": true,
27
+ "layerdrop": 0.0,
28
+ "max_position_embeddings": 2048,
29
+ "media_token_id": 32768,
30
+ "model_type": "opt",
31
+ "num_attention_heads": 12,
32
+ "num_hidden_layers": 12,
33
+ "only_attend_immediate_media": true,
34
+ "pad_token_id": 1,
35
+ "perceiver_depth": 2,
36
+ "perceiver_num_latents": 64,
37
+ "prefix": "</s>",
38
+ "torch_dtype": "float32",
39
+ "transformers_version": "4.29.0",
40
+ "use_cache": true,
41
+ "vocab_size": 32778,
42
+ "word_embed_proj_dim": 768
43
+ }
configuration_flamingo.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from typing import Union
18
+
19
+ import transformers.models.opt.configuration_opt as configuration_opt
20
+
21
+
22
+ class FlamingoConfig(configuration_opt.OPTConfig, dict):
23
+ model_type = "flamingo"
24
+ def __init__(
25
+ self,
26
+ cross_attn_every=2,
27
+ vocab_size=32778,
28
+ media_token_id=32768,
29
+ **kwargs,
30
+ ):
31
+ configuration_opt.OPTConfig.__init__(
32
+ self, vocab_size=vocab_size, **kwargs)
33
+ self.media_token_id = media_token_id
34
+ self.cross_attn_every = cross_attn_every
35
+ dict.__init__(self, **self.__dict__)
flamingo_pytorch.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ import torch.nn.functional as F
4
+
5
+ from einops import rearrange, repeat
6
+ from einops_exts import rearrange_many, repeat_many
7
+ import pdb
8
+
9
+
10
+ def exists(val):
11
+ return val is not None
12
+
13
+ def FeedForward(dim, mult = 4):
14
+ inner_dim = int(dim * mult)
15
+ return nn.Sequential(
16
+ nn.LayerNorm(dim),
17
+ nn.Linear(dim, inner_dim, bias = False),
18
+ nn.GELU(),
19
+ nn.Linear(inner_dim, dim, bias = False)
20
+ )
21
+
22
+ class PerceiverAttention(nn.Module):
23
+ def __init__(
24
+ self,
25
+ *,
26
+ dim,
27
+ dim_head = 64,
28
+ heads = 8
29
+ ):
30
+ super().__init__()
31
+ self.scale = dim_head ** -0.5
32
+ self.heads = heads
33
+ inner_dim = dim_head * heads
34
+
35
+ self.norm_media = nn.LayerNorm(dim)
36
+ self.norm_latents = nn.LayerNorm(dim)
37
+
38
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
39
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
40
+ self.to_out = nn.Linear(inner_dim, dim, bias = False)
41
+
42
+ def forward(self, x, latents):
43
+ """
44
+ einstein notation
45
+ b - batch
46
+ t - time
47
+ n - sequence
48
+ d - dimension
49
+ """
50
+ x = self.norm_media(x)
51
+ latents = self.norm_latents(latents)
52
+
53
+ b, m, h = *x.shape[:2], self.heads
54
+
55
+ q = self.to_q(latents)
56
+
57
+ # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
58
+ kv_input = torch.cat((x, latents), dim = -2)
59
+ k, v = self.to_kv(kv_input).chunk(2, dim = -1)
60
+
61
+ q, k, v = rearrange_many((q, k, v), 'b t n (h d) -> b h t n d', h = h)
62
+
63
+ q = q * self.scale
64
+
65
+ # attention
66
+
67
+ sim = einsum('... i d, ... j d -> ... i j', q, k)
68
+
69
+ sim = sim - sim.amax(dim = -1, keepdim = True).detach()
70
+ attn = sim.softmax(dim = -1)
71
+
72
+ out = einsum('... i j, ... j d -> ... i d', attn, v)
73
+ out = rearrange(out, 'b h t n d -> b t n (h d)', h = h)
74
+ return self.to_out(out)
75
+
76
+ class PerceiverResampler(nn.Module):
77
+ def __init__(
78
+ self,
79
+ *,
80
+ dim,
81
+ depth,
82
+ dim_head = 64,
83
+ heads = 8,
84
+ num_latents = 64,
85
+ num_time_embeds = 4,
86
+ ff_mult = 4,
87
+ inp_dim=None,
88
+ ):
89
+ super().__init__()
90
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
91
+ self.time_pos_emb = nn.Parameter(torch.randn(num_time_embeds, 1, dim))
92
+ if inp_dim is not None:
93
+ self.inp_linear = nn.Linear(inp_dim, dim, bias=False)
94
+ else:
95
+ self.inp_linear = None
96
+
97
+ self.layers = nn.ModuleList([])
98
+ for _ in range(depth):
99
+ self.layers.append(nn.ModuleList([
100
+ PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
101
+ FeedForward(dim = dim, mult = ff_mult)
102
+ ]))
103
+
104
+ self.norm = nn.LayerNorm(dim)
105
+
106
+ def forward(self, x):
107
+ if x.ndim == 3:
108
+ x = rearrange(x, 'b n d -> b 1 n d')
109
+
110
+ if self.inp_linear is not None:
111
+ x = self.inp_linear(x)
112
+
113
+ times = x.shape[1]
114
+ x = x + self.time_pos_emb[:times]
115
+
116
+ latents = repeat(self.latents, 'n d -> b m n d', b = x.shape[0], m = x.shape[1])
117
+
118
+ for attn, ff in self.layers:
119
+ latents = attn(x, latents) + latents
120
+ latents = ff(latents) + latents
121
+
122
+ return self.norm(latents)
123
+
124
+ # gated cross attention
125
+
126
+ class MaskedCrossAttention(nn.Module):
127
+ def __init__(
128
+ self,
129
+ *,
130
+ dim,
131
+ dim_head = 64,
132
+ heads = 8,
133
+ only_attend_immediate_media = True
134
+ ):
135
+ super().__init__()
136
+ self.scale = dim_head ** -0.5
137
+ self.heads = heads
138
+ inner_dim = dim_head * heads
139
+
140
+ self.norm = nn.LayerNorm(dim)
141
+
142
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
143
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
144
+ self.to_out = nn.Linear(inner_dim, dim, bias = False)
145
+
146
+ # whether for text to only attend to immediate preceding image, or all images
147
+
148
+ self.only_attend_immediate_media = only_attend_immediate_media
149
+
150
+ def forward(
151
+ self,
152
+ x,
153
+ media,
154
+ media_locations = None
155
+ ):
156
+ b, t, m = media.shape[:3]
157
+ h = self.heads
158
+
159
+ x = self.norm(x)
160
+
161
+ q = self.to_q(x)
162
+ media = rearrange(media, 'b t n d -> b (t n) d')
163
+
164
+ k, v = self.to_kv(media).chunk(2, dim = -1)
165
+ q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
166
+
167
+ q = q * self.scale
168
+
169
+ sim = einsum('... i d, ... j d -> ... i j', q, k)
170
+
171
+ if exists(media_locations):
172
+ text_time = media_locations.cumsum(dim = -1) # at each boolean of True, increment the time counter (relative to media time)
173
+ media_time = torch.arange(t, device = x.device) + 1
174
+
175
+ # text time must equal media time if only attending to most immediate image
176
+ # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
177
+ mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
178
+
179
+ text_to_media_mask = mask_op(rearrange(text_time, 'b i -> b 1 i 1'), repeat(media_time, 'j -> 1 1 1 (j m)', m = m))
180
+ sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
181
+
182
+ sim = sim - sim.amax(dim = -1, keepdim = True).detach()
183
+ attn = sim.softmax(dim = -1)
184
+
185
+ if exists(media_locations) and self.only_attend_immediate_media:
186
+ # any text without a preceding media needs to have attention zeroed out
187
+ text_without_media_mask = text_time == 0
188
+ text_without_media_mask = rearrange(text_without_media_mask, 'b i -> b 1 i 1')
189
+ attn.masked_fill(text_without_media_mask, 0.)
190
+
191
+ out = einsum('... i j, ... j d -> ... i d', attn, v)
192
+ out = rearrange(out, 'b h n d -> b n (h d)')
193
+ return self.to_out(out)
194
+
195
+ class GatedCrossAttentionBlock(nn.Module):
196
+ def __init__(
197
+ self,
198
+ *,
199
+ dim,
200
+ dim_head = 64,
201
+ heads = 8,
202
+ ff_mult = 4,
203
+ only_attend_immediate_media = True
204
+ ):
205
+ super().__init__()
206
+ self.attn = MaskedCrossAttention(dim = dim, dim_head = dim_head, heads = heads, only_attend_immediate_media = only_attend_immediate_media)
207
+ self.attn_gate = nn.Parameter(torch.tensor([0.]))
208
+
209
+ self.ff = FeedForward(dim, mult = ff_mult)
210
+ self.ff_gate = nn.Parameter(torch.tensor([0.]))
211
+
212
+ def forward(
213
+ self,
214
+ x,
215
+ media, # media tensor, encoded by perceiver resample - (batch, time, latents, dim)
216
+ media_locations = None # boolean tensor indicating positions of media - (batch, sequence)
217
+ ):
218
+ x = self.attn(x, media, media_locations = media_locations) * self.attn_gate.tanh() + x
219
+ x = self.ff(x) * self.ff_gate.tanh() + x
220
+ return x
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 2,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 1,
6
+ "transformers_version": "4.29.0"
7
+ }
modeling_flamingo.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import pdb
3
+ from einops import rearrange
4
+ from typing import List, Optional, Tuple, Union
5
+ import os
6
+
7
+ import torch
8
+ import torch.utils.checkpoint
9
+ from torch import nn
10
+ from torch.nn import CrossEntropyLoss
11
+
12
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
+ import transformers.models.opt.modeling_opt as modeling_opt
14
+ from transformers.models.opt.modeling_opt\
15
+ import OPTDecoderLayer, OPTPreTrainedModel, OPTConfig
16
+ from transformers import ViTModel
17
+ from .utils import exists, freeze_all_layers_, unfreeze_all_layers_
18
+ from .flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler
19
+
20
+
21
+ class OPTLearnedPositionalEmbedding(nn.Embedding):
22
+ """
23
+ This module learns positional embeddings up to a fixed maximum size.
24
+ """
25
+
26
+ def __init__(self, num_embeddings: int, embedding_dim: int):
27
+ # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
28
+ # and adjust num_embeddings appropriately. Other models don't have this hack
29
+ self.offset = 2
30
+ super().__init__(num_embeddings + self.offset, embedding_dim)
31
+
32
+ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
33
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
34
+ attention_mask = attention_mask.long()
35
+
36
+ # create positions depending on attention_mask
37
+ positions = torch.cumsum(attention_mask, dim=1)
38
+ positions = (positions.type_as(attention_mask) * attention_mask).long() - 1
39
+
40
+ # cut positions if `past_key_values_length` is > 0
41
+ positions = positions[:, past_key_values_length:]
42
+
43
+ return super().forward(positions + self.offset)
44
+
45
+
46
+ class OPTDecoder(modeling_opt.OPTDecoder):
47
+ """
48
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
49
+
50
+ Args:
51
+ config: OPTConfig
52
+ embed_tokens (nn.Embedding): output embedding
53
+ """
54
+
55
+ def __init__(self, config: OPTConfig):
56
+ OPTPreTrainedModel.__init__(self, config)
57
+ self.dropout = config.dropout
58
+ self.layerdrop = config.layerdrop
59
+ self.padding_idx = config.pad_token_id
60
+ self.max_target_positions = config.max_position_embeddings
61
+ self.vocab_size = config.vocab_size
62
+ self.media_token_id = config.media_token_id
63
+
64
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)
65
+ self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
66
+
67
+ if config.word_embed_proj_dim != config.hidden_size:
68
+ self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
69
+ else:
70
+ self.project_out = None
71
+
72
+ if config.word_embed_proj_dim != config.hidden_size:
73
+ self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
74
+ else:
75
+ self.project_in = None
76
+
77
+ # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
78
+ # with checkpoints that have been fine-tuned before transformers v4.20.1
79
+ # see https://github.com/facebookresearch/metaseq/pull/164
80
+ if config.do_layer_norm_before and not config._remove_final_layer_norm:
81
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size)
82
+ else:
83
+ self.final_layer_norm = None
84
+
85
+ dim_head = config.hidden_size // config.num_attention_heads
86
+ if not config.id_perceiver:
87
+ self.perceiver_resampler = PerceiverResampler(
88
+ dim=config.hidden_size,
89
+ depth=config.perceiver_depth,
90
+ dim_head=dim_head,
91
+ heads=config.num_attention_heads,
92
+ num_latents=config.perceiver_num_latents,
93
+ inp_dim=config.inp_dim,
94
+ )
95
+ else:
96
+ if config.inp_dim is None:
97
+ self.perceiver_resampler = nn.Identity()
98
+ else:
99
+ self.perceiver_resampler = nn.Linear(
100
+ config.inp_dim, config.hidden_size,
101
+ bias=False)
102
+ self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
103
+ self.gated_attn_layers = nn.ModuleList(
104
+ [GatedCrossAttentionBlock(
105
+ dim=config.hidden_size, dim_head=dim_head, heads=config.num_attention_heads,
106
+ only_attend_immediate_media=config.only_attend_immediate_media)\
107
+ if not (ind % config.cross_attn_every) else None \
108
+ for ind in range(config.num_hidden_layers)])
109
+
110
+ self.gradient_checkpointing = False
111
+ # Initialize weights and apply final processing
112
+ self.post_init()
113
+
114
+ # in flamingo mode, freeze everything but perceiver and gated cross attention
115
+ if not config.finetune_LM:
116
+ freeze_all_layers_(self)
117
+ unfreeze_all_layers_(self.perceiver_resampler)
118
+ [unfreeze_all_layers_(cross_attn) for cross_attn in self.gated_attn_layers if exists(cross_attn)]
119
+
120
+ def forward(
121
+ self,
122
+ input_ids: torch.LongTensor = None,
123
+ attention_mask: Optional[torch.Tensor] = None,
124
+ head_mask: Optional[torch.Tensor] = None,
125
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
126
+ inputs_embeds: Optional[torch.FloatTensor] = None,
127
+ use_cache: Optional[bool] = None,
128
+ output_attentions: Optional[bool] = None,
129
+ output_hidden_states: Optional[bool] = None,
130
+ return_dict: Optional[bool] = None,
131
+ pixel_values=None,
132
+ image_embeds=None
133
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
134
+ r"""
135
+ Args:
136
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
137
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
138
+ provide it.
139
+
140
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
141
+ [`PreTrainedTokenizer.__call__`] for details.
142
+
143
+ [What are input IDs?](../glossary#input-ids)
144
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
145
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
146
+
147
+ - 1 for tokens that are **not masked**,
148
+ - 0 for tokens that are **masked**.
149
+
150
+ [What are attention masks?](../glossary#attention-mask)
151
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
152
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
153
+
154
+ - 1 indicates the head is **not masked**,
155
+ - 0 indicates the head is **masked**.
156
+
157
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
158
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
159
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
160
+
161
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
162
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
163
+
164
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
165
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
166
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
167
+
168
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
169
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
170
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
171
+ than the model's internal embedding lookup matrix.
172
+ output_attentions (`bool`, *optional*):
173
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
174
+ returned tensors for more detail.
175
+ output_hidden_states (`bool`, *optional*):
176
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
177
+ for more detail.
178
+ return_dict (`bool`, *optional*):
179
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
180
+ """
181
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
182
+ output_hidden_states = (
183
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
184
+ )
185
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
186
+
187
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
188
+
189
+ # retrieve input_ids and inputs_embeds
190
+ if input_ids is not None and inputs_embeds is not None:
191
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
192
+ elif input_ids is not None:
193
+ input_shape = input_ids.size()
194
+ input_ids = input_ids.view(-1, input_shape[-1])
195
+ elif inputs_embeds is not None:
196
+ input_shape = inputs_embeds.size()[:-1]
197
+ else:
198
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
199
+ batch, device = input_ids.shape[0], input_ids.device
200
+
201
+ flamingo_mode = exists(pixel_values) or exists(image_embeds)
202
+
203
+ # derive the media token ids (as a boolean tensor), for calculating the masked cross attention
204
+ if flamingo_mode:
205
+ media_locations = input_ids == self.media_token_id
206
+
207
+ assert not (exists(pixel_values) and exists(image_embeds))
208
+ # encode images into embeddings
209
+ # with the img_encoder passed in at init
210
+ # it can also accept precomputed image embeddings
211
+
212
+ if exists(pixel_values):
213
+ assert exists(self.img_encoder), 'img_encoder must be passed in for automatic image encoding'
214
+ if len(pixel_values.shape) == 4:
215
+ pixel_values = torch.unsqueeze(pixel_values, 1)
216
+ pixel_values = rearrange(pixel_values, 'b t ... -> (b t) ...')
217
+
218
+ with torch.no_grad():
219
+ if getattr(self.img_encoder, 'vision_model', None) is not None:
220
+ image_outputs = self.img_encoder.vision_model(
221
+ pixel_values=pixel_values,
222
+ output_hidden_states=True, return_dict=True)
223
+ else:
224
+ image_outputs = self.img_encoder(
225
+ pixel_values=pixel_values,
226
+ output_hidden_states=True, return_dict=True)
227
+
228
+ image_embeds = image_outputs['last_hidden_state']
229
+ image_embeds = rearrange(image_embeds, '(b t) ... -> b t ...', b = batch)
230
+
231
+ if exists(image_embeds):
232
+ image_embeds = self.perceiver_resampler(image_embeds)
233
+
234
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
235
+
236
+ if inputs_embeds is None:
237
+ inputs_embeds = self.embed_tokens(input_ids)
238
+
239
+ # embed positions
240
+ if attention_mask is None:
241
+ attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
242
+ pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
243
+
244
+ attention_mask = self._prepare_decoder_attention_mask(
245
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
246
+ )
247
+
248
+ if self.project_in is not None:
249
+ inputs_embeds = self.project_in(inputs_embeds)
250
+
251
+ hidden_states = inputs_embeds + pos_embeds
252
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
253
+
254
+ # decoder layers
255
+ all_hidden_states = () if output_hidden_states else None
256
+ all_self_attns = () if output_attentions else None
257
+ next_decoder_cache = () if use_cache else None
258
+
259
+ # check if head_mask has a correct number of layers specified if desired
260
+ for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
261
+ if attn_mask is not None:
262
+ if attn_mask.size()[0] != (len(self.layers)):
263
+ raise ValueError(
264
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
265
+ f" {head_mask.size()[0]}."
266
+ )
267
+
268
+ for idx, decoder_layer in enumerate(self.layers):
269
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
270
+ if output_hidden_states:
271
+ all_hidden_states += (hidden_states,)
272
+
273
+ dropout_probability = random.uniform(0, 1)
274
+ if self.training and (dropout_probability < self.layerdrop):
275
+ continue
276
+
277
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
278
+
279
+ flamingo_cross_attn = self.gated_attn_layers[idx]
280
+ if exists(flamingo_cross_attn) and exists(image_embeds):
281
+ hidden_states = flamingo_cross_attn(
282
+ hidden_states,
283
+ image_embeds,
284
+ media_locations = media_locations
285
+ )
286
+
287
+ layer_outputs = decoder_layer(
288
+ hidden_states,
289
+ attention_mask=attention_mask,
290
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
291
+ past_key_value=past_key_value,
292
+ output_attentions=output_attentions,
293
+ use_cache=use_cache,
294
+ )
295
+
296
+ hidden_states = layer_outputs[0]
297
+
298
+ if use_cache:
299
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
300
+
301
+ if output_attentions:
302
+ all_self_attns += (layer_outputs[1],)
303
+
304
+ if self.final_layer_norm is not None:
305
+ hidden_states = self.final_layer_norm(hidden_states)
306
+
307
+ if self.project_out is not None:
308
+ hidden_states = self.project_out(hidden_states)
309
+
310
+ # add hidden states from the last decoder layer
311
+ if output_hidden_states:
312
+ all_hidden_states += (hidden_states,)
313
+
314
+ next_cache = next_decoder_cache if use_cache else None
315
+ if not return_dict:
316
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
317
+ return BaseModelOutputWithPast(
318
+ last_hidden_state=hidden_states,
319
+ past_key_values=next_cache,
320
+ hidden_states=all_hidden_states,
321
+ attentions=all_self_attns,
322
+ )
323
+
324
+
325
+ class OPTModel(modeling_opt.OPTModel):
326
+ def __init__(self, config: OPTConfig):
327
+ OPTPreTrainedModel.__init__(self, config)
328
+ self.decoder = OPTDecoder(config)
329
+
330
+ # Initialize weights and apply final processing
331
+ self.post_init()
332
+
333
+
334
+ class OPTForCausalLM(modeling_opt.OPTForCausalLM):
335
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
336
+
337
+ def __init__(self, config):
338
+ OPTPreTrainedModel.__init__(self, config)
339
+ self.model = OPTModel(config)
340
+
341
+ # the lm_head weight is automatically tied to the embed tokens weight
342
+ self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
343
+
344
+ # Initialize weights and apply final processing
345
+ self.post_init()
346
+
347
+
348
+ def set_default_if_nonexist(config, key, value):
349
+ if getattr(config, key, None) is None:
350
+ setattr(config, key, value)
351
+ return config
352
+
353
+
354
+ def setup_default_flamingo_configs(config):
355
+ set_default_if_nonexist(config, 'perceiver_depth', 2)
356
+ set_default_if_nonexist(config, 'perceiver_num_latents', 64)
357
+ set_default_if_nonexist(config, 'cross_attn_every', 3)
358
+ set_default_if_nonexist(config, 'only_attend_immediate_media', True)
359
+ set_default_if_nonexist(config, 'media_token_id', 50265)
360
+ set_default_if_nonexist(config, 'inp_dim', 768)
361
+ set_default_if_nonexist(config, 'finetune_LM', True)
362
+ set_default_if_nonexist(config, 'id_perceiver', False)
363
+ return config
364
+
365
+
366
+ class FlamingoForCausalLM(modeling_opt.OPTForCausalLM):
367
+ _keys_to_ignore_on_load_missing = [
368
+ r"lm_head.weight",
369
+ ]
370
+
371
+ def __init__(self, config):
372
+ OPTPreTrainedModel.__init__(self, config)
373
+ config = setup_default_flamingo_configs(config)
374
+ self.model = OPTModel(config)
375
+
376
+ # the lm_head weight is automatically tied to the embed tokens weight
377
+ self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
378
+
379
+ # Initialize weights and apply final processing
380
+ self.post_init()
381
+ self.model.decoder.img_encoder = None
382
+ self.loss_fct = CrossEntropyLoss()
383
+ dino_model = ViTModel.from_pretrained("facebook/dino-vitb16")
384
+ self.setup_vis_encoder(dino_model)
385
+
386
+ def setup_vis_encoder(self, img_encoder):
387
+ self.model.decoder.img_encoder = img_encoder
388
+ freeze_all_layers_(img_encoder)
389
+
390
+ def forward(
391
+ self,
392
+ input_ids: torch.LongTensor = None,
393
+ attention_mask: Optional[torch.Tensor] = None,
394
+ head_mask: Optional[torch.Tensor] = None,
395
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
396
+ inputs_embeds: Optional[torch.FloatTensor] = None,
397
+ labels: Optional[torch.LongTensor] = None,
398
+ use_cache: Optional[bool] = None,
399
+ output_attentions: Optional[bool] = None,
400
+ output_hidden_states: Optional[bool] = None,
401
+ return_dict: Optional[bool] = None,
402
+ *args, **kwargs) -> Union[Tuple, CausalLMOutputWithPast]:
403
+ r"""
404
+ Args:
405
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
406
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
407
+ provide it.
408
+
409
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
410
+ [`PreTrainedTokenizer.__call__`] for details.
411
+
412
+ [What are input IDs?](../glossary#input-ids)
413
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
414
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
415
+
416
+ - 1 for tokens that are **not masked**,
417
+ - 0 for tokens that are **masked**.
418
+
419
+ [What are attention masks?](../glossary#attention-mask)
420
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
421
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
422
+
423
+ - 1 indicates the head is **not masked**,
424
+ - 0 indicates the head is **masked**.
425
+
426
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
427
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
428
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
429
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
430
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
431
+
432
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
433
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
434
+
435
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
436
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
437
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
438
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
439
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
440
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
441
+ than the model's internal embedding lookup matrix.
442
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
443
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
444
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
445
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
446
+ use_cache (`bool`, *optional*):
447
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
448
+ (see `past_key_values`).
449
+ output_attentions (`bool`, *optional*):
450
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
451
+ returned tensors for more detail.
452
+ output_hidden_states (`bool`, *optional*):
453
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
454
+ for more detail.
455
+ return_dict (`bool`, *optional*):
456
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
457
+
458
+ Returns:
459
+
460
+ Example:
461
+
462
+ ```python
463
+ >>> from transformers import GPT2Tokenizer, OPTForCausalLM
464
+
465
+ >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
466
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
467
+
468
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
469
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
470
+
471
+ >>> # Generate
472
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
473
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
474
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
475
+ ```"""
476
+
477
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
478
+ output_hidden_states = (
479
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
480
+ )
481
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
482
+
483
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
484
+ outputs = self.model.decoder(
485
+ input_ids=input_ids,
486
+ attention_mask=attention_mask,
487
+ head_mask=head_mask,
488
+ past_key_values=past_key_values,
489
+ inputs_embeds=inputs_embeds,
490
+ use_cache=use_cache,
491
+ output_attentions=output_attentions,
492
+ output_hidden_states=output_hidden_states,
493
+ return_dict=return_dict,
494
+ *args, **kwargs)
495
+
496
+ logits = self.lm_head(outputs[0]).contiguous()
497
+
498
+ loss = None
499
+ if labels is not None:
500
+ # Shift so that tokens < n predict n
501
+ shift_logits = logits[..., :-1, :].contiguous()
502
+ shift_labels = labels[..., 1:].contiguous()
503
+ # Flatten the tokens
504
+ loss = self.loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
505
+
506
+ if not return_dict:
507
+ output = (logits,) + outputs[1:]
508
+ return (loss,) + output if loss is not None else output
509
+
510
+ return CausalLMOutputWithPast(
511
+ loss=loss,
512
+ logits=logits,
513
+ past_key_values=outputs.past_key_values,
514
+ hidden_states=outputs.hidden_states,
515
+ attentions=outputs.attentions,
516
+ )
preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "do_rescale": true,
4
+ "do_resize": true,
5
+ "feature_extractor_type": "ViTFeatureExtractor",
6
+ "image_mean": [
7
+ 0.485,
8
+ 0.456,
9
+ 0.406
10
+ ],
11
+ "image_processor_type": "ViTFeatureExtractor",
12
+ "image_std": [
13
+ 0.229,
14
+ 0.224,
15
+ 0.225
16
+ ],
17
+ "resample": 2,
18
+ "rescale_factor": 0.00392156862745098,
19
+ "size": {
20
+ "height": 224,
21
+ "width": 224
22
+ }
23
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d489df10d3aea1f59fdfeabaa5b0ea4ec5a35832f61c0965537441c6d93892ef
3
+ size 1022117679
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<image>",
4
+ "<PERSON>"
5
+ ],
6
+ "pad_token": "<pad>"
7
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "clean_up_tokenization_spaces": true,
3
+ "model_max_length": 1000000000000000019884624838656,
4
+ "tokenizer_class": "PreTrainedTokenizerFast"
5
+ }
utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def exists(val):
4
+ return val is not None
5
+
6
+ # for controlling freezing during training of flamingo
7
+
8
+ def set_module_requires_grad_(module, requires_grad):
9
+ for param in module.parameters():
10
+ param.requires_grad = requires_grad
11
+
12
+ def freeze_all_layers_(module):
13
+ set_module_requires_grad_(module, False)
14
+
15
+ def unfreeze_all_layers_(module):
16
+ set_module_requires_grad_(module, True)
17
+
18
+ def freeze_model_and_make_eval_(model):
19
+ model.eval()
20
+ freeze_all_layers_(model)
21
+
22
+ def _make_att_wd_mask(
23
+ input_ids_shape: torch.Size,
24
+ dtype: torch.dtype, device: torch.device,
25
+ past_key_values_length: int = 0,
26
+ att_wd_size: int = 0,
27
+ ):
28
+ bsz, tgt_len = input_ids_shape
29
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
30
+ mask_cond = torch.arange(mask.size(-1), device=device)
31
+ mask.masked_fill_(
32
+ mask_cond > (mask_cond - att_wd_size).view(mask.size(-1), 1), 0)
33
+ mask = mask.to(dtype)
34
+
35
+ if past_key_values_length > 0:
36
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
37
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)