import torch

import gradio as gr
import argparse
from utils import load_hyperparam, load_model
from models.tokenize import Tokenizer
from models.llama import *
from generate import LmGeneration

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

args = None
lm_generation = None

def init_args():
    global args
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    args = parser.parse_args()
    args.load_model_path = './model_file/chatllama_7b.bin'
    args.config_path = './config/llama_7b.json'
    #args.load_model_path = './model_file/chatflow_13b.bin'
    #args.config_path = './config/llama_13b_config.json'
    args.spm_model_path = './model_file/tokenizer.model'
    args.batch_size = 1
    args.seq_length = 1024
    args.world_size = 1
    args.use_int8 = False
    args.top_p = 0
    args.repetition_penalty_range = 1024
    args.repetition_penalty_slope = 0
    args.repetition_penalty = 1.15

    args = load_hyperparam(args)

    args.tokenizer = Tokenizer(model_path=args.spm_model_path)
    args.vocab_size = args.tokenizer.sp_model.vocab_size()


def init_model():
    global lm_generation
    torch.set_default_tensor_type(torch.HalfTensor)
    model = LLaMa(args)
    torch.set_default_tensor_type(torch.FloatTensor)
    model = load_model(model, args.load_model_path)
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(torch.cuda.max_memory_allocated() / 1024 ** 3)
    lm_generation = LmGeneration(model, args.tokenizer)


def chat(prompt, top_k, temperature):
    args.top_k = int(top_k)
    args.temperature = temperature
    response = lm_generation.generate(args, [prompt])
    print(response[0])
    return response[0]


if __name__ == '__main__':
    init_args()
    init_model()
    demo = gr.Interface(
        fn=chat,
        inputs=["text", gr.Slider(1, 60, value=10, step=1), gr.Slider(0.1, 2.0, value=1.0, step=0.1)],
        outputs="text",
    )
    demo.launch()