import os
import argparse

import torch
from torchvision import utils

from model.sg2_model import Generator
from tqdm import tqdm
from pathlib import Path

import numpy as np

import subprocess
import shutil
import copy

from styleclip.styleclip_global import style_tensor_to_style_dict, style_dict_to_style_tensor

VALID_EDITS = ["pose", "age", "smile", "gender", "hair_length", "beard"]

SUGGESTED_DISTANCES = {
                       "pose": 3.0,
                       "smile": 2.0,
                       "age": 4.0,
                       "gender": 3.0,
                       "hair_length": -4.0,
                       "beard": 2.0
                      }
                    
def project_code(latent_code, boundary, distance=3.0):

    if len(boundary) == 2:
        boundary = boundary.reshape(1, 1, -1)

    return latent_code + distance * boundary

def project_code_by_edit_name(latent_code, name, strength):
    boundary_dir = Path(os.path.abspath(__file__)).parents[0].joinpath("editing", "interfacegan_boundaries")

    distance = SUGGESTED_DISTANCES[name] * strength
    boundary = torch.load(os.path.join(boundary_dir, f'{name}.pt'), map_location="cpu").numpy()

    return project_code(latent_code, boundary, distance)

def generate_frames(source_latent, target_latents, g_ema_list, output_dir):

    device = "cuda" if torch.cuda.is_available() else "cpu"

    code_is_s = target_latents[0].size()[1] == 9088

    if code_is_s:
        source_s_dict = g_ema_list[0].get_s_code(source_latent, input_is_latent=True)[0]
        np_latent = style_dict_to_style_tensor(source_s_dict, g_ema_list[0]).cpu().detach().numpy()
    else:
        np_latent = source_latent.squeeze(0).cpu().detach().numpy()

    np_target_latents = [target_latent.cpu().detach().numpy() for target_latent in target_latents]

    num_alphas = 20 if code_is_s else min(10, 30 // len(target_latents))

    alphas = np.linspace(0, 1, num=num_alphas)
    
    latents = interpolate_with_target_latents(np_latent, np_target_latents, alphas)

    segments = len(g_ema_list) - 1

    if segments:
        segment_length = len(latents) / segments

        g_ema = copy.deepcopy(g_ema_list[0])

        src_pars = dict(g_ema.named_parameters())
        mix_pars = [dict(model.named_parameters()) for model in g_ema_list]
    else:
        g_ema = g_ema_list[0]

    print("Generating frames for video...")
    for idx, latent in tqdm(enumerate(latents), total=len(latents)):

        if segments:
            mix_alpha = (idx % segment_length) * 1.0 / segment_length
            segment_id = int(idx // segment_length)

            for k in src_pars.keys():
                src_pars[k].data.copy_(mix_pars[segment_id][k] * (1 - mix_alpha) + mix_pars[segment_id + 1][k] * mix_alpha)

        if idx == 0 or segments or latent is not latents[idx - 1]:
            latent_tensor = torch.from_numpy(latent).float().to(device)

            with torch.no_grad():
                if code_is_s:
                    latent_for_gen = style_tensor_to_style_dict(latent_tensor, g_ema)
                    img, _ = g_ema(latent_for_gen, input_is_s_code=True, input_is_latent=True, truncation=1, randomize_noise=False)
                else:
                    img, _ = g_ema([latent_tensor], input_is_latent=True, truncation=1, randomize_noise=False)

        utils.save_image(img, f"{output_dir}/{str(idx).zfill(3)}.jpg", nrow=1, normalize=True, scale_each=True, range=(-1, 1))

def interpolate_forward_backward(source_latent, target_latent, alphas):
    latents_forward  = [a * target_latent + (1-a) * source_latent for a in alphas] # interpolate from source to target
    latents_backward = latents_forward[::-1]                                       # interpolate from target to source
    return latents_forward + [target_latent] * len(alphas) + latents_backward      # forward + short delay at target + return

def interpolate_with_target_latents(source_latent, target_latents, alphas):    
    # interpolate latent codes with all targets

    print("Interpolating latent codes...")
    
    latents = []
    for target_latent in target_latents:
        latents.extend(interpolate_forward_backward(source_latent, target_latent, alphas))

    return latents

def video_from_interpolations(fps, output_dir):

    # combine frames to a video
    command = ["ffmpeg", 
               "-r", f"{fps}", 
               "-i", f"{output_dir}/%03d.jpg", 
               "-c:v", "libx264", 
               "-vf", f"fps={fps}", 
               "-pix_fmt", "yuv420p", 
               f"{output_dir}/out.mp4"]
    
    subprocess.call(command)