from ultralytics import YOLO 
from PIL import Image,ImageDraw
import numpy as np 
from PIL import ImageFilter
from dotenv import load_dotenv
import numpy as np

from ocr_functions import textract_ocr
from pdf2image import convert_from_bytes
from multiprocessing import Pool



model =YOLO("yolo_model/model_3_openvino_model")
labels = ['Achievement', 'Certifications', 'Community', 'Contact', 'Education', 'Experience', 'Interests', 'Languages', 'Name', 'Profil', 'Projects', 'image', 'resume', 'skills']


def check_intersection(bbox1, bbox2):
    # Check for intersection between two bounding boxes
    x1, y1, x2, y2 = bbox1
    x3, y3, x4, y4 = bbox2
    return not (x3 > x2 or x4 < x1 or y3 > y2 or y4 < y1)

def check_inclusion(bbox1, bbox2):
    # Check if one bounding box is completely inside another
    x1, y1, x2, y2 = bbox1
    x3, y3, x4, y4 = bbox2
    return x1 >= x3 and y1 >= y3 and x2 <= x4 and y2 <= y4

def union_bbox(bbox1, bbox2):
    # Calculate the union of two bounding boxes
    x1 = min(bbox1[0], bbox2[0])
    y1 = min(bbox1[1], bbox2[1])
    x2 = max(bbox1[2], bbox2[2])
    y2 = max(bbox1[3], bbox2[3])
    return [x1, y1, x2, y2]

def filter_bboxes(bboxes):
    # Iterate through each pair of bounding boxes and filter out those that intersect or are completely contained within another
    filtered_bboxes = []
    for bbox1 in bboxes:
        is_valid = True
        for bbox2 in filtered_bboxes:
            if check_intersection(bbox1, bbox2):
                # If the two bounding boxes intersect, compute their union
                bbox1 = union_bbox(bbox1, bbox2)
                # Mark the current bbox as invalid to be removed
                is_valid = False
                break
            elif check_inclusion(bbox1, bbox2):
                # If bbox1 is completely contained within bbox2, mark bbox1 as invalid to be removed
                is_valid = False
                break
        if is_valid:
            filtered_bboxes.append(bbox1)
    return filtered_bboxes




def draw_bboxes(image, bboxes ):
    draw = ImageDraw.Draw(image)
    for bbox in bboxes:
        x1, y1, x2, y2 = bbox
        
        x1,y1,x2,y2 = int(x1),int(y1),int(x2),int(y2)
        draw.rectangle([(x1, y1), (x2, y2)], outline=(255, 0, 0), width=2)



def extract_image(image,box):
    x1, y1, x2, y2 = box
    cropped_image = image.crop((x1, y1, x2, y2))




def process_bbox(args):
    image, bbox = args
    return textract_ocr(image, bbox)

def convert_bboxes_to_original(bboxes, original_size, resized_size):
    """
    Convert bounding boxes from resized image size to original image size using NumPy.
    
    :param bboxes: NumPy array of bounding boxes in format [x1, y1, x2, y2] for resized image
    :param original_size: Tuple (original_width, original_height)
    :param resized_size: Tuple (resized_width, resized_height)
    :return: NumPy array of bounding boxes in format [x1, y1, x2, y2] for original image
    """
    original_width, original_height = original_size
    resized_width, resized_height = resized_size
    
    # Calculate scaling factors
    x_scale = original_width / resized_width
    y_scale = original_height / resized_height
    
    # Convert bounding boxes using broadcasting
    bboxes_np = np.array(bboxes)
    bboxes_np[:, 0] *= x_scale  # Scale x1
    bboxes_np[:, 1] *= y_scale  # Scale y1
    bboxes_np[:, 2] *= x_scale  # Scale x2
    bboxes_np[:, 3] *= y_scale  # Scale y2
    
    return bboxes_np


def correct_to_by_classifier(text): 
    pass

def extract_text_from_sections(image):
    cv_parse = {}
    original_size  = image.size
    original_img = image
    image = image.resize((640, 640))
    image = image.convert("RGB")
    image_np = np.array(image)
    
    # Perform model prediction
    result = model(source=image_np, conf=0.20)
    names = result[0].names  # Class names
    data = result[0].boxes.data.numpy()
    
    # Extract bounding boxes and their corresponding class labels
    bboxes = data[:, 0:4].tolist()
    

    class_ids = data[:, 5].astype(int).tolist()
    
    bboxes_filter = filter_bboxes(bboxes)
    original_bboxes = convert_bboxes_to_original(bboxes_filter, original_size, (640,640))
    
    
    
    
    for bbox, class_id in zip(original_bboxes, class_ids):
        class_name = names[class_id]
        if class_name !="image":
            text = textract_ocr(original_img, bbox)
            if class_name in cv_parse:
                cv_parse[class_name] += "\n" + text
            else:
                cv_parse[class_name] = text
    
    return cv_parse

def merge_dicts_append_strings(dict1, dict2):
    # Create a new dictionary to hold the merged results
    merged_dict = {}

    # Add all key-value pairs from dict1 to merged_dict
    for key, value in dict1.items():
        merged_dict[key] = value

    # Append values from dict2 to merged_dict
    for key, value in dict2.items():
        if key in merged_dict:
            merged_dict[key] += "\n" + value
        else:
            merged_dict[key] = value

    return merged_dict
    

def cv_to_json(file):
    cv_parsing = {}
    images = convert_from_bytes(file.read())
    for image in images : 
        cv_parsing = merge_dicts_append_strings(cv_parsing,extract_text_from_sections(image))
    return cv_parsing