RA-BART / custom_bart /encoder_layer.py
MrVicente's picture
added demo base code
6cf191b
raw
history blame
4.47 kB
#############################
# Imports
#############################
# Python modules
from typing import Optional, Tuple
# Remote modules
import torch
from torch import nn
from transformers import BartConfig
from transformers.activations import ACT2FN
# Local modules
from .bart_attention import BartCustomAttention
from .bart_mask_attention import BartCustomMaskAttention
from .config import BartCustomConfig
class BartCustomEncoderLayer(nn.Module):
def __init__(self, config: BartCustomConfig, heads_mask: Optional[torch.Tensor]):
super().__init__()
self.embed_dim = config.d_model
is_simple_mask_commonsense = config.is_simple_mask_commonsense
if not is_simple_mask_commonsense:
print("Selecting complex relation attention")
self.self_attn = BartCustomAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
num_relation_kinds=config.num_relation_kinds,
use_same_relation_kv_emb=config.use_same_relation_kv_emb,
heads_mask=heads_mask,
)
else:
print("Selecting simple (MASK) relation attention")
self.self_attn = BartCustomMaskAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
num_relation_kinds=config.num_relation_kinds,
heads_mask=heads_mask,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
layer_head_mask: torch.FloatTensor,
output_attentions: Optional[bool] = False,
relation_inputs: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states, attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
relation_inputs=relation_inputs,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs