Spaces:
Sleeping
Sleeping
""" | |
OpenAI's GPT-2 ported to PyTorch. | |
""" | |
import math | |
import attr | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
import torch.utils.checkpoint | |
class HParams: | |
n_vocab: int | |
n_ctx: int | |
n_embed: int | |
n_hidden: int | |
n_head: int | |
n_layer: int | |
gradient_checkpointing: bool = False | |
class Model(nn.Module): | |
def __init__(self, hparams: HParams): | |
super().__init__() | |
self.hparams = hparams | |
self.wpe = nn.Embedding(hparams.n_ctx, hparams.n_embed) | |
nn.init.normal_(self.wpe.weight, std=0.01) | |
self.wte = nn.Embedding(hparams.n_vocab, hparams.n_embed) | |
nn.init.normal_(self.wte.weight, std=0.02) | |
self.blocks = nn.ModuleList( | |
[Block(hparams) for _ in range(hparams.n_layer)]) | |
self.ln_f = Norm(self.hparams.n_hidden) | |
if hparams.n_hidden != hparams.n_embed: | |
self.in_proj = Conv1D(hparams.n_embed, hparams.n_hidden) | |
self.out_proj = Conv1D(hparams.n_hidden, hparams.n_embed) | |
else: | |
self.in_proj = self.out_proj = None | |
def forward(self, x, past=None): | |
# Embedding | |
past_length = 0 if past is None else past.shape[-2] | |
batch_size, n_ctx = x.shape | |
position = position_for(batch_size, n_ctx, past_length, x.device) | |
h = self.wte(x) + self.wpe(position) | |
assert h.shape == (batch_size, n_ctx, self.hparams.n_embed) | |
if self.in_proj: | |
h = self.in_proj(h) | |
# Transformer | |
presents = [] | |
for i, block in enumerate(self.blocks): | |
if self.hparams.gradient_checkpointing: | |
h, present = torch.utils.checkpoint.checkpoint( | |
block, h, past[:, i] if past is not None else None) | |
else: | |
h, present = block( | |
h, past=past[:, i] if past is not None else None) | |
presents.append(present) | |
h = self.ln_f(h) | |
if self.out_proj: | |
h = self.out_proj(h) | |
# Output logits | |
h_flat = h.reshape([batch_size * n_ctx, self.hparams.n_embed]) | |
logits = torch.matmul(h_flat, self.wte.weight.t()) | |
logits = logits.reshape([batch_size, n_ctx, self.hparams.n_vocab]) | |
return { | |
'presents': torch.stack(tuple(presents), dim=1), | |
'logits': logits, | |
} | |
class Block(nn.Module): | |
def __init__(self, hparams: HParams): | |
super().__init__() | |
self.ln_1 = Norm(hparams.n_hidden) | |
self.ln_2 = Norm(hparams.n_hidden) | |
self.mlp = MLP(hparams.n_hidden, hparams.n_hidden * 4) | |
self.attn = Attention(hparams) | |
def forward(self, x, past): | |
a, present = self.attn(self.ln_1(x), past=past) | |
x = x + a | |
m = self.mlp(self.ln_2(x)) | |
x = x + m | |
return x, present | |
class Norm(nn.Module): | |
""" Normalize to mean = 0, std = 1, then do a diagonal affine transform. | |
""" | |
def __init__(self, n_features, *, dim=-1, epsilon=1e-5): | |
super().__init__() | |
self.n_features = n_features | |
self.dim = dim | |
self.epsilon = epsilon | |
self.g = nn.Parameter(torch.ones(n_features)) | |
self.b = nn.Parameter(torch.zeros(n_features)) | |
def forward(self, x): | |
assert x.shape[-1] == self.n_features | |
u = torch.mean(x, dim=self.dim, keepdim=True) | |
xmu = x - u | |
s = torch.mean(xmu * xmu, dim=self.dim, keepdim=True) | |
return xmu * torch.rsqrt(s + self.epsilon) * self.g + self.b | |
class MLP(nn.Module): | |
def __init__(self, n_features, n_hidden): | |
super().__init__() | |
self.c_fc = Conv1D(n_features, n_hidden) | |
self.c_proj = Conv1D(n_hidden, n_features) | |
def forward(self, x): | |
x = gelu(self.c_fc(x)) | |
x = self.c_proj(x) | |
return x | |
class Attention(nn.Module): | |
def __init__(self, hparams: HParams): | |
super().__init__() | |
assert hparams.n_hidden % hparams.n_head == 0 | |
self.hparams = hparams | |
self.c_attn = Conv1D(hparams.n_hidden, hparams.n_hidden * 3) | |
self.c_proj = Conv1D(hparams.n_hidden, hparams.n_hidden) | |
def forward(self, x, past): | |
assert len(x.shape) == 3 # [batch, sequence, features] | |
assert x.shape[-1] == self.hparams.n_hidden | |
if past is not None: | |
# Should be [batch, 2, heads, sequence, features], where 2 is [k, v] | |
assert len(past.shape) == 5 | |
assert past.shape[-1] == self.hparams.n_hidden | |
c = self.c_attn(x) | |
q, k, v = map(self.split_heads, torch.split(c, x.shape[-1], dim=2)) | |
present = torch.stack([k, v], dim=1) | |
if past is not None: | |
pk, pv = past[:, 0], past[:, 1] | |
k = torch.cat([pk, k], dim=-2) | |
v = torch.cat([pv, v], dim=-2) | |
a = self.multihead_attn(q, k, v) | |
a = self.merge_heads(a) | |
a = self.c_proj(a) | |
return a, present | |
def split_heads(self, x): | |
""" From [batch, sequence, features] to | |
[batch, heads, sequence, features]. | |
""" | |
return self.split_states(x, self.hparams.n_head).permute(0, 2, 1, 3) | |
def split_states(x, n): | |
""" Reshape the last dimension of x into [n, x.shape[-1]/n]. | |
""" | |
*start, m = x.shape | |
return x.reshape(start + [n, m // n]) | |
def merge_heads(self, x): | |
""" Reverse of split_heads. | |
""" | |
return self.merge_states(x.permute(0, 2, 1, 3)) | |
def merge_states(x): | |
""" Smash the last two dimensions of x into a single dimension. | |
""" | |
*start, a, b = x.shape | |
return x.reshape(start + [a * b]) | |
def mask_attn_weights(self, w): | |
# w has shape [batch, heads, dst_sequence, src_sequence], | |
# where information flows from src to dst. | |
_, _, nd, ns = w.shape | |
b = self.attention_mask(nd, ns, dtype=w.dtype, device=w.device) | |
b = b.reshape((1, 1, nd, ns)) | |
w = w * b - 1e4 * (1 - b) | |
return w | |
def attention_mask(nd, ns, *, dtype, device=None): | |
""" 1's in the lower triangle, counting from the lower right corner. | |
Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), | |
but doesn't produce garbage on TPUs. | |
""" | |
i = torch.arange(0, nd).unsqueeze(1) | |
j = torch.arange(ns) | |
return (i >= j - ns + nd).to(dtype=dtype, device=device) | |
def multihead_attn(self, q, k, v): | |
# q, k, v have shape [batch, heads, sequence, features] | |
w = torch.matmul(q, k.permute(0, 1, 3, 2)) | |
w = w / math.sqrt(v.shape[-1]) | |
w = self.mask_attn_weights(w) | |
w = F.softmax(w, dim=-1) | |
a = torch.matmul(w, v) | |
return a | |
class Conv1D(nn.Linear): | |
def reset_parameters(self): | |
nn.init.normal_(self.weight, std=0.02) | |
nn.init.zeros_(self.bias) | |
def gelu(x, c=math.sqrt(2 / math.pi)): | |
return 0.5 * x * (1 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3)))) | |
def position_for(batch_size, n_steps, past_length, device=None): | |
return (torch.arange(past_length, n_steps + past_length, device=device) | |
.unsqueeze(0).repeat(batch_size, 1)) | |