谢璐璟 commited on
Commit
bee8e94
1 Parent(s): 23117b2
Files changed (1) hide show
  1. app.py +13 -18
app.py CHANGED
@@ -3,13 +3,10 @@ import asyncio
3
  import os
4
  import json
5
  import urllib.request
6
- from openai import AsyncOpenAI
7
 
8
  # 第一个功能:检查YouTube视频是否具有Creative Commons许可证
9
 
10
- # 请确保在环境变量中设置了您的YouTube Data API密钥
11
- API_KEY = "AIzaSyDyPpkFRUpUuSMQbhxwTFxCBLK5qTHU-ms"
12
-
13
  def get_youtube_id(youtube_url):
14
  if 'youtube.com' in youtube_url:
15
  video_id = youtube_url.split('v=')[-1]
@@ -55,17 +52,17 @@ def check_cc_license(youtube_url):
55
  from utils.generate_distractors import prepare_q_inputs, construct_prompt_textonly, generate_distractors
56
  from utils.api_utils import generate_from_openai_chat_completion
57
 
58
- async def generate_distractors_async(model_name: str,
59
- queries: list,
60
- n: int=1,
61
- max_tokens: int=4096):
62
  assert model_name in ["gpt-4o-mini", "gpt-4-turbo", "gpt-4o", "gpt-4o-2024-08-06"], "Invalid model name"
63
 
64
- client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url="https://yanlp.zeabur.app/v1")
65
  messages = prepare_q_inputs(queries)
66
 
67
- # 直接等待协程而不是使用asyncio.run()
68
- responses = await generate_from_openai_chat_completion(
69
  client,
70
  messages=messages,
71
  engine_name=model_name,
@@ -90,8 +87,8 @@ async def generate_distractors_async(model_name: str,
90
 
91
  return queries
92
 
93
- # 定义异步处理函数
94
- async def generate_distractors_gradio(question, option1, option2, option3, option4, answer, answer_analysis):
95
  is_valid, message = validate_inputs(question, option1, option2, option3, option4, answer, answer_analysis)
96
  if not is_valid:
97
  return {"error": message}, "" # Output error message
@@ -106,17 +103,16 @@ async def generate_distractors_gradio(question, option1, option2, option3, optio
106
  'answer_analysis': answer_analysis
107
  }
108
 
109
- queries = [query] # 因为函数期望的是一个列表
110
 
111
- # 调用异步生成干扰项的函数
112
- results = await generate_distractors_async(
113
  model_name="gpt-4o",
114
  queries=queries,
115
  n=1,
116
  max_tokens=4096
117
  )
118
 
119
- # 提取结果
120
  result = results[0]
121
  new_options = {
122
  'E': result.get('option_5', ''),
@@ -126,7 +122,6 @@ async def generate_distractors_gradio(question, option1, option2, option3, optio
126
  new_option_str = f"E: {new_options['E']}\nF:{new_options['F']}\nG:{new_options['G']}"
127
  distractor_analysis = result.get('distractor_analysis', '')
128
 
129
- # 返回新的干扰项和分析
130
  return new_option_str, distractor_analysis
131
 
132
  def validate_inputs(question, option1, option2, option3, option4, answer, analysis):
 
3
  import os
4
  import json
5
  import urllib.request
6
+ from openai import AsyncOpenAI, OpenAI
7
 
8
  # 第一个功能:检查YouTube视频是否具有Creative Commons许可证
9
 
 
 
 
10
  def get_youtube_id(youtube_url):
11
  if 'youtube.com' in youtube_url:
12
  video_id = youtube_url.split('v=')[-1]
 
52
  from utils.generate_distractors import prepare_q_inputs, construct_prompt_textonly, generate_distractors
53
  from utils.api_utils import generate_from_openai_chat_completion
54
 
55
+ def generate_distractors_sync(model_name: str,
56
+ queries: list,
57
+ n: int=1,
58
+ max_tokens: int=4096):
59
  assert model_name in ["gpt-4o-mini", "gpt-4-turbo", "gpt-4o", "gpt-4o-2024-08-06"], "Invalid model name"
60
 
61
+ client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url="https://yanlp.zeabur.app/v1")
62
  messages = prepare_q_inputs(queries)
63
 
64
+ # 同步调用,不使用异步函数
65
+ responses = generate_from_openai_chat_completion(
66
  client,
67
  messages=messages,
68
  engine_name=model_name,
 
87
 
88
  return queries
89
 
90
+ # 处理生成干扰项的同步版本
91
+ def generate_distractors_gradio(question, option1, option2, option3, option4, answer, answer_analysis):
92
  is_valid, message = validate_inputs(question, option1, option2, option3, option4, answer, answer_analysis)
93
  if not is_valid:
94
  return {"error": message}, "" # Output error message
 
103
  'answer_analysis': answer_analysis
104
  }
105
 
106
+ queries = [query]
107
 
108
+ # 调用同步生成干扰项的函数
109
+ results = generate_distractors_sync(
110
  model_name="gpt-4o",
111
  queries=queries,
112
  n=1,
113
  max_tokens=4096
114
  )
115
 
 
116
  result = results[0]
117
  new_options = {
118
  'E': result.get('option_5', ''),
 
122
  new_option_str = f"E: {new_options['E']}\nF:{new_options['F']}\nG:{new_options['G']}"
123
  distractor_analysis = result.get('distractor_analysis', '')
124
 
 
125
  return new_option_str, distractor_analysis
126
 
127
  def validate_inputs(question, option1, option2, option3, option4, answer, analysis):