import os
import numpy as np
import torch
import torch.nn as nn
import gradio as gr
import time
from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights
from torchvision.ops import nms, box_iou
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont, ImageFilter
from breed_health_info import breed_health_info
from breed_noise_info import breed_noise_info
from dog_database import get_dog_description
from scoring_calculation_system import UserPreferences
from recommendation_html_format import format_recommendation_html, get_breed_recommendations
from history_manager import UserHistoryManager
from search_history import create_history_tab, create_history_component
from styles import get_css_styles
from breed_detection import create_detection_tab
from breed_comparison import create_comparison_tab
from breed_recommendation import create_recommendation_tab
from html_templates import (
    format_description_html,
    format_single_dog_result,
    format_multiple_breeds_result,
    format_error_message,
    format_warning_html,
    format_multi_dog_container,
    format_breed_details_html,
    get_color_scheme,
    get_akc_breeds_link
)
from urllib.parse import quote
from ultralytics import YOLO
import asyncio
import traceback


model_yolo = YOLO('yolov8l.pt')

history_manager = UserHistoryManager()

dog_breeds = ["Afghan_Hound", "African_Hunting_Dog", "Airedale", "American_Staffordshire_Terrier",
              "Appenzeller", "Australian_Terrier", "Bedlington_Terrier", "Bernese_Mountain_Dog", "Bichon_Frise",
              "Blenheim_Spaniel", "Border_Collie", "Border_Terrier", "Boston_Bull", "Bouvier_Des_Flandres",
              "Brabancon_Griffon", "Brittany_Spaniel", "Cardigan", "Chesapeake_Bay_Retriever",
              "Chihuahua", "Dachshund", "Dandie_Dinmont", "Doberman", "English_Foxhound", "English_Setter",
              "English_Springer", "EntleBucher", "Eskimo_Dog", "French_Bulldog", "German_Shepherd",
              "German_Short-Haired_Pointer", "Gordon_Setter", "Great_Dane", "Great_Pyrenees",
              "Greater_Swiss_Mountain_Dog","Havanese", "Ibizan_Hound", "Irish_Setter", "Irish_Terrier",
              "Irish_Water_Spaniel", "Irish_Wolfhound", "Italian_Greyhound", "Japanese_Spaniel",
              "Kerry_Blue_Terrier", "Labrador_Retriever", "Lakeland_Terrier", "Leonberg", "Lhasa",
              "Maltese_Dog", "Mexican_Hairless", "Newfoundland", "Norfolk_Terrier", "Norwegian_Elkhound",
              "Norwich_Terrier", "Old_English_Sheepdog", "Pekinese", "Pembroke", "Pomeranian",
              "Rhodesian_Ridgeback", "Rottweiler", "Saint_Bernard", "Saluki", "Samoyed",
              "Scotch_Terrier", "Scottish_Deerhound", "Sealyham_Terrier", "Shetland_Sheepdog", "Shiba_Inu",
              "Shih-Tzu", "Siberian_Husky", "Staffordshire_Bullterrier", "Sussex_Spaniel",
              "Tibetan_Mastiff", "Tibetan_Terrier", "Walker_Hound", "Weimaraner",
              "Welsh_Springer_Spaniel", "West_Highland_White_Terrier", "Yorkshire_Terrier",
              "Affenpinscher", "Basenji", "Basset", "Beagle", "Black-and-Tan_Coonhound", "Bloodhound",
              "Bluetick", "Borzoi", "Boxer", "Briard", "Bull_Mastiff", "Cairn", "Chow", "Clumber",
              "Cocker_Spaniel", "Collie", "Curly-Coated_Retriever", "Dhole", "Dingo",
              "Flat-Coated_Retriever", "Giant_Schnauzer", "Golden_Retriever", "Groenendael", "Keeshond",
              "Kelpie", "Komondor", "Kuvasz", "Malamute", "Malinois", "Miniature_Pinscher",
              "Miniature_Poodle", "Miniature_Schnauzer", "Otterhound", "Papillon", "Pug", "Redbone",
              "Schipperke", "Silky_Terrier", "Soft-Coated_Wheaten_Terrier", "Standard_Poodle",
              "Standard_Schnauzer", "Toy_Poodle", "Toy_Terrier", "Vizsla", "Whippet",
              "Wire-Haired_Fox_Terrier"]


