import os
import io
from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image
from iopaint.single_processing import batch_inpaint_cv2
import gradio as gr
from bgremover import process

# set current working directory cache instead of default
os.environ["TORCH_HOME"] = "./pretrained-model"
os.environ["HUGGINGFACE_HUB_CACHE"] = "./pretrained-model"

def resize_image(input_image_path, width=640, height=640):
    """Resizes an image from image data and returns the resized image."""
    try:
        # Read the image using cv2.imread
        img = cv2.imread(input_image_path, cv2.IMREAD_COLOR)

        # Resize while maintaining the aspect ratio
        shape = img.shape[:2]  # current shape [height, width]
        new_shape = (width, height)  # the shape to resize to

        # Scale ratio (new / old)
        r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
        ratio = r, r  # width, height ratios
        new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))

        # Resize the image
        im = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)

        # Pad the image
        color = (114, 114, 114)  # color used for padding
        dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
        # divide padding into 2 sides
        dw /= 2
        dh /= 2
        # compute padding on all corners
        top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
        left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
        im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
        return im

    except Exception as e:
        raise gr.Error("Error in resizing image!")


def process_images(input_image, append_image, default_class="chair"):
    if not input_image:
        raise gr.Error("Please upload a main image.")

    if not append_image:
        raise gr.Error("Please upload an object image.")

    # Resize input image and get base64 data of resized image
    img = resize_image(input_image)

    if img is None:
        raise gr.Error("Failed to decode resized image!")

    H, W, _ = img.shape
    x_point = 0
    y_point = 0
    width = 1
    height = 1

    # Load a model
    model = YOLO('pretrained-model/yolov8m-seg.pt')  # pretrained YOLOv8m-seg model

    # Run batched inference on a list of images
    results = model(img, imgsz=(W,H), conf=0.5)  # chair class 56 with confidence >= 0.5
    names = model.names

    class_found = False
    for result in results:
        for i, label in enumerate(result.boxes.cls):
            # Check if the label matches the chair label
            if names[int(label)] == default_class:
                class_found = True
                # Convert the tensor to a numpy array
                chair_mask_np = result.masks.data[i].numpy()

                kernel = np.ones((5, 5), np.uint8)  # Create a 5x5 kernel for dilation
                chair_mask_np = cv2.dilate(chair_mask_np, kernel, iterations=2)  # Apply dilation

                # Find contours to get bounding box
                contours, _ = cv2.findContours((chair_mask_np == 1).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

                # Iterate over contours to find the bounding box of each object
                for contour in contours:
                    x, y, w, h = cv2.boundingRect(contour)
                    x_point = x
                    y_point = y
                    width = w
                    height = h

                # Get the corresponding mask
                mask = result.masks.data[i].numpy() * 255
                dilated_mask = cv2.dilate(mask, kernel, iterations=2)  # Apply dilation
                # Resize the mask to match the dimensions of the original image
                resized_mask = cv2.resize(dilated_mask, (img.shape[1], img.shape[0]))

                # call repainting and merge function
                output_numpy = repaitingAndMerge(append_image,width, height, x_point, y_point, img, resized_mask)
                # Return the output numpy image in the API response
                return output_numpy

    # return class not found in prediction
    if not class_found:
        raise gr.Error(f'{default_class} object not found in the image')

def repaitingAndMerge(append_image_path, width, height, xposition, yposition, input_base, mask_base):
    # lama inpainting start
    print("lama inpainting start")
    inpaint_result_np = batch_inpaint_cv2('lama', 'cpu', input_base, mask_base)
    print("lama inpainting end")

    # Create PIL Image from NumPy array
    final_image = Image.fromarray(inpaint_result_np)

    print("merge start")
    # Load the append image using cv2.imread
    append_image = cv2.imread(append_image_path, cv2.IMREAD_UNCHANGED)
    # Resize the append image while preserving transparency
    resized_image = cv2.resize(append_image, (width, height), interpolation=cv2.INTER_AREA)
    # Convert the resized image to RGBA format (assuming it's in BGRA format)
    resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGRA2RGBA)

    # Create a PIL Image from the resized image with transparent background
    #append_image_pil = Image.fromarray(resized_image)

    # remove the bg from image
    append_image_pil = process(resized_image)

    # Paste the append image onto the final image
    final_image.paste(append_image_pil, (xposition, yposition), append_image_pil)
    # Save the resulting image
    print("merge end")
    # Convert the final image to base64
    with io.BytesIO() as output_buffer:
        final_image.save(output_buffer, format='PNG')
        output_numpy = np.array(final_image)

    return output_numpy