import os

os.system("git clone --recursive https://github.com/JD-P/cloob-latent-diffusion")
os.system("cd cloob-latent-diffusion;pip install omegaconf pillow pytorch-lightning einops wandb ftfy regex ./CLIP")

import argparse
from functools import partial
from pathlib import Path
import sys
sys.path.append('./cloob-latent-diffusion')
sys.path.append('./cloob-latent-diffusion/cloob-training')
sys.path.append('./cloob-latent-diffusion/latent-diffusion')
sys.path.append('./cloob-latent-diffusion/taming-transformers')
sys.path.append('./cloob-latent-diffusion/v-diffusion-pytorch')
from omegaconf import OmegaConf
from PIL import Image
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm import trange
from CLIP import clip
from cloob_training import model_pt, pretrained
import ldm.models.autoencoder
from diffusion import sampling, utils
import train_latent_diffusion as train
from huggingface_hub import hf_hub_url, cached_download
import random

# Download the model files
checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))

# Define a few utility functions


def parse_prompt(prompt, default_weight=3.):
    if prompt.startswith('http://') or prompt.startswith('https://'):
        vals = prompt.rsplit(':', 2)
        vals = [vals[0] + ':' + vals[1], *vals[2:]]
    else:
        vals = prompt.rsplit(':', 1)
    vals = vals + ['', default_weight][len(vals):]
    return vals[0], float(vals[1])


def resize_and_center_crop(image, size):
    fac = max(size[0] / image.size[0], size[1] / image.size[1])
    image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
    return TF.center_crop(image, size[::-1])


# Load the models
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('loading models')

# autoencoder
ae_config = OmegaConf.load(ae_config_path)
ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
ae_model.eval().requires_grad_(False).to(device)
ae_model.load_state_dict(torch.load(ae_model_path))
n_ch, side_y, side_x = 4, 32, 32

# diffusion model
model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
model = model.to(device).eval().requires_grad_(False)

# CLOOB
cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
cloob = model_pt.get_pt_model(cloob_config)
checkpoint = pretrained.download_checkpoint(cloob_config)
cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
cloob.eval().requires_grad_(False).to(device)


# The key function: returns a list of n PIL images
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15,
             method='plms', eta=None):
  zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device)
  target_embeds, weights = [zero_embed], []

  for prompt in prompts:
      txt, weight = parse_prompt(prompt)
      target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float())
      weights.append(weight)

  for prompt in images:
      path, weight = parse_prompt(prompt)
      img = Image.open(utils.fetch(path)).convert('RGB')
      clip_size = cloob.config['image_encoder']['image_size']
      img = resize_and_center_crop(img, (clip_size, clip_size))
      batch = TF.to_tensor(img)[None].to(device)
      embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1)
      target_embeds.append(embed)
      weights.append(weight)

  weights = torch.tensor([1 - sum(weights), *weights], device=device)

  torch.manual_seed(seed)

  def cfg_model_fn(x, t):
      n = x.shape[0]
      n_conds = len(target_embeds)
      x_in = x.repeat([n_conds, 1, 1, 1])
      t_in = t.repeat([n_conds])
      clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
      vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
      v = vs.mul(weights[:, None, None, None, None]).sum(0)
      return v

  def run(x, steps):
      if method == 'ddpm':
          return sampling.sample(cfg_model_fn, x, steps, 1., {})
      if method == 'ddim':
          return sampling.sample(cfg_model_fn, x, steps, eta, {})
      if method == 'prk':
          return sampling.prk_sample(cfg_model_fn, x, steps, {})
      if method == 'plms':
          return sampling.plms_sample(cfg_model_fn, x, steps, {})
      if method == 'pie':
          return sampling.pie_sample(cfg_model_fn, x, steps, {})
      if method == 'plms2':
          return sampling.plms2_sample(cfg_model_fn, x, steps, {})
      assert False

  batch_size = n
  x = torch.randn([n, n_ch, side_y, side_x], device=device)
  t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
  steps = utils.get_spliced_ddpm_cosine_schedule(t)
  pil_ims = []
  for i in trange(0, n, batch_size):
      cur_batch_size = min(n - i, batch_size)
      out_latents = run(x[i:i+cur_batch_size], steps)
      outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device))
      for j, out in enumerate(outs):
          pil_ims.append(utils.to_pil_image(out))

  return pil_ims
  
  
