谢璐璟 commited on
Commit
8391956
1 Parent(s): bee8e94
utils/__pycache__/api_utils.cpython-310.pyc CHANGED
Binary files a/utils/__pycache__/api_utils.cpython-310.pyc and b/utils/__pycache__/api_utils.cpython-310.pyc differ
 
utils/__pycache__/generate_distractors.cpython-310.pyc CHANGED
Binary files a/utils/__pycache__/generate_distractors.cpython-310.pyc and b/utils/__pycache__/generate_distractors.cpython-310.pyc differ
 
utils/api_utils.py CHANGED
@@ -1,68 +1,59 @@
1
  import base64
2
  import numpy as np
3
- from typing import Dict
4
  import random
5
-
6
- import asyncio
7
  import logging
8
- import os, json
9
- from typing import Any
10
- from aiohttp import ClientSession
11
- from tqdm.asyncio import tqdm_asyncio
12
- import random
13
  from time import sleep
14
 
15
- import aiolimiter
16
 
17
- import openai
18
- from openai import AsyncOpenAI, OpenAIError
19
- from anthropic import AsyncAnthropic
20
-
21
- async def _throttled_openai_chat_completion_acreate(
22
- client: AsyncOpenAI,
23
  model: str,
24
  messages,
25
  temperature: float,
26
  max_tokens: int,
27
  top_p: float,
28
- limiter: aiolimiter.AsyncLimiter,
29
  json_format: bool = False,
30
  n: int = 1,
31
  ):
32
- async with limiter:
33
- for _ in range(10):
34
- try:
35
- if json_format:
36
- return await client.chat.completions.create(
37
- model=model,
38
- messages=messages,
39
- temperature=temperature,
40
- max_tokens=max_tokens,
41
- top_p=top_p,
42
- n=n,
43
- response_format={"type": "json_object"},
44
- )
45
- else:
46
- return await client.chat.completions.create(
47
- model=model,
48
- messages=messages,
49
- temperature=temperature,
50
- max_tokens=max_tokens,
51
- top_p=top_p,
52
- n=n,
53
- )
54
- except openai.RateLimitError as e:
55
- print("Rate limit exceeded, retrying...")
56
- sleep(random.randint(10, 20)) # 增加重试等待时间
57
- except openai.BadRequestError as e:
58
- print(e)
59
- return None
60
- except OpenAIError as e:
61
- print(e)
62
- sleep(random.randint(5, 10))
63
- return None
 
64
 
