import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageEnhance, ImageDraw
import torch
import streamlit as st
from model.inference_cpu import inference_case

initial_rectangle = {
    "version": "4.4.0",
    'objects': [
        {
            "type": "rect",
            "version": "4.4.0",
            "originX": "left",
            "originY": "top",
            "left": 50,
            "top": 50,
            "width": 100,
            "height": 100,
            'fill': 'rgba(255, 165, 0, 0.3)', 
            'stroke': '#2909F1', 
            'strokeWidth': 3, 
            'strokeDashArray': None, 
            'strokeLineCap': 'butt', 
            'strokeDashOffset': 0, 
            'strokeLineJoin': 'miter', 
            'strokeUniform': True, 
            'strokeMiterLimit': 4, 
            'scaleX': 1, 
            'scaleY': 1, 
            'angle': 0, 
            'flipX': False, 
            'flipY': False, 
            'opacity': 1, 
            'shadow': None, 
            'visible': True, 
            'backgroundColor': '', 
            'fillRule': 
            'nonzero', 
            'paintFirst': 
            'fill', 
            'globalCompositeOperation': 'source-over', 
            'skewX': 0, 
            'skewY': 0, 
            'rx': 0, 
            'ry': 0
        }
    ]
}

def run():
    image = st.session_state.data_item["image"].float()
    image_zoom_out = st.session_state.data_item["zoom_out_image"].float()
    text_prompt = None
    point_prompt = None
    box_prompt = None
    if st.session_state.use_text_prompt:
        text_prompt = st.session_state.text_prompt
    if st.session_state.use_point_prompt and len(st.session_state.points) > 0:
        point_prompt = reflect_points_into_model(st.session_state.points)
    if st.session_state.use_box_prompt:
        box_prompt = reflect_box_into_model(st.session_state.rectangle_3Dbox)
    inference_case.clear()
    st.session_state.preds_3D, st.session_state.preds_3D_ori = inference_case(image, image_zoom_out, 
                                            text_prompt=text_prompt,
                                            _point_prompt=point_prompt,
                                            _box_prompt=box_prompt)

def reflect_box_into_model(box_3d):
    z1, y1, x1, z2, y2, x2 = box_3d
    x1_prompt = int(x1 * 256.0 / 325.0)
    y1_prompt = int(y1 * 256.0 / 325.0)
    z1_prompt = int(z1 * 32.0 / 325.0)
    x2_prompt = int(x2 * 256.0 / 325.0)
    y2_prompt = int(y2 * 256.0 / 325.0)
    z2_prompt = int(z2 * 32.0 / 325.0)
    return torch.tensor(np.array([z1_prompt, y1_prompt, x1_prompt, z2_prompt, y2_prompt, x2_prompt]))

def reflect_json_data_to_3D_box(json_data, view):
    if view == 'xy':
        st.session_state.rectangle_3Dbox[1] = json_data['objects'][0]['top']
        st.session_state.rectangle_3Dbox[2] = json_data['objects'][0]['left']
        st.session_state.rectangle_3Dbox[4] = json_data['objects'][0]['top'] + json_data['objects'][0]['height'] * json_data['objects'][0]['scaleY']
        st.session_state.rectangle_3Dbox[5] = json_data['objects'][0]['left'] + json_data['objects'][0]['width'] * json_data['objects'][0]['scaleX']
    print(st.session_state.rectangle_3Dbox)

def reflect_points_into_model(points):
    points_prompt_list = []
    for point in points:
        z, y, x = point
        x_prompt = int(x * 256.0 / 325.0)
        y_prompt = int(y * 256.0 / 325.0)
        z_prompt = int(z * 32.0 / 325.0)
        points_prompt_list.append([z_prompt, y_prompt, x_prompt])
    points_prompt = np.array(points_prompt_list)
    points_label = np.ones(points_prompt.shape[0])
    print(points_prompt, points_label)
    return (torch.tensor(points_prompt), torch.tensor(points_label))

def show_points(points_ax, points_label, ax):
    color = 'red' if points_label == 0 else 'blue'
    ax.scatter(points_ax[0], points_ax[1], c=color, marker='o', s=200)

def make_fig(image, preds, point_axs=None, current_idx=None, view=None):
    # Convert A to an image
    image = Image.fromarray((image * 255).astype(np.uint8)).convert("RGB")
    enhancer = ImageEnhance.Contrast(image)
    image = enhancer.enhance(2.0)

    # Create a yellow mask from B
    if preds is not None:
        mask = np.where(preds == 1, 255, 0).astype(np.uint8)
        mask = Image.merge("RGB", 
                           (Image.fromarray(mask), 
                            Image.fromarray(mask), 
                            Image.fromarray(np.zeros_like(mask, dtype=np.uint8))))

        # Overlay the mask on the image
        image = Image.blend(image.convert("RGB"), mask, alpha=st.session_state.transparency)
    
    if point_axs is not None:
        draw = ImageDraw.Draw(image)
        radius = 5
        for point in point_axs:
            z, y, x = point
            if view == 'xy' and z == current_idx:
                draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill="blue")
            elif view == 'xz'and y == current_idx:
                draw.ellipse((x-radius, z-radius, x+radius, z+radius), fill="blue")
    return image