import torch |
from torch.utils.data import DataLoader |
from torch.nn.utils import clip_grad_norm_ |
from tqdm import tqdm |
import os |
import logging |
import csv |
import json |
from torch.optim.lr_scheduler import CosineAnnealingLR |
import math |
import sys |
sys.path.append(os.path.dirname(os.path.dirname(__file__))) |
from src.utils.utils import create_run_directory |
from src.dataset.dataset import VideoDataset |
from src.models.model import create_model |
from src.dataset.video_utils import create_transform |
from visualization.visualize import run_visualization |
from visualization.miscalculations_report import analyze_misclassifications |
def train_and_evaluate(config): |
try: |
if "run_dir" not in config: |
config["run_dir"] = create_run_directory() |
config.update({ |
"best_model_path": os.path.join(config["run_dir"], 'best_model.pth'), |
"final_model_path": os.path.join(config["run_dir"], 'final_model.pth'), |
"csv_path": os.path.join(config["run_dir"], 'training_log.csv'), |
"misclassifications_dir": os.path.join(config["run_dir"], 'misclassifications'), |
}) |
config_path = os.path.join(config["run_dir"], 'config.json') |
with open(config_path, 'w') as f: |
json.dump(config, f, indent=2) |
logging.basicConfig(level=logging.INFO, |
format='%(asctime)s - %(levelname)s - %(message)s', |
handlers=[logging.FileHandler(os.path.join(config["run_dir"], 'training.log')), |
logging.StreamHandler()]) |
logger = logging.getLogger(__name__) |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
logger.info(f"Using device: {device}") |
if torch.cuda.is_available(): |
torch.cuda.empty_cache() |
best_val_loss = float('inf') |
epochs_without_improvement = 0 |
if torch.cuda.is_available(): |
torch.cuda.empty_cache() |
print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory/1e9:.2f}GB") |
print(f"Currently allocated: {torch.cuda.memory_allocated()/1e9:.2f}GB") |
model = create_model(config["num_classes"], config["clip_model"]) |
model.unfreeze_vision_encoder(num_layers=config["unfreeze_layers"]) |
model = model.to(device) |
criterion = torch.nn.CrossEntropyLoss().to(device) |
train_dataset = VideoDataset( |
os.path.join(config['data_path'], 'train.csv'), |
config=config |
) |
val_config = config.copy() |
val_dataset = VideoDataset( |
os.path.join(config['data_path'], 'val.csv'), |
config=val_config, |
transform=create_transform(config, training=False) |
) |
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True) |
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False) |
optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"]) |
scheduler = CosineAnnealingLR(optimizer, T_max=config["num_epochs"]) |
with open(config["csv_path"], 'w', newline='') as file: |
writer = csv.writer(file) |
writer.writerow(["epoch", "train_loss", "train_accuracy", "val_loss", "val_accuracy"]) |
def calculate_accuracy(outputs, labels): |
_, predicted = torch.max(outputs, 1) |
correct = (predicted == labels).sum().item() |
total = labels.size(0) |
return correct / total |
def log_misclassifications(outputs, labels, video_paths, dataset, misclassified_videos): |
_, predicted = torch.max(outputs, 1) |
for pred, label, video_path in zip(predicted, labels, video_paths): |
if pred != label: |
true_label = dataset.label_map[label.item()] |
predicted_label = dataset.label_map[pred.item()] |
misclassified_videos.append({ |
'video_path': video_path, |
'true_label': true_label, |
'predicted_label': predicted_label |
}) |
os.makedirs(config["misclassifications_dir"], exist_ok=True) |
for epoch in range(config["num_epochs"]): |
model.train() |
total_loss = 0 |
total_accuracy = 0 |
for frames, labels, video_paths in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{config['num_epochs']}"): |
frames = frames.to(device) |
labels = labels.to(device) |
logits = model(frames) |
loss = criterion(logits, labels) |
accuracy = calculate_accuracy(logits, labels) |
optimizer.zero_grad() |
loss.backward() |
clip_grad_norm_(model.parameters(), max_norm=config["gradient_clip_max_norm"]) |
optimizer.step() |
total_loss += loss.item() |
total_accuracy += accuracy |
avg_train_loss = total_loss / len(train_loader) |
avg_train_accuracy = total_accuracy / len(train_loader) |
model.eval() |
val_loss = 0 |
val_accuracy = 0 |
misclassified_videos = [] |
with torch.no_grad(): |
for frames, labels, video_paths in val_loader: |
frames = frames.to(device) |
labels = labels.to(device) |
logits = model(frames) |
loss = criterion(logits, labels) |
accuracy = calculate_accuracy(logits, labels) |
val_loss += loss.item() |
val_accuracy += accuracy |
log_misclassifications(logits, labels, video_paths, val_dataset, misclassified_videos) |
avg_val_loss = val_loss / len(val_loader) |
avg_val_accuracy = val_accuracy / len(val_loader) |
if misclassified_videos: |
misclassified_log_path = os.path.join(config["misclassifications_dir"], f'epoch_{epoch+1}.json') |
with open(misclassified_log_path, 'w') as f: |
json.dump(misclassified_videos, f, indent=2) |
logger.info(f"Logged {len(misclassified_videos)} misclassified videos to {misclassified_log_path}") |
logger.info(f"Epoch [{epoch+1}/{config['num_epochs']}], " |
f"Train Loss: {avg_train_loss:.4f}, Train Accuracy: {avg_train_accuracy*100:.2f}%, " |
f"Val Loss: {avg_val_loss:.4f}, Val Accuracy: {avg_val_accuracy*100:.2f}%") |
with open(config["csv_path"], 'a', newline='') as file: |
writer = csv.writer(file) |
writer.writerow([epoch+1, avg_train_loss, avg_train_accuracy*100, avg_val_loss, avg_val_accuracy*100]) |
scheduler.step() |
if avg_val_loss < best_val_loss: |
best_val_loss = avg_val_loss |
torch.save(model.state_dict(), config["best_model_path"]) |
logger.info(f"Saved best model to {config['best_model_path']}") |
epochs_without_improvement = 0 |
else: |
epochs_without_improvement += 1 |
if epochs_without_improvement >= config["patience"]: |
logger.info(f"Early stopping triggered after {config['patience']} epochs without improvement") |
break |
if avg_train_accuracy - avg_val_accuracy > config["overfitting_threshold"]: |
logger.warning("Possible overfitting detected") |
logger.info("Training finished!") |
torch.save(model.state_dict(), config["final_model_path"]) |
logger.info(f"Saved final model to {config['final_model_path']}") |
with open(os.path.join(config["run_dir"], 'run_info.txt'), 'w') as f: |
for key, value in config.items(): |
f.write(f"{key}: {value}\n") |
f.write(f"Device: {device}\n") |
f.write(f"Model: {model.__class__.__name__}\n") |
f.write(f"Optimizer: {optimizer.__class__.__name__}\n") |
f.write(f"Scheduler: {scheduler.__class__.__name__}\n") |
f.write(f"Loss function: CrossEntropyLoss\n") |
f.write(f"Data augmentation: RandomHorizontalFlip, RandomRotation(5), ColorJitter\n") |
f.write(f"Mixed precision training: {'Enabled' if 'scaler' in locals() else 'Disabled'}\n") |
f.write(f"Train dataset size: {len(train_dataset)}\n") |
f.write(f"Validation dataset size: {len(val_dataset)}\n") |
f.write(f"Vision encoder frozen: {'Partially' if hasattr(model, 'unfreeze_vision_encoder') else 'Unknown'}\n") |
try: |
logger.info("Running visualization...") |
vis_dir, confusion_matrix = run_visualization(config["run_dir"]) |
logger.info(f"Visualization complete! Check the output directory: {vis_dir}") |
class_accuracies = confusion_matrix.diagonal() / confusion_matrix.sum(axis=1) |
overall_accuracy = confusion_matrix.diagonal().sum() / confusion_matrix.sum() |
logger.info("\nConfusion Matrix Results:") |
for i, (label, accuracy) in enumerate(zip(config['class_labels'], class_accuracies)): |
logger.info(f"{label}: {accuracy:.2%}") |
logger.info(f"Overall Accuracy: {overall_accuracy:.2%}") |
except Exception as e: |
logger.error(f"Error running visualization: {str(e)}") |
try: |
analyze_misclassifications(config["run_dir"]) |
logger.info(f"Misclassification analysis complete! Check the output directory: {config['run_dir']}") |
except Exception as e: |
logger.error(f"Error running misclassification analysis: {str(e)}") |
if math.isnan(avg_val_accuracy) or math.isinf(avg_val_accuracy): |
raise ValueError(f"Invalid validation accuracy: {avg_val_accuracy}") |
print("Script finished.") |
return avg_val_accuracy, vis_dir |
except Exception as e: |
logger.error(f"Training error: {str(e)}") |
raise |
def main(): |
run_dir = create_run_directory() |
class_labels = ["windmill", "halo", "swipe", "baby_mill"][:3] |
config = { |
"class_labels": class_labels, |
"num_classes": len(class_labels), |
"data_path": './data/blog/datasets/bryant/random', |
"batch_size": 8, |
"learning_rate": 2e-6, |
"weight_decay": 0.007, |
"num_epochs": 2, |
"patience": 10, |
"max_frames": 10, |
"sigma": 0.3, |
"image_size": 224, |
"flip_probability": 0.5, |
"rotation_degrees": 15, |
"brightness_jitter": 0.2, |
"contrast_jitter": 0.2, |
"saturation_jitter": 0.2, |
"hue_jitter": 0.1, |
"crop_scale_min": 0.8, |
"crop_scale_max": 1.0, |
"normalization_mean": [0.485, 0.456, 0.406], |
"normalization_std": [0.229, 0.224, 0.225], |
"unfreeze_layers": 3, |
"clip_model": "openai/clip-vit-base-patch32", |
"gradient_clip_max_norm": 1.0, |
"overfitting_threshold": 10, |
"run_dir": run_dir, |
} |
train_and_evaluate(config) |
if __name__ == "__main__": |
main() |