|
import argparse |
|
|
|
import numpy as np |
|
|
|
import os |
|
|
|
import shutil |
|
|
|
import torch |
|
import torch.optim as optim |
|
|
|
from torch.utils.data import DataLoader |
|
|
|
from tqdm import tqdm |
|
|
|
import warnings |
|
|
|
from lib.dataset import MegaDepthDataset |
|
from lib.exceptions import NoGradientError |
|
from lib.loss import loss_function |
|
from lib.model import D2Net |
|
|
|
|
|
|
|
use_cuda = torch.cuda.is_available() |
|
device = torch.device("cuda:0" if use_cuda else "cpu") |
|
|
|
|
|
torch.manual_seed(1) |
|
if use_cuda: |
|
torch.cuda.manual_seed(1) |
|
np.random.seed(1) |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Training script") |
|
|
|
parser.add_argument( |
|
"--dataset_path", type=str, required=True, help="path to the dataset" |
|
) |
|
parser.add_argument( |
|
"--scene_info_path", type=str, required=True, help="path to the processed scenes" |
|
) |
|
|
|
parser.add_argument( |
|
"--preprocessing", |
|
type=str, |
|
default="caffe", |
|
help="image preprocessing (caffe or torch)", |
|
) |
|
parser.add_argument( |
|
"--model_file", type=str, default="models/d2_ots.pth", help="path to the full model" |
|
) |
|
|
|
parser.add_argument( |
|
"--num_epochs", type=int, default=10, help="number of training epochs" |
|
) |
|
parser.add_argument("--lr", type=float, default=1e-3, help="initial learning rate") |
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size") |
|
parser.add_argument( |
|
"--num_workers", type=int, default=4, help="number of workers for data loading" |
|
) |
|
|
|
parser.add_argument( |
|
"--use_validation", |
|
dest="use_validation", |
|
action="store_true", |
|
help="use the validation split", |
|
) |
|
parser.set_defaults(use_validation=False) |
|
|
|
parser.add_argument( |
|
"--log_interval", type=int, default=250, help="loss logging interval" |
|
) |
|
|
|
parser.add_argument("--log_file", type=str, default="log.txt", help="loss logging file") |
|
|
|
parser.add_argument( |
|
"--plot", dest="plot", action="store_true", help="plot training pairs" |
|
) |
|
parser.set_defaults(plot=False) |
|
|
|
parser.add_argument( |
|
"--checkpoint_directory", |
|
type=str, |
|
default="checkpoints", |
|
help="directory for training checkpoints", |
|
) |
|
parser.add_argument( |
|
"--checkpoint_prefix", |
|
type=str, |
|
default="d2", |
|
help="prefix for training checkpoints", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
print(args) |
|
|
|
|
|
if args.plot: |
|
plot_path = "train_vis" |
|
if os.path.isdir(plot_path): |
|
print("[Warning] Plotting directory already exists.") |
|
else: |
|
os.mkdir(plot_path) |
|
|
|
|
|
model = D2Net(model_file=args.model_file, use_cuda=use_cuda) |
|
|
|
|
|
optimizer = optim.Adam( |
|
filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr |
|
) |
|
|
|
|
|
if args.use_validation: |
|
validation_dataset = MegaDepthDataset( |
|
scene_list_path="megadepth_utils/valid_scenes.txt", |
|
scene_info_path=args.scene_info_path, |
|
base_path=args.dataset_path, |
|
train=False, |
|
preprocessing=args.preprocessing, |
|
pairs_per_scene=25, |
|
) |
|
validation_dataloader = DataLoader( |
|
validation_dataset, batch_size=args.batch_size, num_workers=args.num_workers |
|
) |
|
|
|
training_dataset = MegaDepthDataset( |
|
scene_list_path="megadepth_utils/train_scenes.txt", |
|
scene_info_path=args.scene_info_path, |
|
base_path=args.dataset_path, |
|
preprocessing=args.preprocessing, |
|
) |
|
training_dataloader = DataLoader( |
|
training_dataset, batch_size=args.batch_size, num_workers=args.num_workers |
|
) |
|
|
|
|
|
|
|
def process_epoch( |
|
epoch_idx, |
|
model, |
|
loss_function, |
|
optimizer, |
|
dataloader, |
|
device, |
|
log_file, |
|
args, |
|
train=True, |
|
): |
|
epoch_losses = [] |
|
|
|
torch.set_grad_enabled(train) |
|
|
|
progress_bar = tqdm(enumerate(dataloader), total=len(dataloader)) |
|
for batch_idx, batch in progress_bar: |
|
if train: |
|
optimizer.zero_grad() |
|
|
|
batch["train"] = train |
|
batch["epoch_idx"] = epoch_idx |
|
batch["batch_idx"] = batch_idx |
|
batch["batch_size"] = args.batch_size |
|
batch["preprocessing"] = args.preprocessing |
|
batch["log_interval"] = args.log_interval |
|
|
|
try: |
|
loss = loss_function(model, batch, device, plot=args.plot) |
|
except NoGradientError: |
|
continue |
|
|
|
current_loss = loss.data.cpu().numpy()[0] |
|
epoch_losses.append(current_loss) |
|
|
|
progress_bar.set_postfix(loss=("%.4f" % np.mean(epoch_losses))) |
|
|
|
if batch_idx % args.log_interval == 0: |
|
log_file.write( |
|
"[%s] epoch %d - batch %d / %d - avg_loss: %f\n" |
|
% ( |
|
"train" if train else "valid", |
|
epoch_idx, |
|
batch_idx, |
|
len(dataloader), |
|
np.mean(epoch_losses), |
|
) |
|
) |
|
|
|
if train: |
|
loss.backward() |
|
optimizer.step() |
|
|
|
log_file.write( |
|
"[%s] epoch %d - avg_loss: %f\n" |
|
% ("train" if train else "valid", epoch_idx, np.mean(epoch_losses)) |
|
) |
|
log_file.flush() |
|
|
|
return np.mean(epoch_losses) |
|
|
|
|
|
|
|
if os.path.isdir(args.checkpoint_directory): |
|
print("[Warning] Checkpoint directory already exists.") |
|
else: |
|
os.mkdir(args.checkpoint_directory) |
|
|
|
|
|
|
|
if os.path.exists(args.log_file): |
|
print("[Warning] Log file already exists.") |
|
log_file = open(args.log_file, "a+") |
|
|
|
|
|
train_loss_history = [] |
|
validation_loss_history = [] |
|
if args.use_validation: |
|
validation_dataset.build_dataset() |
|
min_validation_loss = process_epoch( |
|
0, |
|
model, |
|
loss_function, |
|
optimizer, |
|
validation_dataloader, |
|
device, |
|
log_file, |
|
args, |
|
train=False, |
|
) |
|
|
|
|
|
for epoch_idx in range(1, args.num_epochs + 1): |
|
|
|
training_dataset.build_dataset() |
|
train_loss_history.append( |
|
process_epoch( |
|
epoch_idx, |
|
model, |
|
loss_function, |
|
optimizer, |
|
training_dataloader, |
|
device, |
|
log_file, |
|
args, |
|
) |
|
) |
|
|
|
if args.use_validation: |
|
validation_loss_history.append( |
|
process_epoch( |
|
epoch_idx, |
|
model, |
|
loss_function, |
|
optimizer, |
|
validation_dataloader, |
|
device, |
|
log_file, |
|
args, |
|
train=False, |
|
) |
|
) |
|
|
|
|
|
checkpoint_path = os.path.join( |
|
args.checkpoint_directory, "%s.%02d.pth" % (args.checkpoint_prefix, epoch_idx) |
|
) |
|
checkpoint = { |
|
"args": args, |
|
"epoch_idx": epoch_idx, |
|
"model": model.state_dict(), |
|
"optimizer": optimizer.state_dict(), |
|
"train_loss_history": train_loss_history, |
|
"validation_loss_history": validation_loss_history, |
|
} |
|
torch.save(checkpoint, checkpoint_path) |
|
if args.use_validation and validation_loss_history[-1] < min_validation_loss: |
|
min_validation_loss = validation_loss_history[-1] |
|
best_checkpoint_path = os.path.join( |
|
args.checkpoint_directory, "%s.best.pth" % args.checkpoint_prefix |
|
) |
|
shutil.copy(checkpoint_path, best_checkpoint_path) |
|
|
|
|
|
log_file.close() |
|
|