import numpy as np
import argparse
import copy
import os, sys
import open3d as o3d
from sys import argv, exit
from PIL import Image
import math
from tqdm import tqdm
import cv2


sys.path.append("../../")

from lib.extractMatchTop import getPerspKeypoints, getPerspKeypointsEnsemble, siftMatching
import pandas as pd


import torch
from lib.model_test import D2Net

#### Cuda ####
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if use_cuda else 'cpu')

#### Argument Parsing ####
parser = argparse.ArgumentParser(description='RoRD ICP evaluation on a DiverseView dataset sequence.')

parser.add_argument('--dataset', type=str, default='/scratch/udit/realsense/RoRD_data/preprocessed/', 
	help='path to the dataset folder')

parser.add_argument('--sequence', type=str, default='data1')

parser.add_argument(
	'--output_dir', type=str, default='out',
	help='output directory for RT estimates'
)

parser.add_argument(
	'--model_rord', type=str, help='path to the RoRD model for evaluation'
)

parser.add_argument(
	'--model_d2', type=str, help='path to the vanilla D2-Net model for evaluation'
)

parser.add_argument(
	'--model_ens', action='store_true',
	help='ensemble model of RoRD + D2-Net'
)

parser.add_argument(
	'--sift', action='store_true',
	help='Sift'
)

parser.add_argument(
	'--viz3d', action='store_true',
	help='visualize the pointcloud registrations'
)

parser.add_argument(
	'--log_interval', type=int, default=9,
	help='Matched image logging interval'
)

parser.add_argument(
	'--camera_file', type=str, default='../../configs/camera.txt',
	help='path to the camera intrinsics file. In order: focal_x, focal_y, center_x, center_y, scaling_factor.'
)

parser.add_argument(
	'--persp', action='store_true', default=False,
	help='Feature matching on perspective images.'
)

parser.set_defaults(fp16=False)
args = parser.parse_args()


if args.model_ens: # Change default paths accordingly for ensemble
	model1_ens = '../../models/rord.pth'
	model2_ens = '../../models/d2net.pth'

def draw_registration_result(source, target, transformation):
	source_temp = copy.deepcopy(source)
	target_temp = copy.deepcopy(target)
	source_temp.transform(transformation)
	trgSph.append(source_temp); trgSph.append(target_temp)
	axis1 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0])
	axis2 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0])
	axis2.transform(transformation)
	trgSph.append(axis1); trgSph.append(axis2)
	o3d.visualization.draw_geometries(trgSph)

def readDepth(depthFile):
	depth = Image.open(depthFile)
	if depth.mode != "I":
		raise Exception("Depth image is not in intensity format")

	return np.asarray(depth)

def readCamera(camera):
	with open (camera, "rt") as file:
		contents = file.read().split()

	focalX = float(contents[0])
	focalY = float(contents[1])
	centerX = float(contents[2])
	centerY = float(contents[3])
	scalingFactor = float(contents[4])

	return focalX, focalY, centerX, centerY, scalingFactor


def getPointCloud(rgbFile, depthFile, pts):
	thresh = 15.0

	depth = readDepth(depthFile)
	rgb = Image.open(rgbFile)

	points = []
	colors = []

	corIdx = [-1]*len(pts)
	corPts = [None]*len(pts)
	ptIdx = 0

	for v in range(depth.shape[0]):
		for u in range(depth.shape[1]):
			Z = depth[v, u] / scalingFactor
			if Z==0: continue
			if (Z > thresh): continue

			X = (u - centerX) * Z / focalX
			Y = (v - centerY) * Z / focalY

			points.append((X, Y, Z))
			colors.append(rgb.getpixel((u, v)))

			if((u, v) in pts):
				index = pts.index((u, v))
				corIdx[index] = ptIdx
				corPts[index] = (X, Y, Z)

			ptIdx = ptIdx+1

	points = np.asarray(points)
	colors = np.asarray(colors)

	pcd = o3d.geometry.PointCloud()
	pcd.points = o3d.utility.Vector3dVector(points)
	pcd.colors = o3d.utility.Vector3dVector(colors/255)

	return pcd, corIdx, corPts


def convertPts(A):
	X = A[0]; Y = A[1]

	x = [];	y = []

	for i in range(len(X)):
		x.append(int(float(X[i])))

	for i in range(len(Y)):
		y.append(int(float(Y[i])))

	pts = []
	for i in range(len(x)):
		pts.append((x[i], y[i]))

	return pts


def getSphere(pts):
	sphs = []

	for element in pts:
		if(element is not None):
			sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.03)
			sphere.paint_uniform_color([0.9, 0.2, 0])

			trans = np.identity(4)
			trans[0, 3] = element[0]
			trans[1, 3] = element[1]
			trans[2, 3] = element[2]

			sphere.transform(trans)
			sphs.append(sphere)

	return sphs