class MultiHeadAttention(nn.Module):

    def __init__(self, in_dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = max(1, in_dim // num_heads)
        self.scaled_dim = self.head_dim * num_heads
        self.fc_in = nn.Linear(in_dim, self.scaled_dim)
        self.query = nn.Linear(self.scaled_dim, self.scaled_dim)
        self.key = nn.Linear(self.scaled_dim, self.scaled_dim)
        self.value = nn.Linear(self.scaled_dim, self.scaled_dim)
        self.fc_out = nn.Linear(self.scaled_dim, in_dim)

    def forward(self, x):
        N = x.shape[0]
        x = self.fc_in(x)
        q = self.query(x).view(N, self.num_heads, self.head_dim)
        k = self.key(x).view(N, self.num_heads, self.head_dim)
        v = self.value(x).view(N, self.num_heads, self.head_dim)

        energy = torch.einsum("nqd,nkd->nqk", [q, k])
        attention = F.softmax(energy / (self.head_dim ** 0.5), dim=2)

        out = torch.einsum("nqk,nvd->nqd", [attention, v])
        out = out.reshape(N, self.scaled_dim)
        out = self.fc_out(out)
        return out

class BaseModel(nn.Module):
    def __init__(self, num_classes, device='cuda' if torch.cuda.is_available() else 'cpu'):
        super().__init__()
        self.device = device
        self.backbone = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)
        self.feature_dim = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()

        self.num_heads = max(1, min(8, self.feature_dim // 64))
        self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)

        self.classifier = nn.Sequential(
            nn.LayerNorm(self.feature_dim),
            nn.Dropout(0.3),
            nn.Linear(self.feature_dim, num_classes)
        )

        self.to(device)

    def forward(self, x):
        x = x.to(self.device)
        features = self.backbone(x)
        attended_features = self.attention(features)
        logits = self.classifier(attended_features)
        return logits, attended_features

# Initialize model
num_classes = len(dog_breeds)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize base model
model = BaseModel(num_classes=num_classes, device=device).to(device)

# Load model path
model_path = "124_best_model_dog.pth"
checkpoint = torch.load(model_path, map_location=device)

# Load model state
model.load_state_dict(checkpoint["base_model"], strict=False)
model.eval()

# Image preprocessing function
def preprocess_image(image):
    # If the image is numpy.ndarray turn into PIL.Image
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)

    # Use torchvision.transforms to process images
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    return transform(image).unsqueeze(0)

async def predict_single_dog(image):
    """
    Predicts the dog breed using only the classifier.
    Args:
        image: PIL Image or numpy array
    Returns:
        tuple: (top1_prob, topk_breeds, relative_probs)
    """
    image_tensor = preprocess_image(image).to(device)
    
    with torch.no_grad():
        # Get model outputs (只使用logits,不需要features)
        logits = model(image_tensor)[0]  # 如果model仍返回tuple,取第一個元素
        probs = F.softmax(logits, dim=1)
        
        # Classifier prediction
        top5_prob, top5_idx = torch.topk(probs, k=5)
        breeds = [dog_breeds[idx.item()] for idx in top5_idx[0]]
        probabilities = [prob.item() for prob in top5_prob[0]]
        
        # Calculate relative probabilities
        sum_probs = sum(probabilities[:3])  # 只取前三個來計算相對概率
        relative_probs = [f"{(prob/sum_probs * 100):.2f}%" for prob in probabilities[:3]]
        
        # Debug output
        print("\nClassifier Predictions:")
        for breed, prob in zip(breeds[:5], probabilities[:5]):
            print(f"{breed}: {prob:.4f}")
            
        return probabilities[0], breeds[:3], relative_probs


async def detect_multiple_dogs(image, conf_threshold=0.3, iou_threshold=0.55):
    results = model_yolo(image, conf=conf_threshold, iou=iou_threshold)[0]
    dogs = []
    boxes = []
    for box in results.boxes:
        if box.cls == 16:  # COCO dataset class for dog is 16
            xyxy = box.xyxy[0].tolist()
            confidence = box.conf.item()
            boxes.append((xyxy, confidence))

    if not boxes:
        dogs.append((image, 1.0, [0, 0, image.width, image.height]))
    else:
        nms_boxes = non_max_suppression(boxes, iou_threshold)

        for box, confidence in nms_boxes:
            x1, y1, x2, y2 = box
            w, h = x2 - x1, y2 - y1
            x1 = max(0, x1 - w * 0.05)
            y1 = max(0, y1 - h * 0.05)
            x2 = min(image.width, x2 + w * 0.05)
            y2 = min(image.height, y2 + h * 0.05)
            cropped_image = image.crop((x1, y1, x2, y2))
            dogs.append((cropped_image, confidence, [x1, y1, x2, y2]))

    return dogs

def non_max_suppression(boxes, iou_threshold):
    keep = []
    boxes = sorted(boxes, key=lambda x: x[1], reverse=True)
    while boxes:
        current = boxes.pop(0)
        keep.append(current)
        boxes = [box for box in boxes if calculate_iou(current[0], box[0]) < iou_threshold]
    return keep


def calculate_iou(box1, box2):
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])

    intersection = max(0, x2 - x1) * max(0, y2 - y1)
    area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])

    iou = intersection / float(area1 + area2 - intersection)
    return iou