65
- async def generate_from_openai_chat_completion(
66
  client,
67
  messages,
68
  engine_name: str,
@@ -73,26 +64,24 @@ async def generate_from_openai_chat_completion(
73
  json_format: bool = False,
74
  n: int = 1,
75
  ):
76
- # https://chat.openai.com/share/09154613-5f66-4c74-828b-7bd9384c2168
77
  delay = 60.0 / requests_per_minute
78
- limiter = aiolimiter.AsyncLimiter(1, delay)
79
- async_responses = [
80
- _throttled_openai_chat_completion_acreate(
 
81
  client,
82
  model=engine_name,
83
  messages=message,
84
  temperature=temperature,
85
  max_tokens=max_tokens,
86
  top_p=top_p,
87
- limiter=limiter,
88
  json_format=json_format,
89
  n=n,
90
  )
91
  for message in messages
92
  ]
93
-
94
- responses = await tqdm_asyncio.gather(*async_responses)
95
-
96
  empty_dict = {
97
  "question": "",
98
  "options": {
@@ -107,7 +96,7 @@ async def generate_from_openai_chat_completion(
107
  "G": "",
108
  },
109
  "correct_answer": ""
110
- }
111
  empty_str = ""
112
  outputs = []
113
  for response in responses:
@@ -135,77 +124,3 @@ async def generate_from_openai_chat_completion(
135
  ])
136
  return outputs
137
 
138
- async def _throttled_claude_chat_completion_acreate(
139
- client: AsyncAnthropic,
140
- model: str,
141
- messages,
142
- temperature: float,
143
- max_tokens: int,
144
- top_p: float,
145
- limiter: aiolimiter.AsyncLimiter,
146
- ):
147
- async with limiter:
148
- try:
149
- return await client.messages.create(
150
- model=model,
151
- messages=messages,
152
- temperature=temperature,
153
- max_tokens=max_tokens,
154
- top_p=top_p,
155
- )
156
- except:
157
- return None
158
-
159
- async def generate_from_claude_chat_completion(
160
- client,
161
- messages,
162
- engine_name: str,
163
- temperature: float = 1.0,
164
- max_tokens: int = 512,
165
- top_p: float = 1.0,
166
- requests_per_minute: int = 100,
167
- n: int = 1,
168
- ):
169
- # https://chat.openai.com/share/09154613-5f66-4c74-828b-7bd9384c2168
170
- delay = 60.0 / requests_per_minute
171
- limiter = aiolimiter.AsyncLimiter(1, delay)
172
-
173
- n_messages = []
174
- for message in messages:
175
- for _ in range(n):
176
- n_messages.append(message)
177
-
178
- async_responses = [
179
- _throttled_claude_chat_completion_acreate(
180
- client,
181
- model=engine_name,
182
- messages=message,
183
- temperature=temperature,
184
- max_tokens=max_tokens,
185
- top_p=top_p,
186
- limiter=limiter,
187
- )
188
- for message in n_messages
189
- ]
190
-
191
- responses = await tqdm_asyncio.gather(*async_responses)
192
-
193
- outputs = []
194
- if n == 1:
195
- for response in responses:
196
- if response and response.content and response.content[0] and response.content[0].text:
197
- outputs.append(response.content[0].text)
198
- else:
199
- outputs.append("")
200
- else:
201
- idx = 0
202
- for response in responses:
203
- if idx % n == 0:
204
- outputs.append([])
205
- idx += 1
206
- if response and response.content and response.content[0] and response.content[0].text:
207
- outputs[-1].append(response.content[0].text)
208
- else:
209
- outputs[-1].append("")
210
-
211
- return outputs
 
1
  import base64
2
  import numpy as np
 
3
  import random
 
 
4
  import logging
5
+ import os
6
+ import json
7
+ import openai
8
+ from openai import OpenAIError
 
9
  from time import sleep
10
 
 
11
 
12
+ def _throttled_openai_chat_completion_create(
13
+ client,
 
 
 
 
14
  model: str,
15
  messages,
16
  temperature: float,
17
  max_tokens: int,
18
  top_p: float,
 
19
  json_format: bool = False,
20
  n: int = 1,
21
  ):
22
+ """同步的OpenAI聊天补全函数,支持限流与重试"""
23
+ for _ in range(10): # 进行10次尝试
24
+ try:
25
+ if json_format:
26
+ return client.chat.completions.create(
27
+ model=model,
28
+ messages=messages,
29
+ temperature=temperature,
30
+ max_tokens=max_tokens,
31
+ top_p=top_p,
32
+ n=n,
33
+ response_format={"type": "json_object"},
34
+ )
35
+ else:
36
+ return client.chat.completions.create(
37
+ model=model,
38
+ messages=messages,
39
+ temperature=temperature,
40
+ max_tokens=max_tokens,
41
+ top_p=top_p,
42
+ n=n,
43
+ )
44
+ except openai.RateLimitError as e:
45
+ print("Rate limit exceeded, retrying...")
46
+ sleep(random.randint(10, 20)) # 增加重试等待时间
47
+ except openai.BadRequestError as e:
48
+ print(e)
49
+ return None
50
+ except OpenAIError as e:
51
+ print(e)
52
+ sleep(random.randint(5, 10))
53
+ return None
54
+
55
 
56
+ def generate_from_openai_chat_completion(
57
  client,
58
  messages,
59
  engine_name: str,
 
64
  json_format: bool = False,
65
  n: int = 1,
66
  ):
67
+ """同步生成OpenAI聊天补全"""
68
  delay = 60.0 / requests_per_minute
69
+ sleep(delay) # 简单的限流处理
70
+
71
+ responses = [
72
+ _throttled_openai_chat_completion_create(
73
  client,
74
  model=engine_name,
75
  messages=message,
76
  temperature=temperature,
77
  max_tokens=max_tokens,
78
  top_p=top_p,
 
79
  json_format=json_format,
80
  n=n,
81
  )
82
  for message in messages
83
  ]
84
+
 
 
85
  empty_dict = {
86
  "question": "",
87
  "options": {
 
96
  "G": "",
97
  },
98
  "correct_answer": ""
99
+ }
100
  empty_str = ""
101
  outputs = []
102
  for response in responses:
 
124
  ])
