RobbiePasquale commited on
Commit
a8090dd
1 Parent(s): 70782ac

Update distill.py

Browse files
Files changed (1) hide show
  1. distill.py +838 -264
distill.py CHANGED
@@ -1,264 +1,838 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.optim as optim
4
- from torch.utils.data import DataLoader, Dataset, random_split
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
6
- from datasets import load_dataset
7
- from typing import List, Optional
8
- import argparse
9
- import os
10
- import json
11
- import jsonlines
12
- from tqdm import tqdm
13
- from torch.cuda.amp import autocast, GradScaler
14
- from torch.utils.tensorboard import SummaryWriter
15
-
16
- # Set up device
17
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
-
19
- class CustomDataset(Dataset):
20
- def __init__(self, inputs, labels):
21
- self.inputs = inputs
22
- self.labels = labels
23
-
24
- def __len__(self):
25
- return len(self.inputs)
26
-
27
- def __getitem__(self, idx):
28
- return {'input_ids': self.inputs[idx], 'labels': self.labels[idx]}
29
-
30
- def load_filtered_dataset(dataset_name: str, config: str, queries: Optional[List[str]] = None):
31
- dataset = load_dataset(dataset_name, config)
32
- if queries:
33
- def filter_func(examples):
34
- return any(query.lower() in examples["text"].lower() for query in queries)
35
- dataset = dataset.filter(filter_func, batched=True)
36
- return dataset
37
-
38
- def prepare_data(tokenizer, dataset, max_length, batch_size):
39
- # Tokenize the inputs and labels
40
- tokenized_inputs = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length)
41
- tokenized_labels = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length)
42
-
43
- # Create custom dataset
44
- custom_dataset = CustomDataset(tokenized_inputs["input_ids"], tokenized_labels["input_ids"])
45
-
46
- # Split into training and validation sets
47
- train_size = int(0.9 * len(custom_dataset))
48
- val_size = len(custom_dataset) - train_size
49
- train_dataset, val_dataset = random_split(custom_dataset, [train_size, val_size])
50
-
51
- # Create DataLoaders
52
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
53
- val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
54
-
55
- return train_loader, val_loader
56
-
57
- def train_step(teacher, student, data_loader, optimizer, criterion, scaler, temperature=2.0):
58
- teacher.eval()
59
- student.train()
60
- total_loss = 0
61
-
62
- for batch in tqdm(data_loader, desc="Training"):
63
- inputs = batch["input_ids"].to(device)
64
- labels = batch["labels"].to(device)
65
-
66
- with autocast():
67
- with torch.no_grad():
68
- teacher_outputs = teacher(inputs).logits
69
- teacher_logits = teacher_outputs / temperature
70
-
71
- student_outputs = student(inputs).logits
72
- student_logits = student_outputs / temperature
73
-
74
- # Compute KL Divergence Loss
75
- loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1))
76
- loss = loss * (temperature ** 2) # Scale loss by temperature squared
77
-
78
- scaler.scale(loss).backward()
79
- scaler.step(optimizer)
80
- scaler.update()
81
- optimizer.zero_grad()
82
-
83
- total_loss += loss.item()
84
-
85
- avg_loss = total_loss / len(data_loader)
86
- return avg_loss
87
-
88
- def validate(teacher, student, data_loader, criterion, temperature=2.0):
89
- teacher.eval()
90
- student.eval()
91
- total_loss = 0
92
-
93
- with torch.no_grad():
94
- for batch in tqdm(data_loader, desc="Validation"):
95
- inputs = batch["input_ids"].to(device)
96
- labels = batch["labels"].to(device)
97
-
98
- teacher_outputs = teacher(inputs).logits
99
- teacher_logits = teacher_outputs / temperature
100
-
101
- student_outputs = student(inputs).logits
102
- student_logits = student_outputs / temperature
103
-
104
- loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1))
105
- loss = loss * (temperature ** 2)
106
-
107
- total_loss += loss.item()
108
-
109
- avg_loss = total_loss / len(data_loader)
110
- return avg_loss
111
-
112
- def save_checkpoint(state, save_dir, epoch):
113
- os.makedirs(save_dir, exist_ok=True)
114
- checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt')
115
- torch.save(state, checkpoint_path)
116
- print(f"Checkpoint saved at {checkpoint_path}")
117
-
118
- def load_checkpoint(model, optimizer, scheduler, scaler, save_dir, epoch):
119
- checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt')
120
- if os.path.isfile(checkpoint_path):
121
- checkpoint = torch.load(checkpoint_path)
122
- model.load_state_dict(checkpoint['model_state_dict'])
123
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
124
- scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
125
- scaler.load_state_dict(checkpoint['scaler_state_dict'])
126
- print(f"Loaded checkpoint from {checkpoint_path}")
127
- else:
128
- print(f"No checkpoint found at {checkpoint_path}")
129
-
130
- def distill_model(
131
- teacher_model_name: str,
132
- student_model_name: str,
133
- dataset_name: str,
134
- config: str,
135
- distill_full_model: bool = True,
136
- query_terms: Optional[List[str]] = None,
137
- num_epochs: int = 3,
138
- batch_size: int = 4,
139
- max_length: int = 128,
140
- learning_rate: float = 5e-5,
141
- temperature: float = 2.0,
142
- save_path: str = "./distilled_model",
143
- log_dir: str = "./logs",
144
- checkpoint_dir: str = "./checkpoints",
145
- early_stopping_patience: int = 3
146
- ):
147
- # Initialize TensorBoard writer
148
- writer = SummaryWriter(log_dir=log_dir)
149
-
150
- # Load tokenizer
151
- tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
152
- if tokenizer.pad_token is None:
153
- tokenizer.pad_token = tokenizer.eos_token
154
-
155
- # Load teacher and student models
156
- teacher = AutoModelForCausalLM.from_pretrained(teacher_model_name).to(device)
157
- student = AutoModelForCausalLM.from_pretrained(student_model_name).to(device)
158
-
159
- # Optionally freeze teacher model parameters
160
- for param in teacher.parameters():
161
- param.requires_grad = False
162
-
163
- # Load and prepare dataset
164
- if distill_full_model:
165
- dataset = load_dataset(dataset_name, config)
166
- else:
167
- dataset = load_filtered_dataset(dataset_name, config, query_terms)
168
-
169
- train_loader, val_loader = prepare_data(tokenizer, dataset, max_length, batch_size)
170
-
171
- # Define optimizer, scheduler, and scaler for mixed precision
172
- optimizer = optim.AdamW(student.parameters(), lr=learning_rate)
173
- scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
174
- scaler = GradScaler()
175
-
176
- # Define loss criterion
177
- criterion = nn.KLDivLoss(reduction="batchmean")
178
-
179
- best_val_loss = float('inf')
180
- epochs_no_improve = 0
181
-
182
- # Training loop
183
- for epoch in range(1, num_epochs + 1):
184
- print(f"\nEpoch {epoch}/{num_epochs}")
185
- print("-" * 20)
186
-
187
- # Training
188
- train_loss = train_step(teacher, student, train_loader, optimizer, criterion, scaler, temperature)
189
- print(f"Training Loss: {train_loss:.4f}")
190
- writer.add_scalar("Loss/Train", train_loss, epoch)
191
-
192
- # Validation
193
- val_loss = validate(teacher, student, val_loader, criterion, temperature)
194
- print(f"Validation Loss: {val_loss:.4f}")
195
- writer.add_scalar("Loss/Validation", val_loss, epoch)
196
-
197
- # Check for improvement
198
- if val_loss < best_val_loss:
199
- best_val_loss = val_loss
200
- epochs_no_improve = 0
201
- # Save the best model
202
- save_checkpoint({
203
- 'epoch': epoch,
204
- 'model_state_dict': student.state_dict(),
205
- 'optimizer_state_dict': optimizer.state_dict(),
206
- 'scheduler_state_dict': scheduler.state_dict(),
207
- 'scaler_state_dict': scaler.state_dict(),
208
- 'best_val_loss': best_val_loss
209
- }, checkpoint_dir, epoch)
210
- # Save the model as the best one
211
- student.save_pretrained(save_path)
212
- tokenizer.save_pretrained(save_path)
213
- print(f"Best model saved at epoch {epoch}")
214
- else:
215
- epochs_no_improve += 1
216
- print(f"No improvement in validation loss for {epochs_no_improve} epoch(s)")
217
- if epochs_no_improve >= early_stopping_patience:
218
- print("Early stopping triggered")
219
- break
220
-
221
- # Step the scheduler
222
- scheduler.step()
223
-
224
- writer.close()
225
- print("\nDistillation completed.")
226
-
227
- def main():
228
- parser = argparse.ArgumentParser(description="Distill a large LLM into a smaller one.")
229
- parser.add_argument("--teacher_model_name", type=str, required=True, help="Name of the teacher model")
230
- parser.add_argument("--student_model_name", type=str, required=True, help="Name of the student model")
231
- parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset")
232
- parser.add_argument("--config", type=str, default=None, help="Dataset configuration (e.g., 'wikitext-2-raw-v1')")
233
- parser.add_argument("--distill_full_model", action="store_true", help="Whether to distill the full model or not")
234
- parser.add_argument("--query_terms", type=str, nargs="+", help="Query terms for filtering the dataset")
235
- parser.add_argument("--num_epochs", type=int, default=3, help="Number of epochs")
236
- parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
237
- parser.add_argument("--max_length", type=int, default=128, help="Maximum sequence length")
238
- parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
239
- parser.add_argument("--temperature", type=float, default=2.0, help="Distillation temperature")
240
- parser.add_argument("--save_path", type=str, default="./distilled_model", help="Path to save the distilled model")
241
- parser.add_argument("--log_dir", type=str, default="./logs", help="Directory for TensorBoard logs")
242
- parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints", help="Directory to save checkpoints")
243
- parser.add_argument("--early_stopping_patience", type=int, default=3, help="Early stopping patience")
244
- return parser.parse_args()
245
-
246
- if __name__ == "__main__":
247
- args = main()
248
- distill_model(
249
- teacher_model_name=args.teacher_model_name,
250
- student_model_name=args.student_model_name,
251
- dataset_name=args.dataset_name,
252
- config=args.config,
253
- distill_full_model=args.distill_full_model,
254
- query_terms=args.query_terms,
255
- num_epochs=args.num_epochs,
256
- batch_size=args.batch_size,
257
- max_length=args.max_length,
258
- learning_rate=args.learning_rate,
259
- temperature=args.temperature,
260
- save_path=args.save_path,
261
- log_dir=args.log_dir,
262
- checkpoint_dir=args.checkpoint_dir,
263
- early_stopping_patience=args.early_stopping_patience
264
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import sys
5
+ import json
6
+ import jsonlines
7
+ import copy
8
+ from typing import List, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.optim as optim
14
+ from torch.utils.data import DataLoader, Dataset, random_split
15
+ from torch.cuda.amp import autocast, GradScaler
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ from datasets import load_dataset
19
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
20
+ from tqdm import tqdm
21
+
22
+ # Set up device
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+ # ======================================
26
+ # Import Custom Components from lightbulb_custom
27
+ # ======================================
28
+ from lightbulb_custom import (
29
+ RotaryPositionalEncoding,
30
+ MultiHeadAttention,
31
+ MoE,
32
+ TransformerBlock,
33
+ Transformer,
34
+ InfoNCE_Loss,
35
+ CovarianceRegularization,
36
+ DynamicsPerformanceLoss,
37
+ ThoughtConsistencyLoss,
38
+ PolicyValueJointLoss,
39
+ ActionDiversityReward,
40
+ ExpectedThoughtValueLoss,
41
+ ExplorationRegularization,
42
+ KL_DivergenceLoss,
43
+ ActionEncoder,
44
+ RepresentationNetwork,
45
+ DynamicsNetwork,
46
+ PredictionNetwork,
47
+ ThoughtNode,
48
+ MCTS,
49
+ State
50
+ )
51
+
52
+ # ==========================
53
+ # Custom Dataset Definition
54
+ # ==========================
55
+ class CustomDataset(Dataset):
56
+ def __init__(self, inputs, labels):
57
+ self.inputs = inputs
58
+ self.labels = labels
59
+
60
+ def __len__(self):
61
+ return len(self.inputs)
62
+
63
+ def __getitem__(self, idx):
64
+ return {'input_ids': self.inputs[idx], 'labels': self.labels[idx]}
65
+
66
+ # ================================
67
+ # Utility Functions for Data Loading
68
+ # ================================
69
+ def load_filtered_dataset(dataset_name: str, config: str, queries: Optional[List[str]] = None):
70
+ dataset = load_dataset(dataset_name, config)
71
+ if queries:
72
+ def filter_func(examples):
73
+ return [any(query.lower() in text.lower() for query in queries) for text in examples["text"]]
74
+ dataset = dataset.filter(filter_func, batched=True)
75
+ return dataset
76
+
77
+ def load_custom_data_from_files(file_paths):
78
+ custom_data = []
79
+ for file_path in file_paths:
80
+ if file_path.endswith('.json'):
81
+ with open(file_path, 'r') as f:
82
+ data = json.load(f)
83
+ if isinstance(data, list):
84
+ custom_data.extend(data)
85
+ else:
86
+ custom_data.append(data)
87
+ elif file_path.endswith('.jsonl'):
88
+ with jsonlines.open(file_path) as reader:
89
+ custom_data.extend(reader)
90
+ return custom_data
91
+
92
+ def preprocess_custom_data(data_list):
93
+ processed_data = []
94
+ for item in data_list:
95
+ # Check if the item is a string (JSON)
96
+ if isinstance(item, str):
97
+ try:
98
+ item = json.loads(item)
99
+ except json.JSONDecodeError:
100
+ print(f"Failed to parse JSON: {item[:100]}...") # Print first 100 chars for debugging
101
+ continue # Skip this item if it's not valid JSON
102
+
103
+ # Process query and content
104
+ query = item.get('query', '')
105
+ content = item.get('content', '')
106
+ if content == "RAG response generation failed.":
107
+ content = ""
108
+
109
+ # Combine query and content
110
+ combined_text = f"Query: {query} Content: {content}"
111
+
112
+ # Process numerical data (assuming these are available in the item dict)
113
+ episode_reward = item.get('episode_reward', 0)
114
+ loss = item.get('loss', 0)
115
+ cosine_similarity = item.get('cosine_similarity', 0)
116
+ rag_performance = item.get('rag_performance', 0)
117
+ ranking_model_performance = item.get('ranking_model_performance', 0)
118
+
119
+ # Create a dictionary with processed data
120
+ processed_item = {
121
+ 'text': combined_text,
122
+ 'episode_reward': episode_reward,
123
+ 'loss': loss,
124
+ 'cosine_similarity': cosine_similarity,
125
+ 'rag_performance': rag_performance,
126
+ 'ranking_model_performance': ranking_model_performance
127
+ }
128
+
129
+ processed_data.append(processed_item)
130
+
131
+ return processed_data
132
+
133
+ def load_custom_data(args, tokenizer, custom_data):
134
+ # Preprocess the custom data
135
+ processed_data = preprocess_custom_data(custom_data)
136
+
137
+ # Create a custom dataset
138
+ class CustomDatasetProcessed(torch.utils.data.Dataset):
139
+ def __init__(self, data, tokenizer, max_length):
140
+ self.data = data
141
+ self.tokenizer = tokenizer
142
+ self.max_length = max_length
143
+
144
+ def __len__(self):
145
+ return len(self.data)
146
+
147
+ def __getitem__(self, idx):
148
+ item = self.data[idx]
149
+ encoded = self.tokenizer.encode_plus(
150
+ item['text'],
151
+ max_length=self.max_length,
152
+ padding='max_length',
153
+ truncation=True,
154
+ return_tensors='pt'
155
+ )
156
+ return {
157
+ 'input_ids': encoded['input_ids'].squeeze(),
158
+ 'attention_mask': encoded['attention_mask'].squeeze(),
159
+ 'episode_reward': torch.tensor(item['episode_reward'], dtype=torch.float),
160
+ 'loss': torch.tensor(item['loss'], dtype=torch.float),
161
+ 'cosine_similarity': torch.tensor(item['cosine_similarity'], dtype=torch.float),
162
+ 'rag_performance': torch.tensor(item['rag_performance'], dtype=torch.float),
163
+ 'ranking_model_performance': torch.tensor(item['ranking_model_performance'], dtype=torch.float)
164
+ }
165
+
166
+ # Create dataset and dataloader
167
+ dataset = CustomDatasetProcessed(processed_data, tokenizer, args.max_length)
168
+
169
+ # Split the dataset into train and eval
170
+ train_size = int(0.8 * len(dataset))
171
+ eval_size = len(dataset) - train_size
172
+ train_dataset, eval_dataset = random_split(dataset, [train_size, eval_size])
173
+
174
+ train_loader = DataLoader(
175
+ train_dataset,
176
+ batch_size=args.batch_size,
177
+ shuffle=True,
178
+ num_workers=4
179
+ )
180
+ eval_loader = DataLoader(
181
+ eval_dataset,
182
+ batch_size=args.batch_size,
183
+ shuffle=False,
184
+ num_workers=4
185
+ )
186
+
187
+ return train_loader, eval_loader
188
+
189
+ def prepare_data(tokenizer, dataset, max_length, batch_size):
190
+ # Tokenize the inputs and labels
191
+ tokenized_inputs = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length)
192
+ tokenized_labels = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length)
193
+
194
+ # Create custom dataset
195
+ custom_dataset = CustomDataset(tokenized_inputs["input_ids"], tokenized_labels["input_ids"])
196
+
197
+ # Split into training and validation sets
198
+ train_size = int(0.9 * len(custom_dataset))
199
+ val_size = len(custom_dataset) - train_size
200
+ train_dataset, val_dataset = random_split(custom_dataset, [train_size, val_size])
201
+
202
+ # Create DataLoaders
203
+ train_loader = DataLoader(
204
+ train_dataset,
205
+ shuffle=True,
206
+ batch_size=batch_size,
207
+ num_workers=4,
208
+ pin_memory=True
209
+ )
210
+ val_loader = DataLoader(
211
+ val_dataset,
212
+ shuffle=False,
213
+ batch_size=batch_size,
214
+ num_workers=4,
215
+ pin_memory=True
216
+ )
217
+
218
+ return train_loader, val_loader
219
+
220
+ # ==========================
221
+ # Training and Validation Functions
222
+ # ==========================
223
+
224
+ def save_all_models(transformer_model, representation_network, dynamics_network, prediction_network, action_encoder, save_dir, epoch):
225
+ """
226
+ Save all models to the specified directory.
227
+ Args:
228
+ transformer_model (nn.Module): Transformer model.
229
+ representation_network (nn.Module): Representation network.
230
+ dynamics_network (nn.Module): Dynamics network.
231
+ prediction_network (nn.Module): Prediction network.
232
+ action_encoder (nn.Module): Action encoder.
233
+ save_dir (str): Directory to save the models.
234
+ epoch (int): Current epoch number.
235
+ """
236
+ os.makedirs(save_dir, exist_ok=True)
237
+
238
+ torch.save(transformer_model.state_dict(), os.path.join(save_dir, f'transformer_model_epoch_{epoch}.pt'))
239
+ torch.save(representation_network.state_dict(), os.path.join(save_dir, f'representation_network_epoch_{epoch}.pt'))
240
+ torch.save(dynamics_network.state_dict(), os.path.join(save_dir, f'dynamics_network_epoch_{epoch}.pt'))
241
+ torch.save(prediction_network.state_dict(), os.path.join(save_dir, f'prediction_network_epoch_{epoch}.pt'))
242
+ torch.save(action_encoder.state_dict(), os.path.join(save_dir, f'action_encoder_epoch_{epoch}.pt'))
243
+
244
+ print(f"All models saved for epoch {epoch}.")
245
+
246
+ def train_epoch_world_model(world_model_components, train_loader, optimizer, scheduler, scaler, args, model_transformer, state_dim, embed_dim, input_dim):
247
+ representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer = world_model_components
248
+ representation_network.train()
249
+ dynamics_network.train()
250
+ prediction_network.train()
251
+ action_encoder.train()
252
+ ppo_agent.policy_network.train()
253
+
254
+ total_loss = 0.0
255
+ optimizer.zero_grad()
256
+ print(f"Starting World Model training epoch with {len(train_loader)} batches...")
257
+
258
+ for i, batch in enumerate(train_loader):
259
+ print(f"Processing batch {i+1}/{len(train_loader)}...")
260
+
261
+ # Move batches to the device
262
+ src_batch = batch['input_ids'].to(device)
263
+ tgt_batch = batch['labels'].to(device)
264
+
265
+ with torch.cuda.amp.autocast():
266
+ print("Forward pass through Transformer (frozen)...")
267
+ with torch.no_grad():
268
+ transformer_output = model_transformer(src_batch, tgt_batch[:, :-1])
269
+
270
+ # World Model - Representation
271
+ state_representation = representation_network(transformer_output)
272
+
273
+ # For simplicity, let's assume true actions are provided (e.g., next tokens)
274
+ true_actions = tgt_batch[:, :-1]
275
+ print(f"True actions shape: {true_actions.shape}")
276
+ action_sequences = true_actions
277
+
278
+ # Get action embeddings
279
+ action_embeddings = action_encoder(action_sequences)
280
+ print(f"Action embeddings shape: {action_embeddings.shape}")
281
+
282
+ # Apply dynamics network
283
+ predicted_next_state_batch = dynamics_network(state_representation, action_embeddings)
284
+ print(f"Predicted next state batch shape: {predicted_next_state_batch.shape}")
285
+
286
+ # Prediction Network - Policy logits and value
287
+ policy_logits, value_estimates = prediction_network(predicted_next_state_batch)
288
+
289
+ # Define true_policy and true_value as placeholders on the GPU
290
+ true_policy = F.one_hot(true_actions, num_classes=input_dim).float()
291
+ true_value = torch.zeros_like(value_estimates).to(device)
292
+
293
+ # Compute individual losses
294
+ ppo_loss = ppo_agent.compute_loss(
295
+ state_representation,
296
+ torch.zeros_like(true_actions, dtype=torch.float32).to(device),
297
+ true_actions,
298
+ torch.zeros_like(value_estimates, dtype=torch.float32).to(device),
299
+ torch.zeros_like(value_estimates, dtype=torch.float32).to(device)
300
+ )
301
+
302
+ info_nce = InfoNCE_Loss()(state_representation.reshape(-1, state_dim),
303
+ F.dropout(state_representation.reshape(-1, state_dim), p=0.1, training=True))
304
+
305
+ covariance = CovarianceRegularization()(predicted_next_state_batch.view(-1, predicted_next_state_batch.size(-1)))
306
+ dynamics_loss = DynamicsPerformanceLoss()(state_representation, predicted_next_state_batch)
307
+
308
+ perturbed_next_state = predicted_next_state_batch + torch.randn_like(predicted_next_state_batch) * 0.01
309
+ thought_loss = ThoughtConsistencyLoss()(predicted_next_state_batch, perturbed_next_state)
310
+
311
+ pv_loss = PolicyValueJointLoss()(policy_logits, true_policy, value_estimates.squeeze(-1), true_value.squeeze(-1))
312
+ action_diversity = ActionDiversityReward()(action_embeddings.view(-1, embed_dim))
313
+
314
+ mcts_best_values = torch.zeros(true_actions.size(0)).to(device)
315
+ etv = ExpectedThoughtValueLoss()(mcts_best_values)
316
+
317
+ visit_counts = torch.ones(true_actions.size(0), policy_logits.size(-1)).to(device)
318
+ exploration = ExplorationRegularization()(visit_counts)
319
+
320
+ old_policy = F.softmax(policy_logits.detach(), dim=-1)
321
+ new_policy = F.softmax(policy_logits, dim=-1)
322
+ kl_loss = KL_DivergenceLoss()(old_policy, new_policy)
323
+
324
+ # Total Loss
325
+ loss = (
326
+ ppo_loss +
327
+ info_nce +
328
+ covariance +
329
+ dynamics_loss +
330
+ thought_loss +
331
+ pv_loss +
332
+ action_diversity +
333
+ etv +
334
+ exploration +
335
+ kl_loss
336
+ )
337
+ loss = loss / args.accumulation_steps
338
+
339
+ print("Backward pass...")
340
+ scaler.scale(loss).backward()
341
+
342
+ if (i + 1) % args.accumulation_steps == 0 or (i + 1) == len(train_loader):
343
+ print("Gradient clipping...")
344
+ scaler.unscale_(optimizer)
345
+ torch.nn.utils.clip_grad_norm_(
346
+ [param for group in optimizer.param_groups for param in group['params']],
347
+ args.max_grad_norm
348
+ )
349
+
350
+ print("Optimizer step...")
351
+ scaler.step(optimizer)
352
+ scaler.update()
353
+
354
+ print("Zeroing gradients...")
355
+ optimizer.zero_grad()
356
+
357
+ print("Updating learning rate...")
358
+ scheduler.step()
359
+
360
+ total_loss += loss.item() * args.accumulation_steps
361
+
362
+ # Print individual losses and total loss for this batch
363
+ print(f"Batch {i+1} completed. Losses:")
364
+ print(f" PPO Loss: {ppo_loss.item():.4f}")
365
+ print(f" InfoNCE Loss: {info_nce.item():.4f}")
366
+ print(f" Covariance Loss: {covariance.item():.4f}")
367
+ print(f" Dynamics Loss: {dynamics_loss.item():.4f}")
368
+ print(f" Thought Consistency Loss: {thought_loss.item():.4f}")
369
+ print(f" Policy-Value Loss: {pv_loss.item():.4f}")
370
+ print(f" Action Diversity Loss: {action_diversity.item():.4f}")
371
+ print(f" Expected Thought Value Loss: {etv.item():.4f}")
372
+ print(f" Exploration Loss: {exploration.item():.4f}")
373
+ print(f" KL Divergence Loss: {kl_loss.item():.4f}")
374
+ print(f" Total Loss: {loss.item():.4f}")
375
+
376
+ avg_loss = total_loss / len(train_loader)
377
+ print(f"World Model training epoch completed. Average loss: {avg_loss:.4f}")
378
+ return avg_loss
379
+
380
+ def train_step(teacher, student, data_loader, optimizer, criterion, scaler, temperature=2.0):
381
+ teacher.eval()
382
+ student.train()
383
+ total_loss = 0
384
+
385
+ for batch in tqdm(data_loader, desc="Training"):
386
+ inputs = batch["input_ids"].to(device)
387
+ labels = batch["labels"].to(device)
388
+
389
+ with autocast():
390
+ with torch.no_grad():
391
+ teacher_outputs = teacher(inputs).logits
392
+ teacher_logits = teacher_outputs / temperature
393
+
394
+ student_outputs = student(inputs).logits
395
+ student_logits = student_outputs / temperature
396
+
397
+ # Compute KL Divergence Loss
398
+ loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1))
399
+ loss = loss * (temperature ** 2) # Scale loss by temperature squared
400
+
401
+ scaler.scale(loss).backward()
402
+ scaler.step(optimizer)
403
+ scaler.update()
404
+ optimizer.zero_grad()
405
+
406
+ total_loss += loss.item()
407
+
408
+ avg_loss = total_loss / len(data_loader)
409
+ return avg_loss
410
+
411
+ def validate(teacher, student, data_loader, criterion, temperature=2.0):
412
+ teacher.eval()
413
+ student.eval()
414
+ total_loss = 0
415
+
416
+ with torch.no_grad():
417
+ for batch in tqdm(data_loader, desc="Validation"):
418
+ inputs = batch["input_ids"].to(device)
419
+ labels = batch["labels"].to(device)
420
+
421
+ teacher_outputs = teacher(inputs).logits
422
+ teacher_logits = teacher_outputs / temperature
423
+
424
+ student_outputs = student(inputs).logits
425
+ student_logits = student_outputs / temperature
426
+
427
+ loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1))
428
+ loss = loss * (temperature ** 2)
429
+
430
+ total_loss += loss.item()
431
+
432
+ avg_loss = total_loss / len(data_loader)
433
+ return avg_loss
434
+
435
+ def save_checkpoint(state, save_dir, epoch):
436
+ os.makedirs(save_dir, exist_ok=True)
437
+ checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt')
438
+ torch.save(state, checkpoint_path)
439
+ print(f"Checkpoint saved at {checkpoint_path}")
440
+
441
+ # ==========================
442
+ # Inference Functions
443
+ # ==========================
444
+
445
+ def infer(query, world_model_components, root_thought_node, tokenizer, max_length=2000, inference_mode='world_model', beam_size=5, n_tokens_predict=3, mcts_iterations=10, exploration_constant=1.414):
446
+ """
447
+ Perform inference given a query, utilizing the Tree of Thought and MCTS with multi-token beam search.
448
+ Args:
449
+ query (str): The input query or prompt.
450
+ world_model_components (tuple): Tuple containing the model components.
451
+ root_thought_node (ThoughtNode): The root node of the Tree of Thought.
452
+ tokenizer (transformers.PreTrainedTokenizer): The tokenizer used.
453
+ max_length (int): Maximum length for the generated sequence.
454
+ inference_mode (str): Inference mode ('world_model', 'without_world_model', 'world_model_tree_of_thought')
455
+ beam_size (int): Size of the beam for beam search
456
+ n_tokens_predict (int): Number of tokens to predict at each step
457
+ mcts_iterations (int): Number of MCTS iterations
458
+ exploration_constant (float): Exploration constant for MCTS
459
+ Returns:
460
+ List[str] or str: The sequence of actions (thoughts) selected or generated text.
461
+ """
462
+ if inference_mode != 'world_model':
463
+ print("Inference mode other than 'world_model' not implemented yet.")
464
+ return ""
465
+
466
+ representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer = world_model_components
467
+
468
+ # Tokenize and encode the query
469
+ input_ids = tokenizer.encode(query, return_tensors='pt').to(device)
470
+ attention_mask = (input_ids != tokenizer.pad_token_id).long()
471
+
472
+ # Use the world model components
473
+ with torch.no_grad():
474
+ transformer_output = model_transformer(input_ids, input_ids)
475
+ # Get the initial state representation
476
+ initial_representation = representation_network(transformer_output) # Shape: (batch_size=1, seq_len, state_dim)
477
+ initial_representation = initial_representation[:, -1, :].unsqueeze(1) # Shape: (batch_size=1, 1, state_dim)
478
+ initial_state = State(
479
+ representation=initial_representation,
480
+ dynamics_network=dynamics_network,
481
+ action_encoder=action_encoder,
482
+ thought_node=root_thought_node
483
+ )
484
+ # Use MCTS with Tree of Thought and multi-token beam search
485
+ mcts = MCTS(prediction_network, dynamics_network, action_encoder, num_iterations=mcts_iterations, exploration_constant=exploration_constant)
486
+
487
+ current_state = initial_state
488
+ thought_sequence = []
489
+
490
+ for _ in range(max_length // n_tokens_predict):
491
+ best_actions = mcts.search_with_beam(current_state)
492
+
493
+ thought_sequence.extend(best_actions)
494
+
495
+ # Apply the best actions to get the next state
496
+ for action in best_actions:
497
+ current_state = current_state.apply_action(action)
498
+
499
+ # Check if we've reached a leaf node (no further actions)
500
+ if len(current_state.thought_node.children) == 0:
501
+ break
502
+
503
+ return thought_sequence
504
+
505
+ # ==========================
506
+ # Main Training Function
507
+ # ==========================
508
+
509
+ def distill_model(
510
+ teacher_model_name: str,
511
+ student_model_name: str,
512
+ dataset_name: str,
513
+ config: str,
514
+ distill_full_model: bool = True,
515
+ query_terms: Optional[List[str]] = None,
516
+ num_epochs: int = 3,
517
+ batch_size: int = 4,
518
+ max_length: int = 128,
519
+ learning_rate: float = 5e-5,
520
+ temperature: float = 2.0,
521
+ save_path: str = "./distilled_model",
522
+ log_dir: str = "./logs",
523
+ checkpoint_dir: str = "./checkpoints",
524
+ early_stopping_patience: int = 3,
525
+ accumulation_steps: int = 1,
526
+ max_grad_norm: float = 1.0,
527
+ weight_decay: float = 0.01
528
+ ):
529
+ # Initialize TensorBoard writer
530
+ writer = SummaryWriter(log_dir=log_dir)
531
+
532
+ # Load tokenizer
533
+ print("Loading tokenizer...")
534
+ tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
535
+ if tokenizer.pad_token is None:
536
+ tokenizer.pad_token = tokenizer.eos_token
537
+ print("Tokenizer loaded successfully.")
538
+
539
+ # Load teacher model
540
+ print("Loading teacher model...")
541
+ teacher = AutoModelForCausalLM.from_pretrained(teacher_model_name).to(device)
542
+ print("Teacher model loaded successfully.")
543
+
544
+ if distill_full_model:
545
+ # Full World Model Distillation
546
+ print(f"Starting Full World Model Distillation into '{student_model_name}'.")
547
+
548
+ # Load or instantiate student model
549
+ print(f"Attempting to load student model '{student_model_name}'...")
550
+ try:
551
+ student = AutoModelForCausalLM.from_pretrained(student_model_name).to(device)
552
+ print(f"Student model '{student_model_name}' loaded successfully.")
553
+ except (OSError, ValueError) as e:
554
+ print(f"Student model '{student_model_name}' not found. Instantiating a new student model.")
555
+ # Instantiate a smaller pre-trained model as the student, e.g., distilgpt2
556
+ try:
557
+ student = AutoModelForCausalLM.from_pretrained('distilgpt2').to(device)
558
+ # Save the instantiated student model with the desired name
559
+ student.save_pretrained(save_path)
560
+ tokenizer.save_pretrained(save_path)
561
+ print(f"New student model '{student_model_name}' instantiated and saved to '{save_path}'.")
562
+ except Exception as inst_e:
563
+ print(f"Failed to instantiate and save student model: {inst_e}")
564
+ sys.exit(1)
565
+
566
+ # Optionally freeze teacher model parameters
567
+ for param in teacher.parameters():
568
+ param.requires_grad = False
569
+
570
+ # Load and prepare dataset
571
+ print(f"Loading full dataset '{dataset_name}' with config '{config}'...")
572
+ dataset = load_dataset(dataset_name, config)
573
+ train_loader, val_loader = prepare_data(tokenizer, dataset, max_length, batch_size)
574
+ print("Data loaded and preprocessed successfully.")
575
+
576
+ # Define optimizer, scheduler, and scaler for mixed precision
577
+ optimizer = optim.AdamW(student.parameters(), lr=learning_rate, weight_decay=weight_decay)
578
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
579
+ scaler = GradScaler()
580
+
581
+ # Define loss criterion
582
+ criterion = nn.KLDivLoss(reduction="batchmean")
583
+
584
+ best_val_loss = float('inf')
585
+ epochs_no_improve = 0
586
+
587
+ # Training loop
588
+ for epoch in range(1, num_epochs + 1):
589
+ print(f"\nEpoch {epoch}/{num_epochs}")
590
+ print("-" * 20)
591
+
592
+ # Training
593
+ train_loss = train_step(teacher, student, train_loader, optimizer, criterion, scaler, temperature)
594
+ print(f"Training Loss: {train_loss:.4f}")
595
+ writer.add_scalar("Loss/Train", train_loss, epoch)
596
+
597
+ # Validation
598
+ val_loss = validate(teacher, student, val_loader, criterion, temperature)
599
+ print(f"Validation Loss: {val_loss:.4f}")
600
+ writer.add_scalar("Loss/Validation", val_loss, epoch)
601
+
602
+ # Check for improvement
603
+ if val_loss < best_val_loss:
604
+ best_val_loss = val_loss
605
+ epochs_no_improve = 0
606
+ # Save the best model
607
+ save_checkpoint({
608
+ 'epoch': epoch,
609
+ 'model_state_dict': student.state_dict(),
610
+ 'optimizer_state_dict': optimizer.state_dict(),
611
+ 'scheduler_state_dict': scheduler.state_dict(),
612
+ 'scaler_state_dict': scaler.state_dict(),
613
+ 'best_val_loss': best_val_loss
614
+ }, checkpoint_dir, epoch)
615
+ # Save the model as the best one
616
+ student.save_pretrained(save_path)
617
+ tokenizer.save_pretrained(save_path)
618
+ print(f"Best model saved at epoch {epoch}")
619
+ else:
620
+ epochs_no_improve += 1
621
+ print(f"No improvement in validation loss for {epochs_no_improve} epoch(s)")
622
+ if epochs_no_improve >= early_stopping_patience:
623
+ print("Early stopping triggered")
624
+ break
625
+
626
+ # Step the scheduler
627
+ scheduler.step()
628
+
629
+ writer.close()
630
+ print("\nFull World Model Distillation completed.")
631
+
632
+ else:
633
+ # Standard Language Model Distillation
634
+ print(f"Starting Standard Language Model Distillation into '{student_model_name}'.")
635
+
636
+ if not query_terms:
637
+ print("Error: --query_terms must be provided for standard language model distillation.")
638
+ sys.exit(1)
639
+
640
+ # Load or instantiate student model
641
+ print(f"Attempting to load student model '{student_model_name}'...")
642
+ try:
643
+ student = AutoModelForCausalLM.from_pretrained(student_model_name).to(device)
644
+ print(f"Student model '{student_model_name}' loaded successfully.")
645
+ except (OSError, ValueError) as e:
646
+ print(f"Student model '{student_model_name}' not found. Instantiating a new student model.")
647
+ # Instantiate a smaller pre-trained model as the student, e.g., distilgpt2
648
+ try:
649
+ student = AutoModelForCausalLM.from_pretrained('distilgpt2').to(device)
650
+ # Save the instantiated student model with the desired name
651
+ student.save_pretrained(save_path)
652
+ tokenizer.save_pretrained(save_path)
653
+ print(f"New student model '{student_model_name}' instantiated and saved to '{save_path}'.")
654
+ except Exception as inst_e:
655
+ print(f"Failed to instantiate and save student model: {inst_e}")
656
+ sys.exit(1)
657
+
658
+ # Optionally freeze teacher model parameters
659
+ for param in teacher.parameters():
660
+ param.requires_grad = False
661
+
662
+ # Load and prepare custom dataset
663
+ print(f"Loading custom data files: {query_terms}")
664
+ custom_data = load_custom_data_from_files(query_terms)
665
+ train_loader, val_loader = load_custom_data(
666
+ args=argparse.Namespace(max_length=max_length),
667
+ tokenizer=tokenizer,
668
+ custom_data=custom_data
669
+ )
670
+ print("Custom data loaded and preprocessed successfully.")
671
+
672
+ # Define optimizer, scheduler, and scaler for mixed precision
673
+ optimizer = optim.AdamW(student.parameters(), lr=learning_rate, weight_decay=weight_decay)
674
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
675
+ scaler = GradScaler()
676
+
677
+ # Define loss criterion
678
+ criterion = nn.KLDivLoss(reduction="batchmean")
679
+
680
+ best_val_loss = float('inf')
681
+ epochs_no_improve = 0
682
+
683
+ # Training loop
684
+ for epoch in range(1, num_epochs + 1):
685
+ print(f"\nEpoch {epoch}/{num_epochs}")
686
+ print("-" * 20)
687
+
688
+ # Training
689
+ train_loss = train_step(teacher, student, train_loader, optimizer, criterion, scaler, temperature)
690
+ print(f"Training Loss: {train_loss:.4f}")
691
+ writer.add_scalar("Loss/Train", train_loss, epoch)
692
+
693
+ # Validation
694
+ val_loss = validate(teacher, student, val_loader, criterion, temperature)
695
+ print(f"Validation Loss: {val_loss:.4f}")
696
+ writer.add_scalar("Loss/Validation", val_loss, epoch)
697
+
698
+ # Check for improvement
699
+ if val_loss < best_val_loss:
700
+ best_val_loss = val_loss
701
+ epochs_no_improve = 0
702
+ # Save the best model
703
+ save_checkpoint({
704
+ 'epoch': epoch,
705
+ 'model_state_dict': student.state_dict(),
706
+ 'optimizer_state_dict': optimizer.state_dict(),
707
+ 'scheduler_state_dict': scheduler.state_dict(),
708
+ 'scaler_state_dict': scaler.state_dict(),
709
+ 'best_val_loss': best_val_loss
710
+ }, checkpoint_dir, epoch)
711
+ # Save the model as the best one
712
+ student.save_pretrained(save_path)
713
+ tokenizer.save_pretrained(save_path)
714
+ print(f"Best model saved at epoch {epoch}")
715
+ else:
716
+ epochs_no_improve += 1
717
+ print(f"No improvement in validation loss for {epochs_no_improve} epoch(s)")
718
+ if epochs_no_improve >= early_stopping_patience:
719
+ print("Early stopping triggered")
720
+ break
721
+
722
+ # Step the scheduler
723
+ scheduler.step()
724
+
725
+ writer.close()
726
+ print("\nStandard Language Model Distillation completed.")
727
+
728
+ # ==========================
729
+ # Argument Parsing
730
+ # ==========================
731
+
732
+ def parse_args():
733
+ parser = argparse.ArgumentParser(description="Distill a large LLM into a smaller one or a full language world model.")
734
+
735
+ # Required arguments
736
+ parser.add_argument("--teacher_model_name", type=str, required=True, help="Name of the teacher model")
737
+ parser.add_argument("--student_model_name", type=str, required=True, help="Name of the student model")
738
+
739
+ # Dataset arguments
740
+ parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset")
741
+ parser.add_argument("--config", type=str, default=None, help="Dataset configuration (e.g., 'wikitext-2-raw-v1')")
742
+
743
+ # Mode selection
744
+ parser.add_argument("--distill_full_model", action="store_true", help="Whether to distill into the full language world model")
745
+
746
+ # For standard distillation
747
+ parser.add_argument("--query_terms", type=str, nargs="+", help="Paths to custom data files for standard language model distillation")
748
+
749
+ # Training hyperparameters
750
+ parser.add_argument("--num_epochs", type=int, default=3, help="Number of epochs")
751
+ parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
752
+ parser.add_argument("--max_length", type=int, default=128, help="Maximum sequence length")
753
+ parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
754
+ parser.add_argument("--temperature", type=float, default=2.0, help="Distillation temperature")
755
+
756
+ # Saving and logging
757
+ parser.add_argument("--save_path", type=str, default="./distilled_model", help="Path to save the distilled model")
758
+ parser.add_argument("--log_dir", type=str, default="./logs", help="Directory for TensorBoard logs")
759
+ parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints", help="Directory to save checkpoints")
760
+
761
+ # Early stopping
762
+ parser.add_argument("--early_stopping_patience", type=int, default=3, help="Early stopping patience")
763
+
764
+ # Gradient accumulation and optimization
765
+ parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
766
+ parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Maximum gradient norm for clipping")
767
+ parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay for optimizer")
768
+
769
+ return parser.parse_args()
770
+
771
+ # ==========================
772
+ # Main Function
773
+ # ==========================
774
+
775
+ def main():
776
+ args = parse_args()
777
+ print("Arguments parsed successfully.")
778
+
779
+ # Create save directories
780
+ os.makedirs(args.save_path, exist_ok=True)
781
+ os.makedirs(args.log_dir, exist_ok=True)
782
+ os.makedirs(args.checkpoint_dir, exist_ok=True)
783
+ print(f"Save directory created: {args.save_path}")
784
+ print(f"Log directory created: {args.log_dir}")
785
+ print(f"Checkpoint directory created: {args.checkpoint_dir}")
786
+
787
+ # Handle dataset loading based on distillation mode
788
+ if args.distill_full_model:
789
+ # Full World Model Distillation
790
+ distill_model(
791
+ teacher_model_name=args.teacher_model_name,
792
+ student_model_name=args.student_model_name,
793
+ dataset_name=args.dataset_name,
794
+ config=args.config,
795
+ distill_full_model=args.distill_full_model,
796
+ query_terms=args.query_terms, # Not used in this mode
797
+ num_epochs=args.num_epochs,
798
+ batch_size=args.batch_size,
799
+ max_length=args.max_length,
800
+ learning_rate=args.learning_rate,
801
+ temperature=args.temperature,
802
+ save_path=args.save_path,
803
+ log_dir=args.log_dir,
804
+ checkpoint_dir=args.checkpoint_dir,
805
+ early_stopping_patience=args.early_stopping_patience,
806
+ accumulation_steps=args.accumulation_steps,
807
+ max_grad_norm=args.max_grad_norm,
808
+ weight_decay=args.weight_decay
809
+ )
810
+ else:
811
+ # Standard Language Model Distillation
812
+ distill_model(
813
+ teacher_model_name=args.teacher_model_name,
814
+ student_model_name=args.student_model_name,
815
+ dataset_name=args.dataset_name,
816
+ config=args.config,
817
+ distill_full_model=args.distill_full_model,
818
+ query_terms=args.query_terms,
819
+ num_epochs=args.num_epochs,
820
+ batch_size=args.batch_size,
821
+ max_length=args.max_length,
822
+ learning_rate=args.learning_rate,
823
+ temperature=args.temperature,
824
+ save_path=args.save_path,
825
+ log_dir=args.log_dir,
826
+ checkpoint_dir=args.checkpoint_dir,
827
+ early_stopping_patience=args.early_stopping_patience,
828
+ accumulation_steps=args.accumulation_steps,
829
+ max_grad_norm=args.max_grad_norm,
830
+ weight_decay=args.weight_decay
831
+ )
832
+
833
+
834
+
835
+ if __name__ == "__main__":
836
+ main()
837
+
838
+