Spaces:
Runtime error
Runtime error
File size: 6,729 Bytes
7900c16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import torch
import torch.nn as nn
from tencentpretrain.utils.rope import precompute_freqs_cis
from tencentpretrain.layers.transformer import TransformerLayer
from tencentpretrain.layers.layer_norm import *
from tencentpretrain.layers.relative_position_embedding import RelativePositionEmbedding
class TransformerEncoder(nn.Module):
"""
BERT encoder exploits 12 or 24 transformer layers to extract features.
"""
def __init__(self, args):
super(TransformerEncoder, self).__init__()
self.mask = args.mask
self.layers_num = args.layers_num
self.parameter_sharing = args.parameter_sharing
self.factorized_embedding_parameterization = args.factorized_embedding_parameterization
self.layernorm_positioning = args.layernorm_positioning
self.relative_position_embedding = args.relative_position_embedding
self.rotary_position_embedding = args.rotary_position_embedding
self.has_residual_attention = args.has_residual_attention
if "deepspeed_checkpoint_activations" in args:
self.deepspeed_checkpoint_activations = args.deepspeed_checkpoint_activations
self.deepspeed_checkpoint_layers_num = args.deepspeed_checkpoint_layers_num
else:
self.deepspeed_checkpoint_activations = False
has_bias = bool(1 - args.remove_transformer_bias)
if self.factorized_embedding_parameterization:
self.linear = nn.Linear(args.emb_size, args.hidden_size)
if self.parameter_sharing:
self.transformer = TransformerLayer(args)
else:
self.transformer = nn.ModuleList(
[TransformerLayer(args) for _ in range(self.layers_num)]
)
if self.layernorm_positioning == "pre":
if args.layernorm == "t5":
self.layer_norm = T5LayerNorm(args.hidden_size)
elif args.layernorm == "rms":
self.layer_norm = RMSNorm(args.hidden_size)
else:
self.layer_norm = LayerNorm(args.hidden_size)
if self.relative_position_embedding:
self.relative_pos_emb = RelativePositionEmbedding(bidirectional=True, heads_num=args.heads_num,
num_buckets=args.relative_attention_buckets_num)
elif self.rotary_position_embedding:
self.freqs_cis = precompute_freqs_cis(args.hidden_size // args.heads_num, args.max_seq_length * 2)
def forward(self, emb, seg):
"""
Args:
emb: [batch_size x seq_length x emb_size]
seg: [batch_size x seq_length]
Returns:
hidden: [batch_size x seq_length x hidden_size]
"""
if self.factorized_embedding_parameterization:
emb = self.linear(emb)
batch_size, seq_length, _ = emb.size()
# Generate mask according to segment indicators.
# mask: [batch_size x 1 x seq_length x seq_length]
if self.mask == "fully_visible":
mask = (seg > 0). \
unsqueeze(1). \
repeat(1, seq_length, 1). \
unsqueeze(1)
mask = mask.float()
mask = (1.0 - mask) * -10000.0
elif self.mask == "causal":
mask = torch.ones(seq_length, seq_length, device=emb.device)
mask = torch.tril(mask)
mask = (1.0 - mask) * -10000
mask = mask.repeat(batch_size, 1, 1, 1)
else:
mask_a = (seg == 1). \
unsqueeze(1). \
repeat(1, seq_length, 1). \
unsqueeze(1).float()
mask_b = (seg > 0). \
unsqueeze(1). \
repeat(1, seq_length, 1). \
unsqueeze(1).float()
mask_tril = torch.ones(seq_length, seq_length, device=emb.device)
mask_tril = torch.tril(mask_tril)
mask_tril = mask_tril.repeat(batch_size, 1, 1, 1)
mask = (mask_a + mask_b + mask_tril >= 2).float()
mask = (1.0 - mask) * -10000.0
hidden = emb
if self.relative_position_embedding:
position_bias = self.relative_pos_emb(hidden, hidden)
else:
position_bias = None
if self.rotary_position_embedding:
freqs_cis = self.freqs_cis[:seq_length].to(hidden.device)
else:
freqs_cis = None
prev_attn = None
if self.deepspeed_checkpoint_activations:
from deepspeed import checkpointing
def custom(start, end):
def custom_forward(*inputs):
x_, y_, position_bias_, freqs_cis_ = inputs
for index in range(start, end):
if self.parameter_sharing:
x_, y_ = self.transformer(x_, mask, position_bias=position_bias_,
has_residual_attention=self.has_residual_attention,
prev_attn=y_, freqs_cis=freqs_cis_)
else:
x_, y_ = self.transformer[index](x_, mask, position_bias=position_bias_,
has_residual_attention=self.has_residual_attention,
prev_attn=y_, freqs_cis=freqs_cis_)
return x_, y_
return custom_forward
l = 0
while l < self.layers_num:
hidden, prev_attn = checkpointing.checkpoint(custom(l, l + self.deepspeed_checkpoint_layers_num),
hidden, prev_attn, position_bias, freqs_cis)
l += self.deepspeed_checkpoint_layers_num
else:
for i in range(self.layers_num):
if self.parameter_sharing:
hidden, prev_attn = self.transformer(hidden, mask, position_bias=position_bias,
has_residual_attention=self.has_residual_attention,
prev_attn=prev_attn, freqs_cis=freqs_cis)
else:
hidden, prev_attn = self.transformer[i](hidden, mask, position_bias=position_bias,
has_residual_attention=self.has_residual_attention,
prev_attn=prev_attn, freqs_cis=freqs_cis)
if self.layernorm_positioning == "pre":
return self.layer_norm(hidden)
else:
return hidden
|