import gradio as gr

def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
  if seed == None :
    seed = random.randint(0, 10000)
  print( prompt, im_prompt, seed, n_steps)
  prompts = [prompt]
  im_prompts = []
  if im_prompt != None:
    im_prompts = [im_prompt]
  pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
  return pil_ims[0]

iface = gr.Interface(fn=gen_ims, 
  inputs=[#gr.inputs.Slider(minimum=1, maximum=1, step=1, default=1,label="Number of images"),
          #gr.inputs.Slider(minimum=0, maximum=200, step=1, label='Random seed', default=0),
          gr.inputs.Textbox(label="Text prompt"),
          gr.inputs.Image(optional=True, label="Image prompt", type='filepath'),
          #gr.inputs.Slider(minimum=10, maximum=35, step=1, default=15,label="Number of steps")
          ], 
  outputs=[gr.outputs.Image(type="pil", label="Generated Image")],
  examples=[
  ["Virgin and Child, in the style of Jacopo Bellini"],
  ["Katsushika Hokusai, The Dragon of Smoke Escaping from Mount Fuji"],
  ["Moon Light Sonata by Basuki Abdullah"],
  ["Twon Tree by M.C. Escher"],
  ["Futurism, in the style of Wassily Kandinsky"], 
  ["Art Nouveau, in the style of John Singer Sargent"], 
  ["Surrealism, in the style of Edgar Degas"], 
  ["Expressionism, in the style of Wassily Kandinsky"], 
  ["Futurism, in the style of Egon Schiele"], 
  ["Neoclassicism, in the style of Gustav Klimt"], 
  ["Cubism, in the style of Gustav Klimt"], 
  ["Op Art, in the style of Marc Chagall"], 
  ["Romanticism, in the style of M.C. Escher"], 
  ["Futurism, in the style of M.C. Escher"], 
  ["Abstract Art, in the style of M.C. Escher"], 
  ["Mannerism, in the style of Paul Klee"], 
  ["Romanesque Art, in the style of Leonardo da Vinci"], 
  ["High Renaissance, in the style of Rembrandt"], 
  ["Magic Realism, in the style of Gustave Dore"], 
  ["Realism, in the style of Jean-Michel Basquiat"], 
  ["Art Nouveau, in the style of Paul Gauguin"], 
  ["Avant-garde, in the style of Pierre-Auguste Renoir"], 
  ["Baroque, in the style of Edward Hopper"], 
  ["Post-Impressionism, in the style of Wassily Kandinsky"], 
  ["Naturalism, in the style of Rene Magritte"], 
  ["Constructivism, in the style of Paul Cezanne"], 
  ["Abstract Expressionism, in the style of Henri Matisse"], 
  ["Pop Art, in the style of Vincent van Gogh"], 
  ["Futurism, in the style of Wassily Kandinsky"], 
  ["Futurism, in the style of Zdzislaw Beksinski"], 
  ['Surrealism, in the style of Salvador Dali'], 
  ["Aaron Wacker, oil on canvas"],
      ["abstract"],
    ["landscape"],
    ["portrait"],
    ["sculpture"],
    ["genre painting"],
    ["installation"],
    ["photo"],
    ["figurative"],
    ["illustration"],
    ["still life"],
    ["history painting"],
    ["cityscape"],
    ["marina"],
    ["animal painting"],
    ["design"],
    ["calligraphy"],
    ["symbolic painting"],
    ["graffiti"],
    ["performance"],
    ["mythological painting"],
    ["battle painting"],
    ["self-portrait"],
    ["Impressionism, oil on canvas"]
  ],
  title='Art Generator and Style Mixer from 🧠 Cloob and 🎨 WikiArt - Visual Art Encyclopedia:',
  description="Trained on images from the [WikiArt](https://www.wikiart.org/) dataset, comprised of visual arts",
  article = 'Model used is: [model card](https://huggingface.co/huggan/distill-ccld-wa)..'

)
iface.launch(enable_queue=True) # , debug=True for colab debugging