125
  return outputs
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/generate_distractors.py CHANGED
@@ -2,10 +2,9 @@ import json
2
  import re
3
  from tqdm import tqdm
4
  import os
5
- import asyncio
6
- from openai import AsyncOpenAI
7
 
8
- from utils.api_utils import generate_from_openai_chat_completion, generate_from_claude_chat_completion
9
 
10
 
11
  def construct_prompt_textonly(question: str, options: list, answer: str, answer_analysis: str) -> str:
@@ -19,7 +18,7 @@ Generate a multiple-choice question with additional distractors that increase th
19
  3. **Use Answer Analysis**: Reference the **correct answer analysis** when creating distractors to ensure they challenge **subject-matter experts**.
20
  4. **Expert-Level Difficulty**: Keep the distractors **challenging and hard to distinguish** from the correct answer, requiring **advanced knowledge** to avoid the correct answer being too obvious.
21
  5. **Balanced Length**: Ensure all options have **similar lengths** to prevent any one option from standing out.
22
- 6. **Distractors Analysis**: Provide a **distractor analysis in Chinese**, explaining why the distractors are **incorrect** but **challenging and hard to distinguish**.
23
 
24
  Please output the result in valid JSON format using the structure below. Make sure there are no extra commas, missing commas, extra quotation marks or missing quotation marks:
