Spaces:
Sleeping
Sleeping
谢璐璟
commited on
Commit
•
8391956
1
Parent(s):
bee8e94
sync
Browse files
utils/__pycache__/api_utils.cpython-310.pyc
CHANGED
Binary files a/utils/__pycache__/api_utils.cpython-310.pyc and b/utils/__pycache__/api_utils.cpython-310.pyc differ
|
|
utils/__pycache__/generate_distractors.cpython-310.pyc
CHANGED
Binary files a/utils/__pycache__/generate_distractors.cpython-310.pyc and b/utils/__pycache__/generate_distractors.cpython-310.pyc differ
|
|
utils/api_utils.py
CHANGED
@@ -1,68 +1,59 @@
|
|
1 |
import base64
|
2 |
import numpy as np
|
3 |
-
from typing import Dict
|
4 |
import random
|
5 |
-
|
6 |
-
import asyncio
|
7 |
import logging
|
8 |
-
import os
|
9 |
-
|
10 |
-
|
11 |
-
from
|
12 |
-
import random
|
13 |
from time import sleep
|
14 |
|
15 |
-
import aiolimiter
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
from anthropic import AsyncAnthropic
|
20 |
-
|
21 |
-
async def _throttled_openai_chat_completion_acreate(
|
22 |
-
client: AsyncOpenAI,
|
23 |
model: str,
|
24 |
messages,
|
25 |
temperature: float,
|
26 |
max_tokens: int,
|
27 |
top_p: float,
|
28 |
-
limiter: aiolimiter.AsyncLimiter,
|
29 |
json_format: bool = False,
|
30 |
n: int = 1,
|
31 |
):
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
64 |
|
65 |
-
|
66 |
client,
|
67 |
messages,
|
68 |
engine_name: str,
|
@@ -73,26 +64,24 @@ async def generate_from_openai_chat_completion(
|
|
73 |
json_format: bool = False,
|
74 |
n: int = 1,
|
75 |
):
|
76 |
-
|
77 |
delay = 60.0 / requests_per_minute
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
81 |
client,
|
82 |
model=engine_name,
|
83 |
messages=message,
|
84 |
temperature=temperature,
|
85 |
max_tokens=max_tokens,
|
86 |
top_p=top_p,
|
87 |
-
limiter=limiter,
|
88 |
json_format=json_format,
|
89 |
n=n,
|
90 |
)
|
91 |
for message in messages
|
92 |
]
|
93 |
-
|
94 |
-
responses = await tqdm_asyncio.gather(*async_responses)
|
95 |
-
|
96 |
empty_dict = {
|
97 |
"question": "",
|
98 |
"options": {
|
@@ -107,7 +96,7 @@ async def generate_from_openai_chat_completion(
|
|
107 |
"G": "",
|
108 |
},
|
109 |
"correct_answer": ""
|
110 |
-
|
111 |
empty_str = ""
|
112 |
outputs = []
|
113 |
for response in responses:
|
@@ -135,77 +124,3 @@ async def generate_from_openai_chat_completion(
|
|
135 |
])
|
136 |
return outputs
|
137 |
|
138 |
-
async def _throttled_claude_chat_completion_acreate(
|
139 |
-
client: AsyncAnthropic,
|
140 |
-
model: str,
|
141 |
-
messages,
|
142 |
-
temperature: float,
|
143 |
-
max_tokens: int,
|
144 |
-
top_p: float,
|
145 |
-
limiter: aiolimiter.AsyncLimiter,
|
146 |
-
):
|
147 |
-
async with limiter:
|
148 |
-
try:
|
149 |
-
return await client.messages.create(
|
150 |
-
model=model,
|
151 |
-
messages=messages,
|
152 |
-
temperature=temperature,
|
153 |
-
max_tokens=max_tokens,
|
154 |
-
top_p=top_p,
|
155 |
-
)
|
156 |
-
except:
|
157 |
-
return None
|
158 |
-
|
159 |
-
async def generate_from_claude_chat_completion(
|
160 |
-
client,
|
161 |
-
messages,
|
162 |
-
engine_name: str,
|
163 |
-
temperature: float = 1.0,
|
164 |
-
max_tokens: int = 512,
|
165 |
-
top_p: float = 1.0,
|
166 |
-
requests_per_minute: int = 100,
|
167 |
-
n: int = 1,
|
168 |
-
):
|
169 |
-
# https://chat.openai.com/share/09154613-5f66-4c74-828b-7bd9384c2168
|
170 |
-
delay = 60.0 / requests_per_minute
|
171 |
-
limiter = aiolimiter.AsyncLimiter(1, delay)
|
172 |
-
|
173 |
-
n_messages = []
|
174 |
-
for message in messages:
|
175 |
-
for _ in range(n):
|
176 |
-
n_messages.append(message)
|
177 |
-
|
178 |
-
async_responses = [
|
179 |
-
_throttled_claude_chat_completion_acreate(
|
180 |
-
client,
|
181 |
-
model=engine_name,
|
182 |
-
messages=message,
|
183 |
-
temperature=temperature,
|
184 |
-
max_tokens=max_tokens,
|
185 |
-
top_p=top_p,
|
186 |
-
limiter=limiter,
|
187 |
-
)
|
188 |
-
for message in n_messages
|
189 |
-
]
|
190 |
-
|
191 |
-
responses = await tqdm_asyncio.gather(*async_responses)
|
192 |
-
|
193 |
-
outputs = []
|
194 |
-
if n == 1:
|
195 |
-
for response in responses:
|
196 |
-
if response and response.content and response.content[0] and response.content[0].text:
|
197 |
-
outputs.append(response.content[0].text)
|
198 |
-
else:
|
199 |
-
outputs.append("")
|
200 |
-
else:
|
201 |
-
idx = 0
|
202 |
-
for response in responses:
|
203 |
-
if idx % n == 0:
|
204 |
-
outputs.append([])
|
205 |
-
idx += 1
|
206 |
-
if response and response.content and response.content[0] and response.content[0].text:
|
207 |
-
outputs[-1].append(response.content[0].text)
|
208 |
-
else:
|
209 |
-
outputs[-1].append("")
|
210 |
-
|
211 |
-
return outputs
|
|
|
1 |
import base64
|
2 |
import numpy as np
|
|
|
3 |
import random
|
|
|
|
|
4 |
import logging
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import openai
|
8 |
+
from openai import OpenAIError
|
|
|
9 |
from time import sleep
|
10 |
|
|
|
11 |
|
12 |
+
def _throttled_openai_chat_completion_create(
|
13 |
+
client,
|
|
|
|
|
|
|
|
|
14 |
model: str,
|
15 |
messages,
|
16 |
temperature: float,
|
17 |
max_tokens: int,
|
18 |
top_p: float,
|
|
|
19 |
json_format: bool = False,
|
20 |
n: int = 1,
|
21 |
):
|
22 |
+
"""同步的OpenAI聊天补全函数,支持限流与重试"""
|
23 |
+
for _ in range(10): # 进行10次尝试
|
24 |
+
try:
|
25 |
+
if json_format:
|
26 |
+
return client.chat.completions.create(
|
27 |
+
model=model,
|
28 |
+
messages=messages,
|
29 |
+
temperature=temperature,
|
30 |
+
max_tokens=max_tokens,
|
31 |
+
top_p=top_p,
|
32 |
+
n=n,
|
33 |
+
response_format={"type": "json_object"},
|
34 |
+
)
|
35 |
+
else:
|
36 |
+
return client.chat.completions.create(
|
37 |
+
model=model,
|
38 |
+
messages=messages,
|
39 |
+
temperature=temperature,
|
40 |
+
max_tokens=max_tokens,
|
41 |
+
top_p=top_p,
|
42 |
+
n=n,
|
43 |
+
)
|
44 |
+
except openai.RateLimitError as e:
|
45 |
+
print("Rate limit exceeded, retrying...")
|
46 |
+
sleep(random.randint(10, 20)) # 增加重试等待时间
|
47 |
+
except openai.BadRequestError as e:
|
48 |
+
print(e)
|
49 |
+
return None
|
50 |
+
except OpenAIError as e:
|
51 |
+
print(e)
|
52 |
+
sleep(random.randint(5, 10))
|
53 |
+
return None
|
54 |
+
|
55 |
|
56 |
+
def generate_from_openai_chat_completion(
|
57 |
client,
|
58 |
messages,
|
59 |
engine_name: str,
|
|
|
64 |
json_format: bool = False,
|
65 |
n: int = 1,
|
66 |
):
|
67 |
+
"""同步生成OpenAI聊天补全"""
|
68 |
delay = 60.0 / requests_per_minute
|
69 |
+
sleep(delay) # 简单的限流处理
|
70 |
+
|
71 |
+
responses = [
|
72 |
+
_throttled_openai_chat_completion_create(
|
73 |
client,
|
74 |
model=engine_name,
|
75 |
messages=message,
|
76 |
temperature=temperature,
|
77 |
max_tokens=max_tokens,
|
78 |
top_p=top_p,
|
|
|
79 |
json_format=json_format,
|
80 |
n=n,
|
81 |
)
|
82 |
for message in messages
|
83 |
]
|
84 |
+
|
|
|
|
|
85 |
empty_dict = {
|
86 |
"question": "",
|
87 |
"options": {
|
|
|
96 |
"G": "",
|
97 |
},
|
98 |
"correct_answer": ""
|
99 |
+
}
|
100 |
empty_str = ""
|
101 |
outputs = []
|
102 |
for response in responses:
|
|
|
124 |
])
|
125 |
return outputs
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/generate_distractors.py
CHANGED
@@ -2,10 +2,9 @@ import json
|
|
2 |
import re
|
3 |
from tqdm import tqdm
|
4 |
import os
|
5 |
-
import
|
6 |
-
from openai import AsyncOpenAI
|
7 |
|
8 |
-
from utils.api_utils import generate_from_openai_chat_completion
|
9 |
|
10 |
|
11 |
def construct_prompt_textonly(question: str, options: list, answer: str, answer_analysis: str) -> str:
|
@@ -19,7 +18,7 @@ Generate a multiple-choice question with additional distractors that increase th
|
|
19 |
3. **Use Answer Analysis**: Reference the **correct answer analysis** when creating distractors to ensure they challenge **subject-matter experts**.
|
20 |
4. **Expert-Level Difficulty**: Keep the distractors **challenging and hard to distinguish** from the correct answer, requiring **advanced knowledge** to avoid the correct answer being too obvious.
|
21 |
5. **Balanced Length**: Ensure all options have **similar lengths** to prevent any one option from standing out.
|
22 |
-
6. **Distractors Analysis**: Provide a **distractor analysis in Chinese**, explaining why the distractors are **incorrect** but **challenging and hard to distinguish
|
23 |
|
24 |
Please output the result in valid JSON format using the structure below. Make sure there are no extra commas, missing commas, extra quotation marks or missing quotation marks:
|
25 |
{{
|
@@ -47,7 +46,6 @@ Answer: {answer}
|
|
47 |
Answer Analysis: {answer_analysis}
|
48 |
"""
|
49 |
|
50 |
-
# prompt = prompt.replace("I don't know.", "Idle.")
|
51 |
return prompt
|
52 |
|
53 |
|
@@ -75,84 +73,32 @@ def prepare_q_inputs(queries):
|
|
75 |
|
76 |
messages.append(prompt_message)
|
77 |
return messages
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
# def extract_json_from_text(text):
|
82 |
-
# text = json.dumps(text)
|
83 |
-
# # 移除转义符和换行符
|
84 |
-
# text = text.replace('\\n', '').replace('\\"', '"')
|
85 |
-
|
86 |
-
# # 定义匹配 JSON 对象的正则表达式模式
|
87 |
-
# json_pattern = re.compile(
|
88 |
-
# r'\{\s*"question":\s*"([^"]*)",\s*"options":\s*\{\s*"A":\s*"([^"]*)",\s*"B":\s*"([^"]*)",\s*"C":\s*"([^"]*)",\s*"D":\s*"([^"]*)"\s*\},'
|
89 |
-
# r'\s*"distractors":\s*\{\s*"E":\s*"([^"]*)",\s*"F":\s*"([^"]*)",\s*"G":\s*"([^"]*)"\s*\},\s*"correct_answer":\s*"([^"]*)"\s*\}',
|
90 |
-
# re.DOTALL
|
91 |
-
# )
|
92 |
-
|
93 |
-
# # 匹配 JSON 结构
|
94 |
-
# match = json_pattern.search(text)
|
95 |
-
|
96 |
-
# if match:
|
97 |
-
# # 捕获到的匹配组
|
98 |
-
# question = match.group(1)
|
99 |
-
# option_a = match.group(2)
|
100 |
-
# option_b = match.group(3)
|
101 |
-
# option_c = match.group(4)
|
102 |
-
# option_d = match.group(5)
|
103 |
-
# distractor_e = match.group(6)
|
104 |
-
# distractor_f = match.group(7)
|
105 |
-
# distractor_g = match.group(8)
|
106 |
-
# correct_answer = match.group(9)
|
107 |
-
|
108 |
-
# # 构建 JSON 对象
|
109 |
-
# json_data = {
|
110 |
-
# "question": question,
|
111 |
-
# "options": {
|
112 |
-
# "A": option_a,
|
113 |
-
# "B": option_b,
|
114 |
-
# "C": option_c,
|
115 |
-
# "D": option_d
|
116 |
-
# },
|
117 |
-
# "distractors": {
|
118 |
-
# "E": distractor_e,
|
119 |
-
# "F": distractor_f,
|
120 |
-
# "G": distractor_g
|
121 |
-
# },
|
122 |
-
# "correct_answer": correct_answer
|
123 |
-
# }
|
124 |
-
|
125 |
-
# return json_data
|
126 |
-
# else:
|
127 |
-
# print("No JSON object found in the text.")
|
128 |
-
# return None
|
129 |
|
130 |
|
131 |
def generate_distractors(model_name: str,
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
|
136 |
assert model_name in ["gpt-4o-mini", "gpt-4-turbo", "gpt-4o", "gpt-4o-2024-08-06"], "Invalid model name"
|
137 |
|
138 |
-
|
|
|
139 |
messages = prepare_q_inputs(queries)
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
)
|
151 |
)
|
152 |
|
153 |
for query, response in zip(queries, responses):
|
154 |
new_options = response
|
155 |
-
# print(new_options)
|
156 |
if new_options and "distractors" in new_options:
|
157 |
query["option_5"] = new_options["distractors"].get("E", "")
|
158 |
else:
|
@@ -170,10 +116,4 @@ def generate_distractors(model_name: str,
|
|
170 |
else:
|
171 |
query["distractor_analysis"] = ""
|
172 |
|
173 |
-
return queries
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
2 |
import re
|
3 |
from tqdm import tqdm
|
4 |
import os
|
5 |
+
from openai import OpenAI # 替换 AsyncOpenAI
|
|
|
6 |
|
7 |
+
from utils.api_utils import generate_from_openai_chat_completion
|
8 |
|
9 |
|
10 |
def construct_prompt_textonly(question: str, options: list, answer: str, answer_analysis: str) -> str:
|
|
|
18 |
3. **Use Answer Analysis**: Reference the **correct answer analysis** when creating distractors to ensure they challenge **subject-matter experts**.
|
19 |
4. **Expert-Level Difficulty**: Keep the distractors **challenging and hard to distinguish** from the correct answer, requiring **advanced knowledge** to avoid the correct answer being too obvious.
|
20 |
5. **Balanced Length**: Ensure all options have **similar lengths** to prevent any one option from standing out.
|
21 |
+
6. **Distractors Analysis**: Provide a **distractor analysis in Chinese**, explaining why the distractors are **incorrect** but **challenging and hard to distinguish**, based on the question, options, and answer analysis.
|
22 |
|
23 |
Please output the result in valid JSON format using the structure below. Make sure there are no extra commas, missing commas, extra quotation marks or missing quotation marks:
|
24 |
{{
|
|
|
46 |
Answer Analysis: {answer_analysis}
|
47 |
"""
|
48 |
|
|
|
49 |
return prompt
|
50 |
|
51 |
|
|
|
73 |
|
74 |
messages.append(prompt_message)
|
75 |
return messages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
|
78 |
def generate_distractors(model_name: str,
|
79 |
+
queries: list,
|
80 |
+
n: int=1,
|
81 |
+
max_tokens: int=4096):
|
82 |
|
83 |
assert model_name in ["gpt-4o-mini", "gpt-4-turbo", "gpt-4o", "gpt-4o-2024-08-06"], "Invalid model name"
|
84 |
|
85 |
+
# 改用同步版本的 OpenAI 客户端
|
86 |
+
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url="https://yanlp.zeabur.app/v1")
|
87 |
messages = prepare_q_inputs(queries)
|
88 |
|
89 |
+
# 直接调用同步的 `generate_from_openai_chat_completion_sync`
|
90 |
+
responses = generate_from_openai_chat_completion(
|
91 |
+
client,
|
92 |
+
messages=messages,
|
93 |
+
engine_name=model_name,
|
94 |
+
n=n,
|
95 |
+
max_tokens=max_tokens,
|
96 |
+
requests_per_minute=30,
|
97 |
+
json_format=True
|
|
|
98 |
)
|
99 |
|
100 |
for query, response in zip(queries, responses):
|
101 |
new_options = response
|
|
|
102 |
if new_options and "distractors" in new_options:
|
103 |
query["option_5"] = new_options["distractors"].get("E", "")
|
104 |
else:
|
|
|
116 |
else:
|
117 |
query["distractor_analysis"] = ""
|
118 |
|
119 |
+
return queries
|
|
|
|
|
|
|
|
|
|
|
|