|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
from .attention_utils import ( |
|
create_layer_with_commonsense_on_specific_head, |
|
find_head_to_mask, |
|
convert_relations_to_binary_mask, |
|
update_weights_regarding_relations_on_specific_head |
|
) |
|
|
|
|
|
class BartCustomAttention(nn.Module): |
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
num_heads: int, |
|
dropout: float = 0.0, |
|
is_decoder: bool = False, |
|
bias: bool = True, |
|
num_relation_kinds: int = 0, |
|
use_same_relation_kv_emb: bool = True, |
|
heads_mask: Optional[torch.Tensor] = None, |
|
): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.num_heads = num_heads |
|
self.dropout = dropout |
|
self.head_dim = embed_dim // num_heads |
|
|
|
if (self.head_dim * num_heads) != self.embed_dim: |
|
raise ValueError( |
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" |
|
f" and `num_heads`: {num_heads})." |
|
) |
|
if heads_mask.size() != (self.num_heads,): |
|
raise ValueError( |
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {heads_mask.size()}" |
|
) |
|
self.heads_mask = heads_mask |
|
|
|
self.scaling = self.head_dim**-0.5 |
|
self.is_decoder = is_decoder |
|
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
|
self.num_relation_kinds = num_relation_kinds |
|
self.relation_k_emb = nn.Embedding(num_relation_kinds + 1, self.head_dim, padding_idx=0) |
|
if use_same_relation_kv_emb: |
|
self.relation_v_emb = self.relation_k_emb |
|
else: |
|
self.relation_v_emb = nn.Embedding(num_relation_kinds + 1, self.head_dim, padding_idx=0) |
|
|
|
self.k_rel_scale = 0.0 |
|
self.v_rel_scale = 1.0 |
|
|
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
key_value_states: Optional[torch.Tensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
layer_head_mask: Optional[torch.Tensor] = None, |
|
output_attentions: bool = False, |
|
relation_inputs: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
"""Input shape: Batch x Time x Channel""" |
|
|
|
|
|
|
|
|
|
is_cross_attention = key_value_states is not None |
|
|
|
bsz, tgt_len, embed_dim = hidden_states.size() |
|
|
|
|
|
if relation_inputs is None: |
|
|
|
print('oh no') |
|
relation_inputs = torch.zeros((bsz, tgt_len, tgt_len)).to('cuda').long() |
|
|
|
assert relation_inputs.shape == (bsz, tgt_len, tgt_len) |
|
|
|
|
|
relation_k_embeds = self.relation_k_emb(relation_inputs) |
|
relation_v_embeds = self.relation_v_emb(relation_inputs) |
|
|
|
|
|
query_states = self.q_proj(hidden_states) * self.scaling |
|
|
|
if is_cross_attention and past_key_value is not None: |
|
|
|
key_states = past_key_value[0] |
|
value_states = past_key_value[1] |
|
elif is_cross_attention: |
|
|
|
key_states = self._shape(self.k_proj(key_value_states), -1, bsz) |
|
value_states = self._shape(self.v_proj(key_value_states), -1, bsz) |
|
elif past_key_value is not None: |
|
|
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
|
key_states = torch.cat([past_key_value[0], key_states], dim=2) |
|
value_states = torch.cat([past_key_value[1], value_states], dim=2) |
|
else: |
|
|
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
|
|
|
if self.is_decoder: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
past_key_value = (key_states, value_states) |
|
|
|
proj_shape = (bsz * self.num_heads, -1, self.head_dim) |
|
query_states = self._shape(query_states, tgt_len, bsz) |
|
src_len = key_states.size(2) |
|
|
|
|
|
attn_weights = torch.matmul( |
|
query_states, key_states.transpose(3, 2) |
|
) |
|
|
|
|
|
q_t = query_states.permute(0, 2, 1, 3) |
|
|
|
|
|
r_t = relation_k_embeds.transpose(-2, -1) |
|
|
|
|
|
q_tr_t_matmul = torch.matmul(q_t, r_t) |
|
q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 2, 1, 3) |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
q_tr_tmatmul_t = self.layer_heads_relation_attention_update( |
|
self.heads_mask, |
|
q_tr_tmatmul_t, |
|
) |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_weights += self.k_rel_scale * q_tr_tmatmul_t |
|
|
|
|
|
|
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): |
|
raise ValueError( |
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" |
|
) |
|
|
|
if attention_mask is not None: |
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len): |
|
raise ValueError( |
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" |
|
) |
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask |
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
|
|
|
|
|
""" |
|
attn_weights = self.layer_heads_relation_attention_update(layer_head_mask, |
|
relation_inputs, |
|
attn_weights, |
|
bsz, |
|
tgt_len, |
|
src_len) |
|
""" |
|
if layer_head_mask is not None: |
|
if layer_head_mask.size() != (self.num_heads,): |
|
raise ValueError( |
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" |
|
) |
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
|
|
|
if output_attentions: |
|
|
|
|
|
|
|
|
|
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
|
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) |
|
else: |
|
attn_weights_reshaped = None |
|
|
|
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
|
|
|
attn_output = torch.bmm(attn_probs, value_states.view(*proj_shape)) |
|
|
|
|
|
|
|
w_t = attn_probs.view(bsz, self.num_heads, tgt_len, src_len).permute(0, 2, 1, 3) |
|
|
|
|
|
|
|
w_tr_matmul = torch.matmul(w_t, relation_v_embeds) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
w_tr_matmul = self.layer_heads_relation_attention_v_update( |
|
self.heads_mask, |
|
w_tr_matmul, |
|
bsz, |
|
tgt_len, |
|
) |
|
""" |
|
w_tr_matmul = self.v_rel_scale * w_tr_matmul |
|
|
|
|
|
|
|
|
|
w_tr_matmul = w_tr_matmul.permute(0, 2, 1, 3) |
|
w_tr_matmul = w_tr_matmul.reshape(bsz * self.num_heads, tgt_len, self.head_dim) |
|
|
|
|
|
|
|
attn_output += w_tr_matmul |
|
|
|
|
|
|
|
|
|
|
|
|
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): |
|
raise ValueError( |
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" |
|
) |
|
|
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) |
|
attn_output = attn_output.transpose(1, 2) |
|
|
|
|
|
|
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) |
|
|
|
attn_output = self.out_proj(attn_output) |
|
|
|
return attn_output, attn_weights_reshaped, past_key_value |
|
|
|
def layer_heads_relation_attention_update(self, |
|
layer_head_mask, |
|
data, |
|
): |
|
if layer_head_mask is not None: |
|
if layer_head_mask.size() != (self.num_heads,): |
|
raise ValueError( |
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" |
|
) |
|
|
|
masked_weights = layer_head_mask.view(self.num_heads, 1, 1) * data |
|
return masked_weights |
|
return data |
|
|
|
def layer_heads_relation_attention_v_update(self, |
|
layer_head_mask, |
|
data, |
|
bsz, |
|
tgt_len, |
|
): |
|
if layer_head_mask is not None: |
|
if layer_head_mask.size() != (self.num_heads,): |
|
raise ValueError( |
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" |
|
) |
|
|
|
|
|
|
|
|
|
masked_weights = layer_head_mask.view(self.num_heads, 1, 1) * data.view(bsz, self.num_heads, tgt_len, self.head_dim) |
|
return masked_weights.view(bsz, tgt_len, self.num_heads, self.head_dim) |
|
return data |