import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Patch
import io
import cv2
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import csv
import pandas as pd

from ultralytics import YOLO
import torch

from paddleocr import PaddleOCR
import postprocess

import gradio as gr


device = "cuda" if torch.cuda.is_available() else "cpu"
detection_model = YOLO('yolov8/runs/detect/yolov8s-custom-detection/weights/best.pt').to(device)
structure_model = YOLO('yolov8/runs/detect/yolov8s-custom-structure-all/weights/best.pt').to(device)
ocr_model = PaddleOCR(use_angle_cls=True, lang="ch", det_limit_side_len=1920)  # TODO use large det_limit_side_len to get better OCR result

detection_class_names = ['table', 'table rotated']
structure_class_names = [
    'table', 'table column', 'table row', 'table column header',
    'table projected row header', 'table spanning cell', 'no object'
]
structure_class_map = {k: v for v, k in enumerate(structure_class_names)}
structure_class_thresholds = {
    "table": 0.5,
    "table column": 0.5,
    "table row": 0.5,
    "table column header": 0.5,
    "table projected row header": 0.5,
    "table spanning cell": 0.5,
    "no object": 10
}


def table_detection(image):
    imgsz = 800
    pred = detection_model.predict(image, imgsz=imgsz)
    pred = pred[0].boxes
    result = pred.cpu().numpy()
    result_list = [list(result.xywhn[i]) + [result.conf[i], result.cls[i]] for i in range(result.shape[0])]
    return result_list


def table_structure(image):
    imgsz = 1024
    pred = structure_model.predict(image, imgsz=imgsz)
    pred = pred[0].boxes
    result = pred.cpu().numpy()
    result_list = [list(result.xywhn[i]) + [result.conf[i], result.cls[i]] for i in range(result.shape[0])]
    return result_list


def crop_image(image, detection_result):
    # crop_filenames = []
    width = image.shape[1]
    height = image.shape[0]
    # print(width, height)
    crop_image = image
    for i, result in enumerate(detection_result[:1]):  # TODO only return first detected table
        class_id = int(result[5])
        score = float(result[4])
        min_x = result[0]
        min_y = result[1]
        w = result[2]
        h = result[3]
        
        # x1 = max(0, int((min_x-w/2-0.02)*width))  # TODO expand 2%
        # y1 = max(0, int((min_y-h/2-0.02)*height))  # TODO expand 2%
        # x2 = min(width, int((min_x+w/2+0.02)*width))  # TODO expand 2%
        # y2 = min(height, int((min_y+h/2+0.02)*height))  # TODO expand 2%
        x1 = max(0, int((min_x-w/2)*width)-10)  # TODO expand 10px
        y1 = max(0, int((min_y-h/2)*height)-10)  # TODO expand 10px
        x2 = min(width, int((min_x+w/2)*width)+10)  # TODO expand 10px
        y2 = min(height, int((min_y+h/2)*height)+10)  # TODO expand 10px
        # print(x1, y1, x2, y2)
        crop_image = image[y1:y2, x1:x2, :]
        # crop_filename = filename[:-4]+'_'+str(i)+'_'+detection_class_names[class_id]+filename[-4:]
        # crop_filenames.append(crop_filename)
        # cv2.imwrite(crop_filename, crop_image)
    return crop_image


def convert_stucture(ocr_result, image, structure_result):
    width = image.shape[1]
    height = image.shape[0]
    # print(width, height)
    
    bboxes = []
    scores = []
    labels = []
    for i, result in enumerate(structure_result):
        class_id = int(result[5])
        score = float(result[4])
        min_x = result[0]
        min_y = result[1]
        w = result[2]
        h = result[3]
        
        x1 = int((min_x-w/2)*width)
        y1 = int((min_y-h/2)*height)
        x2 = int((min_x+w/2)*width)
        y2 = int((min_y+h/2)*height)
        # print(x1, y1, x2, y2)

        bboxes.append([x1, y1, x2, y2])
        scores.append(score)
        labels.append(class_id)
    
    table_objects = []
    for bbox, score, label in zip(bboxes, scores, labels):
        table_objects.append({'bbox': bbox, 'score': score, 'label': label})
    # print('table_objects:', table_objects)
        
    table = {'objects': table_objects, 'page_num': 0}
    
    table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']]
    if len(table_class_objects) > 1:
        table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True)
    try:
        table_bbox = list(table_class_objects[0]['bbox'])
    except:
        table_bbox = (0,0,1000,1000)
    # print('table_class_objects:', table_class_objects)
    # print('table_bbox:', table_bbox)
    
    page_tokens = ocr_result
    tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5]
    # print('tokens_in_table:', tokens_in_table)
    
    table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds)
    
    return table_structures, cells, confidence_score


