from typing import Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn from transformers.models.gpt2.modeling_gpt2 import GPT2Attention class GPT2KNNAttention(GPT2Attention): def __init__(self, config, knn_memory, device, is_cross_attention=False, layer_idx=None, num_retrieve_memories=32): super().__init__(config, is_cross_attention, layer_idx) self.knn_memory = knn_memory self.device = device self.num_retrieve_memories = num_retrieve_memories self.knn_attn_dropout = nn.Dropout(config.attn_pdrop) self.attn_comb_bias = nn.Parameter(torch.empty(self.num_heads,)) nn.init.normal_(self.attn_comb_bias, mean=0.0, std=1.0) # self.attn_comb_bias = nn.Parameter(torch.full((self.num_heads,), 1.0)) def _knn_attn(self, query, key, value, mask, head_mask=None): query = query.unsqueeze(-2) attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: attn_weights = attn_weights / torch.full( [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device ) # Layer-wise attention scaling if self.scale_attn_by_inverse_layer_idx: attn_weights = attn_weights / float(self.layer_idx + 1) # if not self.is_cross_attention: # raise RuntimeError("KNN attention is not yet implemented for !cross_attention") # # if only "normal" attention layer implements causal mask # query_length, key_length = query.size(-3), key.size(-3) # causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] # mask_value = torch.finfo(attn_weights.dtype).min # # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` # mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) # attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) attn_weights = nn.functional.softmax(attn_weights, dim=-1) # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise attn_weights = attn_weights.type(value.dtype) attn_weights = self.knn_attn_dropout(attn_weights) # masking missing keys sh = mask.size() attn_weights = attn_weights * mask.view((sh[0], 1, 1, 1, sh[1])) # Mask heads if we want to if head_mask is not None: attn_weights = attn_weights * head_mask attn_output = torch.matmul(attn_weights, value) attn_output.squeeze_(dim=-2) return attn_output def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_output, attn_weights = super()._attn( query, key, value, attention_mask, head_mask) knn_key, knn_value, knn_mask = self.knn_memory.search( query, self.num_retrieve_memories) g = torch.sigmoid(self.attn_comb_bias)[:, None, None] if knn_key.numel() == 0: return attn_output * (1 - g), attn_weights knn_key, knn_value, knn_mask = knn_key.to( self.device), knn_value.to(self.device), knn_mask.to(self.device) knn_attn_output = self._knn_attn( query, knn_key, knn_value, knn_mask, head_mask) # combining two attentions attn = knn_attn_output * g + attn_output * (1 - g) return attn, attn_weights def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): raise RuntimeError( "KNN attention is not yet implemented for _upcast_and_reordered_attn") def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." ) query = self.q_attn(hidden_states) key, value = self.c_attn(encoder_hidden_states).split( self.split_size, dim=2) attention_mask = encoder_attention_mask else: query, key, value = self.c_attn( hidden_states).split(self.split_size, dim=2) query = self._split_heads(query, self.num_heads, self.head_dim) key = self._split_heads(key, self.num_heads, self.head_dim) value = self._split_heads(value, self.num_heads, self.head_dim) # normalization of queries and keys reduces the effect of staleness query, key = F.normalize(query, dim=-1), F.normalize(key, dim=-1) new_memories = (key, value) if layer_past is not None: past_key, past_value = layer_past key = torch.cat((past_key, key), dim=-2) value = torch.cat((past_value, value), dim=-2) if use_cache is True: present = (key, value) else: present = None if self.reorder_and_upcast_attn: raise RuntimeError("Not implemented") attn_output, attn_weights = self._upcast_and_reordered_attn( query, key, value, attention_mask, head_mask) else: attn_output, attn_weights = self._attn( query, key, value, attention_mask, head_mask) attn_output = self._merge_heads( attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present) if output_attentions: outputs += (attn_weights,) self.knn_memory.add(*new_memories) return outputs # a, present, (attentions)