import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset, random_split from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset from typing import List, Optional import argparse import os import json import jsonlines from tqdm import tqdm from torch.cuda.amp import autocast, GradScaler from torch.utils.tensorboard import SummaryWriter # Set up device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class CustomDataset(Dataset): def __init__(self, inputs, labels): self.inputs = inputs self.labels = labels def __len__(self): return len(self.inputs) def __getitem__(self, idx): return {'input_ids': self.inputs[idx], 'labels': self.labels[idx]} def load_filtered_dataset(dataset_name: str, config: str, queries: Optional[List[str]] = None): dataset = load_dataset(dataset_name, config) if queries: def filter_func(examples): return any(query.lower() in examples["text"].lower() for query in queries) dataset = dataset.filter(filter_func, batched=True) return dataset def prepare_data(tokenizer, dataset, max_length, batch_size): # Tokenize the inputs and labels tokenized_inputs = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length) tokenized_labels = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length) # Create custom dataset custom_dataset = CustomDataset(tokenized_inputs["input_ids"], tokenized_labels["input_ids"]) # Split into training and validation sets train_size = int(0.9 * len(custom_dataset)) val_size = len(custom_dataset) - train_size train_dataset, val_dataset = random_split(custom_dataset, [train_size, val_size]) # Create DataLoaders train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) return train_loader, val_loader def train_step(teacher, student, data_loader, optimizer, criterion, scaler, temperature=2.0): teacher.eval() student.train() total_loss = 0 for batch in tqdm(data_loader, desc="Training"): inputs = batch["input_ids"].to(device) labels = batch["labels"].to(device) with autocast(): with torch.no_grad(): teacher_outputs = teacher(inputs).logits teacher_logits = teacher_outputs / temperature student_outputs = student(inputs).logits student_logits = student_outputs / temperature # Compute KL Divergence Loss loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1)) loss = loss * (temperature ** 2) # Scale loss by temperature squared scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad() total_loss += loss.item() avg_loss = total_loss / len(data_loader) return avg_loss def validate(teacher, student, data_loader, criterion, temperature=2.0): teacher.eval() student.eval() total_loss = 0 with torch.no_grad(): for batch in tqdm(data_loader, desc="Validation"): inputs = batch["input_ids"].to(device) labels = batch["labels"].to(device) teacher_outputs = teacher(inputs).logits teacher_logits = teacher_outputs / temperature student_outputs = student(inputs).logits student_logits = student_outputs / temperature loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1)) loss = loss * (temperature ** 2) total_loss += loss.item() avg_loss = total_loss / len(data_loader) return avg_loss def save_checkpoint(state, save_dir, epoch): os.makedirs(save_dir, exist_ok=True) checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt') torch.save(state, checkpoint_path) print(f"Checkpoint saved at {checkpoint_path}") def load_checkpoint(model, optimizer, scheduler, scaler, save_dir, epoch): checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt') if os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) scaler.load_state_dict(checkpoint['scaler_state_dict']) print(f"Loaded checkpoint from {checkpoint_path}") else: print(f"No checkpoint found at {checkpoint_path}") def distill_model( teacher_model_name: str, student_model_name: str, dataset_name: str, config: str, distill_full_model: bool = True, query_terms: Optional[List[str]] = None, num_epochs: int = 3, batch_size: int = 4, max_length: int = 128, learning_rate: float = 5e-5, temperature: float = 2.0, save_path: str = "./distilled_model", log_dir: str = "./logs", checkpoint_dir: str = "./checkpoints", early_stopping_patience: int = 3 ): # Initialize TensorBoard writer writer = SummaryWriter(log_dir=log_dir) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(teacher_model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load teacher and student models teacher = AutoModelForCausalLM.from_pretrained(teacher_model_name).to(device) student = AutoModelForCausalLM.from_pretrained(student_model_name).to(device) # Optionally freeze teacher model parameters for param in teacher.parameters(): param.requires_grad = False # Load and prepare dataset if distill_full_model: dataset = load_dataset(dataset_name, config) else: dataset = load_filtered_dataset(dataset_name, config, query_terms) train_loader, val_loader = prepare_data(tokenizer, dataset, max_length, batch_size) # Define optimizer, scheduler, and scaler for mixed precision optimizer = optim.AdamW(student.parameters(), lr=learning_rate) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) scaler = GradScaler() # Define loss criterion criterion = nn.KLDivLoss(reduction="batchmean") best_val_loss = float('inf') epochs_no_improve = 0 # Training loop for epoch in range(1, num_epochs + 1): print(f"\nEpoch {epoch}/{num_epochs}") print("-" * 20) # Training train_loss = train_step(teacher, student, train_loader, optimizer, criterion, scaler, temperature) print(f"Training Loss: {train_loss:.4f}") writer.add_scalar("Loss/Train", train_loss, epoch) # Validation val_loss = validate(teacher, student, val_loader, criterion, temperature) print(f"Validation Loss: {val_loss:.4f}") writer.add_scalar("Loss/Validation", val_loss, epoch) # Check for improvement if val_loss < best_val_loss: best_val_loss = val_loss epochs_no_improve = 0 # Save the best model save_checkpoint({ 'epoch': epoch, 'model_state_dict': student.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'scaler_state_dict': scaler.state_dict(), 'best_val_loss': best_val_loss }, checkpoint_dir, epoch) # Save the model as the best one student.save_pretrained(save_path) tokenizer.save_pretrained(save_path) print(f"Best model saved at epoch {epoch}") else: epochs_no_improve += 1 print(f"No improvement in validation loss for {epochs_no_improve} epoch(s)") if epochs_no_improve >= early_stopping_patience: print("Early stopping triggered") break # Step the scheduler scheduler.step() writer.close() print("\nDistillation completed.") def main(): parser = argparse.ArgumentParser(description="Distill a large LLM into a smaller one.") parser.add_argument("--teacher_model_name", type=str, required=True, help="Name of the teacher model") parser.add_argument("--student_model_name", type=str, required=True, help="Name of the student model") parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset") parser.add_argument("--config", type=str, default=None, help="Dataset configuration (e.g., 'wikitext-2-raw-v1')") parser.add_argument("--distill_full_model", action="store_true", help="Whether to distill the full model or not") parser.add_argument("--query_terms", type=str, nargs="+", help="Query terms for filtering the dataset") parser.add_argument("--num_epochs", type=int, default=3, help="Number of epochs") parser.add_argument("--batch_size", type=int, default=4, help="Batch size") parser.add_argument("--max_length", type=int, default=128, help="Maximum sequence length") parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate") parser.add_argument("--temperature", type=float, default=2.0, help="Distillation temperature") parser.add_argument("--save_path", type=str, default="./distilled_model", help="Path to save the distilled model") parser.add_argument("--log_dir", type=str, default="./logs", help="Directory for TensorBoard logs") parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints", help="Directory to save checkpoints") parser.add_argument("--early_stopping_patience", type=int, default=3, help="Early stopping patience") return parser.parse_args() if __name__ == "__main__": args = main() distill_model( teacher_model_name=args.teacher_model_name, student_model_name=args.student_model_name, dataset_name=args.dataset_name, config=args.config, distill_full_model=args.distill_full_model, query_terms=args.query_terms, num_epochs=args.num_epochs, batch_size=args.batch_size, max_length=args.max_length, learning_rate=args.learning_rate, temperature=args.temperature, save_path=args.save_path, log_dir=args.log_dir, checkpoint_dir=args.checkpoint_dir, early_stopping_patience=args.early_stopping_patience )