Commit
·
31fc7e1
0
Parent(s):
init working
Browse files- .gitignore +70 -0
- README.md +103 -0
- cog.yaml +18 -0
- predict.py +69 -0
- requirements.txt +80 -0
- scripts/hyperparameter_tuning.py +79 -0
- scripts/inference.py +106 -0
- scripts/train.py +264 -0
- scripts/visualization/miscalculations_report.py +93 -0
- scripts/visualization/visualize.py +163 -0
- src/models/__init__.py +0 -0
- src/models/model.py +42 -0
- src/utils/utils.py +82 -0
.gitignore
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
env/
|
8 |
+
build/
|
9 |
+
develop-eggs/
|
10 |
+
dist/
|
11 |
+
downloads/
|
12 |
+
eggs/
|
13 |
+
.eggs/
|
14 |
+
lib/
|
15 |
+
lib64/
|
16 |
+
parts/
|
17 |
+
sdist/
|
18 |
+
var/
|
19 |
+
wheels/
|
20 |
+
*.egg-info/
|
21 |
+
.installed.cfg
|
22 |
+
*.egg
|
23 |
+
|
24 |
+
# Virtual Environment
|
25 |
+
venv/
|
26 |
+
ENV/
|
27 |
+
.env
|
28 |
+
|
29 |
+
# IDE
|
30 |
+
.idea/
|
31 |
+
.vscode/
|
32 |
+
*.swp
|
33 |
+
*.swo
|
34 |
+
.DS_Store
|
35 |
+
|
36 |
+
# Project specific
|
37 |
+
runs/
|
38 |
+
checkpoints/
|
39 |
+
*.pth
|
40 |
+
*.ckpt
|
41 |
+
*.pt
|
42 |
+
wandb/
|
43 |
+
logs/
|
44 |
+
.cog/
|
45 |
+
|
46 |
+
# Data
|
47 |
+
data/
|
48 |
+
*.mp4
|
49 |
+
*.avi
|
50 |
+
*.mov
|
51 |
+
*.jpg
|
52 |
+
*.jpeg
|
53 |
+
*.png
|
54 |
+
*.gif
|
55 |
+
*.h5
|
56 |
+
*.npy
|
57 |
+
*.npz
|
58 |
+
|
59 |
+
# Jupyter Notebook
|
60 |
+
.ipynb_checkpoints
|
61 |
+
*.ipynb
|
62 |
+
|
63 |
+
# Logs
|
64 |
+
*.log
|
65 |
+
*.csv
|
66 |
+
*.json
|
67 |
+
|
68 |
+
# Keep specific config files
|
69 |
+
!config.json
|
70 |
+
!requirements.txt
|
README.md
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CLIP-Based Break Dance Move Classifier
|
2 |
+
|
3 |
+
A deep learning model for classifying break dance moves using CLIP (Contrastive Language-Image Pre-Training) embeddings. The model is fine-tuned on break dance videos to classify different power moves including windmills, halos, swipes, and baby mills.
|
4 |
+
|
5 |
+
## Features
|
6 |
+
|
7 |
+
- Video-based classification using CLIP embeddings
|
8 |
+
- Multi-frame temporal analysis
|
9 |
+
- Configurable frame sampling and data augmentation
|
10 |
+
- Real-time inference using Cog
|
11 |
+
- Misclassification analysis tools
|
12 |
+
- Hyperparameter tuning support
|
13 |
+
|
14 |
+
## Setup
|
15 |
+
|
16 |
+
```bash
|
17 |
+
# Install dependencies
|
18 |
+
pip install -r requirements.txt
|
19 |
+
|
20 |
+
# Install Cog (if not already installed)
|
21 |
+
curl -o /usr/local/bin/cog -L https://github.com/replicate/cog/releases/latest/download/cog_`uname -s`_`uname -m`
|
22 |
+
chmod +x /usr/local/bin/cog
|
23 |
+
```
|
24 |
+
|
25 |
+
## Training
|
26 |
+
|
27 |
+
```bash
|
28 |
+
# Run training with default configuration
|
29 |
+
python scripts/train.py
|
30 |
+
|
31 |
+
# Run hyperparameter tuning
|
32 |
+
python scripts/hyperparameter_tuning.py
|
33 |
+
```
|
34 |
+
|
35 |
+
## Inference
|
36 |
+
|
37 |
+
```bash
|
38 |
+
# Using Cog for inference
|
39 |
+
cog predict -i video=@path/to/your/video.mp4
|
40 |
+
|
41 |
+
# Using standard Python script
|
42 |
+
python scripts/inference.py --video path/to/your/video.mp4
|
43 |
+
```
|
44 |
+
|
45 |
+
## Analysis
|
46 |
+
|
47 |
+
```bash
|
48 |
+
# Generate misclassification report
|
49 |
+
python scripts/visualization/miscalculations_report.py
|
50 |
+
|
51 |
+
# Visualize model performance
|
52 |
+
python scripts/visualization/visualize.py
|
53 |
+
```
|
54 |
+
|
55 |
+
## Project Structure
|
56 |
+
|
57 |
+
```
|
58 |
+
clip/
|
59 |
+
├── src/ # Source code
|
60 |
+
│ ├── data/ # Dataset and data processing
|
61 |
+
│ ├── models/ # Model architecture
|
62 |
+
│ └── utils/ # Utility functions
|
63 |
+
├── scripts/ # Training and inference scripts
|
64 |
+
│ └── visualization/ # Visualization tools
|
65 |
+
├── config/ # Configuration files
|
66 |
+
├── runs/ # Training runs and checkpoints
|
67 |
+
├── cog.yaml # Cog configuration
|
68 |
+
└── requirements.txt # Python dependencies
|
69 |
+
```
|
70 |
+
|
71 |
+
## Model Architecture
|
72 |
+
|
73 |
+
- Base: CLIP ViT-Large/14
|
74 |
+
- Custom temporal pooling layer
|
75 |
+
- Fine-tuned vision encoder (last 3 layers)
|
76 |
+
- Output: 4-class classifier
|
77 |
+
|
78 |
+
## Performance
|
79 |
+
|
80 |
+
- Training Accuracy: ~95%
|
81 |
+
- Validation Accuracy: ~92%
|
82 |
+
- Inference Time: ~100ms per video
|
83 |
+
|
84 |
+
## Configuration
|
85 |
+
|
86 |
+
Key hyperparameters can be modified in `config/default.yaml`:
|
87 |
+
- Frame sampling: 10 frames per video
|
88 |
+
- Image size: 224x224
|
89 |
+
- Learning rate: 2e-6
|
90 |
+
- Weight decay: 0.007
|
91 |
+
- Data augmentation parameters
|
92 |
+
|
93 |
+
## License
|
94 |
+
|
95 |
+
[Your License Here]
|
96 |
+
|
97 |
+
## Citation
|
98 |
+
|
99 |
+
If you use this model in your research, please cite:
|
100 |
+
|
101 |
+
```bibtex
|
102 |
+
[Your Citation Here]
|
103 |
+
```
|
cog.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
build:
|
2 |
+
gpu: true
|
3 |
+
cuda: "12.1"
|
4 |
+
python_version: "3.10"
|
5 |
+
system_packages:
|
6 |
+
- "libgl1-mesa-glx"
|
7 |
+
- "libglib2.0-0"
|
8 |
+
python_packages:
|
9 |
+
- "torch==2.3.0"
|
10 |
+
- "torchvision"
|
11 |
+
- "transformers"
|
12 |
+
- "opencv-python"
|
13 |
+
- "pillow"
|
14 |
+
- "numpy"
|
15 |
+
- "scipy"
|
16 |
+
- "huggingface_hub"
|
17 |
+
|
18 |
+
predict: "predict.py:Predictor"
|
predict.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from cog import BasePredictor, Input, Path
|
3 |
+
import torch
|
4 |
+
import json
|
5 |
+
from src.models.model import load_model
|
6 |
+
from src.data.video_utils import create_transform, extract_frames
|
7 |
+
|
8 |
+
CHECKPOINT_DIR = "runs/run_20241024-150232_otherpeopleval_large_model/"
|
9 |
+
|
10 |
+
class Predictor(BasePredictor):
|
11 |
+
def setup(self):
|
12 |
+
"""Load the model into memory to make running multiple predictions efficient"""
|
13 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
print(f"Using device: {self.device}")
|
15 |
+
|
16 |
+
# Load configuration from JSON
|
17 |
+
with open(
|
18 |
+
os.path.join(CHECKPOINT_DIR, "config.json"), 'r') as f:
|
19 |
+
self.config = json.load(f)
|
20 |
+
|
21 |
+
# Create transform
|
22 |
+
self.transform = create_transform(self.config, training=False)
|
23 |
+
|
24 |
+
# Load model
|
25 |
+
self.model = load_model(
|
26 |
+
self.config['num_classes'],
|
27 |
+
os.path.join(CHECKPOINT_DIR, "best_model.pth"),
|
28 |
+
self.device,
|
29 |
+
self.config['clip_model']
|
30 |
+
)
|
31 |
+
self.model.eval()
|
32 |
+
|
33 |
+
def predict(self, video: Path = Input(description="Input video file")) -> dict:
|
34 |
+
"""Run a single prediction on the model"""
|
35 |
+
try:
|
36 |
+
# Extract frames using shared function with config
|
37 |
+
frames, success = extract_frames(
|
38 |
+
str(video),
|
39 |
+
self.config,
|
40 |
+
self.transform
|
41 |
+
)
|
42 |
+
|
43 |
+
if not success or frames is None:
|
44 |
+
raise ValueError(f"Failed to process video: {video}")
|
45 |
+
|
46 |
+
# Now frames is a tensor, not a tuple
|
47 |
+
frames = frames.unsqueeze(0).to(self.device)
|
48 |
+
|
49 |
+
# Get prediction
|
50 |
+
with torch.no_grad():
|
51 |
+
output = self.model(frames)
|
52 |
+
probabilities = torch.softmax(output, dim=1)
|
53 |
+
predicted_class = torch.argmax(probabilities, dim=1).item()
|
54 |
+
confidence = probabilities[0][predicted_class].item()
|
55 |
+
|
56 |
+
# Get all class confidences
|
57 |
+
all_confidences = {
|
58 |
+
label: probabilities[0][i].item()
|
59 |
+
for i, label in enumerate(self.config['class_labels'])
|
60 |
+
}
|
61 |
+
|
62 |
+
return {
|
63 |
+
"class": self.config['class_labels'][predicted_class],
|
64 |
+
"confidence": confidence,
|
65 |
+
"all_confidences": all_confidences
|
66 |
+
}
|
67 |
+
|
68 |
+
except Exception as e:
|
69 |
+
raise ValueError(f"Error processing video: {str(e)}")
|
requirements.txt
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
alembic==1.13.3
|
2 |
+
annotated-types==0.7.0
|
3 |
+
anyio==4.6.2.post1
|
4 |
+
attrs==23.2.0
|
5 |
+
certifi==2024.8.30
|
6 |
+
charset-normalizer==3.4.0
|
7 |
+
click==8.1.7
|
8 |
+
cog==0.12.0
|
9 |
+
colorlog==6.9.0
|
10 |
+
contourpy==1.3.0
|
11 |
+
cycler==0.12.1
|
12 |
+
fastapi==0.110.3
|
13 |
+
filelock==3.16.1
|
14 |
+
fonttools==4.54.1
|
15 |
+
fsspec==2024.10.0
|
16 |
+
greenlet==3.1.1
|
17 |
+
h11==0.14.0
|
18 |
+
httptools==0.6.4
|
19 |
+
huggingface-hub==0.26.2
|
20 |
+
idna==3.10
|
21 |
+
Jinja2==3.1.4
|
22 |
+
joblib==1.4.2
|
23 |
+
kiwisolver==1.4.7
|
24 |
+
Mako==1.3.6
|
25 |
+
MarkupSafe==3.0.2
|
26 |
+
matplotlib==3.9.2
|
27 |
+
mpmath==1.3.0
|
28 |
+
networkx==3.4.2
|
29 |
+
nms_1d_cpu==0.0.0
|
30 |
+
numpy==2.1.2
|
31 |
+
nvidia-cublas-cu12==12.4.5.8
|
32 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
33 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
34 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
35 |
+
nvidia-cudnn-cu12==9.1.0.70
|
36 |
+
nvidia-cufft-cu12==11.2.1.3
|
37 |
+
nvidia-curand-cu12==10.3.5.147
|
38 |
+
nvidia-cusolver-cu12==11.6.1.9
|
39 |
+
nvidia-cusparse-cu12==12.3.1.170
|
40 |
+
nvidia-nccl-cu12==2.21.5
|
41 |
+
nvidia-nvjitlink-cu12==12.4.127
|
42 |
+
nvidia-nvtx-cu12==12.4.127
|
43 |
+
opencv-python==4.10.0.84
|
44 |
+
optuna==4.0.0
|
45 |
+
packaging==24.1
|
46 |
+
pandas==2.2.3
|
47 |
+
pillow==11.0.0
|
48 |
+
pydantic==2.9.2
|
49 |
+
pydantic_core==2.23.4
|
50 |
+
pyparsing==3.2.0
|
51 |
+
python-dateutil==2.9.0.post0
|
52 |
+
python-dotenv==1.0.1
|
53 |
+
pytz==2024.2
|
54 |
+
PyYAML==6.0.2
|
55 |
+
regex==2024.9.11
|
56 |
+
requests==2.32.3
|
57 |
+
safetensors==0.4.5
|
58 |
+
scikit-learn==1.5.2
|
59 |
+
scipy==1.14.1
|
60 |
+
seaborn==0.13.2
|
61 |
+
six==1.16.0
|
62 |
+
sniffio==1.3.1
|
63 |
+
SQLAlchemy==2.0.36
|
64 |
+
starlette==0.37.2
|
65 |
+
structlog==24.4.0
|
66 |
+
sympy==1.13.1
|
67 |
+
threadpoolctl==3.5.0
|
68 |
+
tokenizers==0.20.1
|
69 |
+
torch==2.5.1
|
70 |
+
torchvision==0.20.1
|
71 |
+
tqdm==4.66.6
|
72 |
+
transformers==4.46.1
|
73 |
+
triton==3.1.0
|
74 |
+
typing_extensions==4.12.2
|
75 |
+
tzdata==2024.2
|
76 |
+
urllib3==2.2.3
|
77 |
+
uvicorn==0.32.0
|
78 |
+
uvloop==0.21.0
|
79 |
+
watchfiles==0.24.0
|
80 |
+
websockets==13.1
|
scripts/hyperparameter_tuning.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import optuna
|
2 |
+
import os
|
3 |
+
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
7 |
+
from scripts.train import train_and_evaluate
|
8 |
+
from src.utils.utils import create_run_directory
|
9 |
+
|
10 |
+
def objective(trial, hyperparam_run_dir):
|
11 |
+
config = {
|
12 |
+
"clip_model": trial.suggest_categorical("clip_model", ["openai/clip-vit-base-patch32", "openai/clip-vit-large-patch14"]),
|
13 |
+
"learning_rate": trial.suggest_loguniform("learning_rate", 1e-6, 1e-4),
|
14 |
+
"weight_decay": trial.suggest_loguniform("weight_decay", 1e-8, 1e-1),
|
15 |
+
"unfreeze_layers": trial.suggest_int("unfreeze_layers", 1, 6),
|
16 |
+
"batch_size": trial.suggest_categorical("batch_size", [32, 64, 128]),
|
17 |
+
"gradient_clip_max_norm": trial.suggest_uniform("gradient_clip_max_norm", 0.1, 1.0),
|
18 |
+
"augmentation_strength": trial.suggest_float("augmentation_strength", 0.0, 1.0),
|
19 |
+
"crop_scale_min": trial.suggest_float("crop_scale_min", 0.6, 0.9),
|
20 |
+
"max_frames": trial.suggest_int("max_frames", 5, 15),
|
21 |
+
"sigma": trial.suggest_uniform("sigma", 0.1, 0.5),
|
22 |
+
}
|
23 |
+
|
24 |
+
class_labels = ["windmill", "halo", "swipe", "baby_mill"][:3]
|
25 |
+
|
26 |
+
# Fixed configurations
|
27 |
+
config.update({
|
28 |
+
"class_labels": class_labels,
|
29 |
+
"num_classes": len(class_labels),
|
30 |
+
"data_path": '../finetune/3moves_test',
|
31 |
+
"num_epochs": 50, # Reduced for faster trials
|
32 |
+
"patience": 10, # Adjusted for faster trials
|
33 |
+
"image_size": 224,
|
34 |
+
"crop_scale_max": 1.0,
|
35 |
+
"normalization_mean": [0.485, 0.456, 0.406],
|
36 |
+
"normalization_std": [0.229, 0.224, 0.225],
|
37 |
+
"overfitting_threshold": 10,
|
38 |
+
})
|
39 |
+
|
40 |
+
# Derive augmentation parameters from augmentation_strength
|
41 |
+
config.update({
|
42 |
+
"flip_probability": 0.5 * config["augmentation_strength"],
|
43 |
+
"rotation_degrees": int(15 * config["augmentation_strength"]),
|
44 |
+
"brightness_jitter": 0.2 * config["augmentation_strength"],
|
45 |
+
"contrast_jitter": 0.2 * config["augmentation_strength"],
|
46 |
+
"saturation_jitter": 0.2 * config["augmentation_strength"],
|
47 |
+
"hue_jitter": 0.1 * config["augmentation_strength"],
|
48 |
+
})
|
49 |
+
|
50 |
+
# Create a unique run directory for this trial
|
51 |
+
config["run_dir"] = create_run_directory(prefix=f"trial", parent_dir=hyperparam_run_dir)
|
52 |
+
|
53 |
+
# Run training and evaluation
|
54 |
+
val_accuracy = train_and_evaluate(config)
|
55 |
+
return val_accuracy
|
56 |
+
|
57 |
+
def main():
|
58 |
+
# Set up the study and optimize
|
59 |
+
hyperparam_run_dir = create_run_directory(suffix='_hyperparam')
|
60 |
+
study = optuna.create_study(direction="maximize")
|
61 |
+
study.optimize(lambda trial: objective(trial, hyperparam_run_dir), n_trials=100) # Adjust the number of trials as needed
|
62 |
+
|
63 |
+
# Save the study results
|
64 |
+
study.trials_dataframe().to_csv(os.path.join(hyperparam_run_dir, 'study_results.csv'))
|
65 |
+
|
66 |
+
print("Best trial:")
|
67 |
+
trial = study.best_trial
|
68 |
+
print(" Value: ", trial.value)
|
69 |
+
print(" Params: ")
|
70 |
+
for key, value in trial.params.items():
|
71 |
+
print(" {}: {}".format(key, value))
|
72 |
+
|
73 |
+
# Save the best trial parameters
|
74 |
+
with open(os.path.join(hyperparam_run_dir, 'best_params.txt'), 'w') as f:
|
75 |
+
for key, value in trial.params.items():
|
76 |
+
f.write(f"{key}: {value}\n")
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
main()
|
scripts/inference.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
6 |
+
|
7 |
+
from src.utils.utils import get_latest_run_dir, get_latest_model_path, get_config
|
8 |
+
from src.models.model import load_model
|
9 |
+
from src.data.video_utils import create_transform, extract_frames
|
10 |
+
|
11 |
+
def setup_model(run_dir=None):
|
12 |
+
"""Setup model and configuration"""
|
13 |
+
# Define the device
|
14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
+
|
16 |
+
# Get run directory
|
17 |
+
if run_dir is None:
|
18 |
+
run_dir = get_latest_run_dir()
|
19 |
+
print(f"Using run directory: {run_dir}")
|
20 |
+
|
21 |
+
try:
|
22 |
+
# Load configuration
|
23 |
+
config = get_config(run_dir)
|
24 |
+
print(f"Loaded configuration from: {run_dir}")
|
25 |
+
|
26 |
+
# Load the model
|
27 |
+
model_path = get_latest_model_path(run_dir)
|
28 |
+
print(f"Loading model from: {model_path}")
|
29 |
+
|
30 |
+
model = load_model(
|
31 |
+
config['num_classes'],
|
32 |
+
model_path,
|
33 |
+
device,
|
34 |
+
config['clip_model']
|
35 |
+
)
|
36 |
+
model.eval()
|
37 |
+
|
38 |
+
return model, config, device
|
39 |
+
|
40 |
+
except (ValueError, FileNotFoundError) as e:
|
41 |
+
print(f"Error loading model: {str(e)}")
|
42 |
+
exit(1)
|
43 |
+
|
44 |
+
def predict(video_path, model, config, device):
|
45 |
+
"""Predict class for a video using the model"""
|
46 |
+
transform = create_transform(config, training=False)
|
47 |
+
|
48 |
+
try:
|
49 |
+
frames, success = extract_frames(video_path,
|
50 |
+
config,
|
51 |
+
transform)
|
52 |
+
if not success:
|
53 |
+
raise ValueError(f"Failed to process video: {video_path}")
|
54 |
+
|
55 |
+
frames = frames.to(device)
|
56 |
+
|
57 |
+
# Add batch dimension correctly
|
58 |
+
frames = frames.unsqueeze(0) # Add batch dimension at the start
|
59 |
+
|
60 |
+
with torch.no_grad():
|
61 |
+
try:
|
62 |
+
outputs = model(frames)
|
63 |
+
probabilities = torch.nn.functional.softmax(outputs, dim=1)
|
64 |
+
except Exception as e:
|
65 |
+
print(f"Error during model forward pass: {str(e)}")
|
66 |
+
print(f"Model input shape: {frames.shape}")
|
67 |
+
raise
|
68 |
+
|
69 |
+
|
70 |
+
# Get predictions
|
71 |
+
avg_probabilities = probabilities[0].cpu().numpy()
|
72 |
+
predicted_class = np.argmax(avg_probabilities)
|
73 |
+
|
74 |
+
# Create a dictionary of class labels and their probabilities
|
75 |
+
class_probabilities = {
|
76 |
+
label: float(prob)
|
77 |
+
for label, prob in zip(config['class_labels'], avg_probabilities)
|
78 |
+
}
|
79 |
+
|
80 |
+
return config['class_labels'][predicted_class], class_probabilities
|
81 |
+
|
82 |
+
except Exception as e:
|
83 |
+
raise ValueError(f"Error processing video: {str(e)}")
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
import argparse
|
87 |
+
|
88 |
+
parser = argparse.ArgumentParser(description='Run inference on a video file')
|
89 |
+
parser.add_argument('--video', type=str, required=True,
|
90 |
+
help='Path to the video file')
|
91 |
+
parser.add_argument('--run-dir', type=str,
|
92 |
+
help='Path to specific run directory (optional)')
|
93 |
+
|
94 |
+
args = parser.parse_args()
|
95 |
+
|
96 |
+
# Setup model and config
|
97 |
+
model, config, device = setup_model(args.run_dir)
|
98 |
+
|
99 |
+
try:
|
100 |
+
predicted_label, class_probabilities = predict(args.video, model, config, device)
|
101 |
+
print(f"\nPredicted label: {predicted_label}")
|
102 |
+
print("\nClass probabilities:")
|
103 |
+
for label, prob in class_probabilities.items():
|
104 |
+
print(f" {label}: {prob:.4f}")
|
105 |
+
except ValueError as e:
|
106 |
+
print(f"Error: {str(e)}")
|
scripts/train.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from torch.nn.utils import clip_grad_norm_
|
4 |
+
from tqdm import tqdm
|
5 |
+
import os
|
6 |
+
import logging
|
7 |
+
import csv
|
8 |
+
import json
|
9 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
10 |
+
|
11 |
+
import sys
|
12 |
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
13 |
+
|
14 |
+
from src.utils.utils import create_run_directory
|
15 |
+
from src.data.dataset import VideoDataset
|
16 |
+
from src.models.model import create_model
|
17 |
+
from src.data.video_utils import create_transform
|
18 |
+
|
19 |
+
def train_and_evaluate(config):
|
20 |
+
# Create a run directory if it doesn't exist
|
21 |
+
if "run_dir" not in config:
|
22 |
+
config["run_dir"] = create_run_directory()
|
23 |
+
|
24 |
+
# Update paths based on run_dir
|
25 |
+
config.update({
|
26 |
+
"best_model_path": os.path.join(config["run_dir"], 'best_model.pth'),
|
27 |
+
"final_model_path": os.path.join(config["run_dir"], 'final_model.pth'),
|
28 |
+
"csv_path": os.path.join(config["run_dir"], 'training_log.csv'),
|
29 |
+
"misclassifications_dir": os.path.join(config["run_dir"], 'misclassifications'),
|
30 |
+
})
|
31 |
+
|
32 |
+
config_path = os.path.join(config["run_dir"], 'config.json')
|
33 |
+
with open(config_path, 'w') as f:
|
34 |
+
json.dump(config, f, indent=2)
|
35 |
+
|
36 |
+
# Set up logging
|
37 |
+
logging.basicConfig(level=logging.INFO,
|
38 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
39 |
+
handlers=[logging.FileHandler(os.path.join(config["run_dir"], 'training.log')),
|
40 |
+
logging.StreamHandler()])
|
41 |
+
logger = logging.getLogger(__name__)
|
42 |
+
|
43 |
+
# Set device
|
44 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
45 |
+
logger.info(f"Using device: {device}")
|
46 |
+
|
47 |
+
# Initialize variables
|
48 |
+
best_val_loss = float('inf')
|
49 |
+
epochs_without_improvement = 0
|
50 |
+
|
51 |
+
model = create_model(config["num_classes"], config["clip_model"])
|
52 |
+
|
53 |
+
# Unfreeze the last 2 layers of the vision encoder
|
54 |
+
model.unfreeze_vision_encoder(num_layers=config["unfreeze_layers"])
|
55 |
+
|
56 |
+
# Move model to device
|
57 |
+
model = model.to(device)
|
58 |
+
logger.info(f"Model architecture:\n{model}")
|
59 |
+
|
60 |
+
# Load datasets
|
61 |
+
train_dataset = VideoDataset(
|
62 |
+
os.path.join(config['data_path'], 'train.csv'),
|
63 |
+
config=config
|
64 |
+
)
|
65 |
+
|
66 |
+
# For validation, create a new config with training=False for transforms
|
67 |
+
val_config = config.copy()
|
68 |
+
val_dataset = VideoDataset(
|
69 |
+
os.path.join(config['data_path'], 'val.csv'),
|
70 |
+
config=val_config,
|
71 |
+
transform=create_transform(config, training=False)
|
72 |
+
)
|
73 |
+
|
74 |
+
# Create data loaders
|
75 |
+
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
|
76 |
+
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)
|
77 |
+
|
78 |
+
# Define optimizer and learning rate scheduler
|
79 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
|
80 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=config["num_epochs"])
|
81 |
+
|
82 |
+
# Open a CSV file to log training progress
|
83 |
+
with open(config["csv_path"], 'w', newline='') as file:
|
84 |
+
writer = csv.writer(file)
|
85 |
+
writer.writerow(["epoch", "train_loss", "train_accuracy", "val_loss", "val_accuracy"])
|
86 |
+
|
87 |
+
# Function to calculate accuracy
|
88 |
+
def calculate_accuracy(outputs, labels):
|
89 |
+
_, predicted = torch.max(outputs, 1)
|
90 |
+
correct = (predicted == labels).sum().item()
|
91 |
+
total = labels.size(0)
|
92 |
+
return correct / total
|
93 |
+
|
94 |
+
def log_misclassifications(outputs, labels, video_paths, dataset, misclassified_videos):
|
95 |
+
_, predicted = torch.max(outputs, 1)
|
96 |
+
for pred, label, video_path in zip(predicted, labels, video_paths):
|
97 |
+
if pred != label:
|
98 |
+
true_label = dataset.label_map[label.item()]
|
99 |
+
predicted_label = dataset.label_map[pred.item()]
|
100 |
+
misclassified_videos.append({
|
101 |
+
'video_path': video_path,
|
102 |
+
'true_label': true_label,
|
103 |
+
'predicted_label': predicted_label
|
104 |
+
})
|
105 |
+
|
106 |
+
# Create a subfolder for misclassification logs
|
107 |
+
os.makedirs(config["misclassifications_dir"], exist_ok=True)
|
108 |
+
|
109 |
+
# Training loop
|
110 |
+
for epoch in range(config["num_epochs"]):
|
111 |
+
model.train()
|
112 |
+
total_loss = 0
|
113 |
+
total_accuracy = 0
|
114 |
+
for frames, labels, video_paths in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{config['num_epochs']}"):
|
115 |
+
frames = frames.to(device)
|
116 |
+
labels = labels.to(device)
|
117 |
+
|
118 |
+
logits = model(frames)
|
119 |
+
|
120 |
+
loss = torch.nn.functional.cross_entropy(logits, labels)
|
121 |
+
accuracy = calculate_accuracy(logits, labels)
|
122 |
+
|
123 |
+
optimizer.zero_grad()
|
124 |
+
loss.backward()
|
125 |
+
clip_grad_norm_(model.parameters(), max_norm=config["gradient_clip_max_norm"])
|
126 |
+
optimizer.step()
|
127 |
+
|
128 |
+
total_loss += loss.item()
|
129 |
+
total_accuracy += accuracy
|
130 |
+
|
131 |
+
avg_train_loss = total_loss / len(train_loader)
|
132 |
+
avg_train_accuracy = total_accuracy / len(train_loader)
|
133 |
+
|
134 |
+
# Validation
|
135 |
+
model.eval()
|
136 |
+
val_loss = 0
|
137 |
+
val_accuracy = 0
|
138 |
+
misclassified_videos = []
|
139 |
+
with torch.no_grad():
|
140 |
+
for frames, labels, video_paths in val_loader:
|
141 |
+
frames = frames.to(device)
|
142 |
+
labels = labels.to(device)
|
143 |
+
|
144 |
+
logits = model(frames)
|
145 |
+
|
146 |
+
loss = torch.nn.functional.cross_entropy(logits, labels)
|
147 |
+
accuracy = calculate_accuracy(logits, labels)
|
148 |
+
|
149 |
+
val_loss += loss.item()
|
150 |
+
val_accuracy += accuracy
|
151 |
+
|
152 |
+
# Log misclassifications
|
153 |
+
log_misclassifications(logits, labels, video_paths, val_dataset, misclassified_videos)
|
154 |
+
|
155 |
+
avg_val_loss = val_loss / len(val_loader)
|
156 |
+
avg_val_accuracy = val_accuracy / len(val_loader)
|
157 |
+
|
158 |
+
# Log misclassified videos
|
159 |
+
if misclassified_videos:
|
160 |
+
misclassified_log_path = os.path.join(config["misclassifications_dir"], f'epoch_{epoch+1}.json')
|
161 |
+
with open(misclassified_log_path, 'w') as f:
|
162 |
+
json.dump(misclassified_videos, f, indent=2)
|
163 |
+
logger.info(f"Logged {len(misclassified_videos)} misclassified videos to {misclassified_log_path}")
|
164 |
+
|
165 |
+
# Log the metrics
|
166 |
+
logger.info(f"Epoch [{epoch+1}/{config['num_epochs']}], "
|
167 |
+
f"Train Loss: {avg_train_loss:.4f}, Train Accuracy: {avg_train_accuracy*100:.2f}%, "
|
168 |
+
f"Val Loss: {avg_val_loss:.4f}, Val Accuracy: {avg_val_accuracy*100:.2f}%")
|
169 |
+
|
170 |
+
# Write to CSV
|
171 |
+
with open(config["csv_path"], 'a', newline='') as file:
|
172 |
+
writer = csv.writer(file)
|
173 |
+
writer.writerow([epoch+1, avg_train_loss, avg_train_accuracy*100, avg_val_loss, avg_val_accuracy*100])
|
174 |
+
|
175 |
+
# Learning rate scheduling
|
176 |
+
scheduler.step()
|
177 |
+
|
178 |
+
# Save the best model and check for early stopping
|
179 |
+
if avg_val_loss < best_val_loss:
|
180 |
+
best_val_loss = avg_val_loss
|
181 |
+
torch.save(model.state_dict(), config["best_model_path"])
|
182 |
+
logger.info(f"Saved best model to {config['best_model_path']}")
|
183 |
+
epochs_without_improvement = 0
|
184 |
+
else:
|
185 |
+
epochs_without_improvement += 1
|
186 |
+
|
187 |
+
# Early stopping check
|
188 |
+
if epochs_without_improvement >= config["patience"]:
|
189 |
+
logger.info(f"Early stopping triggered after {config['patience']} epochs without improvement")
|
190 |
+
break
|
191 |
+
|
192 |
+
# Overfitting detection
|
193 |
+
if avg_train_accuracy - avg_val_accuracy > config["overfitting_threshold"]:
|
194 |
+
logger.warning("Possible overfitting detected")
|
195 |
+
|
196 |
+
logger.info("Training finished!")
|
197 |
+
|
198 |
+
# Save the final model
|
199 |
+
torch.save(model.state_dict(), config["final_model_path"])
|
200 |
+
logger.info(f"Saved final model to {config['final_model_path']}")
|
201 |
+
|
202 |
+
# Save run information
|
203 |
+
with open(os.path.join(config["run_dir"], 'run_info.txt'), 'w') as f:
|
204 |
+
for key, value in config.items():
|
205 |
+
f.write(f"{key}: {value}\n")
|
206 |
+
f.write(f"Device: {device}\n")
|
207 |
+
f.write(f"Model: {model.__class__.__name__}\n")
|
208 |
+
f.write(f"Optimizer: {optimizer.__class__.__name__}\n")
|
209 |
+
f.write(f"Scheduler: {scheduler.__class__.__name__}\n")
|
210 |
+
f.write(f"Loss function: CrossEntropyLoss\n")
|
211 |
+
f.write(f"Data augmentation: RandomHorizontalFlip, RandomRotation(5), ColorJitter\n")
|
212 |
+
f.write(f"Mixed precision training: {'Enabled' if 'scaler' in locals() else 'Disabled'}\n")
|
213 |
+
f.write(f"Train dataset size: {len(train_dataset)}\n")
|
214 |
+
f.write(f"Validation dataset size: {len(val_dataset)}\n")
|
215 |
+
f.write(f"Vision encoder frozen: {'Partially' if hasattr(model, 'unfreeze_vision_encoder') else 'Unknown'}\n")
|
216 |
+
|
217 |
+
|
218 |
+
print("Script finished.")
|
219 |
+
|
220 |
+
return avg_val_accuracy
|
221 |
+
|
222 |
+
def main():
|
223 |
+
# Create run directory
|
224 |
+
run_dir = create_run_directory()
|
225 |
+
class_labels = ["windmill", "halo", "swipe", "baby_mill"][:3]
|
226 |
+
|
227 |
+
# Write configuration
|
228 |
+
config = {
|
229 |
+
"class_labels": class_labels,
|
230 |
+
"num_classes": len(class_labels),
|
231 |
+
"data_path": '../finetune/3moves_otherpeopletrain',
|
232 |
+
"batch_size": 32,
|
233 |
+
"learning_rate": 2e-6,
|
234 |
+
"weight_decay": 0.007,
|
235 |
+
"num_epochs": 1,
|
236 |
+
"patience": 10, # for early stopping
|
237 |
+
"max_frames": 10,
|
238 |
+
"sigma": 0.3,
|
239 |
+
"image_size": 224,
|
240 |
+
"flip_probability": 0.5,
|
241 |
+
"rotation_degrees": 15,
|
242 |
+
"brightness_jitter": 0.2,
|
243 |
+
"contrast_jitter": 0.2,
|
244 |
+
"saturation_jitter": 0.2,
|
245 |
+
"hue_jitter": 0.1,
|
246 |
+
"crop_scale_min": 0.8,
|
247 |
+
"crop_scale_max": 1.0,
|
248 |
+
"normalization_mean": [0.485, 0.456, 0.406],
|
249 |
+
"normalization_std": [0.229, 0.224, 0.225],
|
250 |
+
"unfreeze_layers": 3,
|
251 |
+
"clip_model": "openai/clip-vit-large-patch14",
|
252 |
+
# "clip_model": "openai/clip-vit-base-patch32",
|
253 |
+
"gradient_clip_max_norm": 1.0,
|
254 |
+
"overfitting_threshold": 10,
|
255 |
+
"run_dir": run_dir,
|
256 |
+
"best_model_path": os.path.join(run_dir, 'best_model.pth'),
|
257 |
+
"final_model_path": os.path.join(run_dir, 'final_model.pth'),
|
258 |
+
"csv_path": os.path.join(run_dir, 'training_log.csv'),
|
259 |
+
"misclassifications_dir": os.path.join(run_dir, 'misclassifications'),
|
260 |
+
}
|
261 |
+
train_and_evaluate(config)
|
262 |
+
|
263 |
+
if __name__ == "__main__":
|
264 |
+
main()
|
scripts/visualization/miscalculations_report.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from collections import Counter
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from pathlib import Path
|
6 |
+
import sys
|
7 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
8 |
+
from src.utils.utils import get_latest_run_dir
|
9 |
+
|
10 |
+
def analyze_misclassifications(run_dir=None):
|
11 |
+
if run_dir is None:
|
12 |
+
# run_dir = "/home/bawolf/workspace/break/clip/runs/run_20241022-122939_3moves_balanced"
|
13 |
+
run_dir = get_latest_run_dir()
|
14 |
+
|
15 |
+
misclassifications_dir = os.path.join(run_dir, 'misclassifications')
|
16 |
+
all_misclassifications = {}
|
17 |
+
|
18 |
+
# Collect all misclassifications across epochs
|
19 |
+
for file in os.listdir(misclassifications_dir):
|
20 |
+
if file.endswith('.json'):
|
21 |
+
with open(os.path.join(misclassifications_dir, file), 'r') as f:
|
22 |
+
epoch_misclassifications = json.load(f)
|
23 |
+
for item in epoch_misclassifications:
|
24 |
+
video_path = item['video_path']
|
25 |
+
if video_path not in all_misclassifications:
|
26 |
+
all_misclassifications[video_path] = []
|
27 |
+
all_misclassifications[video_path].append(item)
|
28 |
+
|
29 |
+
# Determine the total number of epochs from the files
|
30 |
+
epoch_files = [f for f in os.listdir(misclassifications_dir) if f.startswith('epoch_') and f.endswith('.json')]
|
31 |
+
total_epochs = len(epoch_files)
|
32 |
+
|
33 |
+
# Count misclassifications per video
|
34 |
+
misclassification_counts = {video: len(misclassifications)
|
35 |
+
for video, misclassifications in all_misclassifications.items()}
|
36 |
+
|
37 |
+
# Calculate percentage of epochs each video was misclassified
|
38 |
+
misclassification_percentages = {video: (count / total_epochs) * 100
|
39 |
+
for video, count in misclassification_counts.items()}
|
40 |
+
|
41 |
+
# Sort videos by misclassification percentage
|
42 |
+
sorted_videos = sorted(misclassification_percentages.items(), key=lambda x: x[1], reverse=True)
|
43 |
+
|
44 |
+
# Prepare report
|
45 |
+
report = "Misclassification Analysis Report\n"
|
46 |
+
report += "=================================\n\n"
|
47 |
+
|
48 |
+
# Top N most misclassified videos
|
49 |
+
N = 20
|
50 |
+
report += f"Top {N} Most Misclassified Videos:\n"
|
51 |
+
for video, percentage in sorted_videos[:N]:
|
52 |
+
report += f"{Path(video).name}: Misclassified in {percentage:.2f}% of epochs ({misclassification_counts[video]} out of {total_epochs})\n"
|
53 |
+
misclassifications = all_misclassifications[video]
|
54 |
+
true_label = misclassifications[0]['true_label']
|
55 |
+
predicted_labels = Counter(m['predicted_label'] for m in misclassifications)
|
56 |
+
report += f" True Label: {true_label}\n"
|
57 |
+
report += f" Predicted Labels: {dict(predicted_labels)}\n\n"
|
58 |
+
|
59 |
+
# Overall statistics
|
60 |
+
total_misclassifications = sum(misclassification_counts.values())
|
61 |
+
total_videos = len(misclassification_counts)
|
62 |
+
report += "Overall Statistics:\n"
|
63 |
+
report += f"Total misclassified videos: {total_videos}\n"
|
64 |
+
report += f"Total misclassifications: {total_misclassifications}\n"
|
65 |
+
report += f"Average misclassification percentage per video: {sum(misclassification_percentages.values()) / total_videos:.2f}%\n"
|
66 |
+
report += f"Total epochs: {total_epochs}\n"
|
67 |
+
|
68 |
+
# Save report
|
69 |
+
report_path = os.path.join(run_dir, 'misclassification_report.txt')
|
70 |
+
with open(report_path, 'w') as f:
|
71 |
+
f.write(report)
|
72 |
+
|
73 |
+
# Create visualization
|
74 |
+
plt.figure(figsize=(12, 6))
|
75 |
+
plt.bar(range(len(sorted_videos)), [percentage for _, percentage in sorted_videos])
|
76 |
+
plt.title(f'Videos Ranked by Misclassification Percentage (Total Epochs: {total_epochs})')
|
77 |
+
plt.xlabel('Video Rank')
|
78 |
+
plt.ylabel('Misclassification Percentage')
|
79 |
+
plt.ylim(0, 100) # Set y-axis limit to 0-100%
|
80 |
+
plt.tight_layout()
|
81 |
+
plt.savefig(os.path.join(run_dir, 'misclassification_distribution.png'))
|
82 |
+
|
83 |
+
print(f"Analysis complete. Report saved to {report_path}")
|
84 |
+
print(f"Visualization saved to {os.path.join(run_dir, 'misclassification_distribution.png')}")
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
import sys
|
88 |
+
if len(sys.argv) > 2:
|
89 |
+
print("Usage: python analyze_misclassifications.py [path_to_run_directory]")
|
90 |
+
sys.exit(1)
|
91 |
+
|
92 |
+
run_dir = sys.argv[1] if len(sys.argv) == 2 else None
|
93 |
+
analyze_misclassifications(run_dir)
|
scripts/visualization/visualize.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import seaborn as sns
|
4 |
+
from sklearn.metrics import confusion_matrix, precision_recall_curve, average_precision_score, roc_curve, auc
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
11 |
+
|
12 |
+
from src.data.dataset import VideoDataset
|
13 |
+
from src.utils.utils import get_latest_model_path, get_latest_run_dir, get_config
|
14 |
+
from src.models.model import load_model
|
15 |
+
import json
|
16 |
+
|
17 |
+
def plot_training_curves(log_file, output_dir):
|
18 |
+
data = pd.read_csv(log_file)
|
19 |
+
|
20 |
+
plt.figure(figsize=(12, 5))
|
21 |
+
|
22 |
+
# Plot loss curves
|
23 |
+
plt.subplot(1, 2, 1)
|
24 |
+
plt.plot(data['epoch'], data['train_loss'], label='Train Loss')
|
25 |
+
plt.plot(data['epoch'], data['val_loss'], label='Validation Loss')
|
26 |
+
plt.xlabel('Epochs')
|
27 |
+
plt.ylabel('Loss')
|
28 |
+
plt.title('Training and Validation Loss')
|
29 |
+
plt.legend()
|
30 |
+
|
31 |
+
# Plot accuracy curves
|
32 |
+
plt.subplot(1, 2, 2)
|
33 |
+
plt.plot(data['epoch'], data['train_accuracy'], label='Train Accuracy')
|
34 |
+
plt.plot(data['epoch'], data['val_accuracy'], label='Validation Accuracy')
|
35 |
+
plt.xlabel('Epochs')
|
36 |
+
plt.ylabel('Accuracy')
|
37 |
+
plt.title('Training and Validation Accuracy')
|
38 |
+
plt.legend()
|
39 |
+
|
40 |
+
plt.tight_layout()
|
41 |
+
plt.savefig(os.path.join(output_dir, 'training_curves.png'))
|
42 |
+
plt.close()
|
43 |
+
|
44 |
+
def generate_evaluation_metrics(model, data_loader, device, output_dir, class_labels, data_info):
|
45 |
+
model.eval()
|
46 |
+
all_preds = []
|
47 |
+
all_labels = []
|
48 |
+
all_probs = []
|
49 |
+
|
50 |
+
with torch.no_grad():
|
51 |
+
for frames, labels, _ in data_loader:
|
52 |
+
frames = frames.to(device)
|
53 |
+
labels = labels.to(device)
|
54 |
+
|
55 |
+
outputs = model(frames)
|
56 |
+
probs = torch.softmax(outputs, dim=1)
|
57 |
+
_, predicted = outputs.max(1)
|
58 |
+
|
59 |
+
all_preds.extend(predicted.cpu().numpy())
|
60 |
+
all_labels.extend(labels.cpu().numpy())
|
61 |
+
all_probs.extend(probs.cpu().numpy())
|
62 |
+
|
63 |
+
all_labels = np.array(all_labels)
|
64 |
+
all_preds = np.array(all_preds)
|
65 |
+
all_probs = np.array(all_probs)
|
66 |
+
|
67 |
+
# Compute and plot confusion matrix
|
68 |
+
cm = confusion_matrix(all_labels, all_preds)
|
69 |
+
plt.figure(figsize=(10, 8))
|
70 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
|
71 |
+
plt.xlabel('Predicted Value')
|
72 |
+
plt.ylabel('Actual Value')
|
73 |
+
plt.title(f'Confusion Matrix\n{data_info}')
|
74 |
+
plt.savefig(os.path.join(output_dir, 'confusion_matrix.png'))
|
75 |
+
plt.close()
|
76 |
+
|
77 |
+
colors = ['blue', 'red', 'green', 'yellow', 'purple', 'orange', 'pink', 'cyan']
|
78 |
+
|
79 |
+
# Precision-Recall Curve
|
80 |
+
plt.figure(figsize=(10, 8))
|
81 |
+
for i, class_label in enumerate(class_labels):
|
82 |
+
precision, recall, _ = precision_recall_curve(all_labels == i, all_probs[:, i])
|
83 |
+
average_precision = average_precision_score(all_labels == i, all_probs[:, i])
|
84 |
+
plt.plot(recall, precision, color=colors[i], lw=2,
|
85 |
+
label=f'{class_label} (AP = {average_precision:.2f})')
|
86 |
+
|
87 |
+
plt.xlabel('Recall')
|
88 |
+
plt.ylabel('Precision')
|
89 |
+
plt.title(f'Precision-Recall Curve\n{data_info}')
|
90 |
+
plt.legend(loc="lower left")
|
91 |
+
plt.savefig(f'{output_dir}/precision_recall_curve.png')
|
92 |
+
plt.close()
|
93 |
+
|
94 |
+
# ROC Curve
|
95 |
+
plt.figure(figsize=(10, 8))
|
96 |
+
for i, class_label in enumerate(class_labels):
|
97 |
+
fpr, tpr, _ = roc_curve(all_labels == i, all_probs[:, i])
|
98 |
+
roc_auc = auc(fpr, tpr)
|
99 |
+
plt.plot(fpr, tpr, color=colors[i], lw=2,
|
100 |
+
label=f'{class_label} (AUC = {roc_auc:.2f})')
|
101 |
+
|
102 |
+
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
|
103 |
+
plt.xlim([0.0, 1.0])
|
104 |
+
plt.ylim([0.0, 1.05])
|
105 |
+
plt.xlabel('False Positive Rate')
|
106 |
+
plt.ylabel('True Positive Rate')
|
107 |
+
plt.title(f'Receiver Operating Characteristic (ROC) Curve\n{data_info}')
|
108 |
+
plt.legend(loc="lower right")
|
109 |
+
plt.savefig(f'{output_dir}/roc_curve.png')
|
110 |
+
plt.close()
|
111 |
+
|
112 |
+
return cm
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
# Find the most recent run directory
|
116 |
+
#
|
117 |
+
run_dir = get_latest_run_dir()
|
118 |
+
# run_dir= "/home/bawolf/workspace/break/clip/runs/run_20241024-150232_otherpeopleval_large_model"
|
119 |
+
# run_dir = "/home/bawolf/workspace/break/clip/runs/run_20241022-122939_3moves_balanced"
|
120 |
+
|
121 |
+
# Load configuration
|
122 |
+
config = get_config(run_dir)
|
123 |
+
|
124 |
+
class_labels = config['class_labels']
|
125 |
+
num_classes = config['num_classes']
|
126 |
+
data_path = config['data_path']
|
127 |
+
# data_path= '../finetune/3moves_otherpeopleval'
|
128 |
+
# data_path = '../finetune/otherpeople3moves'
|
129 |
+
|
130 |
+
# Paths
|
131 |
+
log_file = os.path.join(run_dir, 'training_log.csv')
|
132 |
+
model_path = get_latest_model_path(run_dir)
|
133 |
+
test_csv = os.path.join(data_path, 'test.csv')
|
134 |
+
# test_csv = os.path.join(data_path, 'val.csv')
|
135 |
+
# test_csv = os.path.join(data_path, 'train.csv')
|
136 |
+
|
137 |
+
# Get the last directory of data_path and the file name
|
138 |
+
last_dir = os.path.basename(os.path.normpath(data_path))
|
139 |
+
file_name = os.path.basename(test_csv)
|
140 |
+
|
141 |
+
# Create a directory for visualization outputs
|
142 |
+
vis_dir = os.path.join(run_dir, f'visualization_{last_dir}_{file_name.split(".")[0]}')
|
143 |
+
os.makedirs(vis_dir, exist_ok=True)
|
144 |
+
|
145 |
+
# Create data_info string for chart headers
|
146 |
+
data_info = f'Data: {last_dir}, File: {file_name}'
|
147 |
+
|
148 |
+
# Plot training curves
|
149 |
+
plot_training_curves(log_file, vis_dir)
|
150 |
+
|
151 |
+
# Load model
|
152 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
153 |
+
model = load_model(num_classes, model_path, device, config['clip_model'])
|
154 |
+
model.eval()
|
155 |
+
|
156 |
+
# Create test dataset and dataloader
|
157 |
+
test_dataset = VideoDataset(test_csv, config)
|
158 |
+
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)
|
159 |
+
|
160 |
+
# Generate evaluation metrics
|
161 |
+
cm = generate_evaluation_metrics(model, test_loader, device, vis_dir, class_labels, data_info)
|
162 |
+
|
163 |
+
print(f"Visualization complete! Check the output directory: {vis_dir}")
|
src/models/__init__.py
ADDED
File without changes
|
src/models/model.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import CLIPModel
|
4 |
+
|
5 |
+
class VariableLengthCLIP(nn.Module):
|
6 |
+
def __init__(self, clip_model, num_classes):
|
7 |
+
super().__init__()
|
8 |
+
self.clip_model = clip_model
|
9 |
+
self.visual_projection = nn.Linear(clip_model.visual_projection.in_features, num_classes)
|
10 |
+
|
11 |
+
def forward(self, x):
|
12 |
+
batch_size, num_frames, c, h, w = x.shape
|
13 |
+
x = x.view(batch_size * num_frames, c, h, w)
|
14 |
+
features = self.clip_model.vision_model(x).pooler_output
|
15 |
+
features = features.view(batch_size, num_frames, -1)
|
16 |
+
features = torch.mean(features, dim=1) # Average over frames
|
17 |
+
return self.visual_projection(features)
|
18 |
+
|
19 |
+
def unfreeze_vision_encoder(self, num_layers=2):
|
20 |
+
# Freeze the entire vision encoder
|
21 |
+
for param in self.clip_model.vision_model.parameters():
|
22 |
+
param.requires_grad = False
|
23 |
+
# Unfreeze the last few layers of the vision encoder
|
24 |
+
for param in self.clip_model.vision_model.encoder.layers[-num_layers:].parameters():
|
25 |
+
param.requires_grad = True
|
26 |
+
|
27 |
+
def create_model(num_classes, pretrained_model_name="openai/clip-vit-base-patch32"):
|
28 |
+
clip_model = CLIPModel.from_pretrained(pretrained_model_name)
|
29 |
+
return VariableLengthCLIP(clip_model, num_classes)
|
30 |
+
|
31 |
+
def load_model(num_classes, model_path, device, pretrained_model_name="openai/clip-vit-base-patch32"):
|
32 |
+
# Create the model
|
33 |
+
model = create_model(num_classes, pretrained_model_name)
|
34 |
+
|
35 |
+
# Load the state dict
|
36 |
+
state_dict = torch.load(model_path, map_location=device, weights_only=True)
|
37 |
+
|
38 |
+
# Load the state dict, ignoring mismatched keys
|
39 |
+
model.load_state_dict(state_dict, strict=False)
|
40 |
+
|
41 |
+
model.to(device) # Move the model to the appropriate device
|
42 |
+
return model
|
src/utils/utils.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from datetime import datetime
|
4 |
+
|
5 |
+
def create_run_directory(base_dir='runs', prefix='run', suffix='', parent_dir=None):
|
6 |
+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
7 |
+
dir_name = f"{prefix}_{timestamp}{suffix}"
|
8 |
+
|
9 |
+
if parent_dir:
|
10 |
+
run_dir = os.path.join(parent_dir, dir_name)
|
11 |
+
else:
|
12 |
+
run_dir = os.path.join(base_dir, dir_name)
|
13 |
+
os.makedirs(run_dir, exist_ok=True)
|
14 |
+
return run_dir
|
15 |
+
|
16 |
+
# Find the most recent run directory
|
17 |
+
def get_latest_run_dir(base_dir='runs', include_hyperparam=True):
|
18 |
+
all_dirs = []
|
19 |
+
|
20 |
+
for d in os.listdir(base_dir):
|
21 |
+
if d.startswith('run_'):
|
22 |
+
full_path = os.path.join(base_dir, d)
|
23 |
+
all_dirs.append(full_path)
|
24 |
+
|
25 |
+
if d.endswith('_hyperparam') and include_hyperparam:
|
26 |
+
# If it's a hyperparam directory, add its trial subdirectories
|
27 |
+
trial_dirs = [os.path.join(full_path, td) for td in os.listdir(full_path) if td.startswith('trial_')]
|
28 |
+
all_dirs.extend(trial_dirs)
|
29 |
+
|
30 |
+
if not all_dirs:
|
31 |
+
raise ValueError(f"No run directories found in {base_dir}")
|
32 |
+
|
33 |
+
# Sort directories by timestamp in the directory name
|
34 |
+
return max(all_dirs, key=get_dir_timestamp)
|
35 |
+
|
36 |
+
def get_run_file(filename, run_dir=None, required=True):
|
37 |
+
"""Get a file from a run directory
|
38 |
+
|
39 |
+
Args:
|
40 |
+
filename: Name of file to get (e.g., 'best_model.pth', 'config.json')
|
41 |
+
run_dir: Run directory path (uses latest if None)
|
42 |
+
required: Whether to raise an error if file not found
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
str: Path to the file
|
46 |
+
dict: Loaded JSON data if file ends with .json
|
47 |
+
"""
|
48 |
+
if run_dir is None:
|
49 |
+
run_dir = get_latest_run_dir()
|
50 |
+
|
51 |
+
file_path = os.path.join(run_dir, filename)
|
52 |
+
|
53 |
+
if not os.path.exists(file_path):
|
54 |
+
if required:
|
55 |
+
raise FileNotFoundError(f"{filename} not found in {run_dir}")
|
56 |
+
return None
|
57 |
+
|
58 |
+
# Load JSON files automatically
|
59 |
+
if filename.endswith('.json'):
|
60 |
+
with open(file_path, 'r') as f:
|
61 |
+
return json.load(f)
|
62 |
+
|
63 |
+
return file_path
|
64 |
+
|
65 |
+
def get_latest_model_path(run_dir=None):
|
66 |
+
"""Get path to best_model.pth"""
|
67 |
+
return get_run_file('best_model.pth', run_dir)
|
68 |
+
|
69 |
+
def get_config(run_dir=None):
|
70 |
+
"""Get config from run directory"""
|
71 |
+
return get_run_file('config.json', run_dir)
|
72 |
+
|
73 |
+
# Helper function to parse directory name and get timestamp
|
74 |
+
def get_dir_timestamp(dir_path):
|
75 |
+
dir_name = os.path.basename(dir_path)
|
76 |
+
try:
|
77 |
+
# Extract timestamp from directory name
|
78 |
+
timestamp_str = dir_name.split('_')[1] # Assumes format is always prefix_timestamp or prefix_timestamp_suffix
|
79 |
+
return datetime.strptime(timestamp_str, "%Y%m%d-%H%M%S")
|
80 |
+
except (IndexError, ValueError):
|
81 |
+
# If parsing fails, return the earliest possible date
|
82 |
+
return datetime.min
|