def create_breed_comparison(breed1: str, breed2: str) -> dict:
    breed1_info = get_dog_description(breed1)
    breed2_info = get_dog_description(breed2)

    # 標準化數值轉換
    value_mapping = {
        'Size': {'Small': 1, 'Medium': 2, 'Large': 3, 'Giant': 4},
        'Exercise_Needs': {'Low': 1, 'Moderate': 2, 'High': 3, 'Very High': 4},
        'Care_Level': {'Low': 1, 'Moderate': 2, 'High': 3},
        'Grooming_Needs': {'Low': 1, 'Moderate': 2, 'High': 3}
    }

    comparison_data = {
        breed1: {},
        breed2: {}
    }

    for breed, info in [(breed1, breed1_info), (breed2, breed2_info)]:
        comparison_data[breed] = {
            'Size': value_mapping['Size'].get(info['Size'], 2),  # 預設 Medium
            'Exercise_Needs': value_mapping['Exercise_Needs'].get(info['Exercise Needs'], 2),  # 預設 Moderate
            'Care_Level': value_mapping['Care_Level'].get(info['Care Level'], 2),
            'Grooming_Needs': value_mapping['Grooming_Needs'].get(info['Grooming Needs'], 2),
            'Good_with_Children': info['Good with Children'] == 'Yes',
            'Original_Data': info
        }

    return comparison_data


async def predict(image):
    """
    Main prediction function that handles both single and multiple dog detection.

    Args:
        image: PIL Image or numpy array

    Returns:
        tuple: (html_output, annotated_image, initial_state)
    """
    if image is None:
        return format_warning_html("Please upload an image to start."), None, None

    try:
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)

        # Detect dogs in the image
        dogs = await detect_multiple_dogs(image)
        color_scheme = get_color_scheme(len(dogs) == 1)

        # Prepare for annotation
        annotated_image = image.copy()
        draw = ImageDraw.Draw(annotated_image)

        try:
            font = ImageFont.truetype("arial.ttf", 24)
        except:
            font = ImageFont.load_default()

        dogs_info = ""

        # Process each detected dog
        for i, (cropped_image, detection_confidence, box) in enumerate(dogs):
            color = color_scheme if len(dogs) == 1 else color_scheme[i % len(color_scheme)]

            # Draw box and label on image
            draw.rectangle(box, outline=color, width=4)
            label = f"Dog {i+1}"
            label_bbox = draw.textbbox((0, 0), label, font=font)
            label_width = label_bbox[2] - label_bbox[0]
            label_height = label_bbox[3] - label_bbox[1]

            # Draw label background and text
            label_x = box[0] + 5
            label_y = box[1] + 5
            draw.rectangle(
                [label_x - 2, label_y - 2, label_x + label_width + 4, label_y + label_height + 4],
                fill='white',
                outline=color,
                width=2
            )
            draw.text((label_x, label_y), label, fill=color, font=font)

            # Predict breed
            top1_prob, topk_breeds, relative_probs = await predict_single_dog(cropped_image)
            combined_confidence = detection_confidence * top1_prob

            # Format results based on confidence with error handling
            try:
                if combined_confidence < 0.2:
                    dogs_info += format_error_message(color, i+1)
                elif top1_prob >= 0.45:
                    breed = topk_breeds[0]
                    description = get_dog_description(breed)
                    # Handle missing breed description
                    if description is None:
                        # 如果沒有描述,創建一個基本描述
                        description = {
                            "Name": breed,
                            "Size": "Unknown",
                            "Exercise Needs": "Unknown",
                            "Grooming Needs": "Unknown",
                            "Care Level": "Unknown",
                            "Good with Children": "Unknown",
                            "Description": f"Identified as {breed.replace('_', ' ')}"
                        }
                    dogs_info += format_single_dog_result(breed, description, color)
                else:
                    # 修改format_multiple_breeds_result的調用,包含錯誤處理
                    dogs_info += format_multiple_breeds_result(
                        topk_breeds,
                        relative_probs,
                        color,
                        i+1,
                        lambda breed: get_dog_description(breed) or {
                            "Name": breed,
                            "Size": "Unknown",
                            "Exercise Needs": "Unknown",
                            "Grooming Needs": "Unknown",
                            "Care Level": "Unknown",
                            "Good with Children": "Unknown",
                            "Description": f"Identified as {breed.replace('_', ' ')}"
                        }
                    )
            except Exception as e:
                print(f"Error formatting results for dog {i+1}: {str(e)}")
                dogs_info += format_error_message(color, i+1)

        # Wrap final HTML output
        html_output = format_multi_dog_container(dogs_info)

        # Prepare initial state
        initial_state = {
            "dogs_info": dogs_info,
            "image": annotated_image,
            "is_multi_dog": len(dogs) > 1,
            "html_output": html_output
        }

        return html_output, annotated_image, initial_state

    except Exception as e:
        error_msg = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
        print(error_msg)
        return format_warning_html(error_msg), None, None


