RobbiePasquale commited on
Commit
fa2c349
1 Parent(s): c453493

Upload lightbulb_lm.py

Browse files
Files changed (1) hide show
  1. lightbulb_lm.py +517 -0
lightbulb_lm.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import argparse
4
+ import math
5
+ import os
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.optim as optim
11
+ from torch.utils.data import DataLoader
12
+
13
+ from torch.optim.lr_scheduler import CosineAnnealingLR
14
+ from torch.amp import autocast, GradScaler
15
+ from datasets import load_dataset
16
+ from transformers import AutoTokenizer
17
+
18
+ # Set the device
19
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
+
21
+
22
+ def parse_args():
23
+ parser = argparse.ArgumentParser(description='Train Transformer model with advanced features.')
24
+ parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path')
25
+ parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets')
26
+ parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
27
+ parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
28
+ parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs')
29
+ parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
30
+ parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
31
+ parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
32
+ parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
33
+ parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight')
34
+ parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight')
35
+ parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
36
+ parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models')
37
+ parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance')
38
+ args = parser.parse_args()
39
+ return args
40
+
41
+
42
+ def load_data(args, tokenizer):
43
+ # Load the dataset
44
+ dataset = load_dataset(args.dataset_name, args.dataset_config)
45
+
46
+ # Ensure the tokenizer has a padding token
47
+ if tokenizer.pad_token is None:
48
+ tokenizer.pad_token = tokenizer.eos_token
49
+
50
+ def tokenize_function(examples):
51
+ return tokenizer(examples['text'], truncation=True, max_length=args.max_length)
52
+
53
+ tokenized_datasets = dataset.map(
54
+ tokenize_function,
55
+ batched=True,
56
+ num_proc=4,
57
+ remove_columns=dataset['train'].column_names,
58
+ )
59
+
60
+ # Build inputs and labels for language modeling
61
+ block_size = args.max_length
62
+
63
+ def group_texts(examples):
64
+ # Concatenate all texts
65
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
66
+ total_length = len(concatenated_examples['input_ids'])
67
+ # We drop the small remainder
68
+ total_length = (total_length // block_size) * block_size
69
+ # Split by chunks of block_size
70
+ result = {
71
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
72
+ for k, t in concatenated_examples.items()
73
+ }
74
+ result['labels'] = result['input_ids'].copy()
75
+ return result
76
+
77
+ lm_datasets = tokenized_datasets.map(
78
+ group_texts,
79
+ batched=True,
80
+ num_proc=4,
81
+ )
82
+
83
+ # Create DataLoader
84
+ train_dataset = lm_datasets['train']
85
+ eval_dataset = lm_datasets['validation'] if 'validation' in lm_datasets else lm_datasets['test']
86
+
87
+ data_collator = lambda data: {
88
+ 'input_ids': torch.tensor([f['input_ids'] for f in data], dtype=torch.long),
89
+ 'labels': torch.tensor([f['labels'] for f in data], dtype=torch.long)
90
+ }
91
+
92
+ train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, collate_fn=data_collator)
93
+ eval_loader = DataLoader(eval_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=data_collator)
94
+
95
+ return train_loader, eval_loader
96
+
97
+
98
+ class RotaryPositionalEncoding(nn.Module):
99
+ def __init__(self, d_model):
100
+ super(RotaryPositionalEncoding, self).__init__()
101
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
102
+ self.register_buffer('inv_freq', inv_freq)
103
+
104
+ def forward(self, x):
105
+ seq_len, batch_size, _ = x.size()
106
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
107
+ sinusoid_inp = torch.einsum("i,j->ij", t, self.inv_freq)
108
+ sin = sinusoid_inp.sin().unsqueeze(1) # (seq_len, 1, d_model/2)
109
+ cos = sinusoid_inp.cos().unsqueeze(1) # (seq_len, 1, d_model/2)
110
+
111
+ x1 = x[..., 0::2]
112
+ x2 = x[..., 1::2]
113
+
114
+ # Apply rotation
115
+ x_rotated = torch.zeros_like(x)
116
+ x_rotated[..., 0::2] = x1 * cos - x2 * sin
117
+ x_rotated[..., 1::2] = x1 * sin + x2 * cos
118
+
119
+ return x_rotated
120
+
121
+
122
+ class MultiHeadAttention(nn.Module):
123
+ def __init__(self, d_model, num_heads):
124
+ super(MultiHeadAttention, self).__init__()
125
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
126
+ self.d_k = d_model // num_heads
127
+ self.num_heads = num_heads
128
+ self.linear_q = nn.Linear(d_model, d_model)
129
+ self.linear_k = nn.Linear(d_model, d_model)
130
+ self.linear_v = nn.Linear(d_model, d_model)
131
+ self.linear_out = nn.Linear(d_model, d_model)
132
+
133
+ def forward(self, query, key, value, mask=None):
134
+ batch_size = query.size(0)
135
+ query = self.linear_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
136
+ key = self.linear_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
137
+ value = self.linear_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
138
+
139
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)
140
+ if mask is not None:
141
+ scores = scores.masked_fill(mask == 0, -1e9)
142
+ attn = F.softmax(scores, dim=-1)
143
+ output = torch.matmul(attn, value)
144
+
145
+ output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
146
+ return self.linear_out(output)
147
+
148
+
149
+ class MoE(nn.Module):
150
+ def __init__(self, d_model, num_experts, d_ff, top_k=2, dropout=0.1):
151
+ super(MoE, self).__init__()
152
+ self.num_experts = num_experts
153
+ self.top_k = top_k
154
+ self.experts = nn.ModuleList([
155
+ nn.Sequential(
156
+ nn.Linear(d_model, d_ff),
157
+ nn.GELU() if i % 2 == 0 else nn.SiLU(),
158
+ nn.Linear(d_ff, d_model)
159
+ )
160
+ for i in range(num_experts)
161
+ ])
162
+ self.gate = nn.Linear(d_model, num_experts)
163
+ self.dropout = nn.Dropout(dropout)
164
+
165
+ def forward(self, x):
166
+ batch_size, seq_len, d_model = x.size()
167
+ # Compute gating scores
168
+ gate_scores = self.gate(x) # (batch_size, seq_len, num_experts)
169
+ top_k_scores, top_k_indices = torch.topk(gate_scores, self.top_k, dim=-1) # (batch_size, seq_len, top_k)
170
+ top_k_scores = F.softmax(top_k_scores, dim=-1) # (batch_size, seq_len, top_k)
171
+
172
+ # Initialize output
173
+ output = torch.zeros_like(x)
174
+
175
+ # Flatten batch and sequence dimensions
176
+ x_flat = x.view(-1, d_model) # (batch_size * seq_len, d_model)
177
+ output_flat = output.view(-1, d_model)
178
+ top_k_indices_flat = top_k_indices.view(-1, self.top_k) # (batch_size * seq_len, top_k)
179
+ top_k_scores_flat = top_k_scores.view(-1, self.top_k) # (batch_size * seq_len, top_k)
180
+
181
+ for k in range(self.top_k):
182
+ expert_idx_flat = top_k_indices_flat[:, k] # (batch_size * seq_len)
183
+ expert_scores_flat = top_k_scores_flat[:, k] # (batch_size * seq_len)
184
+ for e in range(self.num_experts):
185
+ mask = (expert_idx_flat == e) # Boolean mask
186
+ if mask.any():
187
+ x_masked = x_flat[mask] # Select tokens for expert e
188
+ expert_output = self.experts[e](x_masked) # Apply expert e
189
+ output_flat[mask] += expert_scores_flat[mask].unsqueeze(-1) * expert_output
190
+
191
+ output = output_flat.view(batch_size, seq_len, d_model)
192
+ return self.dropout(output)
193
+
194
+
195
+ class TransformerBlock(nn.Module):
196
+ def __init__(self, d_model, num_heads, d_ff, num_experts, dropout=0.1, top_k=2):
197
+ super(TransformerBlock, self).__init__()
198
+ self.self_attention = MultiHeadAttention(d_model, num_heads)
199
+ self.norm1 = nn.LayerNorm(d_model)
200
+ self.cross_attention = MultiHeadAttention(d_model, num_heads)
201
+ self.norm2 = nn.LayerNorm(d_model)
202
+ self.moe = MoE(d_model, num_experts, d_ff, top_k, dropout)
203
+ self.norm3 = nn.LayerNorm(d_model)
204
+
205
+ def forward(self, x, mask=None, enc_output=None, enc_mask=None):
206
+ # Self-attention
207
+ attn_output = self.self_attention(x, x, x, mask)
208
+ x = self.norm1(x + attn_output)
209
+ # Cross-attention (only in decoder)
210
+ if enc_output is not None:
211
+ cross_attn_output = self.cross_attention(x, enc_output, enc_output, enc_mask)
212
+ x = self.norm2(x + cross_attn_output)
213
+ # Feedforward/MoE
214
+ moe_output = self.moe(x)
215
+ return self.norm3(x + moe_output)
216
+
217
+
218
+ class Transformer(nn.Module):
219
+ def __init__(self, input_dim, d_model, num_heads, num_layers, d_ff, num_experts, output_dim, dropout=0.1, top_k=2):
220
+ super(Transformer, self).__init__()
221
+ self.embedding = nn.Embedding(input_dim, d_model, padding_idx=input_dim - 1)
222
+ self.rotary_positional_encoding = RotaryPositionalEncoding(d_model)
223
+ self.encoder_layers = nn.ModuleList(
224
+ [TransformerBlock(d_model, num_heads, d_ff, num_experts, dropout, top_k) for _ in range(num_layers)]
225
+ )
226
+ self.decoder_layers = nn.ModuleList(
227
+ [TransformerBlock(d_model, num_heads, d_ff, num_experts, dropout, top_k) for _ in range(num_layers)]
228
+ )
229
+ self.output_layer = nn.Linear(d_model, output_dim)
230
+ self.d_model = d_model
231
+
232
+ def forward(self, src, tgt, src_mask=None, tgt_mask=None):
233
+ # Encoder
234
+ src = self.embedding(src) * math.sqrt(self.d_model)
235
+ src = src.transpose(0, 1) # (batch_size, seq_len, d_model) -> (seq_len, batch_size, d_model)
236
+ src = self.rotary_positional_encoding(src)
237
+ src = src.transpose(0, 1) # (seq_len, batch_size, d_model) -> (batch_size, seq_len, d_model)
238
+ for layer in self.encoder_layers:
239
+ src = layer(src, src_mask)
240
+
241
+ # Decoder
242
+ tgt = self.embedding(tgt) * math.sqrt(self.d_model)
243
+ tgt = tgt.transpose(0, 1)
244
+ tgt = self.rotary_positional_encoding(tgt)
245
+ tgt = tgt.transpose(0, 1)
246
+ for layer in self.decoder_layers:
247
+ tgt = layer(tgt, tgt_mask, src, src_mask)
248
+ output = self.output_layer(tgt)
249
+ return output
250
+
251
+ def generate(self, src, tokenizer, max_length=20, temperature=1.0):
252
+ """
253
+ Generate sequences using differentiable sampling (Gumbel-Softmax).
254
+
255
+ Args:
256
+ src (torch.Tensor): Source input tensor of shape (batch_size, seq_len)
257
+ tokenizer (transformers.PreTrainedTokenizer): Tokenizer to access special tokens
258
+ max_length (int): Maximum length of the generated sequence
259
+ temperature (float): Temperature parameter for Gumbel-Softmax
260
+
261
+ Returns:
262
+ torch.Tensor: Generated sequences of shape (batch_size, max_length)
263
+ torch.Tensor: Entropy values for each time step
264
+ torch.Tensor: Variance values for each time step
265
+ """
266
+ batch_size = src.size(0)
267
+
268
+ # Encode the source
269
+ src_enc = self.embedding(src) * math.sqrt(self.d_model)
270
+ src_enc = src_enc.transpose(0, 1)
271
+ src_enc = self.rotary_positional_encoding(src_enc)
272
+ src_enc = src_enc.transpose(0, 1)
273
+ for layer in self.encoder_layers:
274
+ src_enc = layer(src_enc)
275
+
276
+ # Initialize decoder input with <sos> tokens
277
+ tgt_seq = torch.full((batch_size, 1), tokenizer.bos_token_id, dtype=torch.long, device=src.device)
278
+ entropies = []
279
+ variances = []
280
+
281
+ for _ in range(max_length):
282
+ tgt_emb = self.embedding(tgt_seq) * math.sqrt(self.d_model)
283
+ tgt_emb = tgt_emb.transpose(0, 1)
284
+ tgt_emb = self.rotary_positional_encoding(tgt_emb)
285
+ tgt_emb = tgt_emb.transpose(0, 1)
286
+ tgt_dec = tgt_emb
287
+ for layer in self.decoder_layers:
288
+ tgt_dec = layer(tgt_dec, None, src_enc, None)
289
+ output = self.output_layer(tgt_dec) # (batch_size, seq_len, vocab_size)
290
+ logits = output[:, -1, :] # Get logits for the last time step
291
+
292
+ # Compute token probabilities
293
+ probs = F.softmax(logits / temperature, dim=-1) # (batch_size, vocab_size)
294
+
295
+ # Compute entropy
296
+ entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1) # (batch_size)
297
+ entropies.append(entropy)
298
+
299
+ # Sample token using Gumbel-Softmax
300
+ gumbel_noise = -torch.log(-torch.log(torch.rand_like(probs) + 1e-9) + 1e-9)
301
+ y = (logits + gumbel_noise) / temperature
302
+ y = F.softmax(y, dim=-1) # (batch_size, vocab_size)
303
+
304
+ # Compute variance
305
+ variance = torch.var(y, dim=-1) # (batch_size)
306
+ variances.append(variance)
307
+
308
+ # Get token indices (argmax for hard selection)
309
+ next_tokens = torch.argmax(y, dim=-1, keepdim=True) # (batch_size, 1)
310
+ tgt_seq = torch.cat([tgt_seq, next_tokens], dim=1)
311
+
312
+ # Stack entropies and variances
313
+ entropies = torch.stack(entropies, dim=1) # (batch_size, max_length)
314
+ variances = torch.stack(variances, dim=1) # (batch_size, max_length)
315
+
316
+ return tgt_seq[:, 1:], entropies, variances # Exclude the initial <sos> token
317
+
318
+
319
+ def compute_loss(output, target, padding_idx, alpha=0.1, beta=0.1, temperature=1.0):
320
+ """
321
+ Compute the loss with entropy and variance regularization.
322
+
323
+ Args:
324
+ output (torch.Tensor): Model output logits of shape (batch_size, seq_len, vocab_size)
325
+ target (torch.Tensor): Target sequences of shape (batch_size, seq_len)
326
+ padding_idx (int): Padding index to ignore in the loss
327
+ alpha (float): Weight for the entropy regularization term
328
+ beta (float): Weight for the variance regularization term
329
+ temperature (float): Temperature parameter for computing probabilities
330
+
331
+ Returns:
332
+ torch.Tensor: Scalar loss value
333
+ """
334
+ # Cross-entropy loss
335
+ output_flat = output.contiguous().view(-1, output.size(-1))
336
+ target_flat = target.contiguous().view(-1)
337
+ ce_loss = F.cross_entropy(
338
+ output_flat,
339
+ target_flat,
340
+ ignore_index=padding_idx
341
+ )
342
+
343
+ # Compute probabilities
344
+ probs = F.softmax(output / temperature, dim=-1) # (batch_size, seq_len, vocab_size)
345
+
346
+ # Compute entropy
347
+ entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1) # (batch_size, seq_len)
348
+ entropy_loss = -alpha * torch.mean(entropy)
349
+
350
+ # Compute variance
351
+ variance = torch.var(probs, dim=-1) # (batch_size, seq_len)
352
+ variance_loss = -beta * torch.mean(variance)
353
+
354
+ # Total loss
355
+ total_loss = ce_loss + entropy_loss + variance_loss
356
+ return total_loss
357
+
358
+
359
+ def train_epoch(model, train_loader, optimizer, scheduler, scaler, args, padding_idx):
360
+ model.train()
361
+ total_loss = 0.0
362
+ optimizer.zero_grad()
363
+ print(f"Starting training epoch with {len(train_loader)} batches...")
364
+ for i, batch in enumerate(train_loader):
365
+ print(f"Processing batch {i+1}/{len(train_loader)}...")
366
+ src_batch = batch['input_ids'].to(device)
367
+ tgt_batch = batch['labels'].to(device)
368
+
369
+ with autocast(device_type='cuda'):
370
+ print("Forward pass...")
371
+ output = model(src_batch, tgt_batch[:, :-1])
372
+ print("Computing loss...")
373
+ loss = compute_loss(
374
+ output,
375
+ tgt_batch[:, 1:],
376
+ padding_idx,
377
+ alpha=args.alpha,
378
+ beta=args.beta,
379
+ temperature=args.temperature
380
+ )
381
+ loss = loss / args.accumulation_steps
382
+
383
+ print("Backward pass...")
384
+ scaler.scale(loss).backward()
385
+
386
+ if (i + 1) % args.accumulation_steps == 0:
387
+ print("Gradient clipping...")
388
+ scaler.unscale_(optimizer)
389
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
390
+
391
+ print("Optimizer step...")
392
+ scaler.step(optimizer)
393
+ scaler.update()
394
+
395
+ print("Zeroing gradients...")
396
+ optimizer.zero_grad()
397
+
398
+ print("Updating learning rate...")
399
+ scheduler.step()
400
+
401
+ total_loss += loss.item() * args.accumulation_steps
402
+ print(f"Batch {i+1} completed. Current loss: {loss.item():.4f}")
403
+
404
+ avg_loss = total_loss / len(train_loader)
405
+ print(f"Epoch completed. Average loss: {avg_loss:.4f}")
406
+ return avg_loss
407
+
408
+
409
+ def evaluate(model, eval_loader, args, padding_idx):
410
+ model.eval()
411
+ total_loss = 0.0
412
+ with torch.no_grad():
413
+ for batch in eval_loader:
414
+ src_batch = batch['input_ids'].to(device)
415
+ tgt_batch = batch['labels'].to(device)
416
+
417
+ with autocast(device_type='cuda'):
418
+ # Forward pass
419
+ output = model(src_batch, tgt_batch[:, :-1])
420
+ # Compute loss
421
+ loss = compute_loss(
422
+ output,
423
+ tgt_batch[:, 1:],
424
+ padding_idx,
425
+ alpha=args.alpha,
426
+ beta=args.beta,
427
+ temperature=args.temperature
428
+ )
429
+
430
+ total_loss += loss.item()
431
+
432
+ avg_loss = total_loss / len(eval_loader)
433
+ return avg_loss
434
+
435
+
436
+ def main():
437
+ args = parse_args()
438
+ print("Arguments parsed successfully.")
439
+
440
+ # Create save directory
441
+ if not os.path.exists(args.save_dir):
442
+ os.makedirs(args.save_dir)
443
+ print(f"Save directory created: {args.save_dir}")
444
+
445
+ # Load tokenizer
446
+ print("Loading tokenizer...")
447
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
448
+ if tokenizer.pad_token is None:
449
+ tokenizer.pad_token = tokenizer.eos_token
450
+ print("Tokenizer loaded successfully.")
451
+
452
+ # Load data
453
+ print("Loading and preprocessing data...")
454
+ train_loader, eval_loader = load_data(args, tokenizer)
455
+ print("Data loaded and preprocessed successfully.")
456
+
457
+ # Define model parameters
458
+ input_dim = len(tokenizer)
459
+ d_model = 512
460
+ num_heads = 8
461
+ num_layers = 6
462
+ d_ff = 2048
463
+ num_experts = 4
464
+ output_dim = input_dim
465
+ dropout = 0.1
466
+ top_k = 2
467
+
468
+ print("Initializing model...")
469
+ model = Transformer(
470
+ input_dim, d_model, num_heads, num_layers, d_ff, num_experts, output_dim, dropout, top_k
471
+ )
472
+ model = model.to(device)
473
+ print(f"Model initialized and moved to device: {device}")
474
+
475
+ padding_idx = tokenizer.pad_token_id
476
+
477
+ print("Setting up optimizer and scheduler...")
478
+ optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
479
+ scheduler = CosineAnnealingLR(optimizer, T_max=args.num_epochs)
480
+ scaler = GradScaler()
481
+ print("Optimizer and scheduler set up successfully.")
482
+
483
+ print("Starting training loop...")
484
+ for epoch in range(args.num_epochs):
485
+ print(f"Epoch {epoch + 1}/{args.num_epochs} started.")
486
+ avg_train_loss = train_epoch(
487
+ model,
488
+ train_loader,
489
+ optimizer,
490
+ scheduler,
491
+ scaler,
492
+ args,
493
+ padding_idx
494
+ )
495
+ print(f"Epoch {epoch + 1}/{args.num_epochs} training completed.")
496
+
497
+ print(f"Starting evaluation for epoch {epoch + 1}...")
498
+ avg_eval_loss = evaluate(model, eval_loader, args, padding_idx)
499
+ print(f"Evaluation for epoch {epoch + 1} completed.")
500
+
501
+ print(f"Epoch {epoch + 1}/{args.num_epochs}, Train Loss: {avg_train_loss:.4f}, Eval Loss: {avg_eval_loss:.4f}")
502
+
503
+ model_save_path = os.path.join(args.save_dir, f"model_epoch_{epoch + 1}.pt")
504
+ torch.save(model.state_dict(), model_save_path)
505
+ print(f"Model saved for epoch {epoch + 1}")
506
+
507
+ print("Training completed.")
508
+
509
+
510
+ if __name__ == '__main__':
511
+ main()
512
+
513
+
514
+ '''
515
+ Example usage:
516
+ python lightbulb.py --model_name gpt2 --dataset_name wikitext --dataset_config wikitext-2-raw-v1 --batch_size 8 --num_epochs 3
517
+ '''