import gradio as gr
import random
from datasets import load_dataset, Features, Value
import json 
import os
import uuid
from huggingface_hub import HfApi
import time 
# # Sample dataset with unique 10-digit IDs
# qa_dataset = {
#     "1234567890": {
#         "question": "What is the capital of France?",
#         "choices": ["A. Berlin", "B. Madrid", "C. Paris", "D. Lisbon"],
#         "answer": "C. Paris"
#     },
#     "0987654321": {
#         "question": "What is the largest planet in our solar system?",
#         "choices": ["A. Earth", "B. Jupiter", "C. Saturn", "D. Mars", "E. Venus"],
#         "answer": "B. Jupiter"
#     },
#     # Add more questions with unique IDs as needed
# }

truth_data = load_dataset("commonsense-index-dev/commonsense-candidates", "iter7-0520", split="train")

LAST_LOG_UPDATE = time.time()

logs = None 

qa_dataset = {}
for item in truth_data:
    qa_dataset[item["id"]] = {
        "question": item["task"],
        "choices": item["choices"],
        "answer": item["answer"]       
    }
    if "metadata" in item:
        qa_dataset[item["id"]]["reason"] = item["metadata"].get("reasoning", "N/A")

def update_logs():
    global LAST_LOG_UPDATE
    global logs
    # if logs is None or time.time() - LAST_LOG_UPDATE > 3600:
    if logs is None:
        # update logs for every 60 minutes 
        logs = load_dataset("commonsense-index-dev/DemoFeedback", split="train")
        LAST_LOG_UPDATE = time.time()

update_logs()

def get_random_question(user_name="Anonymous"):
    global logs
    update_logs()
    # if user_name == "":
    #     user_name = "Anonymous"
    #     question_id = random.choice(list(qa_dataset.keys()))
    # else:
    # logs = load_dataset("commonsense-index-dev/DemoFeedback", split="train")
    feedback_counts = {qid: 0 for qid in qa_dataset.keys()}
    user_seen_data = set()
    for item in logs:   
        feedback_counts[item["question_id"]] += 1
        if True or item["user_name"] == user_name: # TODO: remove this when data is almost all annotated
            user_seen_data.add(item["question_id"])
    # sample a question that has the least feedback, and if there are multiple, sample randomly
    min_feedback = min(feedback_counts.values())
    question_ids = [k for k, v in feedback_counts.items() if v == min_feedback]
    question_ids = list(set(question_ids) - user_seen_data)
    question_id = random.choice(question_ids)
    question_data = qa_dataset[question_id]
    reasoning = question_data["reason"]
    return question_id, question_data["question"], question_data["choices"], reasoning

def get_question_by_id(question_id):
    if question_id in qa_dataset:
        question_data = qa_dataset[question_id]
        return question_id, question_data["question"], question_data["choices"]
    else:
        return None, "Invalid question ID", []

def check_answer(question_id, choice, reasoning):
    correct_answer = qa_dataset[question_id]["answer"]
    text =  ""
    if choice[3:] == correct_answer:
        text += "### ✅ Correct!"
        text += "\n### Reasoning: " + reasoning
    else:
        text += "### ❌ Incorrect. Try again!"
    return text

def load_question(question_id=None, user_name="Anonymous"): 
    question_id, question, choices, reasoning = get_random_question(user_name) 
    question = f"---\n#### QID: {question_id}\n## {question} \n---"
    choices_markdown = "\n".join(choices)
    return question_id, question, choices_markdown, \
            gr.update(value="", visible=True), reasoning, \
            gr.update(value="", visible=True), \
            gr.update(value="Submit your feedback! 🚀", interactive=True)

def show_buttons(choices_markdown):
    choices = choices_markdown.split("\n")
    visibility = [gr.update(visible=False)] * 10
    for i in range(len(choices)):
        # generate ABCDEFGHIJ labels 
        choices[i] = chr(65 + i) + ") " + choices[i]
        visibility[i] = gr.update(visible=True, value=choices[i])
    
    return visibility