def visualize_cells(image, table_structures, cells):
    width = image.shape[1]
    height = image.shape[0]
    # print(width, height)
    empty_image = np.zeros((height, width, 3), np.uint8)
    empty_image.fill(255)
    empty_image = Image.fromarray(cv2.cvtColor(empty_image, cv2.COLOR_BGR2RGB))
    draw = ImageDraw.Draw(empty_image)
    fontStyle = ImageFont.truetype("SimSong.ttc", 10, encoding="utf-8")
    
    num_cols = len(table_structures['columns'])
    num_rows = len(table_structures['rows'])
    data_rows = [['' for _ in range(num_cols)] for _ in range(num_rows)]
    for i, cell in enumerate(cells):
        bbox = cell['bbox']
        x1 = int(bbox[0])
        y1 = int(bbox[1])
        x2 = int(bbox[2])
        y2 = int(bbox[3])
        col_num = cell['column_nums'][0]
        row_num = cell['row_nums'][0]
        spans = cell['spans']
        text = ''
        for span in spans:
            if 'text' in span:
                text += span['text']     
        data_rows[row_num][col_num] = text
        
        # print('text:', text)
        text_len = len(text)
        # print('text_len:', text_len)
        cell_width = x2-x1
        # print('cell_width:', cell_width)
        num_per_line = cell_width//10
        # print('num_per_line:', num_per_line)
        if num_per_line != 0:
            line_num = text_len//num_per_line
        else:
            line_num = 0
        # print('line_num:', line_num)
        new_text = text[:num_per_line]+'\n'
        for j in range(line_num):
            new_text += text[(j+1)*num_per_line:(j+2)*num_per_line]+'\n'
        # print('new_text:', new_text)
        text = new_text
        
        cv2.rectangle(image, (x1, y1), (x2, y2), color=(0,255,0))
        # cv2.putText(image, str(row_num)+'-'+str(col_num), (x1, y1+30), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
        
        # cv2.rectangle(empty_image, (x1, y1), (x2, y2), color=(0,0,255))
        # cv2.putText(empty_image, str(row_num)+'-'+str(col_num), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
        # cv2.putText(empty_image, text, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
        draw.rectangle([(x1, y1), (x2, y2)], (255,255,255), (0,255,0))
        # draw.text((x1-20, y1), str(row_num)+'-'+str(col_num), (255,0,0), font=fontStyle)
        # draw.text((x1, y1), text, (0,0,255), font=fontStyle)

    df = pd.DataFrame(data_rows)
    df.columns = df.columns.astype(str)
    return image, df, df.to_json()


def ocr(image):
    result = ocr_model.ocr(image, cls=True)
    result = result[0]
    new_result = []
    if result is not None:
        bounding_boxes = [line[0] for line in result]
        txts = [line[1][0] for line in result]
        scores = [line[1][1] for line in result]
        # print('txts:', txts)
        # print('scores:', scores)
        # print('bounding_boxes:', bounding_boxes)
        for label, bbox in zip(txts, bounding_boxes):
            new_result.append({'bbox': [bbox[0][0], bbox[0][1], bbox[2][0], bbox[2][1]], 'text': label})
    
    return new_result


def detect_and_crop_table(image):
    detection_result = table_detection(image)
    # print('detection_result:', detection_result)
    cropped_table = crop_image(image, detection_result)

    return cropped_table


def recognize_table(image, ocr_result):
    structure_result = table_structure(image)
    print('structure_result:', structure_result)
    table_structures, cells, confidence_score = convert_stucture(ocr_result, image, structure_result)
    print('table_structures:', table_structures)
    print('cells:', cells)
    print('confidence_score:', confidence_score)
    image, df, data = visualize_cells(image, table_structures, cells)
        
    return image, df, data


def process_pdf(image):
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    
    cropped_table = detect_and_crop_table(image)
    
    ocr_result = ocr(cropped_table)
    # print('ocr_result:', ocr_result)

    image, df, data = recognize_table(cropped_table, ocr_result)
    print('df:', df)
    
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    return image, df, data
    

title = "Demo: table detection & recognition with Table Structure Recognition (Yolov8)."
description = """Demo for table extraction with the Table Structure Recognition (Yolov8)."""
examples = [['image.png'], ['mistral_paper.png']]

app = gr.Interface(fn=process_pdf, 
                     inputs=gr.Image(type="numpy"), 
                     outputs=[gr.Image(type="numpy", label="Detected table"), gr.Dataframe(label="Table as CSV"), gr.JSON(label="Data as JSON")],
                     title=title,
                     description=description,
                     examples=examples)
app.queue()
# app.launch(debug=True, share=True)
app.launch()