shivendrra commited on
Commit
7f4e854
1 Parent(s): 41885a6

added train and model files

Browse files
Files changed (6) hide show
  1. base/config.json +10 -0
  2. base/decoder.py +213 -0
  3. base/generate.py +120 -0
  4. base/model.py +279 -0
  5. base/run.py +101 -0
  6. base/tokenizer.py +19 -0
base/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "batch_size": 10,
3
+ "block_size": 512,
4
+ "d_model": 512,
5
+ "n_heads": 8,
6
+ "n_layers": 8,
7
+ "dropout": 0.18,
8
+ "norm_eps": 1e-5,
9
+ "learning_rate": 3e-5
10
+ }
base/decoder.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ with open('config.json', 'r', encoding='utf-8') as file:
3
+ params = json.load(file)
4
+
5
+ # required parameters
6
+ block_size = params['block_size']
7
+ d_model = params['d_model']
8
+ n_head = params['n_heads']
9
+ n_layers = params['n_layers']
10
+ learning_rate = params['learning_rate']
11
+ dropout = params['dropout']
12
+ norm_eps = params['norm_eps']
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+
19
+ class RMSNorm(nn.Module):
20
+ def __init__(self, dim: int, eps: float = 1e-6):
21
+ super().__init__()
22
+ self.eps = eps
23
+ self.weight = nn.Parameter(torch.ones(dim))
24
+
25
+ def _norm(self, x):
26
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
27
+
28
+ def forward(self, x):
29
+ output = self._norm(x.float()).type_as(x)
30
+ return output * self.weight
31
+
32
+ class MaskedHead(nn.Module):
33
+ def __init__(self,
34
+ head_size: int,
35
+ d_model: int,
36
+ block_size: int,
37
+ dropout: float):
38
+ super().__init__()
39
+ self.key = nn.Linear(d_model, head_size, bias=True)
40
+ self.query = nn.Linear(d_model, head_size, bias=True)
41
+ self.value = nn.Linear(d_model, head_size, bias=True)
42
+ self.dropout = nn.Dropout(dropout)
43
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
44
+
45
+ def forward(self, x: torch.Tensor):
46
+ B, T, C = x.shape
47
+ key = self.key(x)
48
+ query = self.query(x)
49
+ scores = torch.matmul(query ,key.transpose(-2, -1)) / (key.shape[-1]**-0.5)
50
+
51
+ scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
52
+
53
+ att_mat = F.softmax(scores, dim=-1)
54
+ att_mat = self.dropout(att_mat)
55
+ value = self.value(x)
56
+ output = torch.matmul(att_mat, value)
57
+ return output
58
+
59
+ class UnMaskedHead(nn.Module):
60
+ def __init__(self,
61
+ head_size: int,
62
+ d_model: int,
63
+ block_size: int,
64
+ dropout: float):
65
+ super().__init__()
66
+ self.key = nn.Linear(d_model, head_size, bias=True)
67
+ self.query = nn.Linear(d_model, head_size, bias=True)
68
+ self.value = nn.Linear(d_model, head_size, bias=True)
69
+ self.dropout = nn.Dropout(dropout)
70
+ self.rel_pos_embd = nn.Parameter(torch.randn(block_size, block_size, head_size))
71
+
72
+ def forward(self, x: torch.Tensor):
73
+ B, T, C = x.shape
74
+ key = self.key(x)
75
+ query = self.query(x)
76
+ scores = torch.matmul(query ,key.transpose(-2, -1)) / (key.shape[-1]**-0.5)
77
+
78
+ rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_embd[:T, :T])
79
+ scores = scores + rel_pos_scores
80
+
81
+ att_mat = F.softmax(scores, dim=-1)
82
+ att_mat = self.dropout(att_mat)
83
+ value = self.value(x)
84
+ output = torch.matmul(att_mat, value)
85
+ return output
86
+
87
+ class MaskedAttention(nn.Module):
88
+ def __init__(self,
89
+ d_model: int,
90
+ block_size: int,
91
+ n_head : int,
92
+ dropout: float):
93
+ head_size = d_model // n_head
94
+ super().__init__()
95
+ self.heads = nn.ModuleList([MaskedHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])
96
+ self.projection = nn.Linear(d_model, d_model)
97
+ self.dropout = nn.Dropout(dropout)
98
+
99
+ def forward(self, x: torch.Tensor):
100
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
101
+ out = self.dropout(self.projection(out))
102
+ return out
103
+
104
+ class UnMaskedAttention(nn.Module):
105
+ def __init__(self,
106
+ d_model: int,
107
+ block_size: int,
108
+ n_head : int,
109
+ dropout: float):
110
+ head_size = d_model // n_head
111
+ super().__init__()
112
+ self.heads = nn.ModuleList([UnMaskedHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])
113
+ self.projection = nn.Linear(d_model, d_model)
114
+ self.dropout = nn.Dropout(dropout)
115
+
116
+ def forward(self, x: torch.Tensor):
117
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
118
+ out = self.dropout(self.projection(out))
119
+ return out
120
+
121
+ class FeedForward(nn.Module):
122
+ def __init__(self, d_model, dropout):
123
+ super().__init__()
124
+ self.net = nn.Sequential(
125
+ nn.Linear(d_model, 5 * d_model),
126
+ nn.GELU(),
127
+ nn.Linear(5 * d_model, 5 * d_model),
128
+ nn.Dropout(dropout),
129
+ nn.GELU(),
130
+ nn.Linear(5 * d_model, d_model),
131
+ nn.Dropout(dropout),
132
+ )
133
+
134
+ def forward(self, x: torch.Tensor):
135
+ return self.net(x)
136
+
137
+ class DecoderBlock(nn.Module):
138
+ def __init__(self, d_model: int,
139
+ block_size: int,
140
+ n_head: int,
141
+ norm_eps: float,
142
+ dropout: float):
143
+ super().__init__()
144
+ self.m_att = MaskedAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
145
+ self.um_att = UnMaskedAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
146
+ self.ffwd = FeedForward(d_model, dropout)
147
+ self.dropout = nn.Dropout(dropout)
148
+ self.norm = RMSNorm(d_model, eps=norm_eps)
149
+
150
+ def forward(self, x: torch.Tensor):
151
+ x_out = self.m_att(self.norm(x))
152
+ x_out = x + self.dropout(x_out)
153
+ del x
154
+
155
+ x = self.um_att(self.norm(x_out))
156
+ x = x_out + self.dropout(x)
157
+ del x_out
158
+
159
+ x_out = self.ffwd(self.norm(x))
160
+ x_out = x + self.dropout(x_out)
161
+ del x
162
+
163
+ return x_out
164
+
165
+ class Transformer(nn.Module):
166
+ def __init__(self, vocab_size: int):
167
+ super().__init__()
168
+ self.block_size = block_size
169
+ self.token_embeddings = nn.Embedding(vocab_size, d_model)
170
+ self.pos_encodings = nn.Embedding(block_size, d_model)
171
+ self.decoder = nn.Sequential(*[DecoderBlock(n_head=n_head, d_model=d_model, dropout=dropout, norm_eps=norm_eps, block_size=block_size) for _ in range(n_layers)])
172
+ self.norm_final = RMSNorm(d_model, eps=norm_eps)
173
+ self.linear_final = nn.Linear(d_model, vocab_size)
174
+ self.dropout = nn.Dropout(dropout)
175
+ self.apply(self._init_weights)
176
+
177
+ def _init_weights(self, module):
178
+ if isinstance(module, nn.Linear):
179
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
180
+ if module.bias is not None:
181
+ torch.nn.init.zeros_(module.bias.data)
182
+ elif isinstance(module, nn.Embedding):
183
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
184
+
185
+ def forward(self, idx, targets=None):
186
+ B, T = idx.shape
187
+ toked_model = self.token_embeddings(idx)
188
+ pos_encod = self.pos_encodings(torch.arange(T, device=device))
189
+ x = toked_model + pos_encod
190
+
191
+ x = self.decoder(x)
192
+ logits = self.linear_final(self.norm_final(x))
193
+
194
+ if targets is None:
195
+ loss = None
196
+
197
+ else:
198
+ B, T, C = logits.shape
199
+ logits = logits.view(B*T, C)
200
+ targets = targets.view(B*T)
201
+ loss = F.cross_entropy(logits, targets)
202
+
203
+ return logits, loss
204
+
205
+ def generate(self, idx: torch.Tensor, max_token: int=10):
206
+ for _ in range(max_token):
207
+ idx_cond = idx[:, -self.block_size:]
208
+ logits = self(idx_cond)
209
+ logits = logits[:, -1, :]
210
+ probs = F.softmax(logits, dim=-1)
211
+ idx_next = torch.argmax(probs, dim=-1)
212
+ idx = torch.cat((idx, idx_next), dim=1)
213
+ return idx
base/generate.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ current_directory = os.path.dirname(os.path.abspath(__file__))
3
+ os.chdir(current_directory)
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
+
10
+ from tokenizer import Tokenizer
11
+ tokenizer = Tokenizer()
12
+ vocab_size = tokenizer.get_vocab()
13
+
14
+ from model import Transformer
15
+ model = Transformer(vocab_size)
16
+ checkpoint_path = '/content/drive/MyDrive/base-500m.pth'
17
+ checkpoint = torch.load(checkpoint_path)
18
+ model.load_state_dict(checkpoint)
19
+ m = model.to(device)
20
+
21
+ class Generate:
22
+ def __init__(self):
23
+ self.vocab_size = vocab_size
24
+ self.block_size = m.block_size
25
+
26
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):
27
+ """
28
+ generate new tokens using the trained model
29
+
30
+ Args:
31
+ - idx (Tensor): input tensor representing initial token indices
32
+ - max_new_tokens (int): max no of new tokens to generate
33
+ - temperature (float): softmax temperature for sampling
34
+ - top_k (int): no of top tokens to consider in sampling
35
+
36
+ Returns:
37
+ - generated_tokens (list): list of generated token indices
38
+ """
39
+ generated_tokens = []
40
+
41
+ for _ in range(max_new_tokens):
42
+ idx_cond = idx[:, -m.block_size:]
43
+ logits, _ = self(idx_cond)
44
+ logits = logits[:, -1, :]
45
+
46
+ scaled_logits = logits / temperature
47
+ if top_k > 0:
48
+ scaled_logits = self._top_k_filtering(scaled_logits, top_k)
49
+
50
+ probs = F.softmax(scaled_logits, dim=-1)
51
+ sampled_idx = torch.multinomial(probs, num_samples=1)
52
+ generated_tokens.append(sampled_idx.item())
53
+ idx = torch.cat((idx, sampled_idx), dim=1)
54
+
55
+ return generated_tokens
56
+
57
+ def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):
58
+ """
59
+ Generate predictions for masked tokens using the trained model.
60
+
61
+ Args:
62
+ - idx (Tensor): input tensor representing token indices
63
+ - masked_indices (Tensor): tensor of indices indicating masked positions
64
+ - temperature (float): softmax temperature for sampling
65
+ - top_k (int): no of top tokens to consider in sampling
66
+
67
+ Returns:
68
+ - predicted_tokens (Tensor): tensor of predicted token indices
69
+ """
70
+ B, T = idx.shape
71
+
72
+ toked_model = m.toked_model(idx)
73
+ pos_encod = m.pos_encod(torch.arange(T, device=device))
74
+ x = toked_model + pos_encod
75
+
76
+ for layer in m.enc_layer:
77
+ x_out = layer(x)
78
+
79
+ for layer in m.dec_layer:
80
+ x_final = layer(x, x_out)
81
+
82
+ x_masked = x_final.clone()
83
+ x_masked[masked_indices] = m.toked_model(torch.tensor([6], device=device))
84
+
85
+ x_masked = m.norm_final(x_masked)
86
+ logits = m.linear_final(x_masked)
87
+
88
+ masked_logits = logits[masked_indices].view(-1, logits.size(-1))
89
+ scaled_logits = masked_logits / temperature
90
+ if top_k > 0:
91
+ scaled_logits = self._top_k_filtering(scaled_logits, top_k)
92
+
93
+ probs = F.softmax(scaled_logits, dim=-1)
94
+ predicted_indices = torch.argmax(probs, dim=-1)
95
+
96
+ return predicted_indices
97
+
98
+ def _top_k_filtering(self, logits, top_k):
99
+ """
100
+ filter logits to keep only the top-k tokens
101
+
102
+ Args:
103
+ - logits (Tensor): input tensor representing unscaled logits
104
+ - top_k (int): no of top tokens to keep
105
+
106
+ Returns:
107
+ - filtered_logits (Tensor): filtered logits with only top-k tokens remaining
108
+ """
109
+ values, indices = torch.topk(logits, top_k, dim=-1)
110
+ min_value = values[:, -1].unsqueeze(-1).expand_as(logits)
111
+ filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)
112
+
113
+ return filtered_logits
114
+
115
+ generator = Generate()
116
+
117
+ target_text = "I was in the market when"
118
+ context = torch.tensor([tokenizer.encode(target_text)], dtype=torch.long, device=device)
119
+ generated_output = tokenizer.decode(generator.generate(context, max_new_tokens=50))
120
+ print(target_text, generated_output)
base/model.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ with open('config.json', 'r', encoding='utf-8') as file:
3
+ params = json.load(file)
4
+
5
+ # required parameters
6
+ block_size = params['block_size']
7
+ d_model = params['d_model']
8
+ n_head = params['n_heads']
9
+ n_layers = params['n_layers']
10
+ learning_rate = params['learning_rate']
11
+ dropout = params['dropout']
12
+ norm_eps = params['norm_eps']
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+
18
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+
20
+ class RMSNorm(nn.Module):
21
+ def __init__(self, dim: int, eps: float = 1e-6):
22
+ """
23
+ Initialize the RMSNorm normalization layer.
24
+ Args:
25
+ dim (int): The dimension of the input tensor.
26
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
27
+ Attributes:
28
+ eps (float): A small value added to the denominator for numerical stability.
29
+ weight (nn.Parameter): Learnable scaling parameter.
30
+ """
31
+ super().__init__()
32
+ self.eps = eps
33
+ self.weight = nn.Parameter(torch.ones(dim))
34
+
35
+ def _norm(self, x):
36
+ """
37
+ Apply the RMSNorm normalization to the input tensor.
38
+ Args:
39
+ x (torch.Tensor): The input tensor.
40
+ Returns:
41
+ torch.Tensor: The normalized tensor.
42
+ """
43
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
44
+
45
+ def forward(self, x):
46
+ """
47
+ Forward pass through the RMSNorm layer.
48
+ Args:
49
+ x (torch.Tensor): The input tensor.
50
+ Returns:
51
+ torch.Tensor: The output tensor after applying RMSNorm.
52
+ """
53
+ output = self._norm(x.float()).type_as(x)
54
+ return output * self.weight
55
+
56
+ class UnMaskedHead(nn.Module):
57
+ def __init__(self, head_size, d_model, block_size, dropout):
58
+ super().__init__()
59
+ self.key = nn.Linear(d_model, head_size, bias=True)
60
+ self.query = nn.Linear(d_model, head_size, bias=True)
61
+ self.value = nn.Linear(d_model, head_size, bias=True)
62
+ self.dropout = nn.Dropout(dropout)
63
+ self.rel_pos_embd = nn.Parameter(torch.randn(block_size, block_size, head_size))
64
+
65
+ def forward(self, x):
66
+ B, T, C = x.shape
67
+ key = self.key(x)
68
+ query = self.query(x)
69
+
70
+ scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)
71
+ rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_embd[:T, :T])
72
+ scores = scores + rel_pos_scores
73
+
74
+ att_mat = F.softmax(scores, dim=-1)
75
+ att_mat = self.dropout(att_mat)
76
+ value = self.value(x)
77
+ output = torch.matmul(att_mat, value)
78
+ return output
79
+
80
+ class UnMaskedAttention(nn.Module):
81
+ def __init__(self, d_model, block_size, dropout, n_head):
82
+ head_size = d_model // n_head
83
+ super().__init__()
84
+ self.heads = nn.ModuleList([UnMaskedHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])
85
+ self.proj = nn.Linear(n_head * head_size, d_model)
86
+ self.dropout = nn.Dropout(dropout)
87
+
88
+ def forward(self, x):
89
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
90
+ out = self.dropout(self.proj(out))
91
+ return out
92
+
93
+ class MaskedHead(nn.Module):
94
+ def __init__(self, d_model, head_size, dropout, block_size):
95
+ super().__init__()
96
+ self.key = nn.Linear(d_model, head_size, bias=False)
97
+ self.query = nn.Linear(d_model, head_size, bias=False)
98
+ self.value = nn.Linear(d_model, head_size, bias=False)
99
+ self.dropout = nn.Dropout(dropout)
100
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
101
+
102
+ def forward(self, x):
103
+ B, T, C = x.shape
104
+ key = self.key(x)
105
+ query = self.query(x)
106
+
107
+ scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)
108
+ scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
109
+
110
+ att_mat = F.softmax(scores, dim=-1)
111
+ att_mat = self.dropout(att_mat)
112
+ value = self.value(x)
113
+ output = torch.matmul(att_mat, value)
114
+ return output
115
+
116
+ class CasualMaskedAttention(nn.Module):
117
+ def __init__(self, d_model, block_size, dropout, n_head):
118
+ head_size = d_model // n_head
119
+ super().__init__()
120
+ self.heads = nn.ModuleList([MaskedHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])
121
+ self.proj = nn.Linear(n_head * head_size, d_model)
122
+ self.dropout = nn.Dropout(dropout)
123
+
124
+ def forward(self, x):
125
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
126
+ out = self.dropout(self.proj(out))
127
+ return out
128
+
129
+ class FinalHead(nn.Module):
130
+ def __init__(self, d_model, head_size, dropout, block_size):
131
+ super().__init__()
132
+ self.key = nn.Linear(d_model, head_size, bias=True)
133
+ self.query = nn.Linear(d_model, head_size, bias=True)
134
+ self.value = nn.Linear(d_model, head_size, bias=True)
135
+ self.dropout = nn.Dropout(dropout)
136
+
137
+ def forward(self, x, att):
138
+ B, T, C = x.shape
139
+ key = self.key(att)
140
+ query = self.query(att)
141
+
142
+ scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)
143
+
144
+ att_mat = F.softmax(scores, dim=-1)
145
+ att_mat = self.dropout(att_mat)
146
+ value = self.value(x)
147
+ output = torch.matmul(att_mat, value)
148
+ return output
149
+
150
+ class FinalAttention(nn.Module):
151
+ def __init__(self, d_model, block_size, dropout, n_head):
152
+ head_size = d_model // n_head
153
+ super().__init__()
154
+ self.heads = nn.ModuleList([FinalHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])
155
+ self.proj = nn.Linear(n_head * head_size, d_model)
156
+ self.dropout = nn.Dropout(dropout)
157
+
158
+ def forward(self, x, att):
159
+ out = torch.cat([h(x, att) for h in self.heads], dim=-1)
160
+ out = self.dropout(self.proj(out))
161
+ return out
162
+
163
+ class FeedForward(nn.Module):
164
+ def __init__(self, d_model, dropout):
165
+ super().__init__()
166
+ self.net = nn.Sequential(
167
+ nn.Linear(d_model, 4*d_model),
168
+ nn.GELU(),
169
+ nn.Linear(4*d_model, d_model),
170
+ nn.Dropout(dropout)
171
+ )
172
+
173
+ def forward(self, x):
174
+ return self.net(x)
175
+
176
+ class EncoderNetwork(nn.Module):
177
+ def __init__(self, d_model, n_head, norm_eps, dropout, block_size):
178
+ super().__init__()
179
+ self.s_att = UnMaskedAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
180
+ self.ffwd = FeedForward(d_model, dropout)
181
+ self.dropout = nn.Dropout(dropout)
182
+ self.norm = RMSNorm(d_model, eps=norm_eps)
183
+
184
+ def forward(self, src):
185
+ src = self.norm(src)
186
+ src_out = src + self.dropout(self.s_att(src))
187
+
188
+ src = self.norm(src_out)
189
+ src_f = src + self.dropout(self.ffwd(src))
190
+
191
+ del src_out, src
192
+ return src_f
193
+
194
+ class DecoderNetwork(nn.Module):
195
+ def __init__(self, d_model, n_head, norm_eps, dropout, block_size):
196
+ super().__init__()
197
+ self.m_att = CasualMaskedAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
198
+ self.f_att = FinalAttention(d_model=d_model, n_head=n_head, dropout=dropout, block_size=block_size)
199
+ self.ffwd = FeedForward(d_model, dropout)
200
+ self.dropout = nn.Dropout(dropout)
201
+ self.norm = RMSNorm(d_model, eps=norm_eps)
202
+
203
+ def forward(self, src, att):
204
+ m_att_out = self.norm(src)
205
+ m_out = src + self.dropout(self.m_att(m_att_out))
206
+
207
+ f_out = self.f_att(m_out, self.norm(att))
208
+ f_out = m_out + self.dropout(f_out)
209
+
210
+ src_f = self.norm(f_out)
211
+ src_f = f_out + self.dropout(self.ffwd(src_f))
212
+
213
+ del f_out, m_out, m_att_out, src, att
214
+ return src_f
215
+
216
+ class Transformer(nn.Module):
217
+ def __init__(self, vocab_size):
218
+ super().__init__()
219
+ self.block_size = block_size
220
+ self.toked_model = nn.Embedding(vocab_size, d_model)
221
+ self.pos_encod = nn.Embedding(block_size, d_model)
222
+ self.enc_layer = nn.ModuleList([EncoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])
223
+ self.dec_layer = nn.ModuleList([DecoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])
224
+ self.norm_final = RMSNorm(d_model, eps=norm_eps)
225
+ self.linear_final = nn.Linear(d_model, vocab_size)
226
+ self.dropout = nn.Dropout(dropout)
227
+ self.apply(self._init_weights)
228
+
229
+ def _init_weights(self, module):
230
+ """
231
+ initialize weights of linear and embedding layers
232
+
233
+ Args:
234
+ - module (nn.Module): the module to initialize weights for
235
+ """
236
+ if isinstance(module, nn.Linear):
237
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
238
+ if module.bias is not None:
239
+ torch.nn.init.zeros_(module.bias.data)
240
+ elif isinstance(module, nn.Embedding):
241
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
242
+
243
+ def forward(self, idx, targets=None):
244
+ """
245
+ forward pass of the transformer model
246
+
247
+ Args:
248
+ - idx (Tensor): input tensor representing token indices
249
+ - targets (Tensor): target tensor for computing loss during training
250
+
251
+ Returns:
252
+ - logits (Tensor): output logits from the final linear layer
253
+ - loss (Tensor): optional. computed cross-entropy loss if targets are provided, else None
254
+ """
255
+ B, T = idx.shape
256
+
257
+ toked_model = self.toked_model(idx)
258
+ pos_encod = self.pos_encod(torch.arange(T, device=device))
259
+ x = toked_model + pos_encod
260
+
261
+ for layer in self.enc_layer:
262
+ x_out = layer(x)
263
+
264
+ for layer in self.dec_layer:
265
+ x_final = layer(x, x_out)
266
+
267
+ x_final = self.norm_final(x_final)
268
+ logits = self.linear_final(x_final)
269
+
270
+ if targets is None:
271
+ loss = None
272
+
273
+ else:
274
+ B, T, C = logits.shape
275
+ logits = logits.view(B*T, C)
276
+ targets = targets.view(B*T)
277
+ loss = F.cross_entropy(logits, targets)
278
+
279
+ return logits, loss
base/run.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ use this file to train the model
3
+
4
+ working:
5
+ - imports vatious dependencies first, and then loads the training data
6
+ - tokenizes it, per-character basis
7
+ - loads the required hyper-parameters and the model file
8
+ - trains it till 'max_iters' and saves the model state, and generates outputs
9
+
10
+ with the current set configuration, model can reach upto ~60million parameters
11
+ and can become ~99% accurate with next token prediction
12
+ """
13
+
14
+ import torch
15
+ import json
16
+ import os
17
+ current_directory = os.path.dirname(os.path.abspath(__file__))
18
+ os.chdir(current_directory)
19
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+
21
+ with open('../datasets/wiki_176m.txt', 'r', encoding='utf-8') as file:
22
+ data = file.read()
23
+
24
+ print(f"{(len(data)/1e6):.2f} million letters")
25
+
26
+ from tokenizer import Tokenizer
27
+
28
+ tokenizer = Tokenizer()
29
+ vocab_size = tokenizer.get_vocab()
30
+
31
+ # Train and test splits
32
+ data = torch.tensor(tokenizer.encode(data), dtype=torch.long)
33
+ n = int(0.9*len(data)) # first 90% will be train, rest val
34
+ train_data = data[:n]
35
+ val_data = data[n:]
36
+
37
+ with open('config.json', 'r', encoding='utf-8') as file:
38
+ params = json.load(file)
39
+
40
+ # required parameters
41
+ batch_size = params['batch_size']
42
+ block_size = params['block_size']
43
+ max_iters = 1000
44
+ eval_interval = 100
45
+ eval_iters = 200
46
+ learning_rate = params['learning_rate']
47
+
48
+ torch.manual_seed(1400)
49
+ # data loading
50
+ def get_batch(split):
51
+ # generate a small batch of data of inputs x and targets y
52
+ data = train_data if split == 'train' else val_data
53
+ ix = torch.randint(len(data) - block_size, (batch_size,))
54
+ x = torch.stack([data[i:i+block_size] for i in ix])
55
+ y = torch.stack([data[i+1:i+block_size+1] for i in ix])
56
+ x, y = x.to(device), y.to(device)
57
+ return x, y
58
+
59
+ @torch.no_grad()
60
+ def estimate_loss():
61
+ out = {}
62
+ model.eval()
63
+ for split in ['train', 'val']:
64
+ losses = torch.zeros(eval_iters)
65
+ for k in range(eval_iters):
66
+ X, Y = get_batch(split)
67
+ logits, loss = model(X, Y)
68
+ losses[k] = loss.item()
69
+ out[split] = losses.mean()
70
+ model.train()
71
+ return out
72
+
73
+ from model import Transformer
74
+ model = Transformer(vocab_size)
75
+ m = model.to(device)
76
+
77
+ # no of parameters
78
+ n_param = sum(p.numel() for p in m.parameters())/1e6
79
+ print(f"{n_param:.2f} million")
80
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
81
+ steps = []
82
+ train_losses = []
83
+ val_losses = []
84
+
85
+ for iter in range(max_iters):
86
+
87
+ if iter % eval_interval == 0 or iter == max_iters - 1:
88
+ losses = estimate_loss()
89
+ print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
90
+
91
+ steps.append(iter)
92
+ train_losses.append(losses['train'])
93
+ val_losses.append(losses['val'])
94
+
95
+ xb, yb = get_batch('train')
96
+ logits, loss = model(xb, yb)
97
+ optimizer.zero_grad(set_to_none=True)
98
+ loss.backward()
99
+ optimizer.step()
100
+
101
+ torch.save(model.state_dict(), f'enigma_{n_param:.0f}m.pth')
base/tokenizer.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+
3
+ pre_encodings = 'p50k_base'
4
+ pre_model = 'text-davinci-003'
5
+ class Tokenizer:
6
+ def __init__(self, encoding=None, model=None):
7
+ self.encodings = encoding if encoding is not None else pre_encodings
8
+ self.model = model if model is not None else pre_model
9
+ self.tokenizer = tiktoken.get_encoding(self.encodings)
10
+ self.tokenizer = tiktoken.encoding_for_model(self.model)
11
+
12
+ def encode(self, data):
13
+ return self.tokenizer.encode(data)
14
+
15
+ def decode(self, tokens):
16
+ return self.tokenizer.decode(tokens)
17
+
18
+ def get_vocab(self):
19
+ return self.tokenizer.n_vocab