import numpy as np
import math
import types
import torch
import torch.nn as nn
import numpy as np
import cv2
import re
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
from PIL import Image

def extract_first_sentence(text):
    end_index = text.find('.')
    if end_index != -1:
        first_sentence = text[:end_index + 1]
        return first_sentence.strip()
    else:
        return text.strip()
    
import re
def remove_duplicate_keywords(text, keywords):
    keyword_counts = {}

    words = re.findall(r'\b\w+\b|[.,;!?]', text)

    for keyword in keywords:
        keyword_counts[keyword] = 0
        for i, word in enumerate(words):
            if word.lower() == keyword.lower():
                keyword_counts[keyword] += 1
                if keyword_counts[keyword] > 1:
                    words[i] = ""
    processed_text = " ".join(words)

    return processed_text

def process_text_with_markers(text, parsing_mask_list):
    keywords = ["face", "ears", "eyes", "nose", "mouth"]
    text = remove_duplicate_keywords(text, keywords)
    key_parsing_mask_markers = ["Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Nose", "Upper_Lip", "Lower_Lip"]
    mapping = {
        "Face": "face",
        "Left_Ear": "ears",
        "Right_Ear": "ears",
        "Left_Eye": "eyes",
        "Right_Eye": "eyes",
        "Nose": "nose",
        "Upper_Lip": "mouth",
        "Lower_Lip": "mouth",
    }
    facial_features_align = []
    markers_align = []
    for key in key_parsing_mask_markers:
        if key in parsing_mask_list:
            mapped_key = mapping.get(key, key.lower())
            if mapped_key not in facial_features_align:
                facial_features_align.append(mapped_key)
                markers_align.append("<|"+mapped_key+"|>")

    text_marked = text
    align_parsing_mask_list = parsing_mask_list
    for feature, marker in zip(facial_features_align[::-1], markers_align[::-1]):
        pattern = rf'\b{feature}\b'  
        text_marked_new = re.sub(pattern, f'{feature} {marker}', text_marked, count=1)
        if text_marked == text_marked_new:
            for key, value in mapping.items():
                if value == feature:
                    if key in align_parsing_mask_list:
                        del align_parsing_mask_list[key]   

        text_marked = text_marked_new 

    text_marked = text_marked.replace('\n', '')

    ordered_text = []
    text_none_makers = []
    facial_marked_count = 0
    skip_count = 0
    for marker in markers_align:
        start_idx = text_marked.find(marker)
        end_idx = start_idx + len(marker)

        while start_idx > 0 and text_marked[start_idx - 1] not in [",", ".", ";"]:
            start_idx -= 1

        while end_idx < len(text_marked) and text_marked[end_idx] not in [",", ".", ";"]:
            end_idx += 1

        context = text_marked[start_idx:end_idx].strip()
        if context == "":
            text_none_makers.append(text_marked[:end_idx])
        else:
            if skip_count!=0:
                skip_count -= 1 
                continue
            else:
                ordered_text.append(context + ",") 
                text_delete_makers = text_marked[:start_idx] + text_marked[end_idx:]
                text_marked = text_delete_makers
                facial_marked_count += 1

    align_marked_text = " ".join(ordered_text)
    replace_list = ["<|face|>", "<|ears|>", "<|nose|>", "<|eyes|>", "<|mouth|>"] 
    for item in replace_list:
        align_marked_text = align_marked_text.replace(item, "<|facial|>")

    return align_marked_text, align_parsing_mask_list

