Spaces:
Build error
Build error
File size: 6,548 Bytes
90126b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor
import json
import copy
from tqdm import tqdm
import queue
import time
base_id_prompt = "# Role: 问答机器人\n\n## Profile\n- author: 尖米\n- version: 1.0\n- language: 中文\n- description: 你是机智流的问答机器人,你可以对用户输入的图像、文字进行解析,并根据已有的知识库进行精确回答。\n\n## Skills\n1. 图像识别与解析:能够识别用户上传的图像,并提取其中的关键信息。\n2. 自然语言处理:能够理解并解析用户输入的文字信息,准确把握用户意图。\n3. 知识库应用:根据解析结果,查询知识库,提供准确、相关的答案。\n4. 多轮对话:支持与用户进行多轮对话,提供连续性、上下文相关的回答。\n\n## Rules\n1. 必须充分理解用户输入的图像和文字内容。\n2. 回答需要简洁明了,避免过于复杂或含糊的表述。\n3. 在回答过程中,优先查询和引用公司已有的知识库。\n4. 对于无法回答的问题,需要引导用户提供更多信息或寻求人工客服帮助。\n\n## Workflows\n1. 接收并分析用户输入的图像或文字信息。\n2. 基于图像识别或自然语言处理技术,提取关键信息。\n3. 查询知识库,匹配相关信息。\n4. 向用户提供精准、相关的回答。\n5. 如有必要,进行多轮对话,确保问题得到有效解决。\n\n## Init\n欢迎使用机智流的问答机器人,请输入您的问题,我将尽力为您提供帮助。\n",
# 定义客户端
clients = {
"internlm": OpenAI(
api_key="your_internlm_api_key",
base_url="https://internlm-chat.intern-ai.org.cn/puyu/api/v1/",
),
"glm": OpenAI(
api_key="your_glm_api_key",
base_url="your_glm_url",
),
"deepseek": OpenAI(
api_key="your_deepseek_api_key",
base_url="your_deepseek_url",
)
}
class BaseDataAPI:
def __init__(self, questions_path, save_path, repeat=0, client_name="internlm"):
self.client = clients[client_name]
self.questions_path = questions_path
self.save_path = save_path
self.repeat = repeat
self.data_template = {
"conversation": [
{
"system": base_id_prompt
"input": "xxx",
"output": "xxx"
}
]
}
def get_answer(self, question):
chat_rsp = self.client.chat.completions.create(
model="internlm2.5-latest", # 或 "internlm2-latest" 或 "glm-4"
messages=[
{"role": "system", "content": base_id_prompt},
{"role": "user", "content": question}
],
stream=False,
)
return self.build_data(question, chat_rsp)
def build_data(self, question, chat_rsp):
temp = copy.deepcopy(self.data_template)
temp['conversation'][0]['input'] = question
temp['conversation'][0]['output'] = chat_rsp.choices[0].message.content
return temp
def save(self, train_data):
with open(self.save_path, 'a', encoding='utf-8') as f:
for item in train_data:
json.dump(item, f, ensure_ascii=False)
f.write("\n")
@staticmethod
def load_txt(path):
with open(path, 'r', encoding='utf-8') as f:
return f.read()
def read_questions(self):
prompt = self.load_txt(self.questions_path)
promptlist = prompt.split('\n')
if self.repeat != 0:
promptlist = promptlist * self.repeat
print(f"Total questions: {len(promptlist)}")
return promptlist
class GetDataApi(BaseDataAPI):
def run(self):
answer_queue = queue.Queue()
promptlist = self.read_questions()
with ThreadPoolExecutor(max_workers=10) as pool:
print("Asking...")
futures = [pool.submit(self.get_answer, question) for question in promptlist]
for future in tqdm(futures):
result = future.result()
answer_queue.put(result)
if answer_queue.qsize() >= 10: # 每10个问题保存一次
self.save([answer_queue.get() for _ in range(10)])
# 保存剩余的回答
remaining = []
while not answer_queue.empty():
remaining.append(answer_queue.get())
if remaining:
self.save(remaining)
class ChatData(BaseDataAPI):
def __init__(self, train_data, save_path, client_name="internlm"):
super().__init__(train_data, save_path, client_name=client_name)
self.train_data = train_data
def load_data(self):
with open(self.train_data, 'r', encoding='utf-8') as f:
return f.readlines()
def ask_for_tts(self, question, save_ask):
chat_rsp = self.client.chat.completions.create(
model="internlm2.5-latest", # 或 "glm-4"
messages=[
{"role": "system", "content": base_id_prompt},
{"role": "user", "content": question}
],
stream=False,
)
return self.build_data(save_ask, chat_rsp)
def __call__(self):
train_data = self.load_data()
answer_queue = queue.Queue()
with ThreadPoolExecutor(max_workers=10) as pool:
print("Asking...")
futures = []
for item in train_data:
item = json.loads(item)
question = item['conversation'][0]['output']
save_ask = item['conversation'][0]['input']
futures.append(pool.submit(self.ask_for_tts, question, save_ask))
for future in tqdm(futures):
result = future.result()
answer_queue.put(result)
if answer_queue.qsize() >= 10: # 每10个问题保存一次
self.save([answer_queue.get() for _ in range(10)])
# 保存剩余的回答
remaining = []
while not answer_queue.empty():
remaining.append(answer_queue.get())
if remaining:
self.save(remaining)
if __name__ == '__main__':
questions_path = './tools/L1_XTuner_code/Q_list.txt'
save_path = './data/train_basic.jsonl'
start_time = time.time()
chat_data = GetDataApi(questions_path, save_path)
chat_data()
end_time = time.time()
print('Done')
print(f'Time used: {end_time - start_time:.2f} seconds')
|