File size: 1,018 Bytes
a0a9adf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from torch import nn
from transformers import GPT2LMHeadModel as GPT2LMHeadModelBase
from transformers.models.gpt2.modeling_gpt2 import GPT2Block as GPT2BlockBase
class GPT2Block(GPT2BlockBase):
def forward(self, x, layer_past=None,
attention_mask=None, head_mask=None, use_cache=False,
encoder_hidden_states=None, encoder_attention_mask=None, output_attentions=None):
x = self.ln_1(x)
output_attn = self.attn(
x, layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache)
a = output_attn[0]
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
outputs = (x,) + output_attn[1:]
return outputs
class GPT2LMHeadModel(GPT2LMHeadModelBase):
def __init__(self, config):
super().__init__(config)
self.transformer.h = nn.ModuleList([GPT2Block(config, layer_idx) for layer_idx in range(config.n_layer)]) |