def tokenize_and_mask_noun_phrases_ends(text, image_token_id, facial_token_id, tokenizer):
    input_ids = tokenizer.encode(text) 
    image_noun_phrase_end_mask = [False for _ in input_ids] 
    facial_noun_phrase_end_mask = [False for _ in input_ids]
    clean_input_ids = []
    clean_index = 0
    image_num = 0

    for i, id in enumerate(input_ids):
        if id == image_token_id:
            image_noun_phrase_end_mask[clean_index + image_num - 1] = True
            image_num += 1
        elif id == facial_token_id:
            facial_noun_phrase_end_mask[clean_index - 1] = True   
        else:
            clean_input_ids.append(id)
            clean_index += 1

    max_len = tokenizer.model_max_length 

    if len(clean_input_ids) > max_len:
        clean_input_ids = clean_input_ids[:max_len]
    else:
        clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * (
            max_len - len(clean_input_ids)
        )

    if len(image_noun_phrase_end_mask) > max_len: 
        image_noun_phrase_end_mask = image_noun_phrase_end_mask[:max_len]
    else:
        image_noun_phrase_end_mask = image_noun_phrase_end_mask + [False] * (
            max_len - len(image_noun_phrase_end_mask)
        )

    if len(facial_noun_phrase_end_mask) > max_len: 
        facial_noun_phrase_end_mask = facial_noun_phrase_end_mask[:max_len]
    else:
        facial_noun_phrase_end_mask = facial_noun_phrase_end_mask + [False] * (
            max_len - len(facial_noun_phrase_end_mask)
        )        
    clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long)
    image_noun_phrase_end_mask = torch.tensor(image_noun_phrase_end_mask, dtype=torch.bool)
    facial_noun_phrase_end_mask = torch.tensor(facial_noun_phrase_end_mask, dtype=torch.bool)
    
    return clean_input_ids.unsqueeze(0), image_noun_phrase_end_mask.unsqueeze(0), facial_noun_phrase_end_mask.unsqueeze(0)

def prepare_image_token_idx(image_token_mask, facial_token_mask, max_num_objects=2, max_num_facials=5):
    image_token_idx = torch.nonzero(image_token_mask, as_tuple=True)[1] 
    image_token_idx_mask = torch.ones_like(image_token_idx, dtype=torch.bool) 
    if len(image_token_idx) < max_num_objects: 
        image_token_idx = torch.cat(
            [ 
                image_token_idx,
                torch.zeros(max_num_objects - len(image_token_idx), dtype=torch.long),
            ]
        )
        image_token_idx_mask = torch.cat( 
            [ 
                image_token_idx_mask,
                torch.zeros(
                    max_num_objects - len(image_token_idx_mask),
                    dtype=torch.bool,
                ),
            ]
        )
    facial_token_idx = torch.nonzero(facial_token_mask, as_tuple=True)[1]
    facial_token_idx_mask = torch.ones_like(facial_token_idx, dtype=torch.bool)     
    if len(facial_token_idx) < max_num_facials:
        facial_token_idx = torch.cat(
            [ 
                facial_token_idx,
                torch.zeros(max_num_facials - len(facial_token_idx), dtype=torch.long),
            ]
        )
        facial_token_idx_mask = torch.cat(
            [ 
                facial_token_idx_mask,
                torch.zeros(
                    max_num_facials - len(facial_token_idx_mask),
                    dtype=torch.bool,
                ),
            ]
        )
    image_token_idx = image_token_idx.unsqueeze(0)
    image_token_idx_mask = image_token_idx_mask.unsqueeze(0)
    
    facial_token_idx = facial_token_idx.unsqueeze(0)
    facial_token_idx_mask = facial_token_idx_mask.unsqueeze(0)

    return image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask

def get_object_localization_loss_for_one_layer(
    cross_attention_scores,
    object_segmaps,
    object_token_idx,
    object_token_idx_mask,
    loss_fn,
):
    bxh, num_noise_latents, num_text_tokens = cross_attention_scores.shape
    b, max_num_objects, _, _ = object_segmaps.shape
    size = int(num_noise_latents**0.5)

    object_segmaps = F.interpolate(object_segmaps, size=(size, size), mode="bilinear", antialias=True)

    object_segmaps = object_segmaps.view(
        b, max_num_objects, -1
    )

    num_heads = bxh // b
    cross_attention_scores = cross_attention_scores.view(b, num_heads, num_noise_latents, num_text_tokens)

    
    object_token_attn_prob = torch.gather(
        cross_attention_scores,
        dim=3,
        index=object_token_idx.view(b, 1, 1, max_num_objects).expand(
            b, num_heads, num_noise_latents, max_num_objects
        ),
    )
    object_segmaps = (
        object_segmaps.permute(0, 2, 1)
        .unsqueeze(1)
        .expand(b, num_heads, num_noise_latents, max_num_objects)
    )
    loss = loss_fn(object_token_attn_prob, object_segmaps)

    loss = loss * object_token_idx_mask.view(b, 1, max_num_objects)
    object_token_cnt = object_token_idx_mask.sum(dim=1).view(b, 1) + 1e-5
    loss = (loss.sum(dim=2) / object_token_cnt).mean()

    return loss


