File size: 7,215 Bytes
6faf7e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
"""
OpenAI's GPT-2 ported to PyTorch.
"""
import math

import attr
import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.checkpoint


@attr.s(auto_attribs=True, frozen=True)
class HParams:
    n_vocab: int
    n_ctx: int
    n_embed: int
    n_hidden: int
    n_head: int
    n_layer: int
    gradient_checkpointing: bool = False


class Model(nn.Module):
    def __init__(self, hparams: HParams):
        super().__init__()
        self.hparams = hparams
        self.wpe = nn.Embedding(hparams.n_ctx, hparams.n_embed)
        nn.init.normal_(self.wpe.weight, std=0.01)
        self.wte = nn.Embedding(hparams.n_vocab, hparams.n_embed)
        nn.init.normal_(self.wte.weight, std=0.02)
        self.blocks = nn.ModuleList(
            [Block(hparams) for _ in range(hparams.n_layer)])
        self.ln_f = Norm(self.hparams.n_hidden)
        if hparams.n_hidden != hparams.n_embed:
            self.in_proj = Conv1D(hparams.n_embed, hparams.n_hidden)
            self.out_proj = Conv1D(hparams.n_hidden, hparams.n_embed)
        else:
            self.in_proj = self.out_proj = None

    def forward(self, x, past=None):
        # Embedding
        past_length = 0 if past is None else past.shape[-2]
        batch_size, n_ctx = x.shape
        position = position_for(batch_size, n_ctx, past_length, x.device)
        h = self.wte(x) + self.wpe(position)
        assert h.shape == (batch_size, n_ctx, self.hparams.n_embed)
        if self.in_proj:
            h = self.in_proj(h)
        # Transformer
        presents = []
        for i, block in enumerate(self.blocks):
            if self.hparams.gradient_checkpointing:
                h, present = torch.utils.checkpoint.checkpoint(
                    block, h, past[:, i] if past is not None else None)
            else:
                h, present = block(
                    h, past=past[:, i] if past is not None else None)
            presents.append(present)
        h = self.ln_f(h)
        if self.out_proj:
            h = self.out_proj(h)
        # Output logits
        h_flat = h.reshape([batch_size * n_ctx, self.hparams.n_embed])
        logits = torch.matmul(h_flat, self.wte.weight.t())
        logits = logits.reshape([batch_size, n_ctx, self.hparams.n_vocab])
        return {
            'presents': torch.stack(tuple(presents), dim=1),
            'logits': logits,
        }


class Block(nn.Module):
    def __init__(self, hparams: HParams):
        super().__init__()
        self.ln_1 = Norm(hparams.n_hidden)
        self.ln_2 = Norm(hparams.n_hidden)
        self.mlp = MLP(hparams.n_hidden, hparams.n_hidden * 4)
        self.attn = Attention(hparams)

    def forward(self, x, past):
        a, present = self.attn(self.ln_1(x), past=past)
        x = x + a
        m = self.mlp(self.ln_2(x))
        x = x + m
        return x, present


class Norm(nn.Module):
    """ Normalize to mean = 0, std = 1, then do a diagonal affine transform.
    """
    def __init__(self, n_features, *, dim=-1, epsilon=1e-5):
        super().__init__()
        self.n_features = n_features
        self.dim = dim
        self.epsilon = epsilon
        self.g = nn.Parameter(torch.ones(n_features))
        self.b = nn.Parameter(torch.zeros(n_features))

    def forward(self, x):
        assert x.shape[-1] == self.n_features
        u = torch.mean(x, dim=self.dim, keepdim=True)
        xmu = x - u
        s = torch.mean(xmu * xmu, dim=self.dim, keepdim=True)
        return xmu * torch.rsqrt(s + self.epsilon) * self.g + self.b


class MLP(nn.Module):
    def __init__(self, n_features, n_hidden):
        super().__init__()
        self.c_fc = Conv1D(n_features, n_hidden)
        self.c_proj = Conv1D(n_hidden, n_features)

    def forward(self, x):
        x = gelu(self.c_fc(x))
        x = self.c_proj(x)
        return x


class Attention(nn.Module):
    def __init__(self, hparams: HParams):
        super().__init__()
        assert hparams.n_hidden % hparams.n_head == 0
        self.hparams = hparams
        self.c_attn = Conv1D(hparams.n_hidden, hparams.n_hidden * 3)
        self.c_proj = Conv1D(hparams.n_hidden, hparams.n_hidden)

    def forward(self, x, past):
        assert len(x.shape) == 3  # [batch, sequence, features]
        assert x.shape[-1] == self.hparams.n_hidden
        if past is not None:
            # Should be [batch, 2, heads, sequence, features], where 2 is [k, v]
            assert len(past.shape) == 5
            assert past.shape[-1] == self.hparams.n_hidden
        c = self.c_attn(x)
        q, k, v = map(self.split_heads, torch.split(c, x.shape[-1], dim=2))
        present = torch.stack([k, v], dim=1)
        if past is not None:
            pk, pv = past[:, 0], past[:, 1]
            k = torch.cat([pk, k], dim=-2)
            v = torch.cat([pv, v], dim=-2)
        a = self.multihead_attn(q, k, v)
        a = self.merge_heads(a)
        a = self.c_proj(a)
        return a, present

    def split_heads(self, x):
        """ From [batch, sequence, features] to
        [batch, heads, sequence, features].
        """
        return self.split_states(x, self.hparams.n_head).permute(0, 2, 1, 3)

    @staticmethod
    def split_states(x, n):
        """ Reshape the last dimension of x into [n, x.shape[-1]/n].
        """
        *start, m = x.shape
        return x.reshape(start + [n, m // n])

    def merge_heads(self, x):
        """ Reverse of split_heads.
        """
        return self.merge_states(x.permute(0, 2, 1, 3))

    @staticmethod
    def merge_states(x):
        """ Smash the last two dimensions of x into a single dimension.
        """
        *start, a, b = x.shape
        return x.reshape(start + [a * b])

    def mask_attn_weights(self, w):
        # w has shape [batch, heads, dst_sequence, src_sequence],
        # where information flows from src to dst.
        _, _, nd, ns = w.shape
        b = self.attention_mask(nd, ns, dtype=w.dtype, device=w.device)
        b = b.reshape((1, 1, nd, ns))
        w = w * b - 1e4 * (1 - b)
        return w

    @staticmethod
    def attention_mask(nd, ns, *, dtype, device=None):
        """ 1's in the lower triangle, counting from the lower right corner.
        Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd),
        but doesn't produce garbage on TPUs.
        """
        i = torch.arange(0, nd).unsqueeze(1)
        j = torch.arange(ns)
        return (i >= j - ns + nd).to(dtype=dtype, device=device)

    def multihead_attn(self, q, k, v):
        # q, k, v have shape [batch, heads, sequence, features]
        w = torch.matmul(q, k.permute(0, 1, 3, 2))
        w = w / math.sqrt(v.shape[-1])
        w = self.mask_attn_weights(w)
        w = F.softmax(w, dim=-1)
        a = torch.matmul(w, v)
        return a


class Conv1D(nn.Linear):
    def reset_parameters(self):
        nn.init.normal_(self.weight, std=0.02)
        nn.init.zeros_(self.bias)


def gelu(x, c=math.sqrt(2 / math.pi)):
    return 0.5 * x * (1 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3))))


def position_for(batch_size, n_steps, past_length, device=None):
    return (torch.arange(past_length, n_steps + past_length, device=device)
            .unsqueeze(0).repeat(batch_size, 1))