def get3dCor(src, trg):
	corr = []

	for sId, tId in zip(src, trg):
		if(sId != -1 and tId != -1):
			corr.append((sId, tId))

	corr = np.asarray(corr)

	return corr

if __name__ == "__main__":
	camera_file = args.camera_file
	rgb_csv = args.dataset + args.sequence + '/rtImagesRgb.csv'
	depth_csv = args.dataset + args.sequence + '/rtImagesDepth.csv'

	os.makedirs(os.path.join(args.output_dir, 'vis'), exist_ok=True)
	dir_name = args.output_dir
	os.makedirs(args.output_dir, exist_ok=True)

	focalX, focalY, centerX, centerY, scalingFactor = readCamera(camera_file)

	df_rgb = pd.read_csv(rgb_csv)
	df_dep = pd.read_csv(depth_csv)

	model1 = D2Net(model_file=args.model_d2).to(device)
	model2 = D2Net(model_file=args.model_rord).to(device)

	queryId = 0
	for im_q, dep_q in tqdm(zip(df_rgb['query'], df_dep['query']), total=df_rgb.shape[0]):
		filter_list = []
		dbId = 0
		for im_d, dep_d in tqdm(zip(df_rgb.iteritems(), df_dep.iteritems()), total=df_rgb.shape[1]):
			if im_d[0] == 'query':
				continue
			rgb_name_src = os.path.basename(im_q)
			H_name_src = os.path.splitext(rgb_name_src)[0] + '.npy'
			srcH = args.dataset + args.sequence + '/rgb/' + H_name_src
			rgb_name_trg = os.path.basename(im_d[1][1])
			H_name_trg = os.path.splitext(rgb_name_trg)[0] + '.npy'
			trgH = args.dataset + args.sequence + '/rgb/' + H_name_trg

			srcImg = srcH.replace('.npy', '.jpg')
			trgImg = trgH.replace('.npy', '.jpg')

			if args.model_rord:
				if args.persp:
					srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, HFile1=None, HFile2=None, model=model2, device=device)
				else:
					srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, srcH, trgH, model2, device)
			
			elif args.model_d2:
				if args.persp:
					srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, HFile1=None, HFile2=None, model=model2, device=device)
				else:
					srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, srcH, trgH, model1, device)
			
			elif args.model_ens:
				model1 = D2Net(model_file=model1_ens)
				model1 = model1.to(device)
				model2 = D2Net(model_file=model2_ens)
				model2 = model2.to(device)
				srcPts, trgPts, matchImg = getPerspKeypointsEnsemble(model1, model2, srcImg, trgImg, srcH, trgH, device)
			
			elif args.sift:
				if args.persp:
					srcPts, trgPts, matchImg, _ = siftMatching(srcImg, trgImg, HFile1=None, HFile2=None, device=device)
				else:
					srcPts, trgPts, matchImg, _ = siftMatching(srcImg, trgImg, srcH, trgH, device)

			if(isinstance(srcPts, list) == True):
				print(np.identity(4))
				filter_list.append(np.identity(4))
				continue


			srcPts = convertPts(srcPts)
			trgPts = convertPts(trgPts)

			depth_name_src = os.path.dirname(os.path.dirname(args.dataset)) + '/' + dep_q
			depth_name_trg = os.path.dirname(os.path.dirname(args.dataset)) + '/' + dep_d[1][1]

			srcCld, srcIdx, srcCor = getPointCloud(srcImg, depth_name_src, srcPts)
			trgCld, trgIdx, trgCor = getPointCloud(trgImg, depth_name_trg, trgPts)

			srcSph = getSphere(srcCor)
			trgSph = getSphere(trgCor)
			axis = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0])
			srcSph.append(srcCld); srcSph.append(axis)
			trgSph.append(trgCld); trgSph.append(axis)

			corr = get3dCor(srcIdx, trgIdx)

			p2p = o3d.pipelines.registration.TransformationEstimationPointToPoint()
			trans_init = p2p.compute_transformation(srcCld, trgCld, o3d.utility.Vector2iVector(corr))
			# print(trans_init)
			filter_list.append(trans_init)

			if args.viz3d:
				o3d.visualization.draw_geometries(srcSph)
				o3d.visualization.draw_geometries(trgSph)
				draw_registration_result(srcCld, trgCld, trans_init)

			if(dbId%args.log_interval == 0):
				cv2.imwrite(os.path.join(args.output_dir, 'vis') + "/matchImg.%02d.%02d.jpg"%(queryId, dbId//args.log_interval), matchImg)
			dbId += 1


		RT = np.stack(filter_list).transpose(1,2,0)

		np.save(os.path.join(dir_name, str(queryId) + '.npy'), RT)
		queryId += 1
		print('-----check-------', RT.shape)