import numpy as np
import torch
from diffusers import  DDIMScheduler
import cv2
from utils.sdxl import sdxl
from utils.inversion import Inversion
import math
import torch.nn.functional as F
import utils.utils as utils
import os 
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
import spaces

MAX_NUM_WORDS = 77


class LayerFusion:   
    def get_mask(self, maps, alpha, use_pool,x_t):
        k = 1
        maps = (maps * alpha).sum(-1).mean(1)
        if use_pool:
            maps = F.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
        mask = F.interpolate(maps, size=(x_t.shape[2:])) #[2, 1, 128, 128]
        mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
        mask=(mask - mask.min ()) / (mask.max () - mask.min ())
        mask = mask.gt(self.mask_threshold)
        self.mask=mask
        mask = mask[:1] + mask
        return mask 

    def get_one_mask(self, maps, use_pool, x_t, idx_lst, i=None, sav_img=False):
        k=1
        if sav_img is False:
            mask_tot = 0
            for obj in idx_lst:
                mask = maps[0, :, :, :, obj].mean(0).reshape(1, 1, 32, 32)
                if use_pool:
                    mask = F.max_pool2d(mask, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
                mask = F.interpolate(mask, size=(x_t.shape[2:]))
                mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
                mask=(mask - mask.min ()) / (mask.max () - mask.min ())
                mask = mask.gt(self.mask_threshold[int(self.counter/10)])
                mask_tot |= mask
            mask = mask_tot  
            return mask
        else: 
            for obj in idx_lst:
                mask = maps[0, :, :, :, obj].mean(0).reshape(1, 1, 32, 32)
                if use_pool:
                    mask = F.max_pool2d(mask, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
                mask = F.interpolate(mask, size=(1024, 1024))#[1, 1, 1024, 1024]
                mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
                mask=(mask - mask.min ()) / (mask.max () - mask.min ())
                mask = mask.gt(0.6)  
                mask = np.array(mask[0][0].clone().cpu()).astype(np.uint8)*255
                cv2.imwrite(f'./img/sam_mask/{self.blend_list[i][0]}_{self.counter}.jpg', mask)
        return mask

    def mv_op(self, mp, op, scale=0.2, ones=False, flip=None):
        _, b, H, W = mp.shape
        if ones == False:
            new_mp = torch.zeros_like(mp)
        else:
            new_mp = torch.ones_like(mp)
        K = int(scale*W)
        if op == 'right':
            new_mp[:, :, :, K:] = mp[:, :, :, 0:W-K]
        elif op == 'left':
            new_mp[:, :, :, 0:W-K] = mp[:, :, :, K:]
        elif op == 'down':
            new_mp[:, :, K:, :] = mp[:, :, 0:W-K, :]
        elif op == 'up':
            new_mp[:, :, 0:W-K, :] = mp[:, :, K:, :]
        if flip is not None:
            new_mp = torch.flip(new_mp, dims=flip)
               
        return new_mp

    def mv_layer(self, x_t, bg_id, fg_id, op_id):
        bg_img = x_t[bg_id:(bg_id+1)].clone()
        fg_img = x_t[fg_id:(fg_id+1)].clone()
        fg_mask = self.fg_mask_list[fg_id-3]
        op_list = self.op_list[fg_id-3]

        for item in op_list:
            op, scale = item[0], item[1]
            if scale != 0:
                fg_img = self.mv_op(fg_img, op=op, scale=scale)
                fg_mask = self.mv_op(fg_mask, op=op, scale=scale)
        x_t[op_id:(op_id+1)] = bg_img*(1-fg_mask) + fg_img*fg_mask

    def __call__(self, x_t):
        self.counter += 1
        # inpainting
        if self.blend_time[0] <= self.counter <= self.blend_time[1]:
            x_t[1:2] = x_t[1:2]*self.remove_mask + x_t[0:1]*(1-self.remove_mask) 

        if self.counter == self.blend_time[1] + 1 and self.mode != "removal":
            b = x_t.shape[0]
            bg_id = 1 #bg_layer
            op_id = 2 #canvas
            for fg_id in range(3, b): #fg_layer
                self.mv_layer(x_t, bg_id=bg_id, fg_id=fg_id, op_id=op_id)
                bg_id = op_id
    
        return x_t

    def __init__(self, remove_mask, fg_mask_list, refine_mask=None, 
                blend_time=[0, 40],
                 mode="removal", op_list=None):
        self.counter = 0
        self.mode = mode
        self.op_list = op_list
        self.blend_time = blend_time

        self.remove_mask = remove_mask
        self.refine_mask = refine_mask
        if self.refine_mask is not None:
            self.new_mask = self.remove_mask + self.refine_mask
            self.new_mask[self.new_mask>0] = 1
        else:
            self.new_mask = None
        self.fg_mask_list = fg_mask_list


class Control():
    def step_callback(self, x_t):
        if self.layer_fusion is not None:
             x_t = self.layer_fusion(x_t)
        return x_t
    def __init__(self, layer_fusion):
        self.layer_fusion = layer_fusion

def register_attention_control(model, controller, mask_time=[0, 40], refine_time=[0, 25]):
    def ca_forward(self, place_in_unet):
        to_out = self.to_out
        if type(to_out) is torch.nn.modules.container.ModuleList:
            to_out = self.to_out[0]
        else:
            to_out = self.to_out
        self.counter = 0 #time
        def forward(hidden_states, encoder_hidden_states=None, attention_mask=None): #self_attention
            x = hidden_states.clone() 
            context = encoder_hidden_states
            is_cross = context is not None
            if is_cross is False:
                if controller.layer_fusion is not None and (mask_time[0] < self.counter < mask_time[1]):
                    b, i, j = x.shape
                    H = W = int(math.sqrt(i))
                    x_old = x.clone()
                    x = x.reshape(b, H, W, j)
                    new_mask = controller.layer_fusion.remove_mask
                    if new_mask is not None:
                        new_mask[new_mask>0] = 1
                        new_mask = F.interpolate(new_mask.to(dtype=torch.float32).clone(), size=(H, W), mode='bilinear').cuda()
                        new_mask =  (1 - new_mask).reshape(1, H, W).unsqueeze(-1)
                        if (refine_time[0] < self.counter <= refine_time[1]) and controller.layer_fusion.refine_mask is not None:
                            new_mask = controller.layer_fusion.new_mask
                            new_mask = F.interpolate(new_mask.to(dtype=torch.float32).clone(), size=(H, W), mode='bilinear').cuda()
                            new_mask =  (1 - new_mask).reshape(1, H, W).unsqueeze(-1)                
                        idx = 1 #inpaiint_idx:bg
                        x[int(b/2)+idx, :, :] = (x[int(b/2)+idx, :, :]*new_mask[0])
                    x = x.reshape(b, i, j)
            if is_cross: 
                q = self.to_q(x) 
                k = self.to_k(context)
                v = self.to_v(context)
            else:
                context = x
                q = self.to_q(hidden_states) 
                k = self.to_k(x) 
                v = self.to_v(hidden_states)
            q = self.head_to_batch_dim(q)
            k = self.head_to_batch_dim(k)
            v = self.head_to_batch_dim(v)

            if hasattr(controller, 'count_layers'):
                controller.count_layers(place_in_unet,is_cross)
            sim = torch.einsum("b i d, b j d -> b i j", q.clone(), k.clone()) * self.scale 

            attn = sim.softmax(dim=-1)
            out = torch.einsum("b i j, b j d -> b i d", attn, v)
            out = self.batch_to_head_dim(out)
            global global_cnt
            self.counter += 1
            return to_out(out)
        
        return forward

    def register_recr(net_, count, place_in_unet):
        if net_.__class__.__name__ == 'Attention':
            net_.forward = ca_forward(net_, place_in_unet)
            return count + 1
        elif hasattr(net_, 'children'):
            for net__ in net_.children():
                count = register_recr(net__, count, place_in_unet)
        return count

    cross_att_count = 0
    sub_nets = model.unet.named_children()
    for net in sub_nets:
        if "down" in net[0]:
            cross_att_count += register_recr(net[1], 0, "down")
        elif "up" in net[0]:
            cross_att_count += register_recr(net[1], 0, "up")
        elif "mid" in net[0]:
            cross_att_count += register_recr(net[1], 0, "mid")

    controller.num_att_layers = cross_att_count

class DesignEdit():
    def __init__(self, pretrained_model_path="/home/jyr/model/stable-diffusion-xl-base-1.0"):
        self.model_dtype = "fp16"
        self.pretrained_model_path=pretrained_model_path
        self.num_ddim_steps = 50
        self.mask_time = [0, 40]
        self.op_list = {}
        self.attend_scale = {}
        scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
        if self.model_dtype == "fp16":
            torch_dtype = torch.float16
        elif self.model_dtype == "fp32":
            torch_dtype = torch.float32
        self.pipe = sdxl.from_pretrained(self.pretrained_model_path, torch_dtype=torch_dtype, use_safetensors=True, variant=self.model_dtype,scheduler=scheduler)
   
    @spaces.GPU
    def init_model(self, num_ddim_steps=50):
        device = torch.device('cuda:0')
        self.pipe.to(device)
        inversion = Inversion(self.pipe,num_ddim_steps)
        return self.pipe, inversion
    
    @spaces.GPU(duration=120, enable_queue=True)
    def run_remove(self, original_image=None, mask_1=None, mask_2=None, mask_3=None, refine_mask=None, 
        ori_1=None, ori_2=None, ori_3=None,
        prompt="", save_dir="./tmp", mode='removal',):
        # 01-1: 
        self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps)
        if original_image is None:
            original_image = ori_1 if ori_1 is not None else ori_2 if ori_2 is not None else ori_3
        op_list = None
        attend_scale = 20
        sample_ref_match={0 : 0, 1 : 0}
        ori_shape = original_image.shape

        # 01-2: prepare: image_gt, remove_mask, fg_mask_list, refine_mask
        image_gt = Image.fromarray(original_image).resize((1024, 1024))
        image_gt = np.stack([np.array(image_gt)])
        mask_list = [mask_1, mask_2, mask_3]
        remove_mask = utils.attend_mask(utils.add_masks_resized(mask_list), attend_scale=attend_scale) # numpy to tensor
        fg_mask_list = None
        refine_mask = utils.attend_mask(utils.convert_and_resize_mask(refine_mask)) if refine_mask is not None else None

        # 01-3: prepare: prompts, blend_time, refine_time
        prompts = len(sample_ref_match)*[prompt] # 2
        blend_time = [0, 41]
        refine_time = [0, 25]
        
        # 02: invert
        _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=1)
        
        # 03: init layer_fusion and controller
        lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, refine_mask=refine_mask,
                    blend_time=blend_time, mode=mode, op_list=op_list)
        controller = Control(layer_fusion=lb)
        register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time)
        
        # 04: generate images
        images = self.ldm_model(controller=controller, prompt=prompts,
                        latents=x_t, x_stars=x_stars,  
                        negative_prompt_embeds=prompt_embeds, 
                        negative_pooled_prompt_embeds=pooled_prompt_embeds,
                        sample_ref_match=sample_ref_match)
        folder = None
        utils.view_images(images, folder=folder)
        return [cv2.resize(images[1], (ori_shape[1], ori_shape[0]))]
    
    @spaces.GPU(duration=120, enable_queue=True)
    def run_zooming(self, original_image, width_scale=1, height_scale=1, prompt="", save_dir="./tmp", mode='removal'):
        self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps)
        # 01-1: 
        op_list = {0: ['zooming', [height_scale, width_scale]]}
        ori_shape = original_image.shape
        attend_scale = 30
        sample_ref_match = {0 : 0, 1 : 0}

        # 01-2: prepare: image_gt, remove_mask, fg_mask_list, refine_mask
        img_new, mask = utils.zooming(original_image, [height_scale, width_scale])
        img_new_copy = img_new.copy()
        mask_copy = mask.copy()
        
        image_gt = Image.fromarray(img_new).resize((1024, 1024))
        image_gt = np.stack([np.array(image_gt)])

        remove_mask = utils.attend_mask(utils.convert_and_resize_mask(mask), attend_scale=attend_scale) # numpy to tensor
        fg_mask_list = None
        refine_mask = None

        # 01-3: prepare: prompts, blend_time, refine_time
        prompts = len(sample_ref_match)*[prompt] # 2
        blend_time = [0, 41]
        refine_time = [0, 25]

        # 02: invert
        _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=1)
        
        # 03: init layer_fusion and controller
        lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time,
                    mode=mode, op_list=op_list)
        controller = Control(layer_fusion=lb)
        register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time)
        
        # 04: generate images
        images = self.ldm_model(controller=controller, prompt=prompts,
                        latents=x_t, x_stars=x_stars,  
                        negative_prompt_embeds=prompt_embeds, 
                        negative_pooled_prompt_embeds=pooled_prompt_embeds,
                        sample_ref_match=sample_ref_match)
        folder = None
        utils.view_images(images, folder=folder)
        resized_img = cv2.resize(images[1], (ori_shape[1], ori_shape[0]))
        return [resized_img], [img_new_copy], [mask_copy]
    
    @spaces.GPU(duration=120, enable_queue=True)
    def run_panning(self, original_image, w_direction, w_scale, h_direction, h_scale, prompt="", save_dir="./tmp", mode='removal'):
        # 01-1: prepare: op_list, attend_scale, sample_ref_match
        self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps)
        ori_shape = original_image.shape
        attend_scale = 30
        sample_ref_match = {0 : 0, 1 : 0}

        # 01-2: prepare: image_gt, remove_mask, fg_mask_list, refine_mask
        op_list = [[w_direction, w_scale], [h_direction, h_scale]]
        img_new, mask = utils.panning(original_image, op_list=op_list)
        img_new_copy = img_new.copy()
        mask_copy = mask.copy()
        
        image_gt = Image.fromarray(img_new).resize((1024, 1024))
        image_gt = np.stack([np.array(image_gt)])
        remove_mask = utils.attend_mask(utils.convert_and_resize_mask(mask), attend_scale=attend_scale) # numpy to tensor

        fg_mask_list = None
        refine_mask = None

        # 01-3: prepare: prompts, blend_time, refine_time
        prompts = len(sample_ref_match)*[prompt] # 2
        blend_time = [0, 41]
        refine_time = [0, 25]

        # 02: invert
        _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=1)
        # 03: init layer_fusion and controller
        lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time,
                    mode=mode, op_list=op_list)
        controller = Control(layer_fusion=lb)
        register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time)
        
        # 04: generate images

        images = self.ldm_model(controller=controller, prompt=prompts,
                        latents=x_t, x_stars=x_stars,  
                        negative_prompt_embeds=prompt_embeds, 
                        negative_pooled_prompt_embeds=pooled_prompt_embeds,
                        sample_ref_match=sample_ref_match)
        folder = None
        utils.view_images(images, folder=folder)
        resized_img = cv2.resize(images[1], (ori_shape[1], ori_shape[0]))
        return [resized_img], [img_new_copy], [mask_copy]

    # layer-wise multi-object editing
    def process_layer_states(self, layer_states):
        self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps)
        image_paths = []
        mask_paths = []
        op_list = []
        
        for state in layer_states:
            img, mask, dx, dy, resize, w_flip, h_flip = state
            if img is not None:  
                img = cv2.resize(img, (1024, 1024))
                mask = utils.convert_and_resize_mask(mask)
                dx_command = ['right', dx] if dx > 0 else ['left', -dx]
                dy_command = ['up', dy] if dy > 0 else ['down', -dy]
                flip_code = None
                if w_flip == "left/right" and h_flip == "down/up":
                    flip_code = -1
                elif w_flip == "left/right":
                    flip_code = 1  # 或者其他默认值,根据您的需要设置
                elif h_flip == "down/up":
                    flip_code = 0
                op_list.append([dx_command, dy_command])
                img, mask, _ = utils.resize_image_with_mask(img, mask, resize)
                img, mask, _ = utils.flip_image_with_mask(img, mask, flip_code=flip_code)
                image_paths.append(img)
                mask_paths.append(utils.attend_mask(mask))
        sample_ref_match = {0: 0, 1: 0, 2: 0, 3: 1, 4: 2, 5: 3}
        required_length = len(image_paths) + 3
        truncated_sample_ref_match = {k: sample_ref_match[k] for k in sorted(sample_ref_match.keys())[:required_length]}
        return image_paths, mask_paths, op_list, truncated_sample_ref_match

    @spaces.GPU(duration=200)
    def run_layer(self, bg_img, l1_img, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip, 
        l2_img, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip,
        l3_img, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip,
        bg_mask, l1_mask, l2_mask, l3_mask,
        bg_ori=None, l1_ori=None, l2_ori=None, l3_ori=None,
        prompt="", save_dir="./tmp", mode='layerwise'):
        self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps)
        # 00: prepare: layer-wise states
        bg_img = bg_ori if bg_ori is not None else bg_img
        l1_img = l1_ori if l1_ori is not None else l1_img
        l2_img = l2_ori if l2_ori is not None else l2_img
        l3_img = l3_ori if l3_ori is not None else l3_img
        for mask in [bg_mask, l1_mask, l2_mask, l3_mask]:
            if mask is None:
                mask = np.zeros((1024, 1024), dtype=np.uint8)
            else:
                mask = utils.convert_and_resize_mask(mask)
        l1_state = [l1_img, l1_mask, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip]
        l2_state = [l2_img, l2_mask, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip]
        l3_state = [l3_img, l3_mask, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip]
        ori_shape = bg_img.shape

        image_paths, fg_mask_list, op_list, sample_ref_match = self.process_layer_states([l1_state, l2_state, l3_state])
        if image_paths == []:
            mode = "removal"
        # 01-1: prepare: image_gt, remove_mask, fg_mask_list, refine_mask
        attend_scale = 20
        image_gt = [bg_img] + image_paths
        image_gt = [Image.fromarray(img).resize((1024, 1024)) for img in image_gt]
        image_gt = np.stack(image_gt)      
        remove_mask = utils.attend_mask(bg_mask, attend_scale=attend_scale)
        refine_mask = None

        # 01-2: prepare: promptrun_masks, blend_time, refine_time
        prompts = len(sample_ref_match)*[prompt] # 2
        blend_time = [0, 41]
        refine_time = [0, 25]
        attend_scale = []

        # 02: invert
        _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=len(image_gt))
        # 03: init layer_fusion and controller
        lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time, refine_mask=refine_mask,
                    mode=mode, op_list=op_list)
        controller = Control(layer_fusion=lb)
        register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time)
        # 04: generate images
        images = self.ldm_model(controller=controller, prompt=prompts,
                        latents=x_t, x_stars=x_stars,  
                        negative_prompt_embeds=prompt_embeds, 
                        negative_pooled_prompt_embeds=pooled_prompt_embeds,
                        sample_ref_match=sample_ref_match)
        folder = None
        utils.view_images(images, folder=folder) 
        if mode == 'removal':
            resized_img = cv2.resize(images[1], (ori_shape[1], ori_shape[0]))       
        else:
            resized_img = cv2.resize(images[2], (ori_shape[1], ori_shape[0]))       
        return [resized_img]

    @spaces.GPU(duration=120, enable_queue=True)
    def run_moving(self, bg_img, bg_ori, bg_mask, l1_dx, l1_dy, l1_resize, 
        l1_w_flip=None, l1_h_flip=None, selected_points=None,
        prompt="", save_dir="./tmp", mode='layerwise'):
        self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps)
        # 00: prepare: layer-wise states
        bg_img = bg_ori if bg_ori is not None else bg_img
        l1_img = bg_img
        if bg_mask is None:
            bg_mask = np.zeros((1024, 1024), dtype=np.uint8)
        else:
            bg_mask = utils.convert_and_resize_mask(bg_mask)
        l1_mask = bg_mask
        l1_state = [l1_img, l1_mask, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip]
        ori_shape = bg_img.shape

        image_paths, fg_mask_list, op_list, sample_ref_match = self.process_layer_states([l1_state])

        # 01-1: prepare: image_gt, remove_mask, fg_mask_list, refine_mask
        attend_scale = 20
        image_gt = [bg_img] + image_paths
        image_gt = [Image.fromarray(img).resize((1024, 1024)) for img in image_gt]
        image_gt = np.stack(image_gt)      
        remove_mask = utils.attend_mask(bg_mask, attend_scale=attend_scale)
        refine_mask = None

        # 01-2: prepare: promptrun_masks, blend_time, refine_time
        prompts = len(sample_ref_match)*[prompt] # 2
        blend_time = [0, 41]
        refine_time = [0, 25]
        attend_scale = []

        # 02: invert
        _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=len(image_gt))
        # 03: init layer_fusion and controller
        lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time, refine_mask=refine_mask,
                    mode=mode, op_list=op_list)
        controller = Control(layer_fusion=lb)
        register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time)
        # 04: generate images
        images = self.ldm_model(controller=controller, prompt=prompts,
                        latents=x_t, x_stars=x_stars,  
                        negative_prompt_embeds=prompt_embeds, 
                        negative_pooled_prompt_embeds=pooled_prompt_embeds,
                        sample_ref_match=sample_ref_match)
        folder = None
        utils.view_images(images, folder=folder) 
        resized_img = cv2.resize(images[2], (ori_shape[1], ori_shape[0]))       
        return [resized_img]

    # turn mask to 1024x1024 unit-8
    def run_mask(self, mask_1, mask_2, mask_3, mask_4):
        mask_list = [mask_1, mask_2, mask_3, mask_4]
        final_mask = utils.add_masks_resized(mask_list)
        return final_mask