import numpy as np
import gradio as gr
import cv2
from copy import deepcopy
import torch
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont

from sam.efficient_sam.build_efficient_sam import build_efficient_sam_vits
from src.utils.utils import resize_numpy_image

sam = build_efficient_sam_vits()

def show_point_or_box(image, global_points):
    # for point
    if len(global_points) == 1:
        image = cv2.circle(image, global_points[0], 10, (0, 0, 255), -1)
    # for box
    if len(global_points) == 2:
        p1 = global_points[0]
        p2 = global_points[1]
        image = cv2.rectangle(image,(int(p1[0]),int(p1[1])),(int(p2[0]),int(p2[1])),(0,0,255),2)

    return image
    
def segment_with_points(
    image,
    original_image,
    global_points,
    global_point_label,
    evt: gr.SelectData,
    img_direction,
    save_dir = "./tmp"
):
    if original_image is None:
        original_image = image
    else:
        image = original_image
    if img_direction is None:
        img_direction = original_image
    x, y = evt.index[0], evt.index[1]
    image_path = None
    mask_path = None
    if len(global_points) == 0:
        global_points.append([x, y])
        global_point_label.append(2)
        image_with_point= show_point_or_box(image.copy(), global_points)
        return image_with_point, original_image, None, global_points, global_point_label
    elif len(global_points) == 1:
        global_points.append([x, y])
        global_point_label.append(3)
        x1, y1 = global_points[0]
        x2, y2 = global_points[1]
        if x1 < x2 and y1 >= y2:
            global_points[0][0] = x1
            global_points[0][1] = y2
            global_points[1][0] = x2
            global_points[1][1] = y1
        elif x1 >= x2 and y1 < y2:
            global_points[0][0] = x2
            global_points[0][1] = y1
            global_points[1][0] = x1
            global_points[1][1] = y2
        elif x1 >= x2 and y1 >= y2:
            global_points[0][0] = x2
            global_points[0][1] = y2
            global_points[1][0] = x1
            global_points[1][1] = y1
        image_with_point = show_point_or_box(image.copy(), global_points)
        # data process
        input_point = np.array(global_points)
        input_label = np.array(global_point_label)
        pts_sampled = torch.reshape(torch.tensor(input_point), [1, 1, -1, 2])
        pts_labels = torch.reshape(torch.tensor(input_label), [1, 1, -1])
        img_tensor = transforms.ToTensor()(image)
        # sam
        predicted_logits, predicted_iou = sam(
            img_tensor[None, ...],
            pts_sampled,
            pts_labels,
        )
        mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).float().cpu().detach().numpy()
        mask_image = (mask*255.).astype(np.uint8)
        return image_with_point, original_image, mask_image, global_points, global_point_label
    else:
        global_points=[[x, y]]
        global_point_label=[2]
        image_with_point= show_point_or_box(image.copy(), global_points)
        return image_with_point, original_image, None, global_points, global_point_label


def segment_with_points_paste(
    image,
    original_image,
    global_points,
    global_point_label,
    image_b,
    evt: gr.SelectData,
    dx, 
    dy, 
    resize_scale

):
    if original_image is None:
        original_image = image
    else:
        image = original_image
    x, y = evt.index[0], evt.index[1]
    if len(global_points) == 0:
        global_points.append([x, y])
        global_point_label.append(2)
        image_with_point= show_point_or_box(image.copy(), global_points)
        return image_with_point, original_image, None, global_points, global_point_label, None
    elif len(global_points) == 1:
        global_points.append([x, y])
        global_point_label.append(3)
        x1, y1 = global_points[0]
        x2, y2 = global_points[1]
        if x1 < x2 and y1 >= y2:
            global_points[0][0] = x1
            global_points[0][1] = y2
            global_points[1][0] = x2
            global_points[1][1] = y1
        elif x1 >= x2 and y1 < y2:
            global_points[0][0] = x2
            global_points[0][1] = y1
            global_points[1][0] = x1
            global_points[1][1] = y2
        elif x1 >= x2 and y1 >= y2:
            global_points[0][0] = x2
            global_points[0][1] = y2
            global_points[1][0] = x1
            global_points[1][1] = y1
        image_with_point = show_point_or_box(image.copy(), global_points)
        # data process
        input_point = np.array(global_points)
        input_label = np.array(global_point_label)
        pts_sampled = torch.reshape(torch.tensor(input_point), [1, 1, -1, 2])
        pts_labels = torch.reshape(torch.tensor(input_label), [1, 1, -1])
        img_tensor = transforms.ToTensor()(image)
        # sam
        predicted_logits, predicted_iou = sam(
            img_tensor[None, ...],
            pts_sampled,
            pts_labels,
        )
        mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).float().cpu().detach().numpy()
        mask_uint8 = (mask*255.).astype(np.uint8)

        return image_with_point, original_image, paste_with_mask_and_offset(image, image_b, mask_uint8, dx, dy, resize_scale), global_points, global_point_label, mask_uint8
    else:
        global_points=[[x, y]]
        global_point_label=[2]
        image_with_point= show_point_or_box(image.copy(), global_points)
        return image_with_point, original_image, None, global_points, global_point_label, None

