shivendrra
commited on
Commit
•
7f4e854
1
Parent(s):
41885a6
added train and model files
Browse files- base/config.json +10 -0
- base/decoder.py +213 -0
- base/generate.py +120 -0
- base/model.py +279 -0
- base/run.py +101 -0
- base/tokenizer.py +19 -0
base/config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"batch_size": 10,
|
3 |
+
"block_size": 512,
|
4 |
+
"d_model": 512,
|
5 |
+
"n_heads": 8,
|
6 |
+
"n_layers": 8,
|
7 |
+
"dropout": 0.18,
|
8 |
+
"norm_eps": 1e-5,
|
9 |
+
"learning_rate": 3e-5
|
10 |
+
}
|
base/decoder.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
with open('config.json', 'r', encoding='utf-8') as file:
|
3 |
+
params = json.load(file)
|
4 |
+
|
5 |
+
# required parameters
|
6 |
+
block_size = params['block_size']
|
7 |
+
d_model = params['d_model']
|
8 |
+
n_head = params['n_heads']
|
9 |
+
n_layers = params['n_layers']
|
10 |
+
learning_rate = params['learning_rate']
|
11 |
+
dropout = params['dropout']
|
12 |
+
norm_eps = params['norm_eps']
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.nn import functional as F
|
17 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
18 |
+
|
19 |
+
class RMSNorm(nn.Module):
|
20 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
21 |
+
super().__init__()
|
22 |
+
self.eps = eps
|
23 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
24 |
+
|
25 |
+
def _norm(self, x):
|
26 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
output = self._norm(x.float()).type_as(x)
|
30 |
+
return output * self.weight
|
31 |
+
|
32 |
+
class MaskedHead(nn.Module):
|
33 |
+
def __init__(self,
|
34 |
+
head_size: int,
|
35 |
+
d_model: int,
|
36 |
+
block_size: int,
|
37 |
+
dropout: float):
|
38 |
+
super().__init__()
|
39 |
+
self.key = nn.Linear(d_model, head_size, bias=True)
|
40 |
+
self.query = nn.Linear(d_model, head_size, bias=True)
|
41 |
+
self.value = nn.Linear(d_model, head_size, bias=True)
|
42 |
+
self.dropout = nn.Dropout(dropout)
|
43 |
+
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
|
44 |
+
|
45 |
+
def forward(self, x: torch.Tensor):
|
46 |
+
B, T, C = x.shape
|
47 |
+
key = self.key(x)
|
48 |
+
query = self.query(x)
|
49 |
+
scores = torch.matmul(query ,key.transpose(-2, -1)) / (key.shape[-1]**-0.5)
|
50 |
+
|
51 |
+
scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
|
52 |
+
|
53 |
+
att_mat = F.softmax(scores, dim=-1)
|
54 |
+
att_mat = self.dropout(att_mat)
|
55 |
+
value = self.value(x)
|
56 |
+
output = torch.matmul(att_mat, value)
|
57 |
+
return output
|
58 |
+
|
59 |
+
class UnMaskedHead(nn.Module):
|
60 |
+
def __init__(self,
|
61 |
+
head_size: int,
|
62 |
+
d_model: int,
|
63 |
+
block_size: int,
|
64 |
+
dropout: float):
|
65 |
+
super().__init__()
|
66 |
+
self.key = nn.Linear(d_model, head_size, bias=True)
|
67 |
+
self.query = nn.Linear(d_model, head_size, bias=True)
|
68 |
+
self.value = nn.Linear(d_model, head_size, bias=True)
|
69 |
+
self.dropout = nn.Dropout(dropout)
|
70 |
+
self.rel_pos_embd = nn.Parameter(torch.randn(block_size, block_size, head_size))
|
71 |
+
|
72 |
+
def forward(self, x: torch.Tensor):
|
73 |
+
B, T, C = x.shape
|
74 |
+
key = self.key(x)
|
75 |
+
query = self.query(x)
|
76 |
+
scores = torch.matmul(query ,key.transpose(-2, -1)) / (key.shape[-1]**-0.5)
|
77 |
+
|
78 |
+
rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_embd[:T, :T])
|
79 |
+
scores = scores + rel_pos_scores
|
80 |
+
|
81 |
+
att_mat = F.softmax(scores, dim=-1)
|
82 |
+
att_mat = self.dropout(att_mat)
|
83 |
+
value = self.value(x)
|
84 |
+
output = torch.matmul(att_mat, value)
|
85 |
+
return output
|
86 |
+
|
87 |
+
class MaskedAttention(nn.Module):
|
88 |
+
def __init__(self,
|
89 |
+
d_model: int,
|
90 |
+
block_size: int,
|
91 |
+
n_head : int,
|
92 |
+
dropout: float):
|
93 |
+
head_size = d_model // n_head
|
94 |
+
super().__init__()
|
95 |
+
self.heads = nn.ModuleList([MaskedHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])
|
96 |
+
self.projection = nn.Linear(d_model, d_model)
|
97 |
+
self.dropout = nn.Dropout(dropout)
|
98 |
+
|
99 |
+
def forward(self, x: torch.Tensor):
|
100 |
+
out = torch.cat([h(x) for h in self.heads], dim=-1)
|
101 |
+
out = self.dropout(self.projection(out))
|
102 |
+
return out
|
103 |
+
|
104 |
+
class UnMaskedAttention(nn.Module):
|
105 |
+
def __init__(self,
|
106 |
+
d_model: int,
|
107 |
+
block_size: int,
|
108 |
+
n_head : int,
|
109 |
+
dropout: float):
|
110 |
+
head_size = d_model // n_head
|
111 |
+
super().__init__()
|
112 |
+
self.heads = nn.ModuleList([UnMaskedHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])
|
113 |
+
self.projection = nn.Linear(d_model, d_model)
|
114 |
+
self.dropout = nn.Dropout(dropout)
|
115 |
+
|
116 |
+
def forward(self, x: torch.Tensor):
|
117 |
+
out = torch.cat([h(x) for h in self.heads], dim=-1)
|
118 |
+
out = self.dropout(self.projection(out))
|
119 |
+
return out
|
120 |
+
|
121 |
+
class FeedForward(nn.Module):
|
122 |
+
def __init__(self, d_model, dropout):
|
123 |
+
super().__init__()
|
124 |
+
self.net = nn.Sequential(
|
125 |
+
nn.Linear(d_model, 5 * d_model),
|
126 |
+
nn.GELU(),
|
127 |
+
nn.Linear(5 * d_model, 5 * d_model),
|
128 |
+
nn.Dropout(dropout),
|
129 |
+
nn.GELU(),
|
130 |
+
nn.Linear(5 * d_model, d_model),
|
131 |
+
nn.Dropout(dropout),
|
132 |
+
)
|
133 |
+
|
134 |
+
def forward(self, x: torch.Tensor):
|
135 |
+
return self.net(x)
|
136 |
+
|
137 |
+
class DecoderBlock(nn.Module):
|
138 |
+
def __init__(self, d_model: int,
|
139 |
+
block_size: int,
|
140 |
+
n_head: int,
|
141 |
+
norm_eps: float,
|
142 |
+
dropout: float):
|
143 |
+
super().__init__()
|
144 |
+
self.m_att = MaskedAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
|
145 |
+
self.um_att = UnMaskedAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
|
146 |
+
self.ffwd = FeedForward(d_model, dropout)
|
147 |
+
self.dropout = nn.Dropout(dropout)
|
148 |
+
self.norm = RMSNorm(d_model, eps=norm_eps)
|
149 |
+
|
150 |
+
def forward(self, x: torch.Tensor):
|
151 |
+
x_out = self.m_att(self.norm(x))
|
152 |
+
x_out = x + self.dropout(x_out)
|
153 |
+
del x
|
154 |
+
|
155 |
+
x = self.um_att(self.norm(x_out))
|
156 |
+
x = x_out + self.dropout(x)
|
157 |
+
del x_out
|
158 |
+
|
159 |
+
x_out = self.ffwd(self.norm(x))
|
160 |
+
x_out = x + self.dropout(x_out)
|
161 |
+
del x
|
162 |
+
|
163 |
+
return x_out
|
164 |
+
|
165 |
+
class Transformer(nn.Module):
|
166 |
+
def __init__(self, vocab_size: int):
|
167 |
+
super().__init__()
|
168 |
+
self.block_size = block_size
|
169 |
+
self.token_embeddings = nn.Embedding(vocab_size, d_model)
|
170 |
+
self.pos_encodings = nn.Embedding(block_size, d_model)
|
171 |
+
self.decoder = nn.Sequential(*[DecoderBlock(n_head=n_head, d_model=d_model, dropout=dropout, norm_eps=norm_eps, block_size=block_size) for _ in range(n_layers)])
|
172 |
+
self.norm_final = RMSNorm(d_model, eps=norm_eps)
|
173 |
+
self.linear_final = nn.Linear(d_model, vocab_size)
|
174 |
+
self.dropout = nn.Dropout(dropout)
|
175 |
+
self.apply(self._init_weights)
|
176 |
+
|
177 |
+
def _init_weights(self, module):
|
178 |
+
if isinstance(module, nn.Linear):
|
179 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
180 |
+
if module.bias is not None:
|
181 |
+
torch.nn.init.zeros_(module.bias.data)
|
182 |
+
elif isinstance(module, nn.Embedding):
|
183 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
184 |
+
|
185 |
+
def forward(self, idx, targets=None):
|
186 |
+
B, T = idx.shape
|
187 |
+
toked_model = self.token_embeddings(idx)
|
188 |
+
pos_encod = self.pos_encodings(torch.arange(T, device=device))
|
189 |
+
x = toked_model + pos_encod
|
190 |
+
|
191 |
+
x = self.decoder(x)
|
192 |
+
logits = self.linear_final(self.norm_final(x))
|
193 |
+
|
194 |
+
if targets is None:
|
195 |
+
loss = None
|
196 |
+
|
197 |
+
else:
|
198 |
+
B, T, C = logits.shape
|
199 |
+
logits = logits.view(B*T, C)
|
200 |
+
targets = targets.view(B*T)
|
201 |
+
loss = F.cross_entropy(logits, targets)
|
202 |
+
|
203 |
+
return logits, loss
|
204 |
+
|
205 |
+
def generate(self, idx: torch.Tensor, max_token: int=10):
|
206 |
+
for _ in range(max_token):
|
207 |
+
idx_cond = idx[:, -self.block_size:]
|
208 |
+
logits = self(idx_cond)
|
209 |
+
logits = logits[:, -1, :]
|
210 |
+
probs = F.softmax(logits, dim=-1)
|
211 |
+
idx_next = torch.argmax(probs, dim=-1)
|
212 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
213 |
+
return idx
|
base/generate.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
current_directory = os.path.dirname(os.path.abspath(__file__))
|
3 |
+
os.chdir(current_directory)
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
9 |
+
|
10 |
+
from tokenizer import Tokenizer
|
11 |
+
tokenizer = Tokenizer()
|
12 |
+
vocab_size = tokenizer.get_vocab()
|
13 |
+
|
14 |
+
from model import Transformer
|
15 |
+
model = Transformer(vocab_size)
|
16 |
+
checkpoint_path = '/content/drive/MyDrive/base-500m.pth'
|
17 |
+
checkpoint = torch.load(checkpoint_path)
|
18 |
+
model.load_state_dict(checkpoint)
|
19 |
+
m = model.to(device)
|
20 |
+
|
21 |
+
class Generate:
|
22 |
+
def __init__(self):
|
23 |
+
self.vocab_size = vocab_size
|
24 |
+
self.block_size = m.block_size
|
25 |
+
|
26 |
+
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):
|
27 |
+
"""
|
28 |
+
generate new tokens using the trained model
|
29 |
+
|
30 |
+
Args:
|
31 |
+
- idx (Tensor): input tensor representing initial token indices
|
32 |
+
- max_new_tokens (int): max no of new tokens to generate
|
33 |
+
- temperature (float): softmax temperature for sampling
|
34 |
+
- top_k (int): no of top tokens to consider in sampling
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
- generated_tokens (list): list of generated token indices
|
38 |
+
"""
|
39 |
+
generated_tokens = []
|
40 |
+
|
41 |
+
for _ in range(max_new_tokens):
|
42 |
+
idx_cond = idx[:, -m.block_size:]
|
43 |
+
logits, _ = self(idx_cond)
|
44 |
+
logits = logits[:, -1, :]
|
45 |
+
|
46 |
+
scaled_logits = logits / temperature
|
47 |
+
if top_k > 0:
|
48 |
+
scaled_logits = self._top_k_filtering(scaled_logits, top_k)
|
49 |
+
|
50 |
+
probs = F.softmax(scaled_logits, dim=-1)
|
51 |
+
sampled_idx = torch.multinomial(probs, num_samples=1)
|
52 |
+
generated_tokens.append(sampled_idx.item())
|
53 |
+
idx = torch.cat((idx, sampled_idx), dim=1)
|
54 |
+
|
55 |
+
return generated_tokens
|
56 |
+
|
57 |
+
def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):
|
58 |
+
"""
|
59 |
+
Generate predictions for masked tokens using the trained model.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
- idx (Tensor): input tensor representing token indices
|
63 |
+
- masked_indices (Tensor): tensor of indices indicating masked positions
|
64 |
+
- temperature (float): softmax temperature for sampling
|
65 |
+
- top_k (int): no of top tokens to consider in sampling
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
- predicted_tokens (Tensor): tensor of predicted token indices
|
69 |
+
"""
|
70 |
+
B, T = idx.shape
|
71 |
+
|
72 |
+
toked_model = m.toked_model(idx)
|
73 |
+
pos_encod = m.pos_encod(torch.arange(T, device=device))
|
74 |
+
x = toked_model + pos_encod
|
75 |
+
|
76 |
+
for layer in m.enc_layer:
|
77 |
+
x_out = layer(x)
|
78 |
+
|
79 |
+
for layer in m.dec_layer:
|
80 |
+
x_final = layer(x, x_out)
|
81 |
+
|
82 |
+
x_masked = x_final.clone()
|
83 |
+
x_masked[masked_indices] = m.toked_model(torch.tensor([6], device=device))
|
84 |
+
|
85 |
+
x_masked = m.norm_final(x_masked)
|
86 |
+
logits = m.linear_final(x_masked)
|
87 |
+
|
88 |
+
masked_logits = logits[masked_indices].view(-1, logits.size(-1))
|
89 |
+
scaled_logits = masked_logits / temperature
|
90 |
+
if top_k > 0:
|
91 |
+
scaled_logits = self._top_k_filtering(scaled_logits, top_k)
|
92 |
+
|
93 |
+
probs = F.softmax(scaled_logits, dim=-1)
|
94 |
+
predicted_indices = torch.argmax(probs, dim=-1)
|
95 |
+
|
96 |
+
return predicted_indices
|
97 |
+
|
98 |
+
def _top_k_filtering(self, logits, top_k):
|
99 |
+
"""
|
100 |
+
filter logits to keep only the top-k tokens
|
101 |
+
|
102 |
+
Args:
|
103 |
+
- logits (Tensor): input tensor representing unscaled logits
|
104 |
+
- top_k (int): no of top tokens to keep
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
- filtered_logits (Tensor): filtered logits with only top-k tokens remaining
|
108 |
+
"""
|
109 |
+
values, indices = torch.topk(logits, top_k, dim=-1)
|
110 |
+
min_value = values[:, -1].unsqueeze(-1).expand_as(logits)
|
111 |
+
filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)
|
112 |
+
|
113 |
+
return filtered_logits
|
114 |
+
|
115 |
+
generator = Generate()
|
116 |
+
|
117 |
+
target_text = "I was in the market when"
|
118 |
+
context = torch.tensor([tokenizer.encode(target_text)], dtype=torch.long, device=device)
|
119 |
+
generated_output = tokenizer.decode(generator.generate(context, max_new_tokens=50))
|
120 |
+
print(target_text, generated_output)
|
base/model.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
with open('config.json', 'r', encoding='utf-8') as file:
|
3 |
+
params = json.load(file)
|
4 |
+
|
5 |
+
# required parameters
|
6 |
+
block_size = params['block_size']
|
7 |
+
d_model = params['d_model']
|
8 |
+
n_head = params['n_heads']
|
9 |
+
n_layers = params['n_layers']
|
10 |
+
learning_rate = params['learning_rate']
|
11 |
+
dropout = params['dropout']
|
12 |
+
norm_eps = params['norm_eps']
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.nn import functional as F
|
17 |
+
|
18 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
19 |
+
|
20 |
+
class RMSNorm(nn.Module):
|
21 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
22 |
+
"""
|
23 |
+
Initialize the RMSNorm normalization layer.
|
24 |
+
Args:
|
25 |
+
dim (int): The dimension of the input tensor.
|
26 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
27 |
+
Attributes:
|
28 |
+
eps (float): A small value added to the denominator for numerical stability.
|
29 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
30 |
+
"""
|
31 |
+
super().__init__()
|
32 |
+
self.eps = eps
|
33 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
34 |
+
|
35 |
+
def _norm(self, x):
|
36 |
+
"""
|
37 |
+
Apply the RMSNorm normalization to the input tensor.
|
38 |
+
Args:
|
39 |
+
x (torch.Tensor): The input tensor.
|
40 |
+
Returns:
|
41 |
+
torch.Tensor: The normalized tensor.
|
42 |
+
"""
|
43 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
"""
|
47 |
+
Forward pass through the RMSNorm layer.
|
48 |
+
Args:
|
49 |
+
x (torch.Tensor): The input tensor.
|
50 |
+
Returns:
|
51 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
52 |
+
"""
|
53 |
+
output = self._norm(x.float()).type_as(x)
|
54 |
+
return output * self.weight
|
55 |
+
|
56 |
+
class UnMaskedHead(nn.Module):
|
57 |
+
def __init__(self, head_size, d_model, block_size, dropout):
|
58 |
+
super().__init__()
|
59 |
+
self.key = nn.Linear(d_model, head_size, bias=True)
|
60 |
+
self.query = nn.Linear(d_model, head_size, bias=True)
|
61 |
+
self.value = nn.Linear(d_model, head_size, bias=True)
|
62 |
+
self.dropout = nn.Dropout(dropout)
|
63 |
+
self.rel_pos_embd = nn.Parameter(torch.randn(block_size, block_size, head_size))
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
B, T, C = x.shape
|
67 |
+
key = self.key(x)
|
68 |
+
query = self.query(x)
|
69 |
+
|
70 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)
|
71 |
+
rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_embd[:T, :T])
|
72 |
+
scores = scores + rel_pos_scores
|
73 |
+
|
74 |
+
att_mat = F.softmax(scores, dim=-1)
|
75 |
+
att_mat = self.dropout(att_mat)
|
76 |
+
value = self.value(x)
|
77 |
+
output = torch.matmul(att_mat, value)
|
78 |
+
return output
|
79 |
+
|
80 |
+
class UnMaskedAttention(nn.Module):
|
81 |
+
def __init__(self, d_model, block_size, dropout, n_head):
|
82 |
+
head_size = d_model // n_head
|
83 |
+
super().__init__()
|
84 |
+
self.heads = nn.ModuleList([UnMaskedHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])
|
85 |
+
self.proj = nn.Linear(n_head * head_size, d_model)
|
86 |
+
self.dropout = nn.Dropout(dropout)
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
out = torch.cat([h(x) for h in self.heads], dim=-1)
|
90 |
+
out = self.dropout(self.proj(out))
|
91 |
+
return out
|
92 |
+
|
93 |
+
class MaskedHead(nn.Module):
|
94 |
+
def __init__(self, d_model, head_size, dropout, block_size):
|
95 |
+
super().__init__()
|
96 |
+
self.key = nn.Linear(d_model, head_size, bias=False)
|
97 |
+
self.query = nn.Linear(d_model, head_size, bias=False)
|
98 |
+
self.value = nn.Linear(d_model, head_size, bias=False)
|
99 |
+
self.dropout = nn.Dropout(dropout)
|
100 |
+
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
B, T, C = x.shape
|
104 |
+
key = self.key(x)
|
105 |
+
query = self.query(x)
|
106 |
+
|
107 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)
|
108 |
+
scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
|
109 |
+
|
110 |
+
att_mat = F.softmax(scores, dim=-1)
|
111 |
+
att_mat = self.dropout(att_mat)
|
112 |
+
value = self.value(x)
|
113 |
+
output = torch.matmul(att_mat, value)
|
114 |
+
return output
|
115 |
+
|
116 |
+
class CasualMaskedAttention(nn.Module):
|
117 |
+
def __init__(self, d_model, block_size, dropout, n_head):
|
118 |
+
head_size = d_model // n_head
|
119 |
+
super().__init__()
|
120 |
+
self.heads = nn.ModuleList([MaskedHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])
|
121 |
+
self.proj = nn.Linear(n_head * head_size, d_model)
|
122 |
+
self.dropout = nn.Dropout(dropout)
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
out = torch.cat([h(x) for h in self.heads], dim=-1)
|
126 |
+
out = self.dropout(self.proj(out))
|
127 |
+
return out
|
128 |
+
|
129 |
+
class FinalHead(nn.Module):
|
130 |
+
def __init__(self, d_model, head_size, dropout, block_size):
|
131 |
+
super().__init__()
|
132 |
+
self.key = nn.Linear(d_model, head_size, bias=True)
|
133 |
+
self.query = nn.Linear(d_model, head_size, bias=True)
|
134 |
+
self.value = nn.Linear(d_model, head_size, bias=True)
|
135 |
+
self.dropout = nn.Dropout(dropout)
|
136 |
+
|
137 |
+
def forward(self, x, att):
|
138 |
+
B, T, C = x.shape
|
139 |
+
key = self.key(att)
|
140 |
+
query = self.query(att)
|
141 |
+
|
142 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)
|
143 |
+
|
144 |
+
att_mat = F.softmax(scores, dim=-1)
|
145 |
+
att_mat = self.dropout(att_mat)
|
146 |
+
value = self.value(x)
|
147 |
+
output = torch.matmul(att_mat, value)
|
148 |
+
return output
|
149 |
+
|
150 |
+
class FinalAttention(nn.Module):
|
151 |
+
def __init__(self, d_model, block_size, dropout, n_head):
|
152 |
+
head_size = d_model // n_head
|
153 |
+
super().__init__()
|
154 |
+
self.heads = nn.ModuleList([FinalHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])
|
155 |
+
self.proj = nn.Linear(n_head * head_size, d_model)
|
156 |
+
self.dropout = nn.Dropout(dropout)
|
157 |
+
|
158 |
+
def forward(self, x, att):
|
159 |
+
out = torch.cat([h(x, att) for h in self.heads], dim=-1)
|
160 |
+
out = self.dropout(self.proj(out))
|
161 |
+
return out
|
162 |
+
|
163 |
+
class FeedForward(nn.Module):
|
164 |
+
def __init__(self, d_model, dropout):
|
165 |
+
super().__init__()
|
166 |
+
self.net = nn.Sequential(
|
167 |
+
nn.Linear(d_model, 4*d_model),
|
168 |
+
nn.GELU(),
|
169 |
+
nn.Linear(4*d_model, d_model),
|
170 |
+
nn.Dropout(dropout)
|
171 |
+
)
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
return self.net(x)
|
175 |
+
|
176 |
+
class EncoderNetwork(nn.Module):
|
177 |
+
def __init__(self, d_model, n_head, norm_eps, dropout, block_size):
|
178 |
+
super().__init__()
|
179 |
+
self.s_att = UnMaskedAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
|
180 |
+
self.ffwd = FeedForward(d_model, dropout)
|
181 |
+
self.dropout = nn.Dropout(dropout)
|
182 |
+
self.norm = RMSNorm(d_model, eps=norm_eps)
|
183 |
+
|
184 |
+
def forward(self, src):
|
185 |
+
src = self.norm(src)
|
186 |
+
src_out = src + self.dropout(self.s_att(src))
|
187 |
+
|
188 |
+
src = self.norm(src_out)
|
189 |
+
src_f = src + self.dropout(self.ffwd(src))
|
190 |
+
|
191 |
+
del src_out, src
|
192 |
+
return src_f
|
193 |
+
|
194 |
+
class DecoderNetwork(nn.Module):
|
195 |
+
def __init__(self, d_model, n_head, norm_eps, dropout, block_size):
|
196 |
+
super().__init__()
|
197 |
+
self.m_att = CasualMaskedAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
|
198 |
+
self.f_att = FinalAttention(d_model=d_model, n_head=n_head, dropout=dropout, block_size=block_size)
|
199 |
+
self.ffwd = FeedForward(d_model, dropout)
|
200 |
+
self.dropout = nn.Dropout(dropout)
|
201 |
+
self.norm = RMSNorm(d_model, eps=norm_eps)
|
202 |
+
|
203 |
+
def forward(self, src, att):
|
204 |
+
m_att_out = self.norm(src)
|
205 |
+
m_out = src + self.dropout(self.m_att(m_att_out))
|
206 |
+
|
207 |
+
f_out = self.f_att(m_out, self.norm(att))
|
208 |
+
f_out = m_out + self.dropout(f_out)
|
209 |
+
|
210 |
+
src_f = self.norm(f_out)
|
211 |
+
src_f = f_out + self.dropout(self.ffwd(src_f))
|
212 |
+
|
213 |
+
del f_out, m_out, m_att_out, src, att
|
214 |
+
return src_f
|
215 |
+
|
216 |
+
class Transformer(nn.Module):
|
217 |
+
def __init__(self, vocab_size):
|
218 |
+
super().__init__()
|
219 |
+
self.block_size = block_size
|
220 |
+
self.toked_model = nn.Embedding(vocab_size, d_model)
|
221 |
+
self.pos_encod = nn.Embedding(block_size, d_model)
|
222 |
+
self.enc_layer = nn.ModuleList([EncoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])
|
223 |
+
self.dec_layer = nn.ModuleList([DecoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])
|
224 |
+
self.norm_final = RMSNorm(d_model, eps=norm_eps)
|
225 |
+
self.linear_final = nn.Linear(d_model, vocab_size)
|
226 |
+
self.dropout = nn.Dropout(dropout)
|
227 |
+
self.apply(self._init_weights)
|
228 |
+
|
229 |
+
def _init_weights(self, module):
|
230 |
+
"""
|
231 |
+
initialize weights of linear and embedding layers
|
232 |
+
|
233 |
+
Args:
|
234 |
+
- module (nn.Module): the module to initialize weights for
|
235 |
+
"""
|
236 |
+
if isinstance(module, nn.Linear):
|
237 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
238 |
+
if module.bias is not None:
|
239 |
+
torch.nn.init.zeros_(module.bias.data)
|
240 |
+
elif isinstance(module, nn.Embedding):
|
241 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
242 |
+
|
243 |
+
def forward(self, idx, targets=None):
|
244 |
+
"""
|
245 |
+
forward pass of the transformer model
|
246 |
+
|
247 |
+
Args:
|
248 |
+
- idx (Tensor): input tensor representing token indices
|
249 |
+
- targets (Tensor): target tensor for computing loss during training
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
- logits (Tensor): output logits from the final linear layer
|
253 |
+
- loss (Tensor): optional. computed cross-entropy loss if targets are provided, else None
|
254 |
+
"""
|
255 |
+
B, T = idx.shape
|
256 |
+
|
257 |
+
toked_model = self.toked_model(idx)
|
258 |
+
pos_encod = self.pos_encod(torch.arange(T, device=device))
|
259 |
+
x = toked_model + pos_encod
|
260 |
+
|
261 |
+
for layer in self.enc_layer:
|
262 |
+
x_out = layer(x)
|
263 |
+
|
264 |
+
for layer in self.dec_layer:
|
265 |
+
x_final = layer(x, x_out)
|
266 |
+
|
267 |
+
x_final = self.norm_final(x_final)
|
268 |
+
logits = self.linear_final(x_final)
|
269 |
+
|
270 |
+
if targets is None:
|
271 |
+
loss = None
|
272 |
+
|
273 |
+
else:
|
274 |
+
B, T, C = logits.shape
|
275 |
+
logits = logits.view(B*T, C)
|
276 |
+
targets = targets.view(B*T)
|
277 |
+
loss = F.cross_entropy(logits, targets)
|
278 |
+
|
279 |
+
return logits, loss
|
base/run.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
use this file to train the model
|
3 |
+
|
4 |
+
working:
|
5 |
+
- imports vatious dependencies first, and then loads the training data
|
6 |
+
- tokenizes it, per-character basis
|
7 |
+
- loads the required hyper-parameters and the model file
|
8 |
+
- trains it till 'max_iters' and saves the model state, and generates outputs
|
9 |
+
|
10 |
+
with the current set configuration, model can reach upto ~60million parameters
|
11 |
+
and can become ~99% accurate with next token prediction
|
12 |
+
"""
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
current_directory = os.path.dirname(os.path.abspath(__file__))
|
18 |
+
os.chdir(current_directory)
|
19 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
20 |
+
|
21 |
+
with open('../datasets/wiki_176m.txt', 'r', encoding='utf-8') as file:
|
22 |
+
data = file.read()
|
23 |
+
|
24 |
+
print(f"{(len(data)/1e6):.2f} million letters")
|
25 |
+
|
26 |
+
from tokenizer import Tokenizer
|
27 |
+
|
28 |
+
tokenizer = Tokenizer()
|
29 |
+
vocab_size = tokenizer.get_vocab()
|
30 |
+
|
31 |
+
# Train and test splits
|
32 |
+
data = torch.tensor(tokenizer.encode(data), dtype=torch.long)
|
33 |
+
n = int(0.9*len(data)) # first 90% will be train, rest val
|
34 |
+
train_data = data[:n]
|
35 |
+
val_data = data[n:]
|
36 |
+
|
37 |
+
with open('config.json', 'r', encoding='utf-8') as file:
|
38 |
+
params = json.load(file)
|
39 |
+
|
40 |
+
# required parameters
|
41 |
+
batch_size = params['batch_size']
|
42 |
+
block_size = params['block_size']
|
43 |
+
max_iters = 1000
|
44 |
+
eval_interval = 100
|
45 |
+
eval_iters = 200
|
46 |
+
learning_rate = params['learning_rate']
|
47 |
+
|
48 |
+
torch.manual_seed(1400)
|
49 |
+
# data loading
|
50 |
+
def get_batch(split):
|
51 |
+
# generate a small batch of data of inputs x and targets y
|
52 |
+
data = train_data if split == 'train' else val_data
|
53 |
+
ix = torch.randint(len(data) - block_size, (batch_size,))
|
54 |
+
x = torch.stack([data[i:i+block_size] for i in ix])
|
55 |
+
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
|
56 |
+
x, y = x.to(device), y.to(device)
|
57 |
+
return x, y
|
58 |
+
|
59 |
+
@torch.no_grad()
|
60 |
+
def estimate_loss():
|
61 |
+
out = {}
|
62 |
+
model.eval()
|
63 |
+
for split in ['train', 'val']:
|
64 |
+
losses = torch.zeros(eval_iters)
|
65 |
+
for k in range(eval_iters):
|
66 |
+
X, Y = get_batch(split)
|
67 |
+
logits, loss = model(X, Y)
|
68 |
+
losses[k] = loss.item()
|
69 |
+
out[split] = losses.mean()
|
70 |
+
model.train()
|
71 |
+
return out
|
72 |
+
|
73 |
+
from model import Transformer
|
74 |
+
model = Transformer(vocab_size)
|
75 |
+
m = model.to(device)
|
76 |
+
|
77 |
+
# no of parameters
|
78 |
+
n_param = sum(p.numel() for p in m.parameters())/1e6
|
79 |
+
print(f"{n_param:.2f} million")
|
80 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
81 |
+
steps = []
|
82 |
+
train_losses = []
|
83 |
+
val_losses = []
|
84 |
+
|
85 |
+
for iter in range(max_iters):
|
86 |
+
|
87 |
+
if iter % eval_interval == 0 or iter == max_iters - 1:
|
88 |
+
losses = estimate_loss()
|
89 |
+
print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
|
90 |
+
|
91 |
+
steps.append(iter)
|
92 |
+
train_losses.append(losses['train'])
|
93 |
+
val_losses.append(losses['val'])
|
94 |
+
|
95 |
+
xb, yb = get_batch('train')
|
96 |
+
logits, loss = model(xb, yb)
|
97 |
+
optimizer.zero_grad(set_to_none=True)
|
98 |
+
loss.backward()
|
99 |
+
optimizer.step()
|
100 |
+
|
101 |
+
torch.save(model.state_dict(), f'enigma_{n_param:.0f}m.pth')
|
base/tokenizer.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tiktoken
|
2 |
+
|
3 |
+
pre_encodings = 'p50k_base'
|
4 |
+
pre_model = 'text-davinci-003'
|
5 |
+
class Tokenizer:
|
6 |
+
def __init__(self, encoding=None, model=None):
|
7 |
+
self.encodings = encoding if encoding is not None else pre_encodings
|
8 |
+
self.model = model if model is not None else pre_model
|
9 |
+
self.tokenizer = tiktoken.get_encoding(self.encodings)
|
10 |
+
self.tokenizer = tiktoken.encoding_for_model(self.model)
|
11 |
+
|
12 |
+
def encode(self, data):
|
13 |
+
return self.tokenizer.encode(data)
|
14 |
+
|
15 |
+
def decode(self, tokens):
|
16 |
+
return self.tokenizer.decode(tokens)
|
17 |
+
|
18 |
+
def get_vocab(self):
|
19 |
+
return self.tokenizer.n_vocab
|