RobbiePasquale
commited on
Commit
•
7f47926
1
Parent(s):
c9a5651
Upload 2 files
Browse files- distill.py +264 -0
- main_menu_new.py +191 -0
distill.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.optim as optim
|
4 |
+
from torch.utils.data import DataLoader, Dataset, random_split
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
+
from datasets import load_dataset
|
7 |
+
from typing import List, Optional
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import json
|
11 |
+
import jsonlines
|
12 |
+
from tqdm import tqdm
|
13 |
+
from torch.cuda.amp import autocast, GradScaler
|
14 |
+
from torch.utils.tensorboard import SummaryWriter
|
15 |
+
|
16 |
+
# Set up device
|
17 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
18 |
+
|
19 |
+
class CustomDataset(Dataset):
|
20 |
+
def __init__(self, inputs, labels):
|
21 |
+
self.inputs = inputs
|
22 |
+
self.labels = labels
|
23 |
+
|
24 |
+
def __len__(self):
|
25 |
+
return len(self.inputs)
|
26 |
+
|
27 |
+
def __getitem__(self, idx):
|
28 |
+
return {'input_ids': self.inputs[idx], 'labels': self.labels[idx]}
|
29 |
+
|
30 |
+
def load_filtered_dataset(dataset_name: str, config: str, queries: Optional[List[str]] = None):
|
31 |
+
dataset = load_dataset(dataset_name, config)
|
32 |
+
if queries:
|
33 |
+
def filter_func(examples):
|
34 |
+
return any(query.lower() in examples["text"].lower() for query in queries)
|
35 |
+
dataset = dataset.filter(filter_func, batched=True)
|
36 |
+
return dataset
|
37 |
+
|
38 |
+
def prepare_data(tokenizer, dataset, max_length, batch_size):
|
39 |
+
# Tokenize the inputs and labels
|
40 |
+
tokenized_inputs = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length)
|
41 |
+
tokenized_labels = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length)
|
42 |
+
|
43 |
+
# Create custom dataset
|
44 |
+
custom_dataset = CustomDataset(tokenized_inputs["input_ids"], tokenized_labels["input_ids"])
|
45 |
+
|
46 |
+
# Split into training and validation sets
|
47 |
+
train_size = int(0.9 * len(custom_dataset))
|
48 |
+
val_size = len(custom_dataset) - train_size
|
49 |
+
train_dataset, val_dataset = random_split(custom_dataset, [train_size, val_size])
|
50 |
+
|
51 |
+
# Create DataLoaders
|
52 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
|
53 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
54 |
+
|
55 |
+
return train_loader, val_loader
|
56 |
+
|
57 |
+
def train_step(teacher, student, data_loader, optimizer, criterion, scaler, temperature=2.0):
|
58 |
+
teacher.eval()
|
59 |
+
student.train()
|
60 |
+
total_loss = 0
|
61 |
+
|
62 |
+
for batch in tqdm(data_loader, desc="Training"):
|
63 |
+
inputs = batch["input_ids"].to(device)
|
64 |
+
labels = batch["labels"].to(device)
|
65 |
+
|
66 |
+
with autocast():
|
67 |
+
with torch.no_grad():
|
68 |
+
teacher_outputs = teacher(inputs).logits
|
69 |
+
teacher_logits = teacher_outputs / temperature
|
70 |
+
|
71 |
+
student_outputs = student(inputs).logits
|
72 |
+
student_logits = student_outputs / temperature
|
73 |
+
|
74 |
+
# Compute KL Divergence Loss
|
75 |
+
loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1))
|
76 |
+
loss = loss * (temperature ** 2) # Scale loss by temperature squared
|
77 |
+
|
78 |
+
scaler.scale(loss).backward()
|
79 |
+
scaler.step(optimizer)
|
80 |
+
scaler.update()
|
81 |
+
optimizer.zero_grad()
|
82 |
+
|
83 |
+
total_loss += loss.item()
|
84 |
+
|
85 |
+
avg_loss = total_loss / len(data_loader)
|
86 |
+
return avg_loss
|
87 |
+
|
88 |
+
def validate(teacher, student, data_loader, criterion, temperature=2.0):
|
89 |
+
teacher.eval()
|
90 |
+
student.eval()
|
91 |
+
total_loss = 0
|
92 |
+
|
93 |
+
with torch.no_grad():
|
94 |
+
for batch in tqdm(data_loader, desc="Validation"):
|
95 |
+
inputs = batch["input_ids"].to(device)
|
96 |
+
labels = batch["labels"].to(device)
|
97 |
+
|
98 |
+
teacher_outputs = teacher(inputs).logits
|
99 |
+
teacher_logits = teacher_outputs / temperature
|
100 |
+
|
101 |
+
student_outputs = student(inputs).logits
|
102 |
+
student_logits = student_outputs / temperature
|
103 |
+
|
104 |
+
loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1))
|
105 |
+
loss = loss * (temperature ** 2)
|
106 |
+
|
107 |
+
total_loss += loss.item()
|
108 |
+
|
109 |
+
avg_loss = total_loss / len(data_loader)
|
110 |
+
return avg_loss
|
111 |
+
|
112 |
+
def save_checkpoint(state, save_dir, epoch):
|
113 |
+
os.makedirs(save_dir, exist_ok=True)
|
114 |
+
checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt')
|
115 |
+
torch.save(state, checkpoint_path)
|
116 |
+
print(f"Checkpoint saved at {checkpoint_path}")
|
117 |
+
|
118 |
+
def load_checkpoint(model, optimizer, scheduler, scaler, save_dir, epoch):
|
119 |
+
checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt')
|
120 |
+
if os.path.isfile(checkpoint_path):
|
121 |
+
checkpoint = torch.load(checkpoint_path)
|
122 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
123 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
124 |
+
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
125 |
+
scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
126 |
+
print(f"Loaded checkpoint from {checkpoint_path}")
|
127 |
+
else:
|
128 |
+
print(f"No checkpoint found at {checkpoint_path}")
|
129 |
+
|
130 |
+
def distill_model(
|
131 |
+
teacher_model_name: str,
|
132 |
+
student_model_name: str,
|
133 |
+
dataset_name: str,
|
134 |
+
config: str,
|
135 |
+
distill_full_model: bool = True,
|
136 |
+
query_terms: Optional[List[str]] = None,
|
137 |
+
num_epochs: int = 3,
|
138 |
+
batch_size: int = 4,
|
139 |
+
max_length: int = 128,
|
140 |
+
learning_rate: float = 5e-5,
|
141 |
+
temperature: float = 2.0,
|
142 |
+
save_path: str = "./distilled_model",
|
143 |
+
log_dir: str = "./logs",
|
144 |
+
checkpoint_dir: str = "./checkpoints",
|
145 |
+
early_stopping_patience: int = 3
|
146 |
+
):
|
147 |
+
# Initialize TensorBoard writer
|
148 |
+
writer = SummaryWriter(log_dir=log_dir)
|
149 |
+
|
150 |
+
# Load tokenizer
|
151 |
+
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
|
152 |
+
if tokenizer.pad_token is None:
|
153 |
+
tokenizer.pad_token = tokenizer.eos_token
|
154 |
+
|
155 |
+
# Load teacher and student models
|
156 |
+
teacher = AutoModelForCausalLM.from_pretrained(teacher_model_name).to(device)
|
157 |
+
student = AutoModelForCausalLM.from_pretrained(student_model_name).to(device)
|
158 |
+
|
159 |
+
# Optionally freeze teacher model parameters
|
160 |
+
for param in teacher.parameters():
|
161 |
+
param.requires_grad = False
|
162 |
+
|
163 |
+
# Load and prepare dataset
|
164 |
+
if distill_full_model:
|
165 |
+
dataset = load_dataset(dataset_name, config)
|
166 |
+
else:
|
167 |
+
dataset = load_filtered_dataset(dataset_name, config, query_terms)
|
168 |
+
|
169 |
+
train_loader, val_loader = prepare_data(tokenizer, dataset, max_length, batch_size)
|
170 |
+
|
171 |
+
# Define optimizer, scheduler, and scaler for mixed precision
|
172 |
+
optimizer = optim.AdamW(student.parameters(), lr=learning_rate)
|
173 |
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
|
174 |
+
scaler = GradScaler()
|
175 |
+
|
176 |
+
# Define loss criterion
|
177 |
+
criterion = nn.KLDivLoss(reduction="batchmean")
|
178 |
+
|
179 |
+
best_val_loss = float('inf')
|
180 |
+
epochs_no_improve = 0
|
181 |
+
|
182 |
+
# Training loop
|
183 |
+
for epoch in range(1, num_epochs + 1):
|
184 |
+
print(f"\nEpoch {epoch}/{num_epochs}")
|
185 |
+
print("-" * 20)
|
186 |
+
|
187 |
+
# Training
|
188 |
+
train_loss = train_step(teacher, student, train_loader, optimizer, criterion, scaler, temperature)
|
189 |
+
print(f"Training Loss: {train_loss:.4f}")
|
190 |
+
writer.add_scalar("Loss/Train", train_loss, epoch)
|
191 |
+
|
192 |
+
# Validation
|
193 |
+
val_loss = validate(teacher, student, val_loader, criterion, temperature)
|
194 |
+
print(f"Validation Loss: {val_loss:.4f}")
|
195 |
+
writer.add_scalar("Loss/Validation", val_loss, epoch)
|
196 |
+
|
197 |
+
# Check for improvement
|
198 |
+
if val_loss < best_val_loss:
|
199 |
+
best_val_loss = val_loss
|
200 |
+
epochs_no_improve = 0
|
201 |
+
# Save the best model
|
202 |
+
save_checkpoint({
|
203 |
+
'epoch': epoch,
|
204 |
+
'model_state_dict': student.state_dict(),
|
205 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
206 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
207 |
+
'scaler_state_dict': scaler.state_dict(),
|
208 |
+
'best_val_loss': best_val_loss
|
209 |
+
}, checkpoint_dir, epoch)
|
210 |
+
# Save the model as the best one
|
211 |
+
student.save_pretrained(save_path)
|
212 |
+
tokenizer.save_pretrained(save_path)
|
213 |
+
print(f"Best model saved at epoch {epoch}")
|
214 |
+
else:
|
215 |
+
epochs_no_improve += 1
|
216 |
+
print(f"No improvement in validation loss for {epochs_no_improve} epoch(s)")
|
217 |
+
if epochs_no_improve >= early_stopping_patience:
|
218 |
+
print("Early stopping triggered")
|
219 |
+
break
|
220 |
+
|
221 |
+
# Step the scheduler
|
222 |
+
scheduler.step()
|
223 |
+
|
224 |
+
writer.close()
|
225 |
+
print("\nDistillation completed.")
|
226 |
+
|
227 |
+
def main():
|
228 |
+
parser = argparse.ArgumentParser(description="Distill a large LLM into a smaller one.")
|
229 |
+
parser.add_argument("--teacher_model_name", type=str, required=True, help="Name of the teacher model")
|
230 |
+
parser.add_argument("--student_model_name", type=str, required=True, help="Name of the student model")
|
231 |
+
parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset")
|
232 |
+
parser.add_argument("--config", type=str, default=None, help="Dataset configuration (e.g., 'wikitext-2-raw-v1')")
|
233 |
+
parser.add_argument("--distill_full_model", action="store_true", help="Whether to distill the full model or not")
|
234 |
+
parser.add_argument("--query_terms", type=str, nargs="+", help="Query terms for filtering the dataset")
|
235 |
+
parser.add_argument("--num_epochs", type=int, default=3, help="Number of epochs")
|
236 |
+
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
|
237 |
+
parser.add_argument("--max_length", type=int, default=128, help="Maximum sequence length")
|
238 |
+
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
|
239 |
+
parser.add_argument("--temperature", type=float, default=2.0, help="Distillation temperature")
|
240 |
+
parser.add_argument("--save_path", type=str, default="./distilled_model", help="Path to save the distilled model")
|
241 |
+
parser.add_argument("--log_dir", type=str, default="./logs", help="Directory for TensorBoard logs")
|
242 |
+
parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints", help="Directory to save checkpoints")
|
243 |
+
parser.add_argument("--early_stopping_patience", type=int, default=3, help="Early stopping patience")
|
244 |
+
return parser.parse_args()
|
245 |
+
|
246 |
+
if __name__ == "__main__":
|
247 |
+
args = main()
|
248 |
+
distill_model(
|
249 |
+
teacher_model_name=args.teacher_model_name,
|
250 |
+
student_model_name=args.student_model_name,
|
251 |
+
dataset_name=args.dataset_name,
|
252 |
+
config=args.config,
|
253 |
+
distill_full_model=args.distill_full_model,
|
254 |
+
query_terms=args.query_terms,
|
255 |
+
num_epochs=args.num_epochs,
|
256 |
+
batch_size=args.batch_size,
|
257 |
+
max_length=args.max_length,
|
258 |
+
learning_rate=args.learning_rate,
|
259 |
+
temperature=args.temperature,
|
260 |
+
save_path=args.save_path,
|
261 |
+
log_dir=args.log_dir,
|
262 |
+
checkpoint_dir=args.checkpoint_dir,
|
263 |
+
early_stopping_patience=args.early_stopping_patience
|
264 |
+
)
|
main_menu_new.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# main_menu.py
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
from train_agent import train_agent
|
7 |
+
from test_agent import TestAgent, run_test_session
|
8 |
+
from twisted.internet import reactor, task
|
9 |
+
from lightbulb_custom import main as lightbulb_custom_main
|
10 |
+
from distillation_pipeline import distill_model # Import the distillation function
|
11 |
+
from transformers import logging
|
12 |
+
|
13 |
+
# Suppress transformers warnings for cleaner output
|
14 |
+
logging.set_verbosity_error()
|
15 |
+
|
16 |
+
def parse_main_args():
|
17 |
+
parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks")
|
18 |
+
|
19 |
+
# Task selection
|
20 |
+
parser.add_argument('--task', type=str, choices=[
|
21 |
+
'train_llm_world',
|
22 |
+
'train_agent',
|
23 |
+
'test_agent',
|
24 |
+
'inference_llm',
|
25 |
+
'inference_world_model',
|
26 |
+
'advanced_inference',
|
27 |
+
'distill_full_model', # New option for full model distillation
|
28 |
+
'distill_domain_specific' # New option for selective distillation
|
29 |
+
],
|
30 |
+
required=True,
|
31 |
+
help='Choose task to execute: train_llm_world, train_agent, test_agent, inference_llm, inference_world_model, advanced_inference, distill_full_model, distill_domain_specific')
|
32 |
+
|
33 |
+
# Common arguments
|
34 |
+
parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM')
|
35 |
+
parser.add_argument('--student_model_name', type=str, default='distilgpt2', help='Name of the student model for distillation')
|
36 |
+
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training')
|
37 |
+
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
|
38 |
+
parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training')
|
39 |
+
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training')
|
40 |
+
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training')
|
41 |
+
parser.add_argument('--temperature', type=float, default=2.0, help='Distillation temperature')
|
42 |
+
parser.add_argument('--learning_rate', type=float, default=5e-5, help='Learning rate')
|
43 |
+
|
44 |
+
# Distillation-specific arguments
|
45 |
+
parser.add_argument('--save_path', type=str, default="./distilled_model", help="Path to save the distilled model")
|
46 |
+
parser.add_argument('--log_dir', type=str, default="./logs", help="Directory for TensorBoard logs")
|
47 |
+
parser.add_argument('--checkpoint_dir', type=str, default="./checkpoints", help="Directory to save checkpoints")
|
48 |
+
parser.add_argument('--early_stopping_patience', type=int, default=3, help="Early stopping patience")
|
49 |
+
|
50 |
+
# Inference-specific arguments
|
51 |
+
parser.add_argument('--query', type=str, default='', help='Query for the test_agent or inference tasks')
|
52 |
+
parser.add_argument('--inference_mode', type=str, choices=['without_world_model', 'world_model', 'world_model_tree_of_thought'], help='Inference mode')
|
53 |
+
parser.add_argument('--beam_size', type=int, default=5, help='Beam size for beam search during inference')
|
54 |
+
parser.add_argument('--n_tokens_predict', type=int, default=3, help='Number of tokens to predict at each step during inference')
|
55 |
+
parser.add_argument('--mcts_iterations', type=int, default=10, help='Number of MCTS iterations during inference')
|
56 |
+
parser.add_argument('--mcts_exploration_constant', type=float, default=1.414, help='Exploration constant for MCTS during inference')
|
57 |
+
|
58 |
+
# Distillation-specific arguments
|
59 |
+
parser.add_argument('--distill_full_model', action="store_true", help="Whether to distill the full model or not")
|
60 |
+
parser.add_argument('--query_terms', type=str, nargs="+", help="Query terms for domain-specific distillation")
|
61 |
+
|
62 |
+
# Load model for inference
|
63 |
+
parser.add_argument('--load_model', type=str, help='Path to load the distilled model for inference')
|
64 |
+
|
65 |
+
return parser.parse_args()
|
66 |
+
|
67 |
+
def main():
|
68 |
+
# Parse arguments for the main function
|
69 |
+
args = parse_main_args()
|
70 |
+
|
71 |
+
# Execute tasks based on user input
|
72 |
+
if args.task == 'train_llm_world':
|
73 |
+
print("Starting LLM and World Model Training...")
|
74 |
+
# Directly call the world model main function with appropriate arguments
|
75 |
+
sys.argv = [
|
76 |
+
'lightbulb_custom.py',
|
77 |
+
'--mode', 'train',
|
78 |
+
'--model_name', args.model_name,
|
79 |
+
'--dataset_name', args.dataset_name,
|
80 |
+
'--dataset_config', args.dataset_config,
|
81 |
+
'--batch_size', str(args.batch_size),
|
82 |
+
'--num_epochs', str(args.num_epochs),
|
83 |
+
'--max_length', str(args.max_length)
|
84 |
+
]
|
85 |
+
lightbulb_custom_main()
|
86 |
+
|
87 |
+
elif args.task == 'train_agent':
|
88 |
+
print("Starting Agent Training...")
|
89 |
+
# Call the train_agent function from train_agent.py using Twisted reactor
|
90 |
+
d = task.deferLater(reactor, 0, train_agent)
|
91 |
+
d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True))
|
92 |
+
d.addBoth(lambda _: reactor.stop())
|
93 |
+
reactor.run()
|
94 |
+
|
95 |
+
elif args.task == 'test_agent':
|
96 |
+
print("Starting Test Agent...")
|
97 |
+
test_agent = TestAgent()
|
98 |
+
if args.query:
|
99 |
+
# Directly process a single query
|
100 |
+
result = test_agent.process_query(args.query)
|
101 |
+
print("\nAgent's response:")
|
102 |
+
print(result)
|
103 |
+
else:
|
104 |
+
# Run the interactive session
|
105 |
+
reactor.callWhenRunning(run_test_session)
|
106 |
+
reactor.run()
|
107 |
+
|
108 |
+
elif args.task in ['inference_llm', 'inference_world_model', 'advanced_inference']:
|
109 |
+
print("Starting Inference Task...")
|
110 |
+
# Prepare the arguments for lightbulb_custom.py based on the selected inference task
|
111 |
+
|
112 |
+
# Map the main_menu task to lightbulb_custom.py's inference_mode
|
113 |
+
inference_mode_map = {
|
114 |
+
'inference_llm': 'without_world_model',
|
115 |
+
'inference_world_model': 'world_model',
|
116 |
+
'advanced_inference': 'world_model_tree_of_thought'
|
117 |
+
}
|
118 |
+
|
119 |
+
selected_inference_mode = inference_mode_map.get(args.task, 'world_model_tree_of_thought')
|
120 |
+
|
121 |
+
# Construct sys.argv for lightbulb_custom.py
|
122 |
+
lightbulb_inf_args = [
|
123 |
+
'lightbulb_custom.py',
|
124 |
+
'--mode', 'inference',
|
125 |
+
'--model_name', args.model_name,
|
126 |
+
'--query', args.query,
|
127 |
+
'--max_length', str(args.max_length),
|
128 |
+
'--inference_mode', selected_inference_mode,
|
129 |
+
'--beam_size', str(args.beam_size),
|
130 |
+
'--n_tokens_predict', str(args.n_tokens_predict),
|
131 |
+
'--mcts_iterations', str(args.mcts_iterations),
|
132 |
+
'--mcts_exploration_constant', str(args.mcts_exploration_constant)
|
133 |
+
]
|
134 |
+
|
135 |
+
# Include additional arguments if they exist
|
136 |
+
if args.load_model:
|
137 |
+
lightbulb_inf_args += ['--load_model', args.load_model]
|
138 |
+
|
139 |
+
# Update sys.argv and call the inference main function
|
140 |
+
sys.argv = lightbulb_inf_args
|
141 |
+
lightbulb_custom_main()
|
142 |
+
|
143 |
+
elif args.task == 'distill_full_model':
|
144 |
+
print("Starting Full Model Distillation...")
|
145 |
+
distill_model(
|
146 |
+
teacher_model_name=args.model_name,
|
147 |
+
student_model_name=args.student_model_name,
|
148 |
+
dataset_name=args.dataset_name,
|
149 |
+
config=args.dataset_config,
|
150 |
+
distill_full_model=True,
|
151 |
+
query_terms=None,
|
152 |
+
num_epochs=args.num_epochs,
|
153 |
+
batch_size=args.batch_size,
|
154 |
+
max_length=args.max_length,
|
155 |
+
learning_rate=args.learning_rate,
|
156 |
+
temperature=args.temperature,
|
157 |
+
save_path=args.save_path,
|
158 |
+
log_dir=args.log_dir,
|
159 |
+
checkpoint_dir=args.checkpoint_dir,
|
160 |
+
early_stopping_patience=args.early_stopping_patience
|
161 |
+
)
|
162 |
+
|
163 |
+
elif args.task == 'distill_domain_specific':
|
164 |
+
print("Starting Domain-Specific Distillation...")
|
165 |
+
if not args.query_terms:
|
166 |
+
print("Error: --query_terms must be provided for domain-specific distillation.")
|
167 |
+
sys.exit(1)
|
168 |
+
distill_model(
|
169 |
+
teacher_model_name=args.model_name,
|
170 |
+
student_model_name=args.student_model_name,
|
171 |
+
dataset_name=args.dataset_name,
|
172 |
+
config=args.dataset_config,
|
173 |
+
distill_full_model=False,
|
174 |
+
query_terms=args.query_terms,
|
175 |
+
num_epochs=args.num_epochs,
|
176 |
+
batch_size=args.batch_size,
|
177 |
+
max_length=args.max_length,
|
178 |
+
learning_rate=args.learning_rate,
|
179 |
+
temperature=args.temperature,
|
180 |
+
save_path=args.save_path,
|
181 |
+
log_dir=args.log_dir,
|
182 |
+
checkpoint_dir=args.checkpoint_dir,
|
183 |
+
early_stopping_patience=args.early_stopping_patience
|
184 |
+
)
|
185 |
+
|
186 |
+
else:
|
187 |
+
print(f"Unknown task: {args.task}")
|
188 |
+
sys.exit(1)
|
189 |
+
|
190 |
+
if __name__ == "__main__":
|
191 |
+
main()
|