|
|
|
import argparse |
|
from flask import Flask, jsonify, request, Response |
|
import urllib.parse |
|
import requests |
|
import time |
|
import json |
|
from flask_cors import CORS |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
cors = CORS(app) |
|
|
|
slot_id = -1 |
|
|
|
parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.") |
|
parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n') |
|
parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="GPT4 User:") |
|
parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="GPT4 Assistant:") |
|
parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="") |
|
parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="<|end_of_turn|>") |
|
parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://rpoly1.ddns.net:8818)", default='http://rpoly1.ddns.net:8818') |
|
parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="") |
|
parser.add_argument("--host", type=str, help="Set the ip address to listen.(default: 127.0.0.1)", default='0.0.0.0') |
|
parser.add_argument("--port", type=int, help="Set the port to listen.(default: 8081)", default=7860) |
|
|
|
args = parser.parse_args() |
|
|
|
def is_present(json, key): |
|
try: |
|
buf = json[key] |
|
except KeyError: |
|
return False |
|
if json[key] == None: |
|
return False |
|
return True |
|
|
|
|
|
def convert_chat(messages): |
|
|
|
prompt ="" |
|
system_n = args.system_name.replace("\\n", "\n") |
|
user_n = args.user_name.replace("\\n", "\n") |
|
ai_n = args.ai_name.replace("\\n", "\n") |
|
stop = args.stop.replace("\\n", "\n") |
|
|
|
|
|
for line in messages[:-1]: |
|
if (line["role"] == "system"): |
|
prompt += f"{system_n}{line['text']}{stop}" |
|
if (line["role"] == "user"): |
|
prompt += f"{user_n}{line['text']}{stop}" |
|
if (line["role"] == "ai"): |
|
prompt += f"{ai_n}{line['text']}{stop}" |
|
if (messages[-1]["role"] == "user"): |
|
prompt += f"{user_n}{messages[-1]['text']}{stop}" |
|
prompt += f"{ai_n}" |
|
elif (messages[-1]["role"] == "ai"): |
|
|
|
prompt += f"{ai_n}{messages[-1]['text']}" |
|
|
|
|
|
|
|
return prompt |
|
def convert_chat1(messages): |
|
prompt = "" + args.chat_prompt.replace("\\n", "\n") |
|
|
|
system_n = args.system_name.replace("\\n", "\n") |
|
user_n = args.user_name.replace("\\n", "\n") |
|
ai_n = args.ai_name.replace("\\n", "\n") |
|
stop = args.stop.replace("\\n", "\n") |
|
|
|
|
|
for line in messages: |
|
if (line["role"] == "system"): |
|
prompt += f"{system_n}{line['text']}" |
|
if (line["role"] == "user"): |
|
prompt += f"{user_n}{line['text']}" |
|
if (line["role"] == "ai"): |
|
prompt += f"{ai_n}{line['text']}{stop}" |
|
prompt += ai_n.rstrip() |
|
|
|
return prompt |
|
def make_postData(body, chat=False, stream=False): |
|
postData = {} |
|
if (chat): |
|
postData["prompt"] = convert_chat(body["messages"]) |
|
else: |
|
postData["prompt"] = body["prompt"] |
|
if(is_present(body, "temperature")): postData["temperature"] = body["temperature"] |
|
if(is_present(body, "top_k")): postData["top_k"] = body["top_k"] |
|
if(is_present(body, "top_p")): postData["top_p"] = body["top_p"] |
|
if(is_present(body, "max_tokens")): postData["n_predict"] = body["max_tokens"] |
|
if(is_present(body, "presence_penalty")): postData["presence_penalty"] = body["presence_penalty"] |
|
if(is_present(body, "frequency_penalty")): postData["frequency_penalty"] = body["frequency_penalty"] |
|
if(is_present(body, "repeat_penalty")): postData["repeat_penalty"] = body["repeat_penalty"] |
|
if(is_present(body, "mirostat")): postData["mirostat"] = body["mirostat"] |
|
if(is_present(body, "mirostat_tau")): postData["mirostat_tau"] = body["mirostat_tau"] |
|
if(is_present(body, "mirostat_eta")): postData["mirostat_eta"] = body["mirostat_eta"] |
|
if(is_present(body, "seed")): postData["seed"] = body["seed"] |
|
if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()] |
|
if (args.stop != ""): |
|
postData["stop"] = [args.stop] |
|
else: |
|
postData["stop"] = [] |
|
if(is_present(body, "stop")): postData["stop"] += body["stop"] |
|
postData["n_keep"] = -1 |
|
postData["stream"] = stream |
|
postData["cache_prompt"] = True |
|
postData["slot_id"] = slot_id |
|
return postData |
|
|
|
def make_resData(data, chat=False, promptToken=[]): |
|
resData = { |
|
"id": "chatcmpl" if (chat) else "cmpl", |
|
"object": "chat.completion" if (chat) else "text_completion", |
|
"created": int(time.time()), |
|
"truncated": data["truncated"], |
|
"model": "LLaMA_CPP", |
|
"usage": { |
|
"prompt_tokens": data["tokens_evaluated"], |
|
"completion_tokens": data["tokens_predicted"], |
|
"total_tokens": data["tokens_evaluated"] + data["tokens_predicted"] |
|
} |
|
} |
|
if (len(promptToken) != 0): |
|
resData["promptToken"] = promptToken |
|
if (chat): |
|
|
|
resData["choices"] = [{ |
|
"index": 0, |
|
"message": { |
|
"role": "ai", |
|
"content": data["text"], |
|
}, |
|
"finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" |
|
}] |
|
else: |
|
|
|
resData["choices"] = [{ |
|
"text": data["text"], |
|
"index": 0, |
|
"logprobs": None, |
|
"finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" |
|
}] |
|
return resData |
|
|
|
def make_resData_stream(data, chat=False, time_now = 0, start=False): |
|
resData = { |
|
"id": "chatcmpl" if (chat) else "cmpl", |
|
"object": "chat.completion.chunk" if (chat) else "text_completion.chunk", |
|
"created": time_now, |
|
"model": "LLaMA_CPP", |
|
"choices": [ |
|
{ |
|
"finish_reason": None, |
|
"index": 0 |
|
} |
|
] |
|
} |
|
try: |
|
slot_id = data["slot_id"] |
|
except: |
|
print(data) |
|
if (chat): |
|
if (start): |
|
resData["choices"][0]["delta"] = { |
|
"role": "ai" |
|
} |
|
else: |
|
resData["choices"][0]["delta"] = { |
|
"content": data["content"] |
|
} |
|
if (data["stop"]): |
|
resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" |
|
else: |
|
resData["choices"][0]["text"] = data["content"] |
|
if (data["stop"]): |
|
resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" |
|
|
|
return resData |
|
|
|
|
|
@app.route('/chat/completions', methods=['POST']) |
|
@app.route('/v1/chat/completions', methods=['POST']) |
|
|
|
def chat_completions(): |
|
|
|
|
|
body = request.get_json() |
|
stream = False |
|
tokenize = False |
|
if(is_present(body, "stream")): stream = body["stream"] |
|
if(is_present(body, "tokenize")): tokenize = body["tokenize"] |
|
postData = make_postData(body, chat=True, stream=stream) |
|
|
|
promptToken = [] |
|
if (tokenize): |
|
tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json() |
|
promptToken = tokenData["tokens"] |
|
|
|
if (not stream): |
|
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData)) |
|
print(data.json()) |
|
resData = make_resData(data.json(), chat=True, promptToken=promptToken) |
|
return jsonify(resData) |
|
else: |
|
def generate(): |
|
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True) |
|
time_now = int(time.time()) |
|
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True) |
|
|
|
yield "data: {}\n\n".format(json.dumps({"text": ""})) |
|
for line in data.iter_lines(): |
|
if line: |
|
decoded_line = line.decode('utf-8') |
|
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now) |
|
if not resData["choices"][0]["finish_reason"] == "stop": |
|
content=resData["choices"][0]["delta"]['content'] |
|
try: |
|
yield "data: {}\n\n".format(json.dumps({"text": content})) |
|
except: |
|
print("error") |
|
return Response(generate(), mimetype='text/event-stream') |
|
|
|
|
|
@app.route('/completions', methods=['POST']) |
|
@app.route('/v1/completions', methods=['POST']) |
|
def completion(): |
|
|
|
|
|
body = request.get_json() |
|
stream = False |
|
tokenize = False |
|
if(is_present(body, "stream")): stream = body["stream"] |
|
if(is_present(body, "tokenize")): tokenize = body["tokenize"] |
|
postData = make_postData(body, chat=False, stream=stream) |
|
|
|
promptToken = [] |
|
if (tokenize): |
|
tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json() |
|
promptToken = tokenData["tokens"] |
|
|
|
if (not stream): |
|
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData)) |
|
print(data.json()) |
|
resData = make_resData(data.json(), chat=False, promptToken=promptToken) |
|
return jsonify(resData) |
|
else: |
|
def generate(): |
|
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True) |
|
time_now = int(time.time()) |
|
for line in data.iter_lines(): |
|
if line: |
|
decoded_line = line.decode('utf-8') |
|
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now) |
|
yield 'data: {}\n'.format(json.dumps(resData)) |
|
return Response(generate(), mimetype='text/event-stream') |
|
|
|
if __name__ == '__main__': |
|
cors.run(args.host, port=args.port) |