# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os, sys
os.system('pip install -r requirements.txt')

import gradio as gr
import numpy as np
import dnnlib
import time
import legacy
import torch
import glob
import cv2

from torch_utils import misc
from renderer import Renderer
from training.networks import Generator
from huggingface_hub import hf_hub_download


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
port   = int(sys.argv[1]) if len(sys.argv) > 1 else 21111

model_lists = {
    'ffhq-512x512-basic':   dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_512.pkl'),
    'ffhq-512x512-cc':      dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_512_cc.pkl'),
    'ffhq-256x256-basic':   dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_256.pkl'), 
    'ffhq-1024x1024-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_1024.pkl'), 
}
model_names = [name for name in model_lists]


def set_random_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)


def get_camera_traj(model, pitch, yaw, fov=12, batch_size=1, model_name=None):
    gen = model.synthesis
    range_u, range_v = gen.C.range_u, gen.C.range_v
    if not (('car' in model_name) or ('Car' in model_name)):  # TODO: hack, better option?
        yaw, pitch = 0.5 * yaw, 0.3  * pitch
        pitch = pitch + np.pi/2
        u = (yaw - range_u[0]) / (range_u[1] - range_u[0])
        v = (pitch - range_v[0]) / (range_v[1] - range_v[0])
    else:
        u = (yaw + 1) / 2
        v = (pitch + 1) / 2
    cam = gen.get_camera(batch_size=batch_size, mode=[u, v, 0.5], device=device, fov=fov)
    return cam


def check_name(model_name):
    """Gets model by name."""
    if model_name in model_lists:
        network_pkl = hf_hub_download(**model_lists[model_name])
    else:
        if os.path.isdir(model_name):
            network_pkl = sorted(glob.glob(model_name + '/*.pkl'))[-1]
        else:
            network_pkl = model_name
    return network_pkl


def get_model(network_pkl, render_option=None):
    print('Loading networks from "%s"...' % network_pkl)
    with dnnlib.util.open_url(network_pkl) as f:
        network = legacy.load_network_pkl(f)
        G = network['G_ema'].to(device)  # type: ignore

    with torch.no_grad():
        G2 = Generator(*G.init_args, **G.init_kwargs).to(device)
        misc.copy_params_and_buffers(G, G2, require_all=False)

    print('compile and go through the initial image')
    G2 = G2.eval()
    init_z = torch.from_numpy(np.random.RandomState(0).rand(1, G2.z_dim)).to(device)
    init_cam = get_camera_traj(G2, 0, 0, model_name=network_pkl)
    dummy = G2(z=init_z, c=None, camera_matrices=init_cam, render_option=render_option, theta=0)
    res = dummy['img'].shape[-1]
    imgs = np.zeros((res, res//2, 3))
    return G2, res, imgs


global_states = list(get_model(check_name(model_names[0])))
wss  = [None, None]

def proc_seed(history, seed):
    if isinstance(seed, str):
        seed = 0
    else:
        seed = int(seed)


def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, history):
    history = history or {}
    seeds = []
    trunc = trunc / 100
    
    if model_find != "":
        model_name = model_find

    model_name = check_name(model_name)
    if model_name != history.get("model_name", None):
        model, res, imgs = get_model(model_name, render_option)
        global_states[0] = model
        global_states[1] = res
        global_states[2] = imgs

    model, res, imgs = global_states
    for idx, seed in enumerate([seed1, seed2]):
        if isinstance(seed, str):
            seed = 0
        else:
            seed = int(seed)

        if (seed != history.get(f'seed{idx}', -1)) or \
            (model_name != history.get("model_name", None)) or \
            (trunc != history.get("trunc", 0.7)) or \
            (wss[idx] is None):
            print(f'use seed {seed}')
            set_random_seed(seed)
            z   = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, model.z_dim).astype('float32')).to(device)
            ws  = model.mapping(z=z, c=None, truncation_psi=trunc)
            img = model.get_final_output(styles=ws, camera_matrices=get_camera_traj(model, 0, 0, model_name=model_name), render_option=render_option)
            ws  = ws.detach().cpu().numpy()
            img = img[0].permute(1,2,0).detach().cpu().numpy()


            imgs[idx * res // 2: (1 + idx) * res // 2] = cv2.resize(
                np.asarray(img).clip(-1, 1) * 0.5 + 0.5,
                (res//2, res//2), cv2.INTER_AREA)
            wss[idx] = ws
        else:
            seed = history[f'seed{idx}']
        seeds += [seed]

        history[f'seed{idx}'] = seed
    history['trunc'] = trunc
    history['model_name'] = model_name

    set_random_seed(sum(seeds))

    # style mixing (?)
    ws1, ws2 = [torch.from_numpy(ws).to(device) for ws in wss]
    ws = ws1.clone()
    ws[:, :8] = ws1[:, :8] * mix1 + ws2[:, :8] * (1 - mix1)
    ws[:, 8:] = ws1[:, 8:] * mix2 + ws2[:, 8:] * (1 - mix2)

    # set visualization for other types of inputs.
    if early == 'Normal Map':
        render_option += ',normal,early'
    elif early == 'Gradient Map':
        render_option += ',gradient,early'

    start_t = time.time()
    with torch.no_grad():
        cam = get_camera_traj(model, pitch, yaw, fov, model_name=model_name)
        image = model.get_final_output(
            styles=ws, camera_matrices=cam,
            theta=roll * np.pi,
            render_option=render_option)
    end_t = time.time()

    image = image[0].permute(1,2,0).detach().cpu().numpy().clip(-1, 1) * 0.5 + 0.5

    if imgs.shape[0] == image.shape[0]:
        image = np.concatenate([imgs, image], 1)
    else:
        a = image.shape[0]
        b = int(imgs.shape[1] / imgs.shape[0] * a)
        print(f'resize {a} {b} {image.shape} {imgs.shape}')
        image = np.concatenate([cv2.resize(imgs, (b, a), cv2.INTER_AREA), image], 1)

    print(f'rendering time = {end_t-start_t:.4f}s')
    image = (image * 255).astype('uint8')
    return image, history

model_name = gr.inputs.Dropdown(model_names)
model_find = gr.inputs.Textbox(label="Checkpoint path (folder or .pkl file)", default="")
render_option = gr.inputs.Textbox(label="Additional rendering options", default='freeze_bg,steps:50')
trunc  = gr.inputs.Slider(default=70, maximum=100, minimum=0, label='Truncation trick (%)')
seed1  = gr.inputs.Number(default=1, label="Random seed1")
seed2  = gr.inputs.Number(default=9, label="Random seed2")
mix1   = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="Linear mixing ratio (geometry)")
mix2   = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="Linear mixing ratio (apparence)")
early  = gr.inputs.Radio(['None', 'Normal Map', 'Gradient Map'], default='None', label='Intermedia output')
yaw    = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Yaw")
pitch  = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Pitch")
roll   = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Roll (optional, not suggested for basic config)")
fov    = gr.inputs.Slider(minimum=10, maximum=14, default=12, label="Fov")
css = ".output-image, .input-image, .image-preview {height: 600px !important} "

gr.Interface(fn=f_synthesis,
             inputs=[model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, "state"],
             title="Interactive Web Demo for StyleNeRF (ICLR 2022)",
             description="StyleNeRF: A Style-based 3D-Aware Generator for High-resolution Image Synthesis. Currently the demo runs on CPU only.",
             outputs=["image", "state"],
             layout='unaligned',
             css=css, theme='dark-seafoam',
             live=True).launch(enable_queue=True)