def show_details_html(choice, previous_output, initial_state):
    """
    Generate detailed HTML view for a selected breed.

    Args:
        choice: str, Selected breed option
        previous_output: str, Previous HTML output
        initial_state: dict, Current state information

    Returns:
        tuple: (html_output, gradio_update, updated_state)
    """
    if not choice:
        return previous_output, gr.update(visible=True), initial_state

    try:
        breed = choice.split("More about ")[-1]
        description = get_dog_description(breed)
        html_output = format_breed_details_html(description, breed)

        # Update state
        initial_state["current_description"] = html_output
        initial_state["original_buttons"] = initial_state.get("buttons", [])

        return html_output, gr.update(visible=True), initial_state

    except Exception as e:
        error_msg = f"An error occurred while showing details: {e}"
        print(error_msg)
        return format_warning_html(error_msg), gr.update(visible=True), initial_state

def main():
    with gr.Blocks(css=get_css_styles()) as iface:
        # Header HTML

        gr.HTML("""
        <header style='text-align: center; padding: 20px; margin-bottom: 20px;'>
            <h1 style='font-size: 2.5em; margin-bottom: 10px; color: #2D3748;'>
                🐾 PawMatch AI
            </h1>
            <h2 style='font-size: 1.2em; font-weight: normal; color: #4A5568; margin-top: 5px;'>
                Your Smart Dog Breed Guide
            </h2>
            <div style='width: 50px; height: 3px; background: linear-gradient(90deg, #4299e1, #48bb78); margin: 15px auto;'></div>
            <p style='color: #718096; font-size: 0.9em;'>
                Powered by AI • Breed Recognition • Smart Matching • Companion Guide
            </p>
        </header>
        """)

        # 先創建歷史組件實例(但不創建標籤頁)
        history_component = create_history_component()

        with gr.Tabs():
            # 1. 品種檢測標籤頁
            example_images = [
                'Border_Collie.jpg',
                'Golden_Retriever.jpeg',
                'Saint_Bernard.jpeg',
                'Samoyed.jpg',
                'French_Bulldog.jpeg'
            ]
            detection_components = create_detection_tab(predict, example_images)

            # 2. 品種比較標籤頁
            comparison_components = create_comparison_tab(
                dog_breeds=dog_breeds,
                get_dog_description=get_dog_description,
                breed_health_info=breed_health_info,
                breed_noise_info=breed_noise_info
            )

            # 3. 品種推薦標籤頁
            recommendation_components = create_recommendation_tab(
                UserPreferences=UserPreferences,
                get_breed_recommendations=get_breed_recommendations,
                format_recommendation_html=format_recommendation_html,
                history_component=history_component
            )


            # 4. 最後創建歷史記錄標籤頁
            create_history_tab(history_component)

        # Footer
        gr.HTML('''
            <div style="
                display: flex;
                align-items: center;
                justify-content: center;
                gap: 20px;
                padding: 20px 0;
            ">
                <p style="
                    font-family: 'Arial', sans-serif;
                    font-size: 14px;
                    font-weight: 500;
                    letter-spacing: 2px;
                    background: linear-gradient(90deg, #555, #007ACC);
                    -webkit-background-clip: text;
                    -webkit-text-fill-color: transparent;
                    margin: 0;
                    text-transform: uppercase;
                    display: inline-block;
                ">EXPLORE THE CODE →</p>
                <a href="https://github.com/Eric-Chung-0511/Learning-Record/tree/main/Data%20Science%20Projects/PawMatchAI" style="text-decoration: none;">
                    <img src="https://img.shields.io/badge/GitHub-PawMatch_AI-007ACC?logo=github&style=for-the-badge">
                </a>
            </div>
        ''')

    return iface

if __name__ == "__main__":
    iface = main()
    iface.launch()