import spaces

import os
import torch
import numpy as np
import torch.nn.functional as F
import cv2
import torchvision
from PIL import Image
from einops import rearrange
import tempfile

from objctrl_2_5d.utils.objmask_util import RT2Plucker, Unprojected, roll_with_ignore_multidim, dilate_mask_pytorch
from objctrl_2_5d.utils.filter_utils import get_freq_filter, freq_mix_3d

DEBUG = False

if DEBUG:
    cur_OUTPUT_PATH = 'outputs/tmp'
    os.makedirs(cur_OUTPUT_PATH, exist_ok=True)

# num_inference_steps=25
min_guidance_scale = 1.0
max_guidance_scale = 3.0

area_ratio = 0.3
depth_scale_ = 5.2
center_margin = 10

height, width = 320, 576
num_frames = 14

intrinsics = np.array([[float(width), float(width), float(width) / 2, float(height) / 2]])
intrinsics = np.repeat(intrinsics, num_frames, axis=0) # [n_frame, 4]
fx = intrinsics[0, 0] / width
fy = intrinsics[0, 1] / height
cx = intrinsics[0, 2] / width
cy = intrinsics[0, 3] / height

down_scale = 8
H, W = height // down_scale, width // down_scale
K = np.array([[width / down_scale, 0, W / 2], [0, width / down_scale, H / 2], [0, 0, 1]])

