import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
import random
from datasets import load_dataset

HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = "aixsatoshi/Meta-Llama-3.1-8B-Instruct-plus-Swallow"
MODELS = os.environ.get("MODELS")
MODEL_NAME = MODEL_ID.split("/")[-1]

TITLE = "<h1><center>New japanese LLM model webui</center></h1>"

DESCRIPTION = f"""
<h3>MODEL: <a href="https://huggingface.co/aixsatoshi/Meta-Llama-3.1-8B-Instruct-plus-Swallow">Meta-Llama-3.1-8B-Instruct-plus-Swallow</a></h3>
<center>
<p>aixsatoshi/Meta-Llama-3.1-8B-Instruct-plus-Swallow is the merged model.
<br>
Feel free to test without log.
</p>
</center>
"""

CSS = """
.duplicate-button {
    margin: auto !important;
    color: white !important;
    background: black !important;
    border-radius: 100vh !important;
}
h3 {
    text-align: center;
}
.chatbox .messages .message.user {
    background-color: #e1f5fe;
}
.chatbox .messages .message.bot {
    background-color: #eeeeee;
}
"""

# モデルとトークナイザーの読み込み
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# データセットをロードしてスプリットを確認
dataset = load_dataset("elyza/ELYZA-tasks-100")
print(dataset)

# 使用するスプリット名を確認
split_name = "train" if "train" in dataset else "test"  # デフォルトをtrainにし、なければtestにフォールバック

# 適切なスプリットから10個の例を取得
examples_list = list(dataset[split_name])  # スプリットをリストに変換
examples = random.sample(examples_list, 10)  # リストからランダムに10個選択
example_inputs = [[example['input']] for example in examples]  # ネストされたリストに変換

@spaces.GPU
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
    print(f'message is - {message}')
    print(f'history is - {history}')
    conversation = []
    for prompt, answer in history:
        conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(input_ids, return_tensors="pt").to(0)
    
    streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        inputs, 
        streamer=streamer,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=penalty,
        max_new_tokens=max_new_tokens, 
        do_sample=True, 
        temperature=temperature,
        eos_token_id=[128001, 128009],
    )
    
    thread = Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    buffer = ""
    for new_text in streamer:
        buffer += new_text
        yield buffer

chatbot = gr.Chatbot(height=500)

with gr.Blocks(css=CSS) as demo:
    gr.HTML(TITLE)
    gr.HTML(DESCRIPTION)
    gr.ChatInterface(
        fn=stream_chat,
        chatbot=chatbot,
        fill_height=True,
        theme="soft",
        retry_btn=None,
        undo_btn="Delete Previous",
        clear_btn="Clear",
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Slider(
                minimum=0,
                maximum=1,
                step=0.1,
                value=0.8,
                label="Temperature",
                render=False,
            ),
            gr.Slider(
                minimum=128,
                maximum=4096,
                step=1,
                value=1024,
                label="Max new tokens",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=1.0,
                step=0.1,
                value=0.8,
                label="top_p",
                render=False,
            ),
            gr.Slider(
                minimum=1,
                maximum=20,
                step=1,
                value=20,
                label="top_k",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=2.0,
                step=0.1,
                value=1.0,
                label="Repetition penalty",
                render=False,
            ),
        ],
        #examples=example_inputs,  # ネストされたリストを渡す
        examples=[
            ["Give me five ideas for a child's summer science project."],
            ["Create a tutorial for building a breakout game using markdown."],
            ["超能力を持つ主人公のSF物語のシナリオを考えてください。伏線の設定、テーマやログラインを理論的に使用してください"],
            ["子供の夏休みの自由研究のための、5つのアイデアと、その手法を簡潔に教えてください。"],
            ["パズルゲームのスクリプト作成のためにアドバイスお願いします"],
            ["マークダウン記法にて、ブロック崩しのゲーム作成の教科書作成してください"],
            ["お笑いのトンチ大会のお題を考えてください"],
            ["日本語の慣用句、ことわざについての試験問題を考えてください"],
            ["ドラえもんの登場人物教えて"],
            ["お好み焼きの作り方教えてください"],
        ],
        cache_examples=False,
    )

if __name__ == "__main__":
    demo.launch()