|
import torch |
|
from torch.utils.data import DataLoader |
|
from torch.nn.utils import clip_grad_norm_ |
|
from tqdm import tqdm |
|
import os |
|
import logging |
|
import csv |
|
import json |
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
import math |
|
|
|
import sys |
|
sys.path.append(os.path.dirname(os.path.dirname(__file__))) |
|
|
|
from src.utils.utils import create_run_directory |
|
from src.dataset.dataset import VideoDataset |
|
from src.models.model import create_model |
|
from src.dataset.video_utils import create_transform |
|
from visualization.visualize import run_visualization |
|
from visualization.miscalculations_report import analyze_misclassifications |
|
|
|
def train_and_evaluate(config): |
|
try: |
|
|
|
if "run_dir" not in config: |
|
config["run_dir"] = create_run_directory() |
|
|
|
|
|
config.update({ |
|
"best_model_path": os.path.join(config["run_dir"], 'best_model.pth'), |
|
"final_model_path": os.path.join(config["run_dir"], 'final_model.pth'), |
|
"csv_path": os.path.join(config["run_dir"], 'training_log.csv'), |
|
"misclassifications_dir": os.path.join(config["run_dir"], 'misclassifications'), |
|
}) |
|
|
|
config_path = os.path.join(config["run_dir"], 'config.json') |
|
with open(config_path, 'w') as f: |
|
json.dump(config, f, indent=2) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[logging.FileHandler(os.path.join(config["run_dir"], 'training.log')), |
|
logging.StreamHandler()]) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
logger.info(f"Using device: {device}") |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
best_val_loss = float('inf') |
|
epochs_without_improvement = 0 |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory/1e9:.2f}GB") |
|
print(f"Currently allocated: {torch.cuda.memory_allocated()/1e9:.2f}GB") |
|
|
|
model = create_model(config["num_classes"], config["clip_model"]) |
|
|
|
model.unfreeze_vision_encoder(num_layers=config["unfreeze_layers"]) |
|
model = model.to(device) |
|
|
|
|
|
criterion = torch.nn.CrossEntropyLoss().to(device) |
|
|
|
|
|
|
|
|
|
train_dataset = VideoDataset( |
|
os.path.join(config['data_path'], 'train.csv'), |
|
config=config |
|
) |
|
|
|
|
|
val_config = config.copy() |
|
val_dataset = VideoDataset( |
|
os.path.join(config['data_path'], 'val.csv'), |
|
config=val_config, |
|
transform=create_transform(config, training=False) |
|
) |
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True) |
|
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False) |
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"]) |
|
scheduler = CosineAnnealingLR(optimizer, T_max=config["num_epochs"]) |
|
|
|
|
|
with open(config["csv_path"], 'w', newline='') as file: |
|
writer = csv.writer(file) |
|
writer.writerow(["epoch", "train_loss", "train_accuracy", "val_loss", "val_accuracy"]) |
|
|
|
|
|
def calculate_accuracy(outputs, labels): |
|
_, predicted = torch.max(outputs, 1) |
|
correct = (predicted == labels).sum().item() |
|
total = labels.size(0) |
|
return correct / total |
|
|
|
def log_misclassifications(outputs, labels, video_paths, dataset, misclassified_videos): |
|
_, predicted = torch.max(outputs, 1) |
|
for pred, label, video_path in zip(predicted, labels, video_paths): |
|
if pred != label: |
|
true_label = dataset.label_map[label.item()] |
|
predicted_label = dataset.label_map[pred.item()] |
|
misclassified_videos.append({ |
|
'video_path': video_path, |
|
'true_label': true_label, |
|
'predicted_label': predicted_label |
|
}) |
|
|
|
|
|
os.makedirs(config["misclassifications_dir"], exist_ok=True) |
|
|
|
|
|
for epoch in range(config["num_epochs"]): |
|
model.train() |
|
total_loss = 0 |
|
total_accuracy = 0 |
|
for frames, labels, video_paths in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{config['num_epochs']}"): |
|
frames = frames.to(device) |
|
labels = labels.to(device) |
|
|
|
logits = model(frames) |
|
|
|
loss = criterion(logits, labels) |
|
accuracy = calculate_accuracy(logits, labels) |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
clip_grad_norm_(model.parameters(), max_norm=config["gradient_clip_max_norm"]) |
|
optimizer.step() |
|
|
|
total_loss += loss.item() |
|
total_accuracy += accuracy |
|
|
|
avg_train_loss = total_loss / len(train_loader) |
|
avg_train_accuracy = total_accuracy / len(train_loader) |
|
|
|
|
|
model.eval() |
|
val_loss = 0 |
|
val_accuracy = 0 |
|
misclassified_videos = [] |
|
with torch.no_grad(): |
|
for frames, labels, video_paths in val_loader: |
|
frames = frames.to(device) |
|
labels = labels.to(device) |
|
|
|
logits = model(frames) |
|
|
|
loss = criterion(logits, labels) |
|
accuracy = calculate_accuracy(logits, labels) |
|
|
|
val_loss += loss.item() |
|
val_accuracy += accuracy |
|
|
|
|
|
log_misclassifications(logits, labels, video_paths, val_dataset, misclassified_videos) |
|
|
|
avg_val_loss = val_loss / len(val_loader) |
|
avg_val_accuracy = val_accuracy / len(val_loader) |
|
|
|
|
|
if misclassified_videos: |
|
misclassified_log_path = os.path.join(config["misclassifications_dir"], f'epoch_{epoch+1}.json') |
|
with open(misclassified_log_path, 'w') as f: |
|
json.dump(misclassified_videos, f, indent=2) |
|
logger.info(f"Logged {len(misclassified_videos)} misclassified videos to {misclassified_log_path}") |
|
|
|
|
|
logger.info(f"Epoch [{epoch+1}/{config['num_epochs']}], " |
|
f"Train Loss: {avg_train_loss:.4f}, Train Accuracy: {avg_train_accuracy*100:.2f}%, " |
|
f"Val Loss: {avg_val_loss:.4f}, Val Accuracy: {avg_val_accuracy*100:.2f}%") |
|
|
|
|
|
with open(config["csv_path"], 'a', newline='') as file: |
|
writer = csv.writer(file) |
|
writer.writerow([epoch+1, avg_train_loss, avg_train_accuracy*100, avg_val_loss, avg_val_accuracy*100]) |
|
|
|
|
|
scheduler.step() |
|
|
|
|
|
if avg_val_loss < best_val_loss: |
|
best_val_loss = avg_val_loss |
|
torch.save(model.state_dict(), config["best_model_path"]) |
|
logger.info(f"Saved best model to {config['best_model_path']}") |
|
epochs_without_improvement = 0 |
|
else: |
|
epochs_without_improvement += 1 |
|
|
|
|
|
if epochs_without_improvement >= config["patience"]: |
|
logger.info(f"Early stopping triggered after {config['patience']} epochs without improvement") |
|
break |
|
|
|
|
|
if avg_train_accuracy - avg_val_accuracy > config["overfitting_threshold"]: |
|
logger.warning("Possible overfitting detected") |
|
|
|
logger.info("Training finished!") |
|
|
|
|
|
torch.save(model.state_dict(), config["final_model_path"]) |
|
logger.info(f"Saved final model to {config['final_model_path']}") |
|
|
|
|
|
with open(os.path.join(config["run_dir"], 'run_info.txt'), 'w') as f: |
|
for key, value in config.items(): |
|
f.write(f"{key}: {value}\n") |
|
f.write(f"Device: {device}\n") |
|
f.write(f"Model: {model.__class__.__name__}\n") |
|
f.write(f"Optimizer: {optimizer.__class__.__name__}\n") |
|
f.write(f"Scheduler: {scheduler.__class__.__name__}\n") |
|
f.write(f"Loss function: CrossEntropyLoss\n") |
|
f.write(f"Data augmentation: RandomHorizontalFlip, RandomRotation(5), ColorJitter\n") |
|
f.write(f"Mixed precision training: {'Enabled' if 'scaler' in locals() else 'Disabled'}\n") |
|
f.write(f"Train dataset size: {len(train_dataset)}\n") |
|
f.write(f"Validation dataset size: {len(val_dataset)}\n") |
|
f.write(f"Vision encoder frozen: {'Partially' if hasattr(model, 'unfreeze_vision_encoder') else 'Unknown'}\n") |
|
|
|
|
|
try: |
|
logger.info("Running visualization...") |
|
vis_dir, confusion_matrix = run_visualization(config["run_dir"]) |
|
logger.info(f"Visualization complete! Check the output directory: {vis_dir}") |
|
|
|
|
|
class_accuracies = confusion_matrix.diagonal() / confusion_matrix.sum(axis=1) |
|
overall_accuracy = confusion_matrix.diagonal().sum() / confusion_matrix.sum() |
|
|
|
logger.info("\nConfusion Matrix Results:") |
|
for i, (label, accuracy) in enumerate(zip(config['class_labels'], class_accuracies)): |
|
logger.info(f"{label}: {accuracy:.2%}") |
|
logger.info(f"Overall Accuracy: {overall_accuracy:.2%}") |
|
|
|
except Exception as e: |
|
logger.error(f"Error running visualization: {str(e)}") |
|
|
|
|
|
try: |
|
analyze_misclassifications(config["run_dir"]) |
|
logger.info(f"Misclassification analysis complete! Check the output directory: {config['run_dir']}") |
|
except Exception as e: |
|
logger.error(f"Error running misclassification analysis: {str(e)}") |
|
|
|
|
|
if math.isnan(avg_val_accuracy) or math.isinf(avg_val_accuracy): |
|
raise ValueError(f"Invalid validation accuracy: {avg_val_accuracy}") |
|
|
|
print("Script finished.") |
|
|
|
return avg_val_accuracy, vis_dir |
|
|
|
except Exception as e: |
|
logger.error(f"Training error: {str(e)}") |
|
raise |
|
|
|
def main(): |
|
|
|
run_dir = create_run_directory() |
|
class_labels = ["windmill", "halo", "swipe", "baby_mill"][:3] |
|
|
|
|
|
config = { |
|
"class_labels": class_labels, |
|
"num_classes": len(class_labels), |
|
"data_path": './data/blog/datasets/bryant/random', |
|
"batch_size": 8, |
|
"learning_rate": 2e-6, |
|
"weight_decay": 0.007, |
|
"num_epochs": 2, |
|
"patience": 10, |
|
"max_frames": 10, |
|
"sigma": 0.3, |
|
"image_size": 224, |
|
"flip_probability": 0.5, |
|
"rotation_degrees": 15, |
|
"brightness_jitter": 0.2, |
|
"contrast_jitter": 0.2, |
|
"saturation_jitter": 0.2, |
|
"hue_jitter": 0.1, |
|
"crop_scale_min": 0.8, |
|
"crop_scale_max": 1.0, |
|
"normalization_mean": [0.485, 0.456, 0.406], |
|
"normalization_std": [0.229, 0.224, 0.225], |
|
"unfreeze_layers": 3, |
|
|
|
"clip_model": "openai/clip-vit-base-patch32", |
|
"gradient_clip_max_norm": 1.0, |
|
"overfitting_threshold": 10, |
|
"run_dir": run_dir, |
|
} |
|
train_and_evaluate(config) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|