def get_object_localization_loss(
    cross_attention_scores,
    object_segmaps,
    image_token_idx,
    image_token_idx_mask,
    loss_fn,
):  
    num_layers = len(cross_attention_scores)
    loss = 0
    for k, v in cross_attention_scores.items():
        layer_loss = get_object_localization_loss_for_one_layer(
            v, object_segmaps, image_token_idx, image_token_idx_mask, loss_fn
        )
        loss += layer_loss
    return loss / num_layers

def unet_store_cross_attention_scores(unet, attention_scores, layers=5):
    from diffusers.models.attention_processor import Attention

    UNET_LAYER_NAMES = [ 
        "down_blocks.0",
        "down_blocks.1",
        "down_blocks.2",
        "mid_block",
        "up_blocks.1",
        "up_blocks.2",
        "up_blocks.3",
    ]

    start_layer = (len(UNET_LAYER_NAMES) - layers) // 2
    end_layer = start_layer + layers   
    applicable_layers = UNET_LAYER_NAMES[start_layer:end_layer]

    def make_new_get_attention_scores_fn(name):
        def new_get_attention_scores(module, query, key, attention_mask=None):
            attention_probs = module.old_get_attention_scores(
                query, key, attention_mask
            )
            attention_scores[name] = attention_probs
            return attention_probs

        return new_get_attention_scores 

    for name, module in unet.named_modules():
        if isinstance(module, Attention) and "attn1" in name:
            if not any(layer in name for layer in applicable_layers):
                continue
 
            module.old_get_attention_scores = module.get_attention_scores
            module.get_attention_scores = types.MethodType(
                make_new_get_attention_scores_fn(name), module
            )
    return unet
    
class BalancedL1Loss(nn.Module):
    def __init__(self, threshold=1.0, normalize=False):
        super().__init__()
        self.threshold = threshold
        self.normalize = normalize

    def forward(self, object_token_attn_prob, object_segmaps):
        if self.normalize:
            object_token_attn_prob = object_token_attn_prob / (
                object_token_attn_prob.max(dim=2, keepdim=True)[0] + 1e-5
            )
        background_segmaps = 1 - object_segmaps
        background_segmaps_sum = background_segmaps.sum(dim=2) + 1e-5
        object_segmaps_sum = object_segmaps.sum(dim=2) + 1e-5

        background_loss = (object_token_attn_prob * background_segmaps).sum(
            dim=2
        ) / background_segmaps_sum

        object_loss = (object_token_attn_prob * object_segmaps).sum(
            dim=2
        ) / object_segmaps_sum

        return background_loss - object_loss

def fetch_mask_raw_image(raw_image, mask_image):

    mask_image = mask_image.resize(raw_image.size)
    mask_raw_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (0, 0, 0)), mask_image) 

    return mask_raw_image

