import os
import cv2
import math
import numpy as np
from PIL import Image

import torch
import torchvision.transforms.functional as F

class DemoDataset(object):
    def __init__(self):
        super().__init__()    
        self.LIMBSEQ = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
                [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
                [1, 16], [16, 18], [3, 17], [6, 18]]

        self.COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
                [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
                [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
        
        self.LIMBSEQ_hands = [[0, 1], [1, 2], [2, 3], [3, 4], \
            [0, 5], [5, 6], [6, 7], [7, 8], \
            [0, 9], [9, 10], [10, 11], [11, 12], \
            [0, 13], [13, 14], [14, 15], [15, 16], \
            [0, 17], [17, 18], [18, 19], [19, 20], \
            [21, 22], [22, 23], [23, 24], [24, 25], \
            [21, 26], [26, 27], [27, 28], [28, 29], \
            [21, 30], [30, 31], [31, 32], [32, 33], \
            [21, 34], [34, 35], [35, 36], [36, 37], \
            [21, 38], [38, 39], [39, 40], [40, 41]]
        
        self.COLORS_hands = [[85, 0, 0], [170, 0, 0], [85, 85, 0], [85, 170, 0], [170, 85, 0], [170, 170, 0], [85, 85, 85], \
            [85, 85, 170], [85, 170, 85], [85, 170, 170], [0, 85, 0], [0, 170, 0], [0, 85, 85], [0, 85, 170], \
            [0, 170, 85], [0, 170, 170], [50, 0, 0], [135, 0, 0], [50, 50, 0], [50, 135, 0], [135, 50, 0], \
            [135, 135, 0], [50, 50, 50], [50, 50, 135], [50, 135, 50], [50, 135, 135], [0, 50, 0], [0, 135, 0], \
            [0, 50, 50], [0, 50, 135], [0, 135, 50], [0, 135, 135], [100, 0, 0], [200, 0, 0], [100, 100, 0], \
            [100, 200, 0], [200, 100, 0], [200, 200, 0], [100, 100, 100], [100, 100, 200], [100, 200, 100], [100, 200, 200]
            ]
        
        self.img_size = tuple([512, 352])
    
    def load_item(self, img, pose, handpose=None):

        reference_img = self.get_image_tensor(img)[None,:]
        label, ske = self.get_label_tensor(pose, handpose)
        label = label[None,:]

        return {'reference_image':reference_img, 'target_skeleton':label, 'skeleton_img': ske}
    
    def get_image_tensor(self, bgr_img):
        img = Image.fromarray(cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB))
        img = F.resize(img, self.img_size)
        img = F.to_tensor(img)
        img = F.normalize(img, (0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
        return img    

    def get_label_tensor(self, pose, hand_pose=None):
        canvas = np.zeros((self.img_size[0], self.img_size[1], 3)).astype(np.uint8)
        keypoint = np.array(pose)
        if hand_pose is not None:
            keypoint_hands = np.array(hand_pose)
        else:
            keypoint_hands = None
        
        # keypoint = self.trans_keypoins(keypoint)
        
        stickwidth = 4
        for i in range(18):
            x, y = keypoint[i, 0:2]
            if x == -1 or y == -1:
                continue
            cv2.circle(canvas, (int(x), int(y)), 4, self.COLORS[i], thickness=-1)
        if keypoint_hands is not None:
            for i in range(42):
                    x, y = keypoint_hands[i, 0:2]
                    if x == -1 or y == -1:
                        continue
                    cv2.circle(canvas, (int(x), int(y)), 4, self.COLORS_hands[i], thickness=-1)
        
        joints = []
        for i in range(17):
            Y = keypoint[np.array(self.LIMBSEQ[i])-1, 0]
            X = keypoint[np.array(self.LIMBSEQ[i])-1, 1]            
            cur_canvas = canvas.copy()
            if -1 in Y or -1 in X:
                joints.append(np.zeros_like(cur_canvas[:, :, 0]))
                continue
            mX = np.mean(X)
            mY = np.mean(Y)
            length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
            angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
            polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
            cv2.fillConvexPoly(cur_canvas, polygon, self.COLORS[i])
            canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)

            joint = np.zeros_like(cur_canvas[:, :, 0])
            cv2.fillConvexPoly(joint, polygon, 255)
            joint = cv2.addWeighted(joint, 0.4, joint, 0.6, 0)
            joints.append(joint)
        if keypoint_hands is not None:
            for i in range(40):
                    Y = keypoint_hands[np.array(self.LIMBSEQ_hands[i]), 0]
                    X = keypoint_hands[np.array(self.LIMBSEQ_hands[i]), 1]            
                    cur_canvas = canvas.copy()
                    if -1 in Y or -1 in X:
                        if (i+1) % 4 == 0:
                            joints.append(np.zeros_like(cur_canvas[:, :, 0]))
                        continue
                    mX = np.mean(X)
                    mY = np.mean(Y)
                    length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
                    angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
                    polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), int(stickwidth/2)), int(angle), 0, 360, 1)
                    cv2.fillConvexPoly(cur_canvas, polygon, self.COLORS_hands[i])
                    canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
                    
                    # 一根手指一个通道
                    if i % 4 == 0:
                        joint = np.zeros_like(cur_canvas[:, :, 0])
                    cv2.fillConvexPoly(joint, polygon, 255)
                    joint = cv2.addWeighted(joint, 0.4, joint, 0.6, 0)
                    if (i+1) % 4 == 0:
                        joints.append(joint)
        
        pose = F.to_tensor(Image.fromarray(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)))
        
        tensors_dist = 0
        e = 1
        for i in range(len(joints)):
            im_dist = cv2.distanceTransform(255-joints[i], cv2.DIST_L1, 3)
            im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
            tensor_dist = F.to_tensor(Image.fromarray(im_dist))
            tensors_dist = tensor_dist if e == 1 else torch.cat([tensors_dist, tensor_dist])
            e += 1

        label_tensor = torch.cat((pose, tensors_dist), dim=0)
            
        return label_tensor, canvas     
    
    def tensor2im(self, image_tensor, imtype=np.uint8, normalize=True,
              three_channel_output=True):
        r"""Convert tensor to image.

        Args:
            image_tensor (torch.tensor or list of torch.tensor): If tensor then
                (NxCxHxW) or (NxTxCxHxW) or (CxHxW).
            imtype (np.dtype): Type of output image.
            normalize (bool): Is the input image normalized or not?
                three_channel_output (bool): Should single channel images be made 3
                channel in output?

        Returns:
            (numpy.ndarray, list if case 1, 2 above).
        """
        if image_tensor is None:
            return None
        if isinstance(image_tensor, list):
            return [self.tensor2im(x, imtype, normalize) for x in image_tensor]
        if image_tensor.dim() == 5 or image_tensor.dim() == 4:
            return [self.tensor2im(image_tensor[idx], imtype, normalize)
                    for idx in range(image_tensor.size(0))]

        if image_tensor.dim() == 3:
            image_numpy = image_tensor.detach().float().numpy()
            if normalize:
                image_numpy = (np.transpose(
                    image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
            else:
                image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
            image_numpy = np.clip(image_numpy, 0, 255)
            if image_numpy.shape[2] == 1 and three_channel_output:
                image_numpy = np.repeat(image_numpy, 3, axis=2)
            elif image_numpy.shape[2] > 3:
                image_numpy = image_numpy[:, :, :3]
            return image_numpy.astype(imtype)

    def trans_keypoins(self, keypoints):
        missing_keypoint_index = keypoints == -1

        keypoints[missing_keypoint_index] = -1
        return keypoints