def submit_feedback(question_id, user_reason, user_revision, example_quality, user_name_text):
    if "N/A" in question_id or "N/A" in example_quality:
        # send a message to the user to sample an example and select a choice first 
        return {
            submit_button: {"interactive": True, "__type__": "update", "value": "Submit your feedback! 🚀 Please sample an example and select a choice!"},
        } 
    # create a jsonl file and upload it to hf 
    if user_name_text == "":
        user_name_text = "Anonymous"
    feedback_item = {
        "question_id": question_id,
        "user_name": user_name_text, 
        "user_reason": user_reason,
        "revision": user_revision,
        "example_quality": example_quality,
    }
    jsonl_str = json.dumps(feedback_item)
    api = HfApi()
    token = os.getenv("HF_TOKEN")
    if token is None:
        raise ValueError("Hugging Face token not found. Ensure the HF_TOKEN environment variable is set.")

    # Generate a random filename using UUID
    filename = f"{uuid.uuid4()}.json"

    # Define the repository
    repo_id = "commonsense-index-dev/DemoFeedback"

    # Upload the json_str as a file directly to the specified path in your dataset repository
    api.upload_file(
        token=token,
        repo_id=repo_id,
        repo_type="dataset",
        path_or_fileobj=jsonl_str.encode("utf-8"),  # Convert string to bytes
        path_in_repo=filename,
        commit_message=f"{user_name_text}'s feedback on {question_id}",
    )
    return {
        submit_button: {"interactive": False, "__type__": "update", "value": "Submitted! ✅ \n Please sample the next one."}
    }   

def refresh_feedback(question_id):
    return gr.update(value="", visible=True), gr.update(value="", visible=True), gr.update(value="", visible=True), gr.update(value="", visible=True)

with gr.Blocks() as app:
    gr.Markdown("# Commonsense Index Data Viewer")

    with gr.Row():
        # question_id_input = gr.Textbox(label="Enter Question ID", placeholder="leave empty for random sampling")
        random_button = gr.Button("🎲 Click here to randomly sample an example")

    question_display = gr.Markdown(visible=True)
    choices_markdown = gr.Markdown(visible=False)
    choice_buttons = [gr.Button(visible=False) for _ in range(10)]
    result_display = gr.Markdown(visible=True)
    reasoning_display = gr.Markdown(visible=False)

    question_id = gr.Textbox(label="Question ID:", interactive=False, visible=False)


    with gr.Row():
        with gr.Column(scale=2):
            reason_textbox = gr.Textbox(label="Reason", placeholder="Please talk why the correct answer is correct and why the others are wrong. If you think this is a bad example, please explain too.", type="text", elem_classes="", max_lines=5, lines=3, show_copy_button=False, visible=True, scale=4, interactive=True)
            revision_textbox = gr.Textbox(label="Revision", placeholder="Please suggest a revision to the question.", type="text", elem_classes="", max_lines=5, lines=3, show_copy_button=False, visible=True, scale=4, interactive=True)
        with gr.Column():
            example_quality = gr.Radio(label="Quality", choices=["Good", "Bad"], interactive=True, visible=True)
            user_name = gr.Textbox(label="Your username", placeholder="Your username", type="text", elem_classes="", max_lines=1, show_copy_button=False, visible=True, interactive=True, show_label=False)
            submit_button = gr.Button("Submit your feedback! 🚀", elem_classes="btn_boderline", visible=True, interactive=True)

     
    random_button.click(fn=load_question, inputs=[user_name], outputs=[question_id, question_display, choices_markdown, result_display, reasoning_display, example_quality, submit_button])
    choices_markdown.change(fn=show_buttons, inputs=choices_markdown, outputs=choice_buttons)
    question_id.change(fn=refresh_feedback, inputs=[question_id], outputs=[reason_textbox, revision_textbox, example_quality])
    submit_button.click(fn=submit_feedback, inputs=[question_id, reason_textbox, revision_textbox, example_quality, user_name], outputs=[submit_button])
    for i, button in enumerate(choice_buttons):
        button.click(fn=check_answer, inputs=[question_id, button, reasoning_display], outputs=result_display)

app.launch()