'''
LinCIR
Copyright (c) 2023-present NAVER Corp.
CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/)
'''
import os
import time
from argparse import ArgumentParser
import json

import numpy as np
import torch
import gradio as gr
import faiss

from encode_with_pseudo_tokens import encode_with_pseudo_tokens_HF
from models import build_text_encoder, Phi, PIC2WORD

import transformers
from huggingface_hub import hf_hub_url, cached_download



def parse_args():
    parser = ArgumentParser()
    parser.add_argument("--lincir_ckpt_path", default=None, type=str,
                        help="The output directory where the model predictions and checkpoints will be written")
    parser.add_argument("--pic2word_ckpt_path", default=None, type=str)
    parser.add_argument("--cache_dir", default="./hf_models", type=str,
                        help="Path to model cache folder")
    parser.add_argument("--clip_model_name", default="large", type=str,
                        help="CLIP model to use, e.g 'large', 'huge', 'giga'")
    parser.add_argument("--mixed_precision", default="fp16", type=str)
    parser.add_argument("--test_fps", action="store_true")
    args = parser.parse_args()
    return args


def load_models(args):
    if torch.cuda.is_available():
        device = 'cuda:0'
        dtype = torch.float16
    else:
        device = 'cpu'
        dtype = torch.float32

    clip_vision_model, clip_preprocess, clip_text_model, tokenizer = build_text_encoder(args)

    tokenizer.add_special_tokens({'additional_special_tokens':["[$]"]}) # 49408

    # ours
    phi = Phi(input_dim=clip_text_model.config.projection_dim,
              hidden_dim=clip_text_model.config.projection_dim * 4,
              output_dim=clip_text_model.config.hidden_size, dropout=0.0)
    phi.eval()

    # searle
    phi_searle, _ = torch.hub.load(repo_or_dir='miccunifi/SEARLE', model='searle', source='github',
                                   backbone='ViT-L/14')
    phi_searle.eval()

    # pic2word
    phi_pic2word = PIC2WORD(embed_dim=clip_text_model.config.projection_dim,
                            output_dim=clip_text_model.config.hidden_size)
    phi_pic2word.eval()

    clip_vision_model.to(device, dtype=dtype)
    clip_text_model.to(device, dtype=dtype)

    if not args.test_fps:
        # download and load sd
        if not os.path.exists('./pretrained_models/lincir_large.pt'):
            model_file_url = hf_hub_url(repo_id='navervision/zeroshot-cir-models', filename='lincir_large.pt')
            cached_download(model_file_url, cache_dir='./pretrained_models', force_filename='lincir_large.pt')
        state_dict = torch.load('./pretrained_models/lincir_large.pt', map_location=device)
        phi.load_state_dict(state_dict['Phi'])

        if not os.path.exists('./pretrained_models/pic2word_large.pt'):
            model_file_url = hf_hub_url(repo_id='navervision/zeroshot-cir-models', filename='pic2word_large.pt')
            cached_download(model_file_url, cache_dir='./pretrained_models', force_filename='pic2word_large.pt')
        sd = torch.load('./pretrained_models/pic2word_large.pt', map_location=device)['state_dict_img2text']
        sd = {k[len('module.'):]: v for k, v in sd.items()}
        phi_pic2word.load_state_dict(sd)

    phi.to(device, dtype=dtype)
    phi_searle.to(device, dtype=dtype)
    phi_pic2word.to(device, dtype=dtype)

    decoder = None

    return {'clip_vision_model': clip_vision_model,
            'clip_preprocess': clip_preprocess,
            'clip_text_model': clip_text_model,
            'tokenizer': tokenizer,
            'phi': phi,
            'phi_searle': phi_searle,
            'phi_pic2word': phi_pic2word,
            'decoder': decoder,
            'device': device,
            'dtype': dtype,
            'clip_model_name': args.clip_model_name,
            }