25
  {{
@@ -47,7 +46,6 @@ Answer: {answer}
47
  Answer Analysis: {answer_analysis}
48
  """
49
 
50
- # prompt = prompt.replace("I don't know.", "Idle.")
51
  return prompt
52
 
53
 
@@ -75,84 +73,32 @@ def prepare_q_inputs(queries):
75
 
76
  messages.append(prompt_message)
77
  return messages
78
-
79
-
80
-
81
- # def extract_json_from_text(text):
82
- # text = json.dumps(text)
83
- # # 移除转义符和换行符
84
- # text = text.replace('\\n', '').replace('\\"', '"')
85
-
86
- # # 定义匹配 JSON 对象的正则表达式模式
87
- # json_pattern = re.compile(
88
- # r'\{\s*"question":\s*"([^"]*)",\s*"options":\s*\{\s*"A":\s*"([^"]*)",\s*"B":\s*"([^"]*)",\s*"C":\s*"([^"]*)",\s*"D":\s*"([^"]*)"\s*\},'
89
- # r'\s*"distractors":\s*\{\s*"E":\s*"([^"]*)",\s*"F":\s*"([^"]*)",\s*"G":\s*"([^"]*)"\s*\},\s*"correct_answer":\s*"([^"]*)"\s*\}',
90
- # re.DOTALL
91
- # )
92
-
93
- # # 匹配 JSON 结构
94
- # match = json_pattern.search(text)
95
-
96
- # if match:
97
- # # 捕获到的匹配组
98
- # question = match.group(1)
99
- # option_a = match.group(2)
100
- # option_b = match.group(3)
101
- # option_c = match.group(4)
102
- # option_d = match.group(5)
103
- # distractor_e = match.group(6)
104
- # distractor_f = match.group(7)
105
- # distractor_g = match.group(8)
106
- # correct_answer = match.group(9)
107
-
108
- # # 构建 JSON 对象
109
- # json_data = {
110
- # "question": question,
111
- # "options": {
112
- # "A": option_a,
113
- # "B": option_b,
114
- # "C": option_c,
115
- # "D": option_d
116
- # },
117
- # "distractors": {
118
- # "E": distractor_e,
119
- # "F": distractor_f,
120
- # "G": distractor_g
121
- # },
122
- # "correct_answer": correct_answer
123
- # }
124
-
125
- # return json_data
126
- # else:
127
- # print("No JSON object found in the text.")
128
- # return None
129
 
130
 
131
  def generate_distractors(model_name: str,
132
- queries: list,
133
- n: int=1,
134
- max_tokens: int=4096):
135
 
136
  assert model_name in ["gpt-4o-mini", "gpt-4-turbo", "gpt-4o", "gpt-4o-2024-08-06"], "Invalid model name"
137
 
138
- client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"),base_url="https://yanlp.zeabur.app/v1")
 
139
  messages = prepare_q_inputs(queries)
140
 
141
- responses = asyncio.run(
142
- generate_from_openai_chat_completion(
143
- client,
144
- messages=messages,
145
- engine_name=model_name,
146
- n = n,
147
- max_tokens=max_tokens,
148
- requests_per_minute=30,
149
- json_format=True
150
- )
151
  )
152
 
153
  for query, response in zip(queries, responses):
154
  new_options = response
155
- # print(new_options)
156
  if new_options and "distractors" in new_options:
157
  query["option_5"] = new_options["distractors"].get("E", "")
158
  else:
@@ -170,10 +116,4 @@ def generate_distractors(model_name: str,
170
  else:
171
  query["distractor_analysis"] = ""
172
 
173
- return queries
174
-
175
-
176
-
177
-
178
-
179
-
 
2
  import re
3
  from tqdm import tqdm
4
  import os
5
+ from openai import OpenAI # 替换 AsyncOpenAI
 
6
 
7
+ from utils.api_utils import generate_from_openai_chat_completion
8
 
9
 
10
  def construct_prompt_textonly(question: str, options: list, answer: str, answer_analysis: str) -> str:
 
18
  3. **Use Answer Analysis**: Reference the **correct answer analysis** when creating distractors to ensure they challenge **subject-matter experts**.
19
  4. **Expert-Level Difficulty**: Keep the distractors **challenging and hard to distinguish** from the correct answer, requiring **advanced knowledge** to avoid the correct answer being too obvious.
20
  5. **Balanced Length**: Ensure all options have **similar lengths** to prevent any one option from standing out.
21
+ 6. **Distractors Analysis**: Provide a **distractor analysis in Chinese**, explaining why the distractors are **incorrect** but **challenging and hard to distinguish**, based on the question, options, and answer analysis.
22
 
23
  Please output the result in valid JSON format using the structure below. Make sure there are no extra commas, missing commas, extra quotation marks or missing quotation marks:
24
  {{
 
46
  Answer Analysis: {answer_analysis}
47
  """
48
 
 
49
  return prompt
50
 
51
 
 
73
 
74
  messages.append(prompt_message)
75
  return messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
 
78
  def generate_distractors(model_name: str,
79
+ queries: list,
80
+ n: int=1,
81
+ max_tokens: int=4096):
82
 
83
  assert model_name in ["gpt-4o-mini", "gpt-4-turbo", "gpt-4o", "gpt-4o-2024-08-06"], "Invalid model name"
84
 
85
+ # 改用同步版本的 OpenAI 客户端
86
+ client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url="https://yanlp.zeabur.app/v1")
87
  messages = prepare_q_inputs(queries)
88
 
89
+ # 直接调用同步的 `generate_from_openai_chat_completion_sync`
90
+ responses = generate_from_openai_chat_completion(
91
+ client,
92
+ messages=messages,
93
+ engine_name=model_name,
94
+ n=n,
95
+ max_tokens=max_tokens,
96
+ requests_per_minute=30,
97
+ json_format=True
 
98
  )
99
 
100
  for query, response in zip(queries, responses):
101
  new_options = response
 
102
  if new_options and "distractors" in new_options:
103
  query["option_5"] = new_options["distractors"].get("E", "")
104
  else:
 
116
  else:
117
  query["distractor_analysis"] = ""
118
 
119
+ return queries