import argparse
import numpy as np
import os
import sys
sys.path.append("../")

import shutil

import torch
import torch.optim as optim

from torch.utils.data import DataLoader

from tqdm import tqdm

import warnings

# from lib.dataset import MegaDepthDataset

from lib.exceptions import NoGradientError
from lib.loss import loss_function as orig_loss
from lib.losses.lossPhotoTourism import loss_function as ipr_loss
from lib.model import D2Net
from lib.dataloaders.datasetPhotoTourism_combined import PhotoTourismCombined


# CUDA
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:1" if use_cuda else "cpu")

# Seed
torch.manual_seed(1)
if use_cuda:
    torch.cuda.manual_seed(1)
np.random.seed(1)

# Argument parsing
parser = argparse.ArgumentParser(description='Training script')

parser.add_argument(
    '--dataset_path', type=str, default="/scratch/udit/phototourism/",
    help='path to the dataset'
)
# parser.add_argument(
#     '--scene_info_path', type=str, required=True,
#     help='path to the processed scenes'
# )

parser.add_argument(
    '--preprocessing', type=str, default='caffe',
    help='image preprocessing (caffe or torch)'
)

parser.add_argument(
    '--model_file', type=str, default='models/d2_ots.pth',
    help='path to the full model'
)

parser.add_argument(
    '--num_epochs', type=int, default=10,
    help='number of training epochs'
)
parser.add_argument(
    '--lr', type=float, default=1e-3,
    help='initial learning rate'
)
parser.add_argument(
    '--batch_size', type=int, default=1,
    help='batch size'
)
parser.add_argument(
    '--num_workers', type=int, default=16,
    help='number of workers for data loading'
)

parser.add_argument(
    '--use_validation', dest='use_validation', action='store_true',
    help='use the validation split'
)
parser.set_defaults(use_validation=False)

parser.add_argument(
    '--log_interval', type=int, default=250,
    help='loss logging interval'
)

parser.add_argument(
    '--log_file', type=str, default='log.txt',
    help='loss logging file'
)

parser.add_argument(
    '--plot', dest='plot', action='store_true',
    help='plot training pairs'
)
parser.set_defaults(plot=False)

parser.add_argument(
    '--checkpoint_directory', type=str, default='checkpoints',
    help='directory for training checkpoints'
)
parser.add_argument(
    '--checkpoint_prefix', type=str, default='d2',
    help='prefix for training checkpoints'
)

args = parser.parse_args()
print(args)

# Creating CNN model
model = D2Net(
    model_file=args.model_file,
    use_cuda=False
)
model = model.to(device)

# Optimizer
optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr
)

# Dataset
if args.use_validation:
    validation_dataset = PhotoTourismCombined(
        # scene_list_path='megadepth_utils/valid_scenes.txt',
        # scene_info_path=args.scene_info_path,
        base_path=args.dataset_path,
        train=False,
        preprocessing=args.preprocessing,
        pairs_per_scene=25
    )
    # validation_dataset.build_dataset()
    validation_dataloader = DataLoader(
        validation_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers
    )

training_dataset = PhotoTourismCombined(
    # scene_list_path='megadepth_utils/train_scenes.txt',
    # scene_info_path=args.scene_info_path,
    base_path=args.dataset_path,
    preprocessing=args.preprocessing
)
# training_dataset.build_dataset()

training_dataloader = DataLoader(
    training_dataset,
    batch_size=args.batch_size,
    num_workers=args.num_workers
)


# Define epoch function
def process_epoch(
        epoch_idx,
        model, loss_function, optimizer, dataloader, device,
        log_file, args, train=True, plot_path=None
):
    epoch_losses = []

    torch.set_grad_enabled(train)

    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for batch_idx, (batch,method) in progress_bar:
        if train:
            optimizer.zero_grad()

        batch['train'] = train
        batch['epoch_idx'] = epoch_idx
        batch['batch_idx'] = batch_idx
        batch['batch_size'] = args.batch_size
        batch['preprocessing'] = args.preprocessing
        batch['log_interval'] = args.log_interval

        try:
            loss = loss_function[method](model, batch, device, plot=args.plot, plot_path=plot_path)
        except NoGradientError:
            # print("failed")
            continue

        current_loss = loss.data.cpu().numpy()[0]
        epoch_losses.append(current_loss)

        progress_bar.set_postfix(loss=('%.4f' % np.mean(epoch_losses)))

        if batch_idx % args.log_interval == 0:
            log_file.write('[%s] epoch %d - batch %d / %d - avg_loss: %f\n' % (
                'train' if train else 'valid',
                epoch_idx, batch_idx, len(dataloader), np.mean(epoch_losses)
            ))

        if train:
            loss.backward()
            optimizer.step()

    log_file.write('[%s] epoch %d - avg_loss: %f\n' % (
        'train' if train else 'valid',
        epoch_idx,
        np.mean(epoch_losses)
    ))
    log_file.flush()

    return np.mean(epoch_losses)


# Create the checkpoint directory
checkpoint_directory = os.path.join(args.checkpoint_directory, args.checkpoint_prefix)
if os.path.isdir(checkpoint_directory):
    print('[Warning] Checkpoint directory already exists.')
else:
    os.makedirs(checkpoint_directory, exist_ok=True)

# Open the log file for writing
log_file = os.path.join(checkpoint_directory,args.log_file)
if os.path.exists(log_file):
    print('[Warning] Log file already exists.')
log_file = open(log_file, 'a+')

# Create the folders for plotting if need be
plot_path=None
if args.plot:
    plot_path = os.path.join(checkpoint_directory,'train_vis')
    if os.path.isdir(plot_path):
        print('[Warning] Plotting directory already exists.')
    else:
        os.makedirs(plot_path, exist_ok=True)


# Initialize the history
train_loss_history = []
validation_loss_history = []
if args.use_validation:
    min_validation_loss = process_epoch(
        0,
        model, [orig_loss, ipr_loss], optimizer, validation_dataloader, device,
        log_file, args,
        train=False
    )

# Start the training
for epoch_idx in range(1, args.num_epochs + 1):
    # Process epoch
    train_loss_history.append(
        process_epoch(
            epoch_idx,
            model, [orig_loss, ipr_loss], optimizer, training_dataloader, device,
            log_file, args, train=True, plot_path=plot_path
        )
    )

    if args.use_validation:
        validation_loss_history.append(
            process_epoch(
                epoch_idx,
                model, [orig_loss, ipr_loss], optimizer, validation_dataloader, device,
                log_file, args,
                train=False
            )
        )

    # Save the current checkpoint
    checkpoint_path = os.path.join(
        checkpoint_directory,
        '%02d.pth' % (epoch_idx)
    )
    checkpoint = {
        'args': args,
        'epoch_idx': epoch_idx,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'train_loss_history': train_loss_history,
        'validation_loss_history': validation_loss_history
    }
    torch.save(checkpoint, checkpoint_path)
    if (
        args.use_validation and
        validation_loss_history[-1] < min_validation_loss
    ):
        min_validation_loss = validation_loss_history[-1]
        best_checkpoint_path = os.path.join(
            checkpoint_directory,
            '%s.best.pth' % args.checkpoint_prefix
        )
        shutil.copy(checkpoint_path, best_checkpoint_path)

# Close the log file
log_file.close()