import gradio as gr
import asyncio
import os
import json
import urllib.request
from openai import AsyncOpenAI, OpenAI

# 第一个功能:检查YouTube视频是否具有Creative Commons许可证

def get_youtube_id(youtube_url):
    if 'youtube.com' in youtube_url:
        video_id = youtube_url.split('v=')[-1]
        video_id = video_id.split('&')[0]  # 移除可能的额外参数
    elif 'youtu.be' in youtube_url:
        video_id = youtube_url.split('/')[-1].split('?')[0]
    else:
        video_id = ''
    return video_id

def check_cc_license(youtube_url):
    # 从URL中提取视频ID
    video_id = get_youtube_id(youtube_url)
    if not video_id:
        return "Invalid YouTube URL."
    
    # YouTube Data API URL,用于获取视频详情
    api_url = f'https://www.googleapis.com/youtube/v3/videos?id={video_id}&part=status&key={API_KEY}'
    
    try:
        # 获取视频详情
        response = urllib.request.urlopen(api_url)
        data = json.load(response)
        
        # 检查许可证状态
        if 'items' in data and len(data['items']) > 0:
            item = data['items'][0]
            if item['status']['license'] == 'creativeCommon':
                return "Yes."
            else:
                return "No."
        else:
            return "Video not found."
                
    except Exception as e:
        return f"An error occurred: {str(e)}"

# 第二个功能:为多项选择题生成干扰项



# 从您的模块中导入必要的函数
from utils.generate_distractors import prepare_q_inputs, construct_prompt_textonly, generate_distractors
from utils.api_utils import generate_from_openai_chat_completion

def generate_distractors_sync(model_name: str, 
                              queries: list,
                              n: int=1,
                              max_tokens: int=4096):
    assert model_name in ["gpt-4o-mini", "gpt-4-turbo", "gpt-4o", "gpt-4o-2024-08-06"], "Invalid model name"

    client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url="https://yanlp.zeabur.app/v1")
    messages = prepare_q_inputs(queries)

    # 同步调用,不使用异步函数
    responses = generate_from_openai_chat_completion(
        client,
        messages=messages, 
        engine_name=model_name,
        n=n,
        max_tokens=max_tokens,
        requests_per_minute=30,
        json_format=True
    )

    for query, response in zip(queries, responses):
        new_options = response
        if new_options and "distractors" in new_options:
            query["option_5"] = new_options["distractors"].get("E", "")
            query["option_6"] = new_options["distractors"].get("F", "")
            query["option_7"] = new_options["distractors"].get("G", "")
            query["distractor_analysis"] = new_options["distractors"].get("analysis_of_distractors", "")
        else:
            query["option_5"] = ""
            query["option_6"] = ""
            query["option_7"] = ""
            query["distractor_analysis"] = ""

    return queries

# 处理生成干扰项的同步版本
def generate_distractors_gradio(question, option1, option2, option3, option4, answer, answer_analysis):
    is_valid, message = validate_inputs(question, option1, option2, option3, option4, answer, answer_analysis)
    if not is_valid:
        return {"error": message}, ""  # Output error message
    
    query = {
        'question': question,
        'option_1': option1,
        'option_2': option2,
        'option_3': option3,
        'option_4': option4,
        'answer': answer,
        'answer_analysis': answer_analysis
    }

    queries = [query]

    # 调用同步生成干扰项的函数
    results = generate_distractors_sync(
        model_name="gpt-4o",
        queries=queries,
        n=1,
        max_tokens=4096
    )

    result = results[0]
    new_options = {
        'E': result.get('option_5', ''),
        'F': result.get('option_6', ''),
        'G': result.get('option_7', '')
    }
    new_option_str = f"E: {new_options['E']}\nF:{new_options['F']}\nG:{new_options['G']}"
    distractor_analysis = result.get('distractor_analysis', '')

    return new_option_str, distractor_analysis

def validate_inputs(question, option1, option2, option3, option4,  answer, analysis):
    if not question:
        return False, "问题不能为空"
    if not option1:
        return False, "选项A不能为空"
    if not option2:
        return False, "选项B不能为空"
    if not option3:
        return False, "选项C不能为空"
    if not option4:
        return False, "选项D不能为空"
    if not answer:
        return False, "正确答案不能为空"
    if not analysis:
        return False, "答案解析不能为空"
    return True, ""

with gr.Blocks() as demo:
    gr.Markdown("# CC检查器和干扰项生成器")
    
    with gr.Tabs():
        with gr.TabItem("YouTube Creative Commons检查器"):
            gr.Markdown("## 检查YouTube视频是否具有Creative Commons许可证")
            youtube_url_input = gr.Textbox(label="YouTube视频URL")
            cc_license_output = gr.Textbox(label="是否为Creative Commons许可证?")
            check_button = gr.Button("检查许可证")
            check_button.click(
                fn=check_cc_license,
                inputs=youtube_url_input,
                outputs=cc_license_output
            )
        with gr.TabItem("多项选择题干扰项生成器"):
            gr.Markdown("## 为多项选择题生成干扰项")
            with gr.Row():
                question_input = gr.Textbox(label="问题", lines=2)
            with gr.Row():
                option1_input = gr.Textbox(label="选项A")
                option2_input = gr.Textbox(label="选项B")
                option3_input = gr.Textbox(label="选项C")
                option4_input = gr.Textbox(label="选项D")
            with gr.Row():
                answer_input = gr.Textbox(label="正确答案(A/B/C/D)")
            with gr.Row():
                answer_analysis_input = gr.Textbox(label="答案解析", lines=3)
            generate_button = gr.Button("生成干扰项")
            output_options = gr.Textbox(label="生成的干扰选项")
            output_analysis = gr.Textbox(label="干扰项解析", lines=5)
            generate_button.click(
                fn=generate_distractors_gradio,
                inputs=[question_input, option1_input, option2_input, option3_input, option4_input, answer_input, answer_analysis_input],
                outputs=[output_options, output_analysis]
            )

# 运行Gradio应用
demo.launch()