@spaces.GPU(duration=50)
def run(pipeline, device):
    def run_objctrl_2_5d(condition_image, 
                         mask, 
                         depth, 
                         RTs, 
                         bg_mode, 
                         shared_wapring_latents, 
                         scale_wise_masks, 
                         rescale, 
                         seed, 
                         ds, dt, 
                         num_inference_steps=25):
        
        seed = int(seed)
                
        center_h_margin, center_w_margin = center_margin, center_margin
        depth_center = np.mean(depth[height//2-center_h_margin:height//2+center_h_margin, width//2-center_w_margin:width//2+center_w_margin])
        
        if rescale > 0:
            depth_rescale = round(depth_scale_ * rescale / depth_center, 2)
        else:
            depth_rescale = 1.0
            
        depth = depth * depth_rescale
        
        depth_down = F.interpolate(torch.tensor(depth).unsqueeze(0).unsqueeze(0), 
                                    (H, W), mode='bilinear', align_corners=False).squeeze().numpy() # [H, W]
        
        ## latent
        generator = torch.Generator()
        generator.manual_seed(seed)
        
        latents_org = pipeline.prepare_latents(
                1,
                14,
                8,
                height,
                width,
                pipeline.dtype,
                device,
                generator,
                None,
            )
        latents_org = latents_org / pipeline.scheduler.init_noise_sigma
        
        cur_plucker_embedding, _, _ = RT2Plucker(RTs, RTs.shape[0], (height, width), fx, fy, cx, cy) # 6, V, H, W
        cur_plucker_embedding = cur_plucker_embedding.to(device)
        cur_plucker_embedding = cur_plucker_embedding[None, ...] # b 6 f h w
        cur_plucker_embedding = cur_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w
        cur_plucker_embedding = cur_plucker_embedding[:, :num_frames, ...]
        cur_pose_features = pipeline.pose_encoder(cur_plucker_embedding)
        
        # bg_mode = ["Fixed", "Reverse", "Free"]
        if bg_mode == "Fixed":
            fix_RTs = np.repeat(RTs[0][None, ...], num_frames, axis=0) # [n_frame, 4, 3]
            fix_plucker_embedding, _, _ = RT2Plucker(fix_RTs, num_frames, (height, width), fx, fy, cx, cy) # 6, V, H, W
            fix_plucker_embedding = fix_plucker_embedding.to(device)
            fix_plucker_embedding = fix_plucker_embedding[None, ...] # b 6 f h w
            fix_plucker_embedding = fix_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w
            fix_plucker_embedding = fix_plucker_embedding[:, :num_frames, ...]
            fix_pose_features = pipeline.pose_encoder(fix_plucker_embedding)
            
        elif bg_mode == "Reverse":
            bg_plucker_embedding, _, _ = RT2Plucker(RTs[::-1], RTs.shape[0], (height, width), fx, fy, cx, cy) # 6, V, H, W
            bg_plucker_embedding = bg_plucker_embedding.to(device)
            bg_plucker_embedding = bg_plucker_embedding[None, ...] # b 6 f h w
            bg_plucker_embedding = bg_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w
            bg_plucker_embedding = bg_plucker_embedding[:, :num_frames, ...]
            fix_pose_features = pipeline.pose_encoder(bg_plucker_embedding)
            
        else:
            fix_pose_features = None
            
        #### preparing mask
        
        mask = Image.fromarray(mask)
        mask = mask.resize((W, H))
        mask = np.array(mask).astype(np.float32)
        mask = np.expand_dims(mask, axis=-1)
        
        # visulize mask
        if DEBUG:
            mask_sum_vis = mask[..., 0]
            mask_sum_vis = (mask_sum_vis * 255.0).astype(np.uint8)
            mask_sum_vis = Image.fromarray(mask_sum_vis)
            
            mask_sum_vis.save(f'{cur_OUTPUT_PATH}/org_mask.png')
        
        try:
            warped_masks = Unprojected(mask, depth_down, RTs, H=H, W=W, K=K)
        
            warped_masks.insert(0, mask)
                
        except:
            # mask to bbox
            print(f'!!! Mask is too small to warp; mask to bbox') 
            mask = mask[:, :, 0]
            coords = cv2.findNonZero(mask)
            x, y, w, h = cv2.boundingRect(coords)
            # mask[y:y+h, x:x+w] = 1.0
            
            center_x, center_y = x + w // 2, y + h // 2
            center_z = depth_down[center_y, center_x]
            
            # RTs [n_frame, 3, 4] to [n_frame, 4, 4] , add [0, 0, 0, 1]
            RTs = np.concatenate([RTs, np.array([[[0, 0, 0, 1]]] * num_frames)], axis=1)
            
            # RTs: world to camera
            P0 = np.array([center_x, center_y, 1])
            Pc0 = np.linalg.inv(K) @ P0 * center_z
            pw = np.linalg.inv(RTs[0]) @ np.array([Pc0[0], Pc0[1], center_z, 1]) # [4]
            
            P = [np.array([center_x, center_y])]
            for i in range(1, num_frames):
                Pci = RTs[i] @ pw
                Pi = K @ Pci[:3] / Pci[2]
                P.append(Pi[:2])
            
            warped_masks = [mask]
            for i in range(1, num_frames):
                shift_x = int(round(P[i][0] - P[0][0]))
                shift_y = int(round(P[i][1] - P[0][1]))

                cur_mask = roll_with_ignore_multidim(mask, [shift_y, shift_x])
                warped_masks.append(cur_mask)
                
                
            warped_masks = [v[..., None] for v in warped_masks]
                
        warped_masks = np.stack(warped_masks, axis=0) # [f, h, w]
        warped_masks = np.repeat(warped_masks, 3, axis=-1) # [f, h, w, 3]
        
        mask_sum = np.sum(warped_masks, axis=0, keepdims=True)  # [1, H, W, 3]
        mask_sum[mask_sum > 1.0] = 1.0
        mask_sum = mask_sum[0,:,:, 0]
        
        if DEBUG:
            ## visulize warp mask    
            warp_masks_vis = torch.tensor(warped_masks)
            warp_masks_vis = (warp_masks_vis * 255.0).to(torch.uint8)
            torchvision.io.write_video(f'{cur_OUTPUT_PATH}/warped_masks.mp4', warp_masks_vis, fps=10, video_codec='h264', options={'crf': '10'})
            
            # visulize mask
            mask_sum_vis = mask_sum
            mask_sum_vis = (mask_sum_vis * 255.0).astype(np.uint8)
            mask_sum_vis = Image.fromarray(mask_sum_vis)
            
            mask_sum_vis.save(f'{cur_OUTPUT_PATH}/merged_mask.png')
            
        if scale_wise_masks:
            min_area = H * W * area_ratio # cal in downscale
            non_zero_len = mask_sum.sum() 
            
            print(f'non_zero_len: {non_zero_len}, min_area: {min_area}')
            
            if non_zero_len > min_area:
                kernel_sizes = [1, 1, 1, 3]
            elif non_zero_len > min_area * 0.5:
                kernel_sizes = [3, 1, 1, 5]
            else:
                kernel_sizes = [5, 3, 3, 7]
        else:
            kernel_sizes = [1, 1, 1, 1]
            
        mask = torch.from_numpy(mask_sum) # [h, w]
        mask = mask[None, None, ...] # [1, 1, h, w]
        mask = F.interpolate(mask, (height, width), mode='bilinear', align_corners=False) # [1, 1, H, W]
        # mask = mask.repeat(1, num_frames, 1, 1) # [1, f, H, W]
        mask = mask.to(pipeline.dtype).to(device)
        
        ##### Mask End ######
        
        ### Got blending pose features Start ###
    
        pose_features = []
        for i in range(0, len(cur_pose_features)):
            kernel_size = kernel_sizes[i]
            h, w = cur_pose_features[i].shape[-2:]
            
            if fix_pose_features is None:
                pose_features.append(torch.zeros_like(cur_pose_features[i]))
            else:
                pose_features.append(fix_pose_features[i])
                
            cur_mask = F.interpolate(mask, (h, w), mode='bilinear', align_corners=False)
            cur_mask = dilate_mask_pytorch(cur_mask, kernel_size=kernel_size) # [1, 1, H, W]
            cur_mask = cur_mask.repeat(1, num_frames, 1, 1) # [1, f, H, W]
            
            if DEBUG:
                # visulize mask
                mask_vis = cur_mask[0, 0].cpu().numpy() * 255.0
                mask_vis = Image.fromarray(mask_vis.astype(np.uint8))
                mask_vis.save(f'{cur_OUTPUT_PATH}/mask_k{kernel_size}_scale{i}.png')
                
            cur_mask = cur_mask[None, ...] # [1, 1, f, H, W]
            pose_features[-1] = cur_pose_features[i] * cur_mask + pose_features[-1] * (1 - cur_mask)

        ### Got blending pose features End ###
        
        ##### Warp Noise Start ######
        
        if shared_wapring_latents:
            noise = latents_org[0, 0].data.cpu().numpy().copy() #[14, 4, 40, 72]
            noise = np.transpose(noise, (1, 2, 0)) # [40, 72, 4]

            try:
                warp_noise = Unprojected(noise, depth_down, RTs, H=H, W=W, K=K)
                warp_noise.insert(0, noise)
            except:
                print(f'!!! Noise is too small to warp; mask to bbox')
                
                warp_noise = [noise]
                for i in range(1, num_frames):
                    shift_x = int(round(P[i][0] - P[0][0]))
                    shift_y = int(round(P[i][1] - P[0][1]))
                    
                    cur_noise= roll_with_ignore_multidim(noise, [shift_y, shift_x])
                    warp_noise.append(cur_noise)
                    
                warp_noise = np.stack(warp_noise, axis=0) # [f, h, w, 4]
        
            if DEBUG:
                ## visulize warp noise
                warp_noise_vis = torch.tensor(warp_noise)[..., :3] * torch.tensor(warped_masks)
                warp_noise_vis = (warp_noise_vis - warp_noise_vis.min()) / (warp_noise_vis.max() - warp_noise_vis.min())
                warp_noise_vis = (warp_noise_vis * 255.0).to(torch.uint8)
        
                torchvision.io.write_video(f'{cur_OUTPUT_PATH}/warp_noise.mp4', warp_noise_vis, fps=10, video_codec='h264', options={'crf': '10'})
        
        
            warp_latents = torch.tensor(warp_noise).permute(0, 3, 1, 2).to(latents_org.device).to(latents_org.dtype) # [frame, 4, H, W]
            warp_latents = warp_latents.unsqueeze(0) # [1, frame, 4, H, W]
            
            warped_masks = torch.tensor(warped_masks).permute(0, 3, 1, 2).unsqueeze(0) # [1, frame, 3, H, W]
            mask_extend = torch.concat([warped_masks, warped_masks[:,:,0:1]], dim=2) # [1, frame, 4, H, W]
            mask_extend = mask_extend.to(latents_org.device).to(latents_org.dtype)
            
            warp_latents = warp_latents * mask_extend + latents_org * (1 - mask_extend)
            warp_latents = warp_latents.permute(0, 2, 1, 3, 4)
            random_noise = latents_org.clone().permute(0, 2, 1, 3, 4)
                
            filter_shape = warp_latents.shape

            freq_filter = get_freq_filter(
                filter_shape, 
                device = device, 
                filter_type='butterworth',
                n=4,
                d_s=ds,
                d_t=dt
            )
            
            warp_latents = freq_mix_3d(warp_latents, random_noise, freq_filter)
            warp_latents = warp_latents.permute(0, 2, 1, 3, 4)
            
        else:
            warp_latents = latents_org.clone()
            
        generator.manual_seed(42)

        with torch.no_grad():
            result = pipeline(
                image=condition_image,
                pose_embedding=cur_plucker_embedding,
                height=height,
                width=width,
                num_frames=num_frames,
                num_inference_steps=num_inference_steps,
                min_guidance_scale=min_guidance_scale,
                max_guidance_scale=max_guidance_scale,
                do_image_process=True,
                generator=generator,
                output_type='pt',
                pose_features= pose_features,
                latents = warp_latents
            ).frames[0].cpu() #[f, c, h, w]
            
        
        result = rearrange(result, 'f c h w -> f h w c')
        result = (result * 255.0).to(torch.uint8)

        video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
        torchvision.io.write_video(video_path, result, fps=10, video_codec='h264', options={'crf': '8'})
        
        return video_path
    
    return run_objctrl_2_5d