Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
from tencentpretrain.layers.layer_norm import * | |
from tencentpretrain.layers.position_ffn import PositionwiseFeedForward, GatedFeedForward | |
from tencentpretrain.layers.multi_headed_attn import MultiHeadedAttention | |
from tencentpretrain.layers.relative_position_embedding import RelativePositionEmbedding | |
class TransformerLayer(nn.Module): | |
""" | |
Transformer layer mainly consists of two parts: | |
multi-headed self-attention and feed forward layer. | |
""" | |
def __init__(self, args): | |
super(TransformerLayer, self).__init__() | |
self.layernorm_positioning = args.layernorm_positioning | |
if hasattr(args, "attention_head_size"): | |
attention_head_size = args.attention_head_size | |
else: | |
attention_head_size = args.hidden_size // args.heads_num | |
has_bias = bool(1 - args.remove_transformer_bias) | |
with_scale = bool(1 - args.remove_attention_scale) | |
# Multi-headed self-attention. | |
self.self_attn = MultiHeadedAttention( | |
args.hidden_size, args.heads_num, attention_head_size, args.dropout, has_bias=has_bias, with_scale = with_scale | |
) | |
self.dropout_1 = nn.Dropout(args.dropout) | |
# Feed forward layer. | |
if args.feed_forward == "gated": | |
self.feed_forward = GatedFeedForward( | |
args.hidden_size, args.feedforward_size, args.hidden_act, has_bias | |
) | |
else: | |
self.feed_forward = PositionwiseFeedForward( | |
args.hidden_size, args.feedforward_size, args.hidden_act, has_bias | |
) | |
self.dropout_2 = nn.Dropout(args.dropout) | |
if args.layernorm == "t5": | |
self.layer_norm_1 = T5LayerNorm(args.hidden_size) | |
self.layer_norm_2 = T5LayerNorm(args.hidden_size) | |
elif args.layernorm == "rms": | |
self.layer_norm_1 = RMSNorm(args.hidden_size) | |
self.layer_norm_2 = RMSNorm(args.hidden_size) | |
else: | |
self.layer_norm_1 = LayerNorm(args.hidden_size) | |
self.layer_norm_2 = LayerNorm(args.hidden_size) | |
def forward(self, hidden, mask, position_bias=None, has_residual_attention=False, prev_attn=None, freqs_cis=None): | |
""" | |
Args: | |
hidden: [batch_size x seq_length x emb_size] | |
mask: [batch_size x 1 x seq_length x seq_length] | |
position_bias: [1 x heads_num x seq_length x seq_length] | |
Returns: | |
output: [batch_size x seq_length x hidden_size] | |
""" | |
if self.layernorm_positioning == "post": | |
inter, prev_attn_out = self.self_attn(hidden, hidden, hidden, mask, position_bias, has_residual_attention, prev_attn, freqs_cis) | |
inter = self.dropout_1(inter) | |
inter = self.layer_norm_1(inter + hidden) | |
output = self.dropout_2(self.feed_forward(inter)) | |
output = self.layer_norm_2(output + inter) | |
else: | |
inter = self.layer_norm_1(hidden) | |
inter, prev_attn_out = self.self_attn(inter, inter, inter, mask, position_bias, has_residual_attention, prev_attn, freqs_cis) | |
inter = self.dropout_1(inter) | |
hidden = hidden + inter | |
output = self.layer_norm_2(hidden) | |
output = self.dropout_2(self.feed_forward(output)) + hidden | |
return output, prev_attn_out | |
class TransformerDecoderLayer(nn.Module): | |
def __init__(self, args): | |
super(TransformerDecoderLayer, self).__init__() | |
self.layernorm_positioning = args.layernorm_positioning | |
if hasattr(args, "attention_head_size"): | |
attention_head_size = args.attention_head_size | |
else: | |
attention_head_size = args.hidden_size // args.heads_num | |
has_bias = bool(1 - args.remove_transformer_bias) | |
with_scale = bool(1 - args.remove_attention_scale) | |
# Multi-headed self-attention. | |
self.self_attn = MultiHeadedAttention( | |
args.hidden_size, args.heads_num, attention_head_size, args.dropout, has_bias=has_bias, with_scale=with_scale | |
) | |
self.dropout_1 = nn.Dropout(args.dropout) | |
# Multi-headed context-attention. | |
self.context_attn = MultiHeadedAttention( | |
args.hidden_size, args.heads_num, attention_head_size, args.dropout, has_bias=has_bias, with_scale=with_scale | |
) | |
self.dropout_2 = nn.Dropout(args.dropout) | |
# Feed forward layer. | |
if args.feed_forward == "gated": | |
self.feed_forward = GatedFeedForward( | |
args.hidden_size, args.feedforward_size, args.hidden_act, has_bias | |
) | |
else: | |
self.feed_forward = PositionwiseFeedForward( | |
args.hidden_size, args.feedforward_size, args.hidden_act, has_bias | |
) | |
self.dropout_3 = nn.Dropout(args.dropout) | |
# Layer Normalization | |
if args.layernorm == "t5": | |
self.layer_norm_1 = T5LayerNorm(args.hidden_size) | |
self.layer_norm_2 = T5LayerNorm(args.hidden_size) | |
self.layer_norm_3 = T5LayerNorm(args.hidden_size) | |
else: | |
self.layer_norm_1 = LayerNorm(args.hidden_size) | |
self.layer_norm_2 = LayerNorm(args.hidden_size) | |
self.layer_norm_3 = LayerNorm(args.hidden_size) | |
def forward(self, hidden, encoder_hidden, mask_decoder, mask_encoder, self_position_bias=None, context_position_bias=None): | |
""" | |
Args: | |
hidden: [batch_size x seq_length x emb_size] | |
encoder_hidden: [batch_size x seq_length x emb_size] | |
mask_encoder: [batch_size x 1 x seq_length x seq_length] | |
mask_decoder: [batch_size x 1 x seq_length x seq_length] | |
self_position_bias: [1 x heads_num x seq_length x seq_length] | |
context_position_bias: [1 x heads_num x seq_length x seq_length] | |
Returns: | |
output: [batch_size x seq_length x hidden_size] | |
""" | |
if self.layernorm_positioning == "post": | |
query, _ = self.self_attn(hidden, hidden, hidden, mask_decoder, self_position_bias) | |
query = self.dropout_1(query) | |
query_norm = self.layer_norm_1(query + hidden) | |
mid, _ = self.context_attn(encoder_hidden, encoder_hidden, query_norm, mask_encoder, context_position_bias) | |
mid = self.dropout_2(mid) | |
mid_norm = self.layer_norm_2(mid + query_norm) | |
output = self.dropout_3(self.feed_forward(mid_norm)) | |
output = self.layer_norm_3(output + mid_norm) | |
else: | |
hidden_norm = self.layer_norm_1(hidden) | |
query, _ = self.self_attn(hidden_norm, hidden_norm, hidden_norm, mask_decoder, self_position_bias) | |
query = self.dropout_1(query) | |
query = query + hidden | |
query_norm = self.layer_norm_2(query) | |
mid, _ = self.context_attn(encoder_hidden, encoder_hidden, query_norm, mask_encoder, context_position_bias) | |
mid = self.dropout_2(mid) | |
mid = mid + query | |
mid_norm = self.layer_norm_3(mid) | |
output = self.dropout_3(self.feed_forward(mid_norm)) + mid | |
return output | |