@torch.no_grad()
def predict(images, input_text, model_name):
    start_time = time.time()
    input_images = model_dict['clip_preprocess'](images, return_tensors='pt')['pixel_values'].to(model_dict['device'])
    input_text = input_text.replace('$', '[$]')
    input_tokens = model_dict['tokenizer'](text=input_text, return_tensors='pt', padding='max_length', truncation=True)['input_ids'].to(model_dict['device'])
    input_tokens = torch.where(input_tokens == 49408,
                               torch.ones_like(input_tokens) * 259,
                               input_tokens)
    image_features = model_dict['clip_vision_model'](pixel_values=input_images.to(model_dict['dtype'])).image_embeds
    clip_image_time = time.time() - start_time

    start_time = time.time()
    if model_name == 'lincir':
        estimated_token_embeddings = model_dict['phi'](image_features)
    elif model_name == 'searle':
        estimated_token_embeddings = model_dict['phi_searle'](image_features)
    else: # model_name == 'pic2word'
        estimated_token_embeddings = model_dict['phi_pic2word'](image_features)
    phi_time = time.time() - start_time

    start_time = time.time()
    text_embeddings, text_last_hidden_states = encode_with_pseudo_tokens_HF(model_dict['clip_text_model'], input_tokens, estimated_token_embeddings, return_last_states=True)
    clip_text_time = time.time() - start_time

    start_time = time.time()

    _, results = faiss_index.search(text_embeddings.cpu().numpy(), k=10)

    retrieval_time = time.time() - start_time

    output = ''

    for idx, retrieved_idx in enumerate(results[0]):
        image_url = image_urls[retrieved_idx]
        output += f'![image]({image_url})\n'

    time_output = {'CLIP visual extractor': clip_image_time,
                   'CLIP textual extractor': clip_text_time,
                   'Phi projection': phi_time,
                   'CLIP retrieval': retrieval_time,
                   }
    setup_output = {'device': model_dict['device'],
                    'dtype': model_dict['dtype'],
                    'Phi': model_name,
                    'CLIP': model_dict['clip_model_name'],
                    }

    return {'time': time_output, 'setup': setup_output}, output


def test_fps(batch_size=1):
    dummy_images = torch.rand([batch_size, 3, 224, 224])

    todo_list = ['phi', 'phi_pic2word']

    input_tokens = model_dict['tokenizer'](text=['a photo of $1 with flowers'] * batch_size, return_tensors='pt', padding='max_length', truncation=True)['input_ids'].to(model_dict['device'])
    input_tokens = torch.where(input_tokens == 49409,
                               torch.ones_like(input_tokens) * 259,
                               input_tokens)

    for model_name in todo_list:
        time_array = []
        n_repeat = 100
        for _ in range(n_repeat):
            start_time = time.time()
            image_features = model_dict['clip_vision_model'](pixel_values=dummy_images.to(model_dict['clip_vision_model'].device, dtype=model_dict['clip_vision_model'].dtype)).image_embeds
            token_embeddings = model_dict[model_name](image_features)
            text_embeddings = encode_with_pseudo_tokens_HF(model_dict['clip_text_model'], input_tokens, token_embeddings)
            end_time = time.time()
            if _ > 5:
                time_array.append(end_time - start_time)
        print(f"{model_name}: {np.mean(time_array):.4f}")


if __name__ == '__main__':
    args = parse_args()

    global model_dict, faiss_index, image_urls

    model_dict = load_models(args)

    if args.test_fps:
        # check FPS of all models.
        test_fps(1)
        exit()

    faiss_index = faiss.read_index('./clip_large.index', faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)

    image_urls = json.load(open('./image_urls.json'))

    title = 'Zeroshot CIR demo to search high-quality AI images'

    md_title = f'''# {title}
    [LinCIR](https://arxiv.org/abs/2312.01998): Language-only Training of Zero-shot Composed Image Retrieval  
    [SEARLE](https://arxiv.org/abs/2303.15247): Zero-shot Composed Image Retrieval with Textual Inversion  
    [Pic2Word](https://arxiv.org/abs/2302.03084): Mapping Pictures to Words for Zero-shot Composed Image Retrieval  

    K-NN index for the retrieval results are entirely trained using [the upscaled midjourney v5 images (444,901)](https://huggingface.co/datasets/wanng/midjourney-v5-202304-clean).
    '''

    with gr.Blocks(title=title) as demo:
        gr.Markdown(md_title)
        with gr.Row():
            with gr.Column():
                with gr.Row():
                    image_source = gr.Image(type='pil', label='image1')
                model_name = gr.Radio(['lincir', 'searle', 'pic2word'], label='Phi model', value='lincir')
                text_input = gr.Textbox(value='', label='Input text guidance. Special token is $')
                submit_button = gr.Button('Submit')
                gr.Examples([["example1.jpg", "$, pencil sketch", 'lincir']], inputs=[image_source, text_input, model_name])
            with gr.Column():
                json_output = gr.JSON(label='Processing time')
                md_output = gr.Markdown(label='Output')

        submit_button.click(predict, inputs=[image_source, text_input, model_name], outputs=[json_output, md_output])

    demo.queue()

    demo.launch()