bawolf commited on
Commit
31fc7e1
·
0 Parent(s):

init working

Browse files
.gitignore ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ env/
8
+ build/
9
+ develop-eggs/
10
+ dist/
11
+ downloads/
12
+ eggs/
13
+ .eggs/
14
+ lib/
15
+ lib64/
16
+ parts/
17
+ sdist/
18
+ var/
19
+ wheels/
20
+ *.egg-info/
21
+ .installed.cfg
22
+ *.egg
23
+
24
+ # Virtual Environment
25
+ venv/
26
+ ENV/
27
+ .env
28
+
29
+ # IDE
30
+ .idea/
31
+ .vscode/
32
+ *.swp
33
+ *.swo
34
+ .DS_Store
35
+
36
+ # Project specific
37
+ runs/
38
+ checkpoints/
39
+ *.pth
40
+ *.ckpt
41
+ *.pt
42
+ wandb/
43
+ logs/
44
+ .cog/
45
+
46
+ # Data
47
+ data/
48
+ *.mp4
49
+ *.avi
50
+ *.mov
51
+ *.jpg
52
+ *.jpeg
53
+ *.png
54
+ *.gif
55
+ *.h5
56
+ *.npy
57
+ *.npz
58
+
59
+ # Jupyter Notebook
60
+ .ipynb_checkpoints
61
+ *.ipynb
62
+
63
+ # Logs
64
+ *.log
65
+ *.csv
66
+ *.json
67
+
68
+ # Keep specific config files
69
+ !config.json
70
+ !requirements.txt
README.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLIP-Based Break Dance Move Classifier
2
+
3
+ A deep learning model for classifying break dance moves using CLIP (Contrastive Language-Image Pre-Training) embeddings. The model is fine-tuned on break dance videos to classify different power moves including windmills, halos, swipes, and baby mills.
4
+
5
+ ## Features
6
+
7
+ - Video-based classification using CLIP embeddings
8
+ - Multi-frame temporal analysis
9
+ - Configurable frame sampling and data augmentation
10
+ - Real-time inference using Cog
11
+ - Misclassification analysis tools
12
+ - Hyperparameter tuning support
13
+
14
+ ## Setup
15
+
16
+ ```bash
17
+ # Install dependencies
18
+ pip install -r requirements.txt
19
+
20
+ # Install Cog (if not already installed)
21
+ curl -o /usr/local/bin/cog -L https://github.com/replicate/cog/releases/latest/download/cog_`uname -s`_`uname -m`
22
+ chmod +x /usr/local/bin/cog
23
+ ```
24
+
25
+ ## Training
26
+
27
+ ```bash
28
+ # Run training with default configuration
29
+ python scripts/train.py
30
+
31
+ # Run hyperparameter tuning
32
+ python scripts/hyperparameter_tuning.py
33
+ ```
34
+
35
+ ## Inference
36
+
37
+ ```bash
38
+ # Using Cog for inference
39
+ cog predict -i video=@path/to/your/video.mp4
40
+
41
+ # Using standard Python script
42
+ python scripts/inference.py --video path/to/your/video.mp4
43
+ ```
44
+
45
+ ## Analysis
46
+
47
+ ```bash
48
+ # Generate misclassification report
49
+ python scripts/visualization/miscalculations_report.py
50
+
51
+ # Visualize model performance
52
+ python scripts/visualization/visualize.py
53
+ ```
54
+
55
+ ## Project Structure
56
+
57
+ ```
58
+ clip/
59
+ ├── src/ # Source code
60
+ │ ├── data/ # Dataset and data processing
61
+ │ ├── models/ # Model architecture
62
+ │ └── utils/ # Utility functions
63
+ ├── scripts/ # Training and inference scripts
64
+ │ └── visualization/ # Visualization tools
65
+ ├── config/ # Configuration files
66
+ ├── runs/ # Training runs and checkpoints
67
+ ├── cog.yaml # Cog configuration
68
+ └── requirements.txt # Python dependencies
69
+ ```
70
+
71
+ ## Model Architecture
72
+
73
+ - Base: CLIP ViT-Large/14
74
+ - Custom temporal pooling layer
75
+ - Fine-tuned vision encoder (last 3 layers)
76
+ - Output: 4-class classifier
77
+
78
+ ## Performance
79
+
80
+ - Training Accuracy: ~95%
81
+ - Validation Accuracy: ~92%
82
+ - Inference Time: ~100ms per video
83
+
84
+ ## Configuration
85
+
86
+ Key hyperparameters can be modified in `config/default.yaml`:
87
+ - Frame sampling: 10 frames per video
88
+ - Image size: 224x224
89
+ - Learning rate: 2e-6
90
+ - Weight decay: 0.007
91
+ - Data augmentation parameters
92
+
93
+ ## License
94
+
95
+ [Your License Here]
96
+
97
+ ## Citation
98
+
99
+ If you use this model in your research, please cite:
100
+
101
+ ```bibtex
102
+ [Your Citation Here]
103
+ ```
cog.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ gpu: true
3
+ cuda: "12.1"
4
+ python_version: "3.10"
5
+ system_packages:
6
+ - "libgl1-mesa-glx"
7
+ - "libglib2.0-0"
8
+ python_packages:
9
+ - "torch==2.3.0"
10
+ - "torchvision"
11
+ - "transformers"
12
+ - "opencv-python"
13
+ - "pillow"
14
+ - "numpy"
15
+ - "scipy"
16
+ - "huggingface_hub"
17
+
18
+ predict: "predict.py:Predictor"
predict.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from cog import BasePredictor, Input, Path
3
+ import torch
4
+ import json
5
+ from src.models.model import load_model
6
+ from src.data.video_utils import create_transform, extract_frames
7
+
8
+ CHECKPOINT_DIR = "runs/run_20241024-150232_otherpeopleval_large_model/"
9
+
10
+ class Predictor(BasePredictor):
11
+ def setup(self):
12
+ """Load the model into memory to make running multiple predictions efficient"""
13
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ print(f"Using device: {self.device}")
15
+
16
+ # Load configuration from JSON
17
+ with open(
18
+ os.path.join(CHECKPOINT_DIR, "config.json"), 'r') as f:
19
+ self.config = json.load(f)
20
+
21
+ # Create transform
22
+ self.transform = create_transform(self.config, training=False)
23
+
24
+ # Load model
25
+ self.model = load_model(
26
+ self.config['num_classes'],
27
+ os.path.join(CHECKPOINT_DIR, "best_model.pth"),
28
+ self.device,
29
+ self.config['clip_model']
30
+ )
31
+ self.model.eval()
32
+
33
+ def predict(self, video: Path = Input(description="Input video file")) -> dict:
34
+ """Run a single prediction on the model"""
35
+ try:
36
+ # Extract frames using shared function with config
37
+ frames, success = extract_frames(
38
+ str(video),
39
+ self.config,
40
+ self.transform
41
+ )
42
+
43
+ if not success or frames is None:
44
+ raise ValueError(f"Failed to process video: {video}")
45
+
46
+ # Now frames is a tensor, not a tuple
47
+ frames = frames.unsqueeze(0).to(self.device)
48
+
49
+ # Get prediction
50
+ with torch.no_grad():
51
+ output = self.model(frames)
52
+ probabilities = torch.softmax(output, dim=1)
53
+ predicted_class = torch.argmax(probabilities, dim=1).item()
54
+ confidence = probabilities[0][predicted_class].item()
55
+
56
+ # Get all class confidences
57
+ all_confidences = {
58
+ label: probabilities[0][i].item()
59
+ for i, label in enumerate(self.config['class_labels'])
60
+ }
61
+
62
+ return {
63
+ "class": self.config['class_labels'][predicted_class],
64
+ "confidence": confidence,
65
+ "all_confidences": all_confidences
66
+ }
67
+
68
+ except Exception as e:
69
+ raise ValueError(f"Error processing video: {str(e)}")
requirements.txt ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ alembic==1.13.3
2
+ annotated-types==0.7.0
3
+ anyio==4.6.2.post1
4
+ attrs==23.2.0
5
+ certifi==2024.8.30
6
+ charset-normalizer==3.4.0
7
+ click==8.1.7
8
+ cog==0.12.0
9
+ colorlog==6.9.0
10
+ contourpy==1.3.0
11
+ cycler==0.12.1
12
+ fastapi==0.110.3
13
+ filelock==3.16.1
14
+ fonttools==4.54.1
15
+ fsspec==2024.10.0
16
+ greenlet==3.1.1
17
+ h11==0.14.0
18
+ httptools==0.6.4
19
+ huggingface-hub==0.26.2
20
+ idna==3.10
21
+ Jinja2==3.1.4
22
+ joblib==1.4.2
23
+ kiwisolver==1.4.7
24
+ Mako==1.3.6
25
+ MarkupSafe==3.0.2
26
+ matplotlib==3.9.2
27
+ mpmath==1.3.0
28
+ networkx==3.4.2
29
+ nms_1d_cpu==0.0.0
30
+ numpy==2.1.2
31
+ nvidia-cublas-cu12==12.4.5.8
32
+ nvidia-cuda-cupti-cu12==12.4.127
33
+ nvidia-cuda-nvrtc-cu12==12.4.127
34
+ nvidia-cuda-runtime-cu12==12.4.127
35
+ nvidia-cudnn-cu12==9.1.0.70
36
+ nvidia-cufft-cu12==11.2.1.3
37
+ nvidia-curand-cu12==10.3.5.147
38
+ nvidia-cusolver-cu12==11.6.1.9
39
+ nvidia-cusparse-cu12==12.3.1.170
40
+ nvidia-nccl-cu12==2.21.5
41
+ nvidia-nvjitlink-cu12==12.4.127
42
+ nvidia-nvtx-cu12==12.4.127
43
+ opencv-python==4.10.0.84
44
+ optuna==4.0.0
45
+ packaging==24.1
46
+ pandas==2.2.3
47
+ pillow==11.0.0
48
+ pydantic==2.9.2
49
+ pydantic_core==2.23.4
50
+ pyparsing==3.2.0
51
+ python-dateutil==2.9.0.post0
52
+ python-dotenv==1.0.1
53
+ pytz==2024.2
54
+ PyYAML==6.0.2
55
+ regex==2024.9.11
56
+ requests==2.32.3
57
+ safetensors==0.4.5
58
+ scikit-learn==1.5.2
59
+ scipy==1.14.1
60
+ seaborn==0.13.2
61
+ six==1.16.0
62
+ sniffio==1.3.1
63
+ SQLAlchemy==2.0.36
64
+ starlette==0.37.2
65
+ structlog==24.4.0
66
+ sympy==1.13.1
67
+ threadpoolctl==3.5.0
68
+ tokenizers==0.20.1
69
+ torch==2.5.1
70
+ torchvision==0.20.1
71
+ tqdm==4.66.6
72
+ transformers==4.46.1
73
+ triton==3.1.0
74
+ typing_extensions==4.12.2
75
+ tzdata==2024.2
76
+ urllib3==2.2.3
77
+ uvicorn==0.32.0
78
+ uvloop==0.21.0
79
+ watchfiles==0.24.0
80
+ websockets==13.1
scripts/hyperparameter_tuning.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import optuna
2
+ import os
3
+
4
+ import os
5
+ import sys
6
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
7
+ from scripts.train import train_and_evaluate
8
+ from src.utils.utils import create_run_directory
9
+
10
+ def objective(trial, hyperparam_run_dir):
11
+ config = {
12
+ "clip_model": trial.suggest_categorical("clip_model", ["openai/clip-vit-base-patch32", "openai/clip-vit-large-patch14"]),
13
+ "learning_rate": trial.suggest_loguniform("learning_rate", 1e-6, 1e-4),
14
+ "weight_decay": trial.suggest_loguniform("weight_decay", 1e-8, 1e-1),
15
+ "unfreeze_layers": trial.suggest_int("unfreeze_layers", 1, 6),
16
+ "batch_size": trial.suggest_categorical("batch_size", [32, 64, 128]),
17
+ "gradient_clip_max_norm": trial.suggest_uniform("gradient_clip_max_norm", 0.1, 1.0),
18
+ "augmentation_strength": trial.suggest_float("augmentation_strength", 0.0, 1.0),
19
+ "crop_scale_min": trial.suggest_float("crop_scale_min", 0.6, 0.9),
20
+ "max_frames": trial.suggest_int("max_frames", 5, 15),
21
+ "sigma": trial.suggest_uniform("sigma", 0.1, 0.5),
22
+ }
23
+
24
+ class_labels = ["windmill", "halo", "swipe", "baby_mill"][:3]
25
+
26
+ # Fixed configurations
27
+ config.update({
28
+ "class_labels": class_labels,
29
+ "num_classes": len(class_labels),
30
+ "data_path": '../finetune/3moves_test',
31
+ "num_epochs": 50, # Reduced for faster trials
32
+ "patience": 10, # Adjusted for faster trials
33
+ "image_size": 224,
34
+ "crop_scale_max": 1.0,
35
+ "normalization_mean": [0.485, 0.456, 0.406],
36
+ "normalization_std": [0.229, 0.224, 0.225],
37
+ "overfitting_threshold": 10,
38
+ })
39
+
40
+ # Derive augmentation parameters from augmentation_strength
41
+ config.update({
42
+ "flip_probability": 0.5 * config["augmentation_strength"],
43
+ "rotation_degrees": int(15 * config["augmentation_strength"]),
44
+ "brightness_jitter": 0.2 * config["augmentation_strength"],
45
+ "contrast_jitter": 0.2 * config["augmentation_strength"],
46
+ "saturation_jitter": 0.2 * config["augmentation_strength"],
47
+ "hue_jitter": 0.1 * config["augmentation_strength"],
48
+ })
49
+
50
+ # Create a unique run directory for this trial
51
+ config["run_dir"] = create_run_directory(prefix=f"trial", parent_dir=hyperparam_run_dir)
52
+
53
+ # Run training and evaluation
54
+ val_accuracy = train_and_evaluate(config)
55
+ return val_accuracy
56
+
57
+ def main():
58
+ # Set up the study and optimize
59
+ hyperparam_run_dir = create_run_directory(suffix='_hyperparam')
60
+ study = optuna.create_study(direction="maximize")
61
+ study.optimize(lambda trial: objective(trial, hyperparam_run_dir), n_trials=100) # Adjust the number of trials as needed
62
+
63
+ # Save the study results
64
+ study.trials_dataframe().to_csv(os.path.join(hyperparam_run_dir, 'study_results.csv'))
65
+
66
+ print("Best trial:")
67
+ trial = study.best_trial
68
+ print(" Value: ", trial.value)
69
+ print(" Params: ")
70
+ for key, value in trial.params.items():
71
+ print(" {}: {}".format(key, value))
72
+
73
+ # Save the best trial parameters
74
+ with open(os.path.join(hyperparam_run_dir, 'best_params.txt'), 'w') as f:
75
+ for key, value in trial.params.items():
76
+ f.write(f"{key}: {value}\n")
77
+
78
+ if __name__ == "__main__":
79
+ main()
scripts/inference.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import os
4
+ import sys
5
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
6
+
7
+ from src.utils.utils import get_latest_run_dir, get_latest_model_path, get_config
8
+ from src.models.model import load_model
9
+ from src.data.video_utils import create_transform, extract_frames
10
+
11
+ def setup_model(run_dir=None):
12
+ """Setup model and configuration"""
13
+ # Define the device
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ # Get run directory
17
+ if run_dir is None:
18
+ run_dir = get_latest_run_dir()
19
+ print(f"Using run directory: {run_dir}")
20
+
21
+ try:
22
+ # Load configuration
23
+ config = get_config(run_dir)
24
+ print(f"Loaded configuration from: {run_dir}")
25
+
26
+ # Load the model
27
+ model_path = get_latest_model_path(run_dir)
28
+ print(f"Loading model from: {model_path}")
29
+
30
+ model = load_model(
31
+ config['num_classes'],
32
+ model_path,
33
+ device,
34
+ config['clip_model']
35
+ )
36
+ model.eval()
37
+
38
+ return model, config, device
39
+
40
+ except (ValueError, FileNotFoundError) as e:
41
+ print(f"Error loading model: {str(e)}")
42
+ exit(1)
43
+
44
+ def predict(video_path, model, config, device):
45
+ """Predict class for a video using the model"""
46
+ transform = create_transform(config, training=False)
47
+
48
+ try:
49
+ frames, success = extract_frames(video_path,
50
+ config,
51
+ transform)
52
+ if not success:
53
+ raise ValueError(f"Failed to process video: {video_path}")
54
+
55
+ frames = frames.to(device)
56
+
57
+ # Add batch dimension correctly
58
+ frames = frames.unsqueeze(0) # Add batch dimension at the start
59
+
60
+ with torch.no_grad():
61
+ try:
62
+ outputs = model(frames)
63
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
64
+ except Exception as e:
65
+ print(f"Error during model forward pass: {str(e)}")
66
+ print(f"Model input shape: {frames.shape}")
67
+ raise
68
+
69
+
70
+ # Get predictions
71
+ avg_probabilities = probabilities[0].cpu().numpy()
72
+ predicted_class = np.argmax(avg_probabilities)
73
+
74
+ # Create a dictionary of class labels and their probabilities
75
+ class_probabilities = {
76
+ label: float(prob)
77
+ for label, prob in zip(config['class_labels'], avg_probabilities)
78
+ }
79
+
80
+ return config['class_labels'][predicted_class], class_probabilities
81
+
82
+ except Exception as e:
83
+ raise ValueError(f"Error processing video: {str(e)}")
84
+
85
+ if __name__ == "__main__":
86
+ import argparse
87
+
88
+ parser = argparse.ArgumentParser(description='Run inference on a video file')
89
+ parser.add_argument('--video', type=str, required=True,
90
+ help='Path to the video file')
91
+ parser.add_argument('--run-dir', type=str,
92
+ help='Path to specific run directory (optional)')
93
+
94
+ args = parser.parse_args()
95
+
96
+ # Setup model and config
97
+ model, config, device = setup_model(args.run_dir)
98
+
99
+ try:
100
+ predicted_label, class_probabilities = predict(args.video, model, config, device)
101
+ print(f"\nPredicted label: {predicted_label}")
102
+ print("\nClass probabilities:")
103
+ for label, prob in class_probabilities.items():
104
+ print(f" {label}: {prob:.4f}")
105
+ except ValueError as e:
106
+ print(f"Error: {str(e)}")
scripts/train.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from torch.nn.utils import clip_grad_norm_
4
+ from tqdm import tqdm
5
+ import os
6
+ import logging
7
+ import csv
8
+ import json
9
+ from torch.optim.lr_scheduler import CosineAnnealingLR
10
+
11
+ import sys
12
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
13
+
14
+ from src.utils.utils import create_run_directory
15
+ from src.data.dataset import VideoDataset
16
+ from src.models.model import create_model
17
+ from src.data.video_utils import create_transform
18
+
19
+ def train_and_evaluate(config):
20
+ # Create a run directory if it doesn't exist
21
+ if "run_dir" not in config:
22
+ config["run_dir"] = create_run_directory()
23
+
24
+ # Update paths based on run_dir
25
+ config.update({
26
+ "best_model_path": os.path.join(config["run_dir"], 'best_model.pth'),
27
+ "final_model_path": os.path.join(config["run_dir"], 'final_model.pth'),
28
+ "csv_path": os.path.join(config["run_dir"], 'training_log.csv'),
29
+ "misclassifications_dir": os.path.join(config["run_dir"], 'misclassifications'),
30
+ })
31
+
32
+ config_path = os.path.join(config["run_dir"], 'config.json')
33
+ with open(config_path, 'w') as f:
34
+ json.dump(config, f, indent=2)
35
+
36
+ # Set up logging
37
+ logging.basicConfig(level=logging.INFO,
38
+ format='%(asctime)s - %(levelname)s - %(message)s',
39
+ handlers=[logging.FileHandler(os.path.join(config["run_dir"], 'training.log')),
40
+ logging.StreamHandler()])
41
+ logger = logging.getLogger(__name__)
42
+
43
+ # Set device
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ logger.info(f"Using device: {device}")
46
+
47
+ # Initialize variables
48
+ best_val_loss = float('inf')
49
+ epochs_without_improvement = 0
50
+
51
+ model = create_model(config["num_classes"], config["clip_model"])
52
+
53
+ # Unfreeze the last 2 layers of the vision encoder
54
+ model.unfreeze_vision_encoder(num_layers=config["unfreeze_layers"])
55
+
56
+ # Move model to device
57
+ model = model.to(device)
58
+ logger.info(f"Model architecture:\n{model}")
59
+
60
+ # Load datasets
61
+ train_dataset = VideoDataset(
62
+ os.path.join(config['data_path'], 'train.csv'),
63
+ config=config
64
+ )
65
+
66
+ # For validation, create a new config with training=False for transforms
67
+ val_config = config.copy()
68
+ val_dataset = VideoDataset(
69
+ os.path.join(config['data_path'], 'val.csv'),
70
+ config=val_config,
71
+ transform=create_transform(config, training=False)
72
+ )
73
+
74
+ # Create data loaders
75
+ train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
76
+ val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)
77
+
78
+ # Define optimizer and learning rate scheduler
79
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
80
+ scheduler = CosineAnnealingLR(optimizer, T_max=config["num_epochs"])
81
+
82
+ # Open a CSV file to log training progress
83
+ with open(config["csv_path"], 'w', newline='') as file:
84
+ writer = csv.writer(file)
85
+ writer.writerow(["epoch", "train_loss", "train_accuracy", "val_loss", "val_accuracy"])
86
+
87
+ # Function to calculate accuracy
88
+ def calculate_accuracy(outputs, labels):
89
+ _, predicted = torch.max(outputs, 1)
90
+ correct = (predicted == labels).sum().item()
91
+ total = labels.size(0)
92
+ return correct / total
93
+
94
+ def log_misclassifications(outputs, labels, video_paths, dataset, misclassified_videos):
95
+ _, predicted = torch.max(outputs, 1)
96
+ for pred, label, video_path in zip(predicted, labels, video_paths):
97
+ if pred != label:
98
+ true_label = dataset.label_map[label.item()]
99
+ predicted_label = dataset.label_map[pred.item()]
100
+ misclassified_videos.append({
101
+ 'video_path': video_path,
102
+ 'true_label': true_label,
103
+ 'predicted_label': predicted_label
104
+ })
105
+
106
+ # Create a subfolder for misclassification logs
107
+ os.makedirs(config["misclassifications_dir"], exist_ok=True)
108
+
109
+ # Training loop
110
+ for epoch in range(config["num_epochs"]):
111
+ model.train()
112
+ total_loss = 0
113
+ total_accuracy = 0
114
+ for frames, labels, video_paths in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{config['num_epochs']}"):
115
+ frames = frames.to(device)
116
+ labels = labels.to(device)
117
+
118
+ logits = model(frames)
119
+
120
+ loss = torch.nn.functional.cross_entropy(logits, labels)
121
+ accuracy = calculate_accuracy(logits, labels)
122
+
123
+ optimizer.zero_grad()
124
+ loss.backward()
125
+ clip_grad_norm_(model.parameters(), max_norm=config["gradient_clip_max_norm"])
126
+ optimizer.step()
127
+
128
+ total_loss += loss.item()
129
+ total_accuracy += accuracy
130
+
131
+ avg_train_loss = total_loss / len(train_loader)
132
+ avg_train_accuracy = total_accuracy / len(train_loader)
133
+
134
+ # Validation
135
+ model.eval()
136
+ val_loss = 0
137
+ val_accuracy = 0
138
+ misclassified_videos = []
139
+ with torch.no_grad():
140
+ for frames, labels, video_paths in val_loader:
141
+ frames = frames.to(device)
142
+ labels = labels.to(device)
143
+
144
+ logits = model(frames)
145
+
146
+ loss = torch.nn.functional.cross_entropy(logits, labels)
147
+ accuracy = calculate_accuracy(logits, labels)
148
+
149
+ val_loss += loss.item()
150
+ val_accuracy += accuracy
151
+
152
+ # Log misclassifications
153
+ log_misclassifications(logits, labels, video_paths, val_dataset, misclassified_videos)
154
+
155
+ avg_val_loss = val_loss / len(val_loader)
156
+ avg_val_accuracy = val_accuracy / len(val_loader)
157
+
158
+ # Log misclassified videos
159
+ if misclassified_videos:
160
+ misclassified_log_path = os.path.join(config["misclassifications_dir"], f'epoch_{epoch+1}.json')
161
+ with open(misclassified_log_path, 'w') as f:
162
+ json.dump(misclassified_videos, f, indent=2)
163
+ logger.info(f"Logged {len(misclassified_videos)} misclassified videos to {misclassified_log_path}")
164
+
165
+ # Log the metrics
166
+ logger.info(f"Epoch [{epoch+1}/{config['num_epochs']}], "
167
+ f"Train Loss: {avg_train_loss:.4f}, Train Accuracy: {avg_train_accuracy*100:.2f}%, "
168
+ f"Val Loss: {avg_val_loss:.4f}, Val Accuracy: {avg_val_accuracy*100:.2f}%")
169
+
170
+ # Write to CSV
171
+ with open(config["csv_path"], 'a', newline='') as file:
172
+ writer = csv.writer(file)
173
+ writer.writerow([epoch+1, avg_train_loss, avg_train_accuracy*100, avg_val_loss, avg_val_accuracy*100])
174
+
175
+ # Learning rate scheduling
176
+ scheduler.step()
177
+
178
+ # Save the best model and check for early stopping
179
+ if avg_val_loss < best_val_loss:
180
+ best_val_loss = avg_val_loss
181
+ torch.save(model.state_dict(), config["best_model_path"])
182
+ logger.info(f"Saved best model to {config['best_model_path']}")
183
+ epochs_without_improvement = 0
184
+ else:
185
+ epochs_without_improvement += 1
186
+
187
+ # Early stopping check
188
+ if epochs_without_improvement >= config["patience"]:
189
+ logger.info(f"Early stopping triggered after {config['patience']} epochs without improvement")
190
+ break
191
+
192
+ # Overfitting detection
193
+ if avg_train_accuracy - avg_val_accuracy > config["overfitting_threshold"]:
194
+ logger.warning("Possible overfitting detected")
195
+
196
+ logger.info("Training finished!")
197
+
198
+ # Save the final model
199
+ torch.save(model.state_dict(), config["final_model_path"])
200
+ logger.info(f"Saved final model to {config['final_model_path']}")
201
+
202
+ # Save run information
203
+ with open(os.path.join(config["run_dir"], 'run_info.txt'), 'w') as f:
204
+ for key, value in config.items():
205
+ f.write(f"{key}: {value}\n")
206
+ f.write(f"Device: {device}\n")
207
+ f.write(f"Model: {model.__class__.__name__}\n")
208
+ f.write(f"Optimizer: {optimizer.__class__.__name__}\n")
209
+ f.write(f"Scheduler: {scheduler.__class__.__name__}\n")
210
+ f.write(f"Loss function: CrossEntropyLoss\n")
211
+ f.write(f"Data augmentation: RandomHorizontalFlip, RandomRotation(5), ColorJitter\n")
212
+ f.write(f"Mixed precision training: {'Enabled' if 'scaler' in locals() else 'Disabled'}\n")
213
+ f.write(f"Train dataset size: {len(train_dataset)}\n")
214
+ f.write(f"Validation dataset size: {len(val_dataset)}\n")
215
+ f.write(f"Vision encoder frozen: {'Partially' if hasattr(model, 'unfreeze_vision_encoder') else 'Unknown'}\n")
216
+
217
+
218
+ print("Script finished.")
219
+
220
+ return avg_val_accuracy
221
+
222
+ def main():
223
+ # Create run directory
224
+ run_dir = create_run_directory()
225
+ class_labels = ["windmill", "halo", "swipe", "baby_mill"][:3]
226
+
227
+ # Write configuration
228
+ config = {
229
+ "class_labels": class_labels,
230
+ "num_classes": len(class_labels),
231
+ "data_path": '../finetune/3moves_otherpeopletrain',
232
+ "batch_size": 32,
233
+ "learning_rate": 2e-6,
234
+ "weight_decay": 0.007,
235
+ "num_epochs": 1,
236
+ "patience": 10, # for early stopping
237
+ "max_frames": 10,
238
+ "sigma": 0.3,
239
+ "image_size": 224,
240
+ "flip_probability": 0.5,
241
+ "rotation_degrees": 15,
242
+ "brightness_jitter": 0.2,
243
+ "contrast_jitter": 0.2,
244
+ "saturation_jitter": 0.2,
245
+ "hue_jitter": 0.1,
246
+ "crop_scale_min": 0.8,
247
+ "crop_scale_max": 1.0,
248
+ "normalization_mean": [0.485, 0.456, 0.406],
249
+ "normalization_std": [0.229, 0.224, 0.225],
250
+ "unfreeze_layers": 3,
251
+ "clip_model": "openai/clip-vit-large-patch14",
252
+ # "clip_model": "openai/clip-vit-base-patch32",
253
+ "gradient_clip_max_norm": 1.0,
254
+ "overfitting_threshold": 10,
255
+ "run_dir": run_dir,
256
+ "best_model_path": os.path.join(run_dir, 'best_model.pth'),
257
+ "final_model_path": os.path.join(run_dir, 'final_model.pth'),
258
+ "csv_path": os.path.join(run_dir, 'training_log.csv'),
259
+ "misclassifications_dir": os.path.join(run_dir, 'misclassifications'),
260
+ }
261
+ train_and_evaluate(config)
262
+
263
+ if __name__ == "__main__":
264
+ main()
scripts/visualization/miscalculations_report.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from collections import Counter
4
+ import matplotlib.pyplot as plt
5
+ from pathlib import Path
6
+ import sys
7
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
8
+ from src.utils.utils import get_latest_run_dir
9
+
10
+ def analyze_misclassifications(run_dir=None):
11
+ if run_dir is None:
12
+ # run_dir = "/home/bawolf/workspace/break/clip/runs/run_20241022-122939_3moves_balanced"
13
+ run_dir = get_latest_run_dir()
14
+
15
+ misclassifications_dir = os.path.join(run_dir, 'misclassifications')
16
+ all_misclassifications = {}
17
+
18
+ # Collect all misclassifications across epochs
19
+ for file in os.listdir(misclassifications_dir):
20
+ if file.endswith('.json'):
21
+ with open(os.path.join(misclassifications_dir, file), 'r') as f:
22
+ epoch_misclassifications = json.load(f)
23
+ for item in epoch_misclassifications:
24
+ video_path = item['video_path']
25
+ if video_path not in all_misclassifications:
26
+ all_misclassifications[video_path] = []
27
+ all_misclassifications[video_path].append(item)
28
+
29
+ # Determine the total number of epochs from the files
30
+ epoch_files = [f for f in os.listdir(misclassifications_dir) if f.startswith('epoch_') and f.endswith('.json')]
31
+ total_epochs = len(epoch_files)
32
+
33
+ # Count misclassifications per video
34
+ misclassification_counts = {video: len(misclassifications)
35
+ for video, misclassifications in all_misclassifications.items()}
36
+
37
+ # Calculate percentage of epochs each video was misclassified
38
+ misclassification_percentages = {video: (count / total_epochs) * 100
39
+ for video, count in misclassification_counts.items()}
40
+
41
+ # Sort videos by misclassification percentage
42
+ sorted_videos = sorted(misclassification_percentages.items(), key=lambda x: x[1], reverse=True)
43
+
44
+ # Prepare report
45
+ report = "Misclassification Analysis Report\n"
46
+ report += "=================================\n\n"
47
+
48
+ # Top N most misclassified videos
49
+ N = 20
50
+ report += f"Top {N} Most Misclassified Videos:\n"
51
+ for video, percentage in sorted_videos[:N]:
52
+ report += f"{Path(video).name}: Misclassified in {percentage:.2f}% of epochs ({misclassification_counts[video]} out of {total_epochs})\n"
53
+ misclassifications = all_misclassifications[video]
54
+ true_label = misclassifications[0]['true_label']
55
+ predicted_labels = Counter(m['predicted_label'] for m in misclassifications)
56
+ report += f" True Label: {true_label}\n"
57
+ report += f" Predicted Labels: {dict(predicted_labels)}\n\n"
58
+
59
+ # Overall statistics
60
+ total_misclassifications = sum(misclassification_counts.values())
61
+ total_videos = len(misclassification_counts)
62
+ report += "Overall Statistics:\n"
63
+ report += f"Total misclassified videos: {total_videos}\n"
64
+ report += f"Total misclassifications: {total_misclassifications}\n"
65
+ report += f"Average misclassification percentage per video: {sum(misclassification_percentages.values()) / total_videos:.2f}%\n"
66
+ report += f"Total epochs: {total_epochs}\n"
67
+
68
+ # Save report
69
+ report_path = os.path.join(run_dir, 'misclassification_report.txt')
70
+ with open(report_path, 'w') as f:
71
+ f.write(report)
72
+
73
+ # Create visualization
74
+ plt.figure(figsize=(12, 6))
75
+ plt.bar(range(len(sorted_videos)), [percentage for _, percentage in sorted_videos])
76
+ plt.title(f'Videos Ranked by Misclassification Percentage (Total Epochs: {total_epochs})')
77
+ plt.xlabel('Video Rank')
78
+ plt.ylabel('Misclassification Percentage')
79
+ plt.ylim(0, 100) # Set y-axis limit to 0-100%
80
+ plt.tight_layout()
81
+ plt.savefig(os.path.join(run_dir, 'misclassification_distribution.png'))
82
+
83
+ print(f"Analysis complete. Report saved to {report_path}")
84
+ print(f"Visualization saved to {os.path.join(run_dir, 'misclassification_distribution.png')}")
85
+
86
+ if __name__ == "__main__":
87
+ import sys
88
+ if len(sys.argv) > 2:
89
+ print("Usage: python analyze_misclassifications.py [path_to_run_directory]")
90
+ sys.exit(1)
91
+
92
+ run_dir = sys.argv[1] if len(sys.argv) == 2 else None
93
+ analyze_misclassifications(run_dir)
scripts/visualization/visualize.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import matplotlib.pyplot as plt
3
+ import seaborn as sns
4
+ from sklearn.metrics import confusion_matrix, precision_recall_curve, average_precision_score, roc_curve, auc
5
+ from torch.utils.data import DataLoader
6
+ import pandas as pd
7
+ import numpy as np
8
+ import os
9
+ import sys
10
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
11
+
12
+ from src.data.dataset import VideoDataset
13
+ from src.utils.utils import get_latest_model_path, get_latest_run_dir, get_config
14
+ from src.models.model import load_model
15
+ import json
16
+
17
+ def plot_training_curves(log_file, output_dir):
18
+ data = pd.read_csv(log_file)
19
+
20
+ plt.figure(figsize=(12, 5))
21
+
22
+ # Plot loss curves
23
+ plt.subplot(1, 2, 1)
24
+ plt.plot(data['epoch'], data['train_loss'], label='Train Loss')
25
+ plt.plot(data['epoch'], data['val_loss'], label='Validation Loss')
26
+ plt.xlabel('Epochs')
27
+ plt.ylabel('Loss')
28
+ plt.title('Training and Validation Loss')
29
+ plt.legend()
30
+
31
+ # Plot accuracy curves
32
+ plt.subplot(1, 2, 2)
33
+ plt.plot(data['epoch'], data['train_accuracy'], label='Train Accuracy')
34
+ plt.plot(data['epoch'], data['val_accuracy'], label='Validation Accuracy')
35
+ plt.xlabel('Epochs')
36
+ plt.ylabel('Accuracy')
37
+ plt.title('Training and Validation Accuracy')
38
+ plt.legend()
39
+
40
+ plt.tight_layout()
41
+ plt.savefig(os.path.join(output_dir, 'training_curves.png'))
42
+ plt.close()
43
+
44
+ def generate_evaluation_metrics(model, data_loader, device, output_dir, class_labels, data_info):
45
+ model.eval()
46
+ all_preds = []
47
+ all_labels = []
48
+ all_probs = []
49
+
50
+ with torch.no_grad():
51
+ for frames, labels, _ in data_loader:
52
+ frames = frames.to(device)
53
+ labels = labels.to(device)
54
+
55
+ outputs = model(frames)
56
+ probs = torch.softmax(outputs, dim=1)
57
+ _, predicted = outputs.max(1)
58
+
59
+ all_preds.extend(predicted.cpu().numpy())
60
+ all_labels.extend(labels.cpu().numpy())
61
+ all_probs.extend(probs.cpu().numpy())
62
+
63
+ all_labels = np.array(all_labels)
64
+ all_preds = np.array(all_preds)
65
+ all_probs = np.array(all_probs)
66
+
67
+ # Compute and plot confusion matrix
68
+ cm = confusion_matrix(all_labels, all_preds)
69
+ plt.figure(figsize=(10, 8))
70
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
71
+ plt.xlabel('Predicted Value')
72
+ plt.ylabel('Actual Value')
73
+ plt.title(f'Confusion Matrix\n{data_info}')
74
+ plt.savefig(os.path.join(output_dir, 'confusion_matrix.png'))
75
+ plt.close()
76
+
77
+ colors = ['blue', 'red', 'green', 'yellow', 'purple', 'orange', 'pink', 'cyan']
78
+
79
+ # Precision-Recall Curve
80
+ plt.figure(figsize=(10, 8))
81
+ for i, class_label in enumerate(class_labels):
82
+ precision, recall, _ = precision_recall_curve(all_labels == i, all_probs[:, i])
83
+ average_precision = average_precision_score(all_labels == i, all_probs[:, i])
84
+ plt.plot(recall, precision, color=colors[i], lw=2,
85
+ label=f'{class_label} (AP = {average_precision:.2f})')
86
+
87
+ plt.xlabel('Recall')
88
+ plt.ylabel('Precision')
89
+ plt.title(f'Precision-Recall Curve\n{data_info}')
90
+ plt.legend(loc="lower left")
91
+ plt.savefig(f'{output_dir}/precision_recall_curve.png')
92
+ plt.close()
93
+
94
+ # ROC Curve
95
+ plt.figure(figsize=(10, 8))
96
+ for i, class_label in enumerate(class_labels):
97
+ fpr, tpr, _ = roc_curve(all_labels == i, all_probs[:, i])
98
+ roc_auc = auc(fpr, tpr)
99
+ plt.plot(fpr, tpr, color=colors[i], lw=2,
100
+ label=f'{class_label} (AUC = {roc_auc:.2f})')
101
+
102
+ plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
103
+ plt.xlim([0.0, 1.0])
104
+ plt.ylim([0.0, 1.05])
105
+ plt.xlabel('False Positive Rate')
106
+ plt.ylabel('True Positive Rate')
107
+ plt.title(f'Receiver Operating Characteristic (ROC) Curve\n{data_info}')
108
+ plt.legend(loc="lower right")
109
+ plt.savefig(f'{output_dir}/roc_curve.png')
110
+ plt.close()
111
+
112
+ return cm
113
+
114
+ if __name__ == "__main__":
115
+ # Find the most recent run directory
116
+ #
117
+ run_dir = get_latest_run_dir()
118
+ # run_dir= "/home/bawolf/workspace/break/clip/runs/run_20241024-150232_otherpeopleval_large_model"
119
+ # run_dir = "/home/bawolf/workspace/break/clip/runs/run_20241022-122939_3moves_balanced"
120
+
121
+ # Load configuration
122
+ config = get_config(run_dir)
123
+
124
+ class_labels = config['class_labels']
125
+ num_classes = config['num_classes']
126
+ data_path = config['data_path']
127
+ # data_path= '../finetune/3moves_otherpeopleval'
128
+ # data_path = '../finetune/otherpeople3moves'
129
+
130
+ # Paths
131
+ log_file = os.path.join(run_dir, 'training_log.csv')
132
+ model_path = get_latest_model_path(run_dir)
133
+ test_csv = os.path.join(data_path, 'test.csv')
134
+ # test_csv = os.path.join(data_path, 'val.csv')
135
+ # test_csv = os.path.join(data_path, 'train.csv')
136
+
137
+ # Get the last directory of data_path and the file name
138
+ last_dir = os.path.basename(os.path.normpath(data_path))
139
+ file_name = os.path.basename(test_csv)
140
+
141
+ # Create a directory for visualization outputs
142
+ vis_dir = os.path.join(run_dir, f'visualization_{last_dir}_{file_name.split(".")[0]}')
143
+ os.makedirs(vis_dir, exist_ok=True)
144
+
145
+ # Create data_info string for chart headers
146
+ data_info = f'Data: {last_dir}, File: {file_name}'
147
+
148
+ # Plot training curves
149
+ plot_training_curves(log_file, vis_dir)
150
+
151
+ # Load model
152
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
153
+ model = load_model(num_classes, model_path, device, config['clip_model'])
154
+ model.eval()
155
+
156
+ # Create test dataset and dataloader
157
+ test_dataset = VideoDataset(test_csv, config)
158
+ test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)
159
+
160
+ # Generate evaluation metrics
161
+ cm = generate_evaluation_metrics(model, test_loader, device, vis_dir, class_labels, data_info)
162
+
163
+ print(f"Visualization complete! Check the output directory: {vis_dir}")
src/models/__init__.py ADDED
File without changes
src/models/model.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPModel
4
+
5
+ class VariableLengthCLIP(nn.Module):
6
+ def __init__(self, clip_model, num_classes):
7
+ super().__init__()
8
+ self.clip_model = clip_model
9
+ self.visual_projection = nn.Linear(clip_model.visual_projection.in_features, num_classes)
10
+
11
+ def forward(self, x):
12
+ batch_size, num_frames, c, h, w = x.shape
13
+ x = x.view(batch_size * num_frames, c, h, w)
14
+ features = self.clip_model.vision_model(x).pooler_output
15
+ features = features.view(batch_size, num_frames, -1)
16
+ features = torch.mean(features, dim=1) # Average over frames
17
+ return self.visual_projection(features)
18
+
19
+ def unfreeze_vision_encoder(self, num_layers=2):
20
+ # Freeze the entire vision encoder
21
+ for param in self.clip_model.vision_model.parameters():
22
+ param.requires_grad = False
23
+ # Unfreeze the last few layers of the vision encoder
24
+ for param in self.clip_model.vision_model.encoder.layers[-num_layers:].parameters():
25
+ param.requires_grad = True
26
+
27
+ def create_model(num_classes, pretrained_model_name="openai/clip-vit-base-patch32"):
28
+ clip_model = CLIPModel.from_pretrained(pretrained_model_name)
29
+ return VariableLengthCLIP(clip_model, num_classes)
30
+
31
+ def load_model(num_classes, model_path, device, pretrained_model_name="openai/clip-vit-base-patch32"):
32
+ # Create the model
33
+ model = create_model(num_classes, pretrained_model_name)
34
+
35
+ # Load the state dict
36
+ state_dict = torch.load(model_path, map_location=device, weights_only=True)
37
+
38
+ # Load the state dict, ignoring mismatched keys
39
+ model.load_state_dict(state_dict, strict=False)
40
+
41
+ model.to(device) # Move the model to the appropriate device
42
+ return model
src/utils/utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from datetime import datetime
4
+
5
+ def create_run_directory(base_dir='runs', prefix='run', suffix='', parent_dir=None):
6
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
7
+ dir_name = f"{prefix}_{timestamp}{suffix}"
8
+
9
+ if parent_dir:
10
+ run_dir = os.path.join(parent_dir, dir_name)
11
+ else:
12
+ run_dir = os.path.join(base_dir, dir_name)
13
+ os.makedirs(run_dir, exist_ok=True)
14
+ return run_dir
15
+
16
+ # Find the most recent run directory
17
+ def get_latest_run_dir(base_dir='runs', include_hyperparam=True):
18
+ all_dirs = []
19
+
20
+ for d in os.listdir(base_dir):
21
+ if d.startswith('run_'):
22
+ full_path = os.path.join(base_dir, d)
23
+ all_dirs.append(full_path)
24
+
25
+ if d.endswith('_hyperparam') and include_hyperparam:
26
+ # If it's a hyperparam directory, add its trial subdirectories
27
+ trial_dirs = [os.path.join(full_path, td) for td in os.listdir(full_path) if td.startswith('trial_')]
28
+ all_dirs.extend(trial_dirs)
29
+
30
+ if not all_dirs:
31
+ raise ValueError(f"No run directories found in {base_dir}")
32
+
33
+ # Sort directories by timestamp in the directory name
34
+ return max(all_dirs, key=get_dir_timestamp)
35
+
36
+ def get_run_file(filename, run_dir=None, required=True):
37
+ """Get a file from a run directory
38
+
39
+ Args:
40
+ filename: Name of file to get (e.g., 'best_model.pth', 'config.json')
41
+ run_dir: Run directory path (uses latest if None)
42
+ required: Whether to raise an error if file not found
43
+
44
+ Returns:
45
+ str: Path to the file
46
+ dict: Loaded JSON data if file ends with .json
47
+ """
48
+ if run_dir is None:
49
+ run_dir = get_latest_run_dir()
50
+
51
+ file_path = os.path.join(run_dir, filename)
52
+
53
+ if not os.path.exists(file_path):
54
+ if required:
55
+ raise FileNotFoundError(f"{filename} not found in {run_dir}")
56
+ return None
57
+
58
+ # Load JSON files automatically
59
+ if filename.endswith('.json'):
60
+ with open(file_path, 'r') as f:
61
+ return json.load(f)
62
+
63
+ return file_path
64
+
65
+ def get_latest_model_path(run_dir=None):
66
+ """Get path to best_model.pth"""
67
+ return get_run_file('best_model.pth', run_dir)
68
+
69
+ def get_config(run_dir=None):
70
+ """Get config from run directory"""
71
+ return get_run_file('config.json', run_dir)
72
+
73
+ # Helper function to parse directory name and get timestamp
74
+ def get_dir_timestamp(dir_path):
75
+ dir_name = os.path.basename(dir_path)
76
+ try:
77
+ # Extract timestamp from directory name
78
+ timestamp_str = dir_name.split('_')[1] # Assumes format is always prefix_timestamp or prefix_timestamp_suffix
79
+ return datetime.strptime(timestamp_str, "%Y%m%d-%H%M%S")
80
+ except (IndexError, ValueError):
81
+ # If parsing fails, return the earliest possible date
82
+ return datetime.min