|
import gevent.pywsgi |
|
from gevent import monkey;monkey.patch_all() |
|
from flask import Flask, request, Response, jsonify |
|
import argparse |
|
import requests |
|
import random |
|
import string |
|
import time |
|
import json |
|
import os |
|
|
|
app = Flask(__name__) |
|
app.json.sort_keys = False |
|
|
|
parser = argparse.ArgumentParser(description="An example of Hunyuan demo with a similar API to OAI.") |
|
parser.add_argument("--host", type=str, help="Set the ip address.(default: 0.0.0.0)", default='0.0.0.0') |
|
parser.add_argument("--port", type=int, help="Set the port.(default: 7860)", default=7860) |
|
args = parser.parse_args() |
|
|
|
base_url = os.getenv('MODEL_BASE_URL') |
|
print(base_url) |
|
|
|
@app.route('/api/v1/models', methods=["GET", "POST"]) |
|
@app.route('/v1/models', methods=["GET", "POST"]) |
|
def model_list(): |
|
time_now = int(time.time()) |
|
model_list = { |
|
"object": "list", |
|
"data": [ |
|
{ |
|
"id": "hunyuan-large", |
|
"object": "model", |
|
"created": time_now, |
|
"owned_by": "tastypear" |
|
}, |
|
{ |
|
"id": "gpt-3.5-turbo", |
|
"object": "model", |
|
"created": time_now, |
|
"owned_by": "tastypear" |
|
} |
|
] |
|
} |
|
return jsonify(model_list) |
|
|
|
@app.route("/", methods=["GET"]) |
|
def index(): |
|
return Response(f'Hunyuan OpenAI Compatible API<br><br>'+ |
|
f'Set "{os.getenv("SPACE_URL")}/api" as proxy (or API Domain) in your Chatbot.<br><br>'+ |
|
f'The complete API is: {os.getenv("SPACE_URL")}/api/v1/chat/completions<br><br>'+ |
|
f"Don't set the Syetem Prompt. It will be ignored.") |
|
|
|
@app.route("/api/v1/chat/completions", methods=["POST", "OPTIONS"]) |
|
@app.route("/v1/chat/completions", methods=["POST", "OPTIONS"]) |
|
def chat_completions(): |
|
|
|
if request.method == "OPTIONS": |
|
return Response( |
|
headers={ |
|
"Access-Control-Allow-Origin": "*", |
|
"Access-Control-Allow-Headers": "*", |
|
} |
|
) |
|
|
|
data = request.get_json() |
|
|
|
|
|
system = "You are a helpful assistant." |
|
chat_history = [] |
|
prompt = "" |
|
|
|
if "messages" in data: |
|
messages = data["messages"] |
|
message_size = len(messages) |
|
|
|
prompt = messages[-1].get("content") |
|
|
|
for i in range(message_size - 1): |
|
role_this = messages[i].get("role") |
|
role_next = messages[i + 1].get("role") |
|
if role_this == "system": |
|
system = messages[i].get("content") |
|
elif role_this == "user": |
|
if role_next == "assistant": |
|
chat_history.append( |
|
[messages[i].get("content"), messages[i + 1].get("content")] |
|
) |
|
else: |
|
chat_history.append([messages[i].get("content"), " "]) |
|
|
|
|
|
|
|
|
|
|
|
fn_index = 3 |
|
|
|
|
|
chars = string.ascii_lowercase + string.digits |
|
session_hash = "".join(random.choice(chars) for _ in range(10)) |
|
|
|
single_prompt_data = { |
|
'data': [ |
|
prompt, |
|
[], |
|
], |
|
'event_data': None, |
|
'fn_index': 1, |
|
'trigger_id': 5, |
|
'session_hash': session_hash, |
|
} |
|
response = requests.post(f'{base_url}/gradio_api/run/predict', json=single_prompt_data) |
|
|
|
context_data = { |
|
'data': [ |
|
None, |
|
chat_history+[[prompt,None]] |
|
], |
|
'event_data': None, |
|
'fn_index': fn_index, |
|
'trigger_id': 5, |
|
'session_hash': session_hash, |
|
} |
|
response = requests.post(f"{base_url}/gradio_api/queue/join", json=context_data) |
|
|
|
def generate(): |
|
|
|
url = f"{base_url}/gradio_api/queue/data?session_hash={session_hash}" |
|
data = requests.get(url, stream=True) |
|
|
|
|
|
time_now = int(time.time()) |
|
|
|
for line in data.iter_lines(): |
|
if line: |
|
decoded_line = line.decode("utf-8") |
|
json_line = json.loads(decoded_line[6:]) |
|
if json_line["msg"] == "process_starts": |
|
res_data = gen_res_data({}, time_now=time_now, start=True) |
|
yield f"data: {json.dumps(res_data)}\n\n" |
|
elif json_line["msg"] == "process_generating": |
|
res_data = gen_res_data(json_line, time_now=time_now) |
|
yield f"data: {json.dumps(res_data)}\n\n" |
|
elif json_line["msg"] == "process_completed": |
|
yield "data: [DONE]" |
|
|
|
return Response( |
|
generate(), |
|
mimetype="text/event-stream", |
|
headers={ |
|
"Access-Control-Allow-Origin": "*", |
|
"Access-Control-Allow-Headers": "*", |
|
}, |
|
) |
|
|
|
|
|
def gen_res_data(data, time_now=0, start=False): |
|
res_data = { |
|
"id": "chatcmpl", |
|
"object": "chat.completion.chunk", |
|
"created": time_now, |
|
"model": "hunyuan-large", |
|
"choices": [{"index": 0, "finish_reason": None}], |
|
} |
|
|
|
if start: |
|
res_data["choices"][0]["delta"] = {"role": "assistant", "content": ""} |
|
else: |
|
chat_pair = data["output"]["data"][0] |
|
if chat_pair == []: |
|
res_data["choices"][0]["finish_reason"] = "stop" |
|
else: |
|
res_data["choices"][0]["delta"] = {"content": chat_pair[-1][-1]} |
|
return res_data |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
gevent.pywsgi.WSGIServer((args.host, args.port), app).serve_forever() |