def paste_with_mask_and_offset(image_a, image_b, mask, x_offset=0, y_offset=0, delta=1):
    try:
        numpy_mask = np.array(mask)
        y_coords, x_coords = np.nonzero(numpy_mask)  
        x_min = x_coords.min()  
        x_max = x_coords.max()  
        y_min = y_coords.min()  
        y_max = y_coords.max()
        target_center_x = int((x_min + x_max) / 2)
        target_center_y = int((y_min + y_max) / 2)

        image_a = Image.fromarray(image_a)
        image_b = Image.fromarray(image_b)
        mask = Image.fromarray(mask)

        if image_a.size != mask.size:
            mask = mask.resize(image_a.size)

        cropped_image = Image.composite(image_a, Image.new('RGBA', image_a.size, (0, 0, 0, 0)), mask)
        x_b = int(target_center_x * (image_b.width / cropped_image.width))
        y_b = int(target_center_y * (image_b.height / cropped_image.height))
        x_offset = x_offset - int((delta - 1) * x_b)
        y_offset = y_offset - int((delta - 1) * y_b)
        cropped_image = cropped_image.resize(image_b.size)
        new_size = (int(cropped_image.width * delta), int(cropped_image.height * delta))
        cropped_image = cropped_image.resize(new_size)
        image_b.putalpha(128) 
        result_image = Image.new('RGBA', image_b.size, (0, 0, 0, 0))
        result_image.paste(image_b, (0, 0))
        result_image.paste(cropped_image, (x_offset, y_offset), mask=cropped_image)

        return result_image
    except:
        return None

def upload_image_move(img, original_image):
    if original_image is not None:
        return original_image
    else:
        return img

def fun_clear(*args):
    result = []
    for arg in args:
        if isinstance(arg, list):
            result.append([])
        else:
            result.append(None)
    return tuple(result)

def clear_points(img):
    image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
    if mask.sum() > 0:
        mask = np.uint8(mask > 0)
        masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
    else:
        masked_img = image.copy()

    return [], masked_img

def get_point(img, sel_pix, evt: gr.SelectData):
    sel_pix.append(evt.index)
    points = []
    for idx, point in enumerate(sel_pix):
        if idx % 2 == 0:
            cv2.circle(img, tuple(point), 10, (0, 0, 255), -1)
        else:
            cv2.circle(img, tuple(point), 10, (255, 0, 0), -1)
        points.append(tuple(point))
        if len(points) == 2:
            cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5)
            points = []
    return img if isinstance(img, np.ndarray) else np.array(img)

def calculate_translation_percentage(ori_shape, selected_points):
    dx = selected_points[1][0] - selected_points[0][0]
    dy = selected_points[1][1] - selected_points[0][1]
    dx_percentage = dx / ori_shape[1]
    dy_percentage = dy / ori_shape[0]
    
    return dx_percentage, dy_percentage

def get_point_move(original_image, img, sel_pix, evt: gr.SelectData):
    if original_image is not None:
        img = original_image.copy()
    else:
        original_image = img.copy()
    if len(sel_pix)<2:
        sel_pix.append(evt.index)
    else:
        sel_pix = [evt.index]
    points = []
    dx, dy = 0, 0
    for idx, point in enumerate(sel_pix):
        if idx % 2 == 0:
            cv2.circle(img, tuple(point), 10, (0, 0, 255), -1)
        else:
            cv2.circle(img, tuple(point), 10, (255, 0, 0), -1)
        points.append(tuple(point))
        if len(points) == 2:
            cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5)
            ori_shape = original_image.shape
            dx, dy = calculate_translation_percentage(original_image.shape, sel_pix)
            points = []
    img = np.array(img)

    return img, original_image, sel_pix, dx, dy

def store_img(img):
    image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
    if mask.sum() > 0:
        mask = np.uint8(mask > 0)
        masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
    else:
        masked_img = image.copy()

    return image, masked_img, mask
# im["background"], im["layers"][0]
def store_img_move(img, mask=None):
    if mask is not None:
        image = img["background"]
        return image, None, mask
    image, mask = img["background"], np.float32(["layers"][0][:, :, 0]) / 255.
    if mask.sum() > 0:
        mask = np.uint8(mask > 0)
        masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
    else:
        masked_img = image.copy()

    return image, masked_img, (mask*255.).astype(np.uint8)

def store_img_move_old(img, mask=None):
    if mask is not None:
        image = img["image"]
        return image, None, mask
    image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
    if mask.sum() > 0:
        mask = np.uint8(mask > 0)
        masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
    else:
        masked_img = image.copy()

    return image, masked_img, (mask*255.).astype(np.uint8)

def mask_image(image, mask, color=[255,0,0], alpha=0.5, max_resolution=None):
    """ Overlay mask on image for visualization purpose. 
    Args:
        image (H, W, 3) or (H, W): input image
        mask (H, W): mask to be overlaid
        color: the color of overlaid mask
        alpha: the transparency of the mask
    """
    if max_resolution is not None:
        image, _ = resize_numpy_image(image, max_resolution*max_resolution)
        mask = cv2.resize(mask, (image.shape[1], image.shape[0]),interpolation=cv2.INTER_NEAREST)

    out = deepcopy(image)
    img = deepcopy(image)
    img[mask == 1] = color
    out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out)
    contours = cv2.findContours(np.uint8(deepcopy(mask)), cv2.RETR_TREE, 
                        cv2.CHAIN_APPROX_SIMPLE)[-2:]
    return out