|
import re, math, torch |
|
from collections import OrderedDict |
|
from typing import Optional, Tuple |
|
|
|
from torch import nn |
|
from torch.nn.init import trunc_normal_, normal_ |
|
import torch.utils.checkpoint |
|
|
|
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel |
|
|
|
|
|
class ClassInstantier(OrderedDict): |
|
def __getitem__(self, key): |
|
content = super().__getitem__(key) |
|
cls, kwargs = content if isinstance(content, tuple) else (content, {}) |
|
return cls(**kwargs) |
|
|
|
|
|
ACT2CLS = {"silu": nn.SiLU} |
|
|
|
ACT2FN = ClassInstantier(ACT2CLS) |
|
|
|
|
|
class WeightedNorm(nn.Module): |
|
def __init__(self, hidden_size): |
|
""" |
|
WeightedNorm |
|
""" |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.norm = nn.LayerNorm(self.hidden_size) |
|
self.wheight = nn.Parameter(torch.ones(self.hidden_size)) |
|
normal_(self.wheight, mean=1, std=.02) |
|
|
|
def forward(self, x): |
|
x = self.norm(x) |
|
return x * self.wheight |
|
|
|
|
|
class PerceiverMLP(nn.Module): |
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
intermediate_size: int, |
|
output_size: int, |
|
hidden_act: str, |
|
): |
|
super().__init__() |
|
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
|
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
|
self.down_proj = nn.Linear(intermediate_size, output_size, bias=False) |
|
self.act_fn = ACT2FN[hidden_act] |
|
|
|
def forward(self, x): |
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
""" |
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
""" |
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
class PerceiverAttention(nn.Module): |
|
def __init__(self, connector_config, layer_idx: Optional[int] = None) -> None: |
|
"""Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`""" |
|
super().__init__() |
|
|
|
self.layer_idx = None |
|
self.hidden_size = connector_config.text_hidden_size |
|
self.num_heads = connector_config.resampler_n_heads |
|
self.head_dim = connector_config.resampler_head_dim |
|
self.num_key_value_heads = connector_config.num_key_value_heads |
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
|
|
self.is_causal = False |
|
|
|
def forward( |
|
self, |
|
latents: torch.Tensor, |
|
context: torch.Tensor, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
""" |
|
Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension! |
|
|
|
Args: |
|
latents (`torch.Tensor`): Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to. |
|
context (`torch.Tensor`): Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample. |
|
output_attentions (`bool`, *optional*, defaults to `False`): Whether to return attention weights. |
|
use_cache (`bool`, *optional*, defaults to `False`): Whether to use past_key_value for caching. |
|
""" |
|
bsz, q_len, _ = latents.size() |
|
kv_seq_len = q_len + context.size()[1] |
|
|
|
hidden_states = torch.concat([context, latents], dim=-2) |
|
|
|
query_states = self.q_proj(latents) |
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
|
past_key_value = getattr(self, "past_key_value", past_key_value) |
|
|
|
if past_key_value is not None: |
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) |
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
|
|
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): |
|
raise ValueError( |
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" |
|
f" {attn_weights.size()}" |
|
) |
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
|
raise ValueError( |
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
|
f" {attn_output.size()}" |
|
) |
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) |
|
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
|
|
return attn_output, attn_weights, past_key_value |
|
|
|
|
|
PERCEIVER_ATTENTION_CLASSES = { |
|
"eager": PerceiverAttention, |
|
} |
|
|
|
|
|
class PerceiverLayer(nn.Module): |
|
def __init__(self, connector_config, layer_idx: int): |
|
super().__init__() |
|
self.hidden_size = connector_config.text_hidden_size |
|
self.n_latents = connector_config.num_output_tokens |
|
self.depth = connector_config.resampler_depth |
|
self.ff_multi = connector_config.ff_multi |
|
|
|
self.input_latents_norm = WeightedNorm(self.hidden_size) |
|
self.input_context_norm = WeightedNorm(self.hidden_size) |
|
self.self_attn = PERCEIVER_ATTENTION_CLASSES[connector_config._attn_implementation](connector_config, |
|
layer_idx=layer_idx) |
|
self.post_attention_layernorm = WeightedNorm(self.hidden_size) |
|
self.mlp = PerceiverMLP( |
|
hidden_size=connector_config.text_hidden_size, |
|
intermediate_size=connector_config.text_hidden_size * self.ff_multi, |
|
output_size=connector_config.text_hidden_size, |
|
hidden_act=connector_config.hidden_act, |
|
) |
|
|
|
def forward( |
|
self, |
|
latents: torch.Tensor, |
|
context: torch.Tensor, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
**kwargs, |
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
""" |
|
Args: |
|
latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
|
context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
use_cache (`bool`, *optional*): |
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
|
(see `past_key_values`). |
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states |
|
""" |
|
residual = latents |
|
|
|
latents = self.input_latents_norm(latents) |
|
context = self.input_context_norm(context) |
|
|
|
latents, self_attn_weights, present_key_value = self.self_attn( |
|
latents=latents, |
|
context=context, |
|
) |
|
|
|
latents = residual + latents |
|
residual = latents |
|
|
|
latents = self.post_attention_layernorm(latents) |
|
latents = self.mlp(latents) |
|
latents = residual + latents |
|
|
|
outputs = (latents,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights,) |
|
|
|
if use_cache: |
|
outputs += (present_key_value,) |
|
|
|
return outputs |
|
|
|
|
|
class PerceiverResampler(nn.Module): |
|
"""Perceiver Resampler that compresses input embeddings into a fixed number of latents.""" |
|
|
|
def __init__(self, connector_config) -> None: |
|
super().__init__() |
|
self.hidden_size = connector_config.text_hidden_size |
|
self.hidden_act = connector_config.hidden_act |
|
self.n_latents = connector_config.num_output_tokens |
|
self.depth = connector_config.resampler_depth |
|
|
|
|
|
self.latents = nn.Parameter(torch.zeros(self.n_latents, self.hidden_size)) |
|
|
|
|
|
self.layers = nn.ModuleList([PerceiverLayer(connector_config, idx) for idx in range(self.depth)]) |
|
self.norm = WeightedNorm(self.hidden_size) |
|
self._use_flash_attention_2 = connector_config._attn_implementation == "flash_attention_2" |
|
|
|
def forward( |
|
self, |
|
context: torch.Tensor, |
|
attention_mask: torch.Tensor = None, |
|
) -> torch.Tensor: |
|
|
|
latents = self.latents.unsqueeze(0).expand((context.shape[0], *self.latents.size())) |
|
|
|
compressed_context = latents |
|
for i, perceiver_layer in enumerate(self.layers): |
|
layer_outputs = perceiver_layer( |
|
compressed_context, |
|
context, |
|
past_key_value=None, |
|
output_attentions=False, |
|
use_cache=False, |
|
) |
|
compressed_context = layer_outputs[0] |
|
|
|
compressed_context = self.norm(compressed_context) |
|
return compressed_context |
|
|
|
|
|
def build_mm_projector( |
|
input_dim, |
|
output_dim, |
|
projector_type, |
|
hidden_act='silu', |
|
delay_load=False, |
|
token_input_shape=0, |
|
**kwargs |
|
) -> nn.Sequential: |
|
|
|
modules = [nn.Linear(input_dim, output_dim)] |
|
mlp_gelu_match = re.match(r'.*mlp(\d+)x_gelu$', projector_type) |
|
if mlp_gelu_match is not None: |
|
mlp_depth = int(mlp_gelu_match.group(1)) |
|
for _ in range(mlp_depth - 1): |
|
modules.append(nn.GELU()) |
|
modules.append(nn.Linear(output_dim, output_dim)) |
|
|
|
return nn.Sequential(*modules) |
|
|
|
|
|
class MMConnector(PreTrainedModel): |
|
config_class = PretrainedConfig |
|
|
|
def __init__(self, config: PretrainedConfig) -> None: |
|
super().__init__(config) |
|
self.proj = build_mm_projector(config.vision_hidden_size, config.text_hidden_size, |
|
config.projector_type, token_input_shape=config.token_input_shape) |
|
self.resampler = PerceiverResampler(config) |
|
|
|
def forward(self, x): |
|
x = self.proj(x) |
|
x = self.resampler(x) |
|
return x |
|
|