mapping_table = [
    {"Mask Value": 0, "Body Part": "Background", "RGB Color": [0, 0, 0]},
    {"Mask Value": 1, "Body Part": "Face", "RGB Color": [255, 0, 0]},
    {"Mask Value": 2, "Body Part": "Left_Eyebrow", "RGB Color": [255, 85, 0]},
    {"Mask Value": 3, "Body Part": "Right_Eyebrow", "RGB Color": [255, 170, 0]},
    {"Mask Value": 4, "Body Part": "Left_Eye", "RGB Color": [255, 0, 85]},
    {"Mask Value": 5, "Body Part": "Right_Eye", "RGB Color": [255, 0, 170]},
    {"Mask Value": 6, "Body Part": "Hair", "RGB Color": [0, 0, 255]},
    {"Mask Value": 7, "Body Part": "Left_Ear", "RGB Color": [85, 0, 255]},
    {"Mask Value": 8, "Body Part": "Right_Ear", "RGB Color": [170, 0, 255]},
    {"Mask Value": 9, "Body Part": "Mouth_External Contour", "RGB Color": [0, 255, 85]},
    {"Mask Value": 10, "Body Part": "Nose", "RGB Color": [0, 255, 0]},
    {"Mask Value": 11, "Body Part": "Mouth_Inner_Contour", "RGB Color": [0, 255, 170]},
    {"Mask Value": 12, "Body Part": "Upper_Lip", "RGB Color": [85, 255, 0]}, 
    {"Mask Value": 13, "Body Part": "Lower_Lip", "RGB Color": [170, 255, 0]},
    {"Mask Value": 14, "Body Part": "Neck", "RGB Color": [0, 85, 255]},
    {"Mask Value": 15, "Body Part": "Neck_Inner Contour", "RGB Color": [0, 170, 255]},
    {"Mask Value": 16, "Body Part": "Cloth", "RGB Color": [255, 255, 0]},
    {"Mask Value": 17, "Body Part": "Hat", "RGB Color": [255, 0, 255]},
    {"Mask Value": 18, "Body Part": "Earring", "RGB Color": [255, 85, 255]},
    {"Mask Value": 19, "Body Part": "Necklace", "RGB Color": [255, 255, 85]},
    {"Mask Value": 20, "Body Part": "Glasses", "RGB Color": [255, 170, 255]},
    {"Mask Value": 21, "Body Part": "Hand", "RGB Color": [255, 0, 255]},
    {"Mask Value": 22, "Body Part": "Wristband", "RGB Color": [0, 255, 255]},
    {"Mask Value": 23, "Body Part": "Clothes_Upper", "RGB Color": [85, 255, 255]},
    {"Mask Value": 24, "Body Part": "Clothes_Lower", "RGB Color": [170, 255, 255]}
]


def masks_for_unique_values(image_raw_mask):

    image_array = np.array(image_raw_mask)
    unique_values, counts = np.unique(image_array, return_counts=True)
    masks_dict = {}
    for value in unique_values:
        binary_image = np.uint8(image_array == value) * 255
        contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        mask = np.zeros_like(image_array)
        for contour in contours:
            cv2.drawContours(mask, [contour], -1, (255), thickness=cv2.FILLED)
        
        if value == 0: 
            body_part="WithoutBackground"
            mask2 = np.where(mask == 255, 0, 255).astype(mask.dtype)
            masks_dict[body_part] = Image.fromarray(mask2)
        
        body_part = next((entry["Body Part"] for entry in mapping_table if entry["Mask Value"] == value), f"Unknown_{value}")
        if body_part.startswith("Unknown_"):
            continue            

        masks_dict[body_part] = Image.fromarray(mask)
    
    return masks_dict
# FFN
def FeedForward(dim, mult=4):
    inner_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim, bias=False),
        nn.GELU(),
        nn.Linear(inner_dim, dim, bias=False),
    )


def reshape_tensor(x, heads):
    bs, length, width = x.shape
    x = x.view(bs, length, heads, -1)
    x = x.transpose(1, 2)
    x = x.reshape(bs, heads, length, -1)
    return x

class PerceiverAttention(nn.Module):
    def __init__(self, *, dim, dim_head=64, heads=8):
        super().__init__()
        self.scale = dim_head**-0.5
        self.dim_head = dim_head
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

    def forward(self, x, latents):
        """
        Args:
            x (torch.Tensor): image features
                shape (b, n1, D)
            latent (torch.Tensor): latent features
                shape (b, n2, D)
        """

        x = self.norm1(x)
        latents = self.norm2(latents)

        b, l, _ = latents.shape

        q = self.to_q(latents)
        kv_input = torch.cat((x, latents), dim=-2)
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)

        q = reshape_tensor(q, self.heads)
        k = reshape_tensor(k, self.heads)
        v = reshape_tensor(v, self.heads)

        # attention
        scale = 1 / math.sqrt(math.sqrt(self.dim_head))
        weight = (q * scale) @ (k * scale).transpose(-2, -1)
        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
        out = weight @ v

        out = out.permute(0, 2, 1, 3).reshape(b, l, -1)

        return self.to_out(out)

