|
import os |
|
import configparser |
|
import argparse |
|
import logging |
|
from functools import partial |
|
from typing import Any, Dict, Optional, Union |
|
|
|
import lightning as L |
|
from lightning.pytorch import seed_everything |
|
from lightning.pytorch.loggers import WandbLogger |
|
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar |
|
import torch |
|
from torch.utils.data import DataLoader |
|
|
|
from data.kubric_data import KubricData |
|
from models.locotrack_model import LocoTrack |
|
import model_utils |
|
from data.evaluation_datasets import get_eval_dataset |
|
|
|
|
|
class LocoTrackModel(L.LightningModule): |
|
def __init__( |
|
self, |
|
model_kwargs: Optional[Dict[str, Any]] = None, |
|
model_forward_kwargs: Optional[Dict[str, Any]] = None, |
|
loss_name: Optional[str] = 'tapir_loss', |
|
loss_kwargs: Optional[Dict[str, Any]] = None, |
|
query_first: Optional[bool] = False, |
|
optimizer_name: Optional[str] = 'Adam', |
|
optimizer_kwargs: Optional[Dict[str, Any]] = None, |
|
scheduler_name: Optional[str] = 'OneCycleLR', |
|
scheduler_kwargs: Optional[Dict[str, Any]] = None, |
|
): |
|
super().__init__() |
|
self.model = LocoTrack(**(model_kwargs or {})) |
|
self.model_forward_kwargs = model_forward_kwargs or {} |
|
self.loss = partial(model_utils.__dict__[loss_name], **(loss_kwargs or {})) |
|
self.query_first = query_first |
|
|
|
self.optimizer_name = optimizer_name |
|
self.optimizer_kwargs = optimizer_kwargs or {'lr': 2e-3} |
|
self.scheduler_name = scheduler_name |
|
self.scheduler_kwargs = scheduler_kwargs or {'max_lr': 2e-3, 'pct_start': 0.05, 'total_steps': 300000} |
|
|
|
def training_step(self, batch, batch_idx): |
|
output = self.model(batch['video'], batch['query_points'], **self.model_forward_kwargs) |
|
loss, loss_scalars = self.loss(batch, output) |
|
|
|
self.log_dict( |
|
{f'train/{k}': v.item() for k, v in loss_scalars.items()}, |
|
logger=True, |
|
on_step=True, |
|
sync_dist=True, |
|
) |
|
|
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx, dataloader_idx=None): |
|
output = self.model(batch['video'], batch['query_points'], **self.model_forward_kwargs) |
|
loss, loss_scalars = self.loss(batch, output) |
|
metrics = model_utils.eval_batch(batch, output, query_first=self.query_first) |
|
|
|
if self.trainer.global_rank == 0: |
|
log_prefix = 'val/' |
|
if dataloader_idx is not None: |
|
log_prefix = f'val/data_{dataloader_idx}/' |
|
|
|
self.log_dict( |
|
{log_prefix + k: v for k, v in loss_scalars.items()}, |
|
logger=True, |
|
rank_zero_only=True, |
|
) |
|
self.log_dict( |
|
{log_prefix + k: v.item() for k, v in metrics.items()}, |
|
logger=True, |
|
rank_zero_only=True, |
|
) |
|
logging.info(f"Batch {batch_idx}: {metrics}") |
|
|
|
def test_step(self, batch, batch_idx, dataloader_idx=None): |
|
output = self.model(batch['video'], batch['query_points'], **self.model_forward_kwargs) |
|
loss, loss_scalars = self.loss(batch, output) |
|
metrics = model_utils.eval_batch(batch, output, query_first=self.query_first) |
|
|
|
if self.trainer.global_rank == 0: |
|
log_prefix = 'test/' |
|
if dataloader_idx is not None: |
|
log_prefix = f'test/data_{dataloader_idx}/' |
|
|
|
self.log_dict( |
|
{log_prefix + k: v for k, v in loss_scalars.items()}, |
|
logger=True, |
|
rank_zero_only=True, |
|
) |
|
self.log_dict( |
|
{log_prefix + k: v.item() for k, v in metrics.items()}, |
|
logger=True, |
|
rank_zero_only=True, |
|
) |
|
logging.info(f"Batch {batch_idx}: {metrics}") |
|
|
|
def configure_optimizers(self): |
|
weights = [p for n, p in self.named_parameters() if 'bias' not in n] |
|
bias = [p for n, p in self.named_parameters() if 'bias' in n] |
|
|
|
optimizer = torch.optim.__dict__[self.optimizer_name]( |
|
[ |
|
{'params': weights, **self.optimizer_kwargs}, |
|
{'params': bias, **self.optimizer_kwargs, 'weight_decay': 0.} |
|
] |
|
) |
|
scheduler = torch.optim.lr_scheduler.__dict__[self.scheduler_name](optimizer, **self.scheduler_kwargs) |
|
|
|
return [optimizer], [{"scheduler": scheduler, "interval": "step"}] |
|
|
|
|
|
def train( |
|
mode: str, |
|
save_path: str, |
|
val_dataset_path: str, |
|
ckpt_path: str = None, |
|
kubric_dir: str = '', |
|
precision: str = '32', |
|
batch_size: int = 1, |
|
val_check_interval: Union[int, float] = 5000, |
|
log_every_n_steps: int = 10, |
|
gradient_clip_val: float = 1.0, |
|
max_steps: int = 300_000, |
|
model_kwargs: Optional[Dict[str, Any]] = None, |
|
model_forward_kwargs: Optional[Dict[str, Any]] = None, |
|
loss_name: str = 'tapir_loss', |
|
loss_kwargs: Optional[Dict[str, Any]] = None, |
|
optimizer_name: str = 'Adam', |
|
optimizer_kwargs: Optional[Dict[str, Any]] = None, |
|
scheduler_name: str = 'OneCycleLR', |
|
scheduler_kwargs: Optional[Dict[str, Any]] = None, |
|
|
|
): |
|
"""Train the LocoTrack model with specified configurations.""" |
|
seed_everything(42, workers=True) |
|
|
|
model = LocoTrackModel( |
|
model_kwargs=model_kwargs, |
|
model_forward_kwargs=model_forward_kwargs, |
|
loss_name=loss_name, |
|
loss_kwargs=loss_kwargs, |
|
query_first='q_first' in mode, |
|
optimizer_name=optimizer_name, |
|
optimizer_kwargs=optimizer_kwargs, |
|
scheduler_name=scheduler_name, |
|
scheduler_kwargs=scheduler_kwargs, |
|
) |
|
if ckpt_path is not None and 'train' in mode: |
|
model.load_state_dict(torch.load(ckpt_path)['state_dict']) |
|
|
|
logger = WandbLogger(project='LocoTrack_Pytorch', save_dir=save_path, id=os.path.basename(save_path)) |
|
lr_monitor = LearningRateMonitor(logging_interval='step') |
|
checkpoint_callback = ModelCheckpoint( |
|
dirpath=save_path, |
|
save_last=True, |
|
save_top_k=3, |
|
mode="max", |
|
monitor="val/average_pts_within_thresh", |
|
auto_insert_metric_name=True, |
|
save_on_train_epoch_end=False, |
|
) |
|
|
|
eval_dataset = get_eval_dataset( |
|
mode=mode, |
|
path=val_dataset_path, |
|
) |
|
eval_dataloder = { |
|
k: DataLoader( |
|
v, |
|
batch_size=1, |
|
shuffle=False, |
|
) for k, v in eval_dataset.items() |
|
} |
|
|
|
if 'train' in mode: |
|
trainer = L.Trainer( |
|
strategy='ddp', |
|
logger=logger, |
|
precision=precision, |
|
val_check_interval=val_check_interval, |
|
log_every_n_steps=log_every_n_steps, |
|
gradient_clip_val=gradient_clip_val, |
|
max_steps=max_steps, |
|
sync_batchnorm=True, |
|
callbacks=[checkpoint_callback, lr_monitor], |
|
) |
|
train_dataloader = KubricData( |
|
global_rank=trainer.global_rank, |
|
data_dir=kubric_dir, |
|
batch_size=batch_size * trainer.world_size, |
|
) |
|
trainer.fit(model, train_dataloader, eval_dataloder, ckpt_path=ckpt_path) |
|
elif 'eval' in mode: |
|
trainer = L.Trainer(strategy='ddp', logger=logger, precision=precision) |
|
trainer.test(model, eval_dataloder, ckpt_path=ckpt_path) |
|
else: |
|
raise ValueError(f"Invalid mode: {mode}") |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description="Train or evaluate the LocoTrack model.") |
|
parser.add_argument('--config', type=str, default='config.ini', help="Path to the configuration file.") |
|
parser.add_argument('--mode', type=str, required=True, help="Mode to run: 'train' or 'eval' with optional 'q_first' and the name of evaluation dataset.") |
|
parser.add_argument('--ckpt_path', type=str, default=None, help="Path to the checkpoint file") |
|
parser.add_argument('--save_path', type=str, default='snapshots', help="Path to save the logs and checkpoints.") |
|
|
|
args = parser.parse_args() |
|
config = configparser.ConfigParser() |
|
config.read(args.config) |
|
|
|
|
|
train_params = { |
|
'mode': args.mode, |
|
'ckpt_path': args.ckpt_path, |
|
'save_path': args.save_path, |
|
'val_dataset_path': eval(config.get('TRAINING', 'val_dataset_path', fallback='{}')), |
|
'kubric_dir': config.get('TRAINING', 'kubric_dir', fallback=''), |
|
'precision': config.get('TRAINING', 'precision', fallback='32'), |
|
'batch_size': config.getint('TRAINING', 'batch_size', fallback=1), |
|
'val_check_interval': config.getfloat('TRAINING', 'val_check_interval', fallback=5000), |
|
'log_every_n_steps': config.getint('TRAINING', 'log_every_n_steps', fallback=10), |
|
'gradient_clip_val': config.getfloat('TRAINING', 'gradient_clip_val', fallback=1.0), |
|
'max_steps': config.getint('TRAINING', 'max_steps', fallback=300000), |
|
'model_kwargs': eval(config.get('MODEL', 'model_kwargs', fallback='{}')), |
|
'model_forward_kwargs': eval(config.get('MODEL', 'model_forward_kwargs', fallback='{}')), |
|
'loss_name': config.get('LOSS', 'loss_name', fallback='tapir_loss'), |
|
'loss_kwargs': eval(config.get('LOSS', 'loss_kwargs', fallback='{}')), |
|
'optimizer_name': config.get('OPTIMIZER', 'optimizer_name', fallback='Adam'), |
|
'optimizer_kwargs': eval(config.get('OPTIMIZER', 'optimizer_kwargs', fallback='{"lr": 2e-3}')), |
|
'scheduler_name': config.get('SCHEDULER', 'scheduler_name', fallback='OneCycleLR'), |
|
'scheduler_kwargs': eval(config.get('SCHEDULER', 'scheduler_kwargs', fallback='{"max_lr": 2e-3, "pct_start": 0.05, "total_steps": 300000}')), |
|
} |
|
|
|
train(**train_params) |
|
|