keithhon commited on
Commit
5fe0715
1 Parent(s): 60eb46a

Upload dalle/models/stage2/layers.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dalle/models/stage2/layers.py +140 -0
dalle/models/stage2/layers.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+ # Modified from minGPT (https://github.com/karpathy/minGPT)
7
+ # Copyright (c) 2020 Andrej Karpathy. All Rights Reserved.
8
+ # ------------------------------------------------------------------------------------
9
+
10
+ import math
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+
15
+
16
+ class GELU(nn.Module):
17
+ def __init__(self, use_approx=False):
18
+ super().__init__()
19
+ self.use_approx = use_approx
20
+
21
+ def forward(self, x):
22
+ if self.use_approx:
23
+ return x * torch.sigmoid(1.702 * x)
24
+ else:
25
+ return F.gelu(x)
26
+
27
+
28
+ class MultiHeadSelfAttention(nn.Module):
29
+
30
+ def __init__(self,
31
+ ctx_len: int,
32
+ embed_dim: int,
33
+ n_heads: int,
34
+ resid_pdrop: float,
35
+ attn_pdrop: float,
36
+ attn_bias: bool,
37
+ use_mask: bool = True):
38
+ super().__init__()
39
+ assert embed_dim % n_heads == 0
40
+
41
+ # key, query, value projections for all heads
42
+ self.key = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
43
+ self.query = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
44
+ self.value = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
45
+
46
+ # regularization
47
+ self.attn_drop = nn.Dropout(attn_pdrop)
48
+ self.resid_drop = nn.Dropout(resid_pdrop)
49
+
50
+ # output projection
51
+ self.proj = nn.Linear(embed_dim, embed_dim, attn_bias)
52
+
53
+ self.n_heads = n_heads
54
+ self.ctx_len = ctx_len
55
+ self.use_mask = use_mask
56
+ if self.use_mask:
57
+ self.register_buffer("mask", torch.ones(ctx_len, ctx_len), persistent=False)
58
+ self.mask = torch.tril(self.mask).view(1, ctx_len, ctx_len)
59
+
60
+ def forward(self, x, use_cache=False, layer_past=None):
61
+ B, T, C = x.shape
62
+ x = x.transpose(0, 1).contiguous() # (B, T, C) -> (T, B, C)
63
+
64
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
65
+ k = self.key(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
66
+ q = self.query(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
67
+ v = self.value(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
68
+
69
+ if use_cache:
70
+ present = torch.stack([k, v])
71
+
72
+ if layer_past is not None:
73
+ past_key, past_value = layer_past
74
+ k = torch.cat([past_key, k], dim=-2)
75
+ v = torch.cat([past_value, v], dim=-2)
76
+
77
+ if use_cache and layer_past is not None:
78
+ # Tensor shape below: (B * nh, 1, hs) X (B * nh, hs, K) -> (B * nh, 1, K)
79
+ att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
80
+ att = F.softmax(att, dim=-1)
81
+ att = self.attn_drop(att)
82
+ y = torch.bmm(att, v) # (B*nh, 1, K) X (B*nh, K, hs) -> (B*nh, 1, hs)
83
+ else:
84
+ # Tensor shape below: (B * nh, T, hs) X (B * nh, hs, T) -> (B * nh, T, T)
85
+ att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
86
+ if self.use_mask:
87
+ mask = self.mask if T == self.ctx_len else self.mask[:, :T, :T]
88
+ att = att.masked_fill(mask == 0, float('-inf'))
89
+ att = F.softmax(att, dim=-1)
90
+ att = self.attn_drop(att)
91
+ y = torch.bmm(att, v) # (B*nh, T, T) X (B*nh, T, hs) -> (B*nh, T, hs)
92
+ y = y.transpose(0, 1).contiguous().view(T, B, C) # re-assemble all head outputs side by side
93
+
94
+ # output projection
95
+ y = self.resid_drop(self.proj(y))
96
+ if use_cache:
97
+ return y.transpose(0, 1).contiguous(), present # (T, B, C) -> (B, T, C)
98
+ else:
99
+ return y.transpose(0, 1).contiguous() # (T, B, C) -> (B, T, C)
100
+
101
+
102
+ class Block(nn.Module):
103
+
104
+ def __init__(self,
105
+ ctx_len: int,
106
+ embed_dim: int,
107
+ n_heads: int,
108
+ mlp_bias: bool,
109
+ attn_bias: bool,
110
+ resid_pdrop: bool,
111
+ attn_pdrop: bool,
112
+ gelu_use_approx: bool):
113
+ super().__init__()
114
+ self.ln1 = nn.LayerNorm(embed_dim)
115
+ self.ln2 = nn.LayerNorm(embed_dim)
116
+
117
+ self.attn = MultiHeadSelfAttention(ctx_len=ctx_len,
118
+ embed_dim=embed_dim,
119
+ n_heads=n_heads,
120
+ attn_pdrop=attn_pdrop,
121
+ resid_pdrop=resid_pdrop,
122
+ attn_bias=attn_bias,
123
+ use_mask=True)
124
+ self.mlp = nn.Sequential(
125
+ nn.Linear(embed_dim, 4 * embed_dim, bias=mlp_bias),
126
+ GELU(gelu_use_approx),
127
+ nn.Linear(4 * embed_dim, embed_dim, bias=mlp_bias),
128
+ nn.Dropout(resid_pdrop),
129
+ )
130
+
131
+ def forward(self, x):
132
+ x = x + self.attn(self.ln1(x))
133
+ x = x + self.mlp(self.ln2(x))
134
+ return x
135
+
136
+ def sample(self, x, layer_past=None):
137
+ attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
138
+ x = x + attn
139
+ x = x + self.mlp(self.ln2(x))
140
+ return x, present