class FacePerceiverResampler(torch.nn.Module):
    def __init__(
        self,
        *,
        dim=768,
        depth=4,
        dim_head=64,
        heads=16,
        embedding_dim=1280,
        output_dim=768,
        ff_mult=4,
    ):
        super().__init__()
        
        self.proj_in = torch.nn.Linear(embedding_dim, dim)
        self.proj_out = torch.nn.Linear(dim, output_dim)
        self.norm_out = torch.nn.LayerNorm(output_dim)
        self.layers = torch.nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                torch.nn.ModuleList(
                    [
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )
    def forward(self, latents, x): # latents.torch.Size([2, 4, 768])  x.torch.Size([2, 257, 1280])
        x = self.proj_in(x) # x.torch.Size([2, 257, 768])
        for attn, ff in self.layers:
            latents = attn(x, latents) + latents # latents.torch.Size([2, 4, 768])
            latents = ff(latents) + latents # latents.torch.Size([2, 4, 768])
        latents = self.proj_out(latents)
        return self.norm_out(latents)
  
class ProjPlusModel(torch.nn.Module):
    def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
        super().__init__()
        
        self.cross_attention_dim = cross_attention_dim
        self.num_tokens = num_tokens
        
        self.proj = torch.nn.Sequential(
            torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
            torch.nn.GELU(),
            torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
        )
        self.norm = torch.nn.LayerNorm(cross_attention_dim)
        
        self.perceiver_resampler = FacePerceiverResampler(
            dim=cross_attention_dim,
            depth=4,
            dim_head=64,
            heads=cross_attention_dim // 64,
            embedding_dim=clip_embeddings_dim,
            output_dim=cross_attention_dim,
            ff_mult=4,
        )
        
    def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):

        x = self.proj(id_embeds)
        x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
        x = self.norm(x) 
        out = self.perceiver_resampler(x, clip_embeds) 
        if shortcut:
            out = scale * x +  out
        return out 
    
class AttentionMLP(nn.Module):
    def __init__(
        self,
        dtype=torch.float16,
        dim=1024,
        depth=8,
        dim_head=64,
        heads=16,
        single_num_tokens=1,
        embedding_dim=1280,
        output_dim=768,
        ff_mult=4,
        max_seq_len: int = 257*2,
        apply_pos_emb: bool = False,
        num_latents_mean_pooled: int = 0,
    ):
        super().__init__()
        self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None

        self.single_num_tokens = single_num_tokens
        self.latents = nn.Parameter(torch.randn(1, self.single_num_tokens, dim) / dim**0.5)

        self.proj_in = nn.Linear(embedding_dim, dim)

        self.proj_out = nn.Linear(dim, output_dim)
        self.norm_out = nn.LayerNorm(output_dim)

        self.to_latents_from_mean_pooled_seq = (
            nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, dim * num_latents_mean_pooled),
                Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
            )
            if num_latents_mean_pooled > 0
            else None
        )

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )

    def forward(self, x):
        if self.pos_emb is not None:
            n, device = x.shape[1], x.device
            pos_emb = self.pos_emb(torch.arange(n, device=device))
            x = x + pos_emb
        # x torch.Size([5, 257, 1280])
        latents = self.latents.repeat(x.size(0), 1, 1)

        x = self.proj_in(x) # torch.Size([5, 257, 1024])

        if self.to_latents_from_mean_pooled_seq:
            meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
            latents = torch.cat((meanpooled_latents, latents), dim=-2)

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        latents = self.proj_out(latents)
        return self.norm_out(latents)


def masked_mean(t, *, dim, mask=None):
    if mask is None:
        return t.mean(dim=dim)

    denom = mask.sum(dim=dim, keepdim=True)
    mask = rearrange(mask, "b n -> b n 1")
    masked_t = t.masked_fill(~mask, 0.0)

    return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)