upload
Browse files- .gitignore +1 -0
- Dockerfile +16 -0
- README.md +6 -1
- docker_as_a_service/docker_as_a_service.py +145 -0
- docker_as_a_service/shared_utils/advanced_markdown_format.py +478 -0
- docker_as_a_service/shared_utils/char_visual_effect.py +25 -0
- docker_as_a_service/shared_utils/colorful.py +88 -0
- docker_as_a_service/shared_utils/config_loader.py +131 -0
- docker_as_a_service/shared_utils/connect_void_terminal.py +91 -0
- docker_as_a_service/shared_utils/cookie_manager.py +127 -0
- docker_as_a_service/shared_utils/docker_as_service_api.py +70 -0
- docker_as_a_service/shared_utils/fastapi_server.py +322 -0
- docker_as_a_service/shared_utils/handle_upload.py +156 -0
- docker_as_a_service/shared_utils/key_pattern_manager.py +121 -0
- docker_as_a_service/shared_utils/logging.py +69 -0
- docker_as_a_service/shared_utils/map_names.py +34 -0
- docker_as_a_service/shared_utils/text_mask.py +109 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.pyc
|
Dockerfile
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM fuqingxu/bbdown
|
2 |
+
|
3 |
+
RUN apt update && apt-get install -y python3 python3-dev python3-pip
|
4 |
+
|
5 |
+
RUN python3 -m pip install fastapi pydantic loguru --break-system-packages
|
6 |
+
RUN python3 -m pip install requests python-multipart --break-system-packages
|
7 |
+
RUN python3 -m pip install uvicorn --break-system-packages
|
8 |
+
|
9 |
+
COPY ./docker_as_a_service /docker_as_a_service
|
10 |
+
|
11 |
+
WORKDIR /docker_as_a_service
|
12 |
+
|
13 |
+
# ENTRYPOINT [ "python3", "docker_as_a_service.py" ]
|
14 |
+
ENTRYPOINT ["/bin/bash", "-c"]
|
15 |
+
CMD ["python3 docker_as_a_service.py"]
|
16 |
+
# CMD ["python3", "docker_as_a_service.py"]
|
README.md
CHANGED
@@ -1,10 +1,15 @@
|
|
1 |
---
|
2 |
title: Bbdown
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
colorTo: indigo
|
6 |
sdk: docker
|
7 |
pinned: false
|
|
|
8 |
---
|
9 |
|
10 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: Bbdown
|
3 |
+
emoji: 🐳
|
4 |
colorFrom: blue
|
5 |
colorTo: indigo
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
+
app_port: 7860
|
9 |
---
|
10 |
|
11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
12 |
+
|
13 |
+
|
14 |
+
1. create space
|
15 |
+
2.
|
docker_as_a_service/docker_as_a_service.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
DaaS (Docker as a Service) is a service
|
3 |
+
that allows users to run docker commands on the server side.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from fastapi import FastAPI
|
7 |
+
from fastapi.responses import StreamingResponse
|
8 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException
|
9 |
+
from pydantic import BaseModel, Field
|
10 |
+
from typing import Optional, Dict
|
11 |
+
import time
|
12 |
+
import os
|
13 |
+
import asyncio
|
14 |
+
import subprocess
|
15 |
+
import uuid
|
16 |
+
import threading
|
17 |
+
import queue
|
18 |
+
from shared_utils.docker_as_service_api import DockerServiceApiComModel
|
19 |
+
|
20 |
+
app = FastAPI()
|
21 |
+
|
22 |
+
def python_obj_to_pickle_file_bytes(obj):
|
23 |
+
import pickle
|
24 |
+
import io
|
25 |
+
with io.BytesIO() as f:
|
26 |
+
pickle.dump(obj, f)
|
27 |
+
return f.getvalue()
|
28 |
+
|
29 |
+
def yield_message(message):
|
30 |
+
dsacm = DockerServiceApiComModel(server_message=message)
|
31 |
+
return python_obj_to_pickle_file_bytes(dsacm)
|
32 |
+
|
33 |
+
def read_output(stream, output_queue):
|
34 |
+
while True:
|
35 |
+
line_stdout = stream.readline()
|
36 |
+
# print('recv')
|
37 |
+
if line_stdout:
|
38 |
+
output_queue.put(line_stdout)
|
39 |
+
else:
|
40 |
+
break
|
41 |
+
|
42 |
+
|
43 |
+
async def stream_generator(request_obj):
|
44 |
+
import tempfile
|
45 |
+
# Create a temporary directory
|
46 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
47 |
+
|
48 |
+
# Construct the docker command
|
49 |
+
download_folder = temp_dir
|
50 |
+
|
51 |
+
# Get list of existing files before download
|
52 |
+
existing_file_before_download = []
|
53 |
+
|
54 |
+
video_id = request_obj.client_command
|
55 |
+
cmd = [
|
56 |
+
'/root/.dotnet/tools/BBDown',
|
57 |
+
video_id,
|
58 |
+
'--use-app-api',
|
59 |
+
'--work-dir',
|
60 |
+
f'{os.path.abspath(temp_dir)}'
|
61 |
+
]
|
62 |
+
cmd = ' '.join(cmd)
|
63 |
+
yield yield_message(cmd)
|
64 |
+
process = subprocess.Popen(cmd,
|
65 |
+
stdout=subprocess.PIPE,
|
66 |
+
stderr=subprocess.PIPE,
|
67 |
+
shell=True,
|
68 |
+
text=True)
|
69 |
+
|
70 |
+
stdout_queue = queue.Queue()
|
71 |
+
thread = threading.Thread(target=read_output, args=(process.stdout, stdout_queue))
|
72 |
+
thread.daemon = True
|
73 |
+
thread.start()
|
74 |
+
stderr_queue = queue.Queue()
|
75 |
+
thread = threading.Thread(target=read_output, args=(process.stderr, stderr_queue))
|
76 |
+
thread.daemon = True
|
77 |
+
thread.start()
|
78 |
+
|
79 |
+
while True:
|
80 |
+
print("looping")
|
81 |
+
# Check if there is any output in the queue
|
82 |
+
try:
|
83 |
+
output_stdout = stdout_queue.get_nowait() # Non-blocking get
|
84 |
+
if output_stdout:
|
85 |
+
print(output_stdout)
|
86 |
+
yield yield_message(output_stdout)
|
87 |
+
|
88 |
+
output_stderr = stderr_queue.get_nowait() # Non-blocking get
|
89 |
+
if output_stderr:
|
90 |
+
print(output_stdout)
|
91 |
+
yield yield_message(output_stderr)
|
92 |
+
except queue.Empty:
|
93 |
+
pass # No output available
|
94 |
+
|
95 |
+
# Break the loop if the process has finished
|
96 |
+
if process.poll() is not None:
|
97 |
+
break
|
98 |
+
|
99 |
+
await asyncio.sleep(0.25)
|
100 |
+
|
101 |
+
# Get the return code
|
102 |
+
return_code = process.returncode
|
103 |
+
yield yield_message("(return code:) " + str(return_code))
|
104 |
+
|
105 |
+
# print(f"Successfully downloaded video {video_id}")
|
106 |
+
existing_file_after_download = list(os.listdir(download_folder))
|
107 |
+
# get the difference
|
108 |
+
downloaded_files = [
|
109 |
+
f for f in existing_file_after_download if f not in existing_file_before_download
|
110 |
+
]
|
111 |
+
downloaded_files_path = [
|
112 |
+
os.path.join(download_folder, f) for f in existing_file_after_download if f not in existing_file_before_download
|
113 |
+
]
|
114 |
+
# read file
|
115 |
+
server_file_attach = {}
|
116 |
+
for fp, fn in zip(downloaded_files_path, downloaded_files):
|
117 |
+
with open(fp, "rb") as f:
|
118 |
+
file_bytes = f.read()
|
119 |
+
server_file_attach[fn] = file_bytes
|
120 |
+
|
121 |
+
dsacm = DockerServiceApiComModel(
|
122 |
+
server_message="complete",
|
123 |
+
server_file_attach=server_file_attach,
|
124 |
+
)
|
125 |
+
yield python_obj_to_pickle_file_bytes(dsacm)
|
126 |
+
|
127 |
+
|
128 |
+
@app.post("/stream")
|
129 |
+
async def stream_response(file: UploadFile = File(...)):
|
130 |
+
# read the file in memory, treat it as pickle file, and unpickle it
|
131 |
+
import pickle
|
132 |
+
import io
|
133 |
+
content = await file.read()
|
134 |
+
with io.BytesIO(content) as f:
|
135 |
+
request_obj = pickle.load(f)
|
136 |
+
# process the request_obj
|
137 |
+
return StreamingResponse(stream_generator(request_obj), media_type="application/octet-stream")
|
138 |
+
|
139 |
+
@app.get("/")
|
140 |
+
async def hi():
|
141 |
+
return "Hello, this is Docker as a Service (DaaS)!"
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
import uvicorn
|
145 |
+
uvicorn.run(app, host="0.0.0.0", port=49000)
|
docker_as_a_service/shared_utils/advanced_markdown_format.py
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import markdown
|
2 |
+
import re
|
3 |
+
import os
|
4 |
+
import math
|
5 |
+
|
6 |
+
from loguru import logger
|
7 |
+
from textwrap import dedent
|
8 |
+
from functools import lru_cache
|
9 |
+
from pymdownx.superfences import fence_code_format
|
10 |
+
from latex2mathml.converter import convert as tex2mathml
|
11 |
+
from shared_utils.config_loader import get_conf as get_conf
|
12 |
+
from shared_utils.text_mask import apply_gpt_academic_string_mask
|
13 |
+
|
14 |
+
markdown_extension_configs = {
|
15 |
+
"mdx_math": {
|
16 |
+
"enable_dollar_delimiter": True,
|
17 |
+
"use_gitlab_delimiters": False,
|
18 |
+
},
|
19 |
+
}
|
20 |
+
|
21 |
+
code_highlight_configs = {
|
22 |
+
"pymdownx.superfences": {
|
23 |
+
"css_class": "codehilite",
|
24 |
+
"custom_fences": [
|
25 |
+
{"name": "mermaid", "class": "mermaid", "format": fence_code_format}
|
26 |
+
],
|
27 |
+
},
|
28 |
+
"pymdownx.highlight": {
|
29 |
+
"css_class": "codehilite",
|
30 |
+
"guess_lang": True,
|
31 |
+
# 'auto_title': True,
|
32 |
+
# 'linenums': True
|
33 |
+
},
|
34 |
+
}
|
35 |
+
|
36 |
+
code_highlight_configs_block_mermaid = {
|
37 |
+
"pymdownx.superfences": {
|
38 |
+
"css_class": "codehilite",
|
39 |
+
# "custom_fences": [
|
40 |
+
# {"name": "mermaid", "class": "mermaid", "format": fence_code_format}
|
41 |
+
# ],
|
42 |
+
},
|
43 |
+
"pymdownx.highlight": {
|
44 |
+
"css_class": "codehilite",
|
45 |
+
"guess_lang": True,
|
46 |
+
# 'auto_title': True,
|
47 |
+
# 'linenums': True
|
48 |
+
},
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
mathpatterns = {
|
53 |
+
r"(?<!\\|\$)(\$)([^\$]+)(\$)": {"allow_multi_lines": False}, # $...$
|
54 |
+
r"(?<!\\)(\$\$)([^\$]+)(\$\$)": {"allow_multi_lines": True}, # $$...$$
|
55 |
+
r"(?<!\\)(\\\[)(.+?)(\\\])": {"allow_multi_lines": False}, # \[...\]
|
56 |
+
r'(?<!\\)(\\\()(.+?)(\\\))': {'allow_multi_lines': False}, # \(...\)
|
57 |
+
# r'(?<!\\)(\\begin{([a-z]+?\*?)})(.+?)(\\end{\2})': {'allow_multi_lines': True}, # \begin...\end
|
58 |
+
# r'(?<!\\)(\$`)([^`]+)(`\$)': {'allow_multi_lines': False}, # $`...`$
|
59 |
+
}
|
60 |
+
|
61 |
+
def tex2mathml_catch_exception(content, *args, **kwargs):
|
62 |
+
try:
|
63 |
+
content = tex2mathml(content, *args, **kwargs)
|
64 |
+
except:
|
65 |
+
content = content
|
66 |
+
return content
|
67 |
+
|
68 |
+
|
69 |
+
def replace_math_no_render(match):
|
70 |
+
content = match.group(1)
|
71 |
+
if "mode=display" in match.group(0):
|
72 |
+
content = content.replace("\n", "</br>")
|
73 |
+
return f'<font color="#00FF00">$$</font><font color="#FF00FF">{content}</font><font color="#00FF00">$$</font>'
|
74 |
+
else:
|
75 |
+
return f'<font color="#00FF00">$</font><font color="#FF00FF">{content}</font><font color="#00FF00">$</font>'
|
76 |
+
|
77 |
+
|
78 |
+
def replace_math_render(match):
|
79 |
+
content = match.group(1)
|
80 |
+
if "mode=display" in match.group(0):
|
81 |
+
if "\\begin{aligned}" in content:
|
82 |
+
content = content.replace("\\begin{aligned}", "\\begin{array}")
|
83 |
+
content = content.replace("\\end{aligned}", "\\end{array}")
|
84 |
+
content = content.replace("&", " ")
|
85 |
+
content = tex2mathml_catch_exception(content, display="block")
|
86 |
+
return content
|
87 |
+
else:
|
88 |
+
return tex2mathml_catch_exception(content)
|
89 |
+
|
90 |
+
|
91 |
+
def markdown_bug_hunt(content):
|
92 |
+
"""
|
93 |
+
解决一个mdx_math的bug(单$包裹begin命令时多余<script>)
|
94 |
+
"""
|
95 |
+
content = content.replace(
|
96 |
+
'<script type="math/tex">\n<script type="math/tex; mode=display">',
|
97 |
+
'<script type="math/tex; mode=display">',
|
98 |
+
)
|
99 |
+
content = content.replace("</script>\n</script>", "</script>")
|
100 |
+
return content
|
101 |
+
|
102 |
+
|
103 |
+
def is_equation(txt):
|
104 |
+
"""
|
105 |
+
判定是否为公式 | 测试1 写出洛伦兹定律,使用tex格式公式 测试2 给出柯西不等式,使用latex格式 测试3 写出麦克斯韦方程组
|
106 |
+
"""
|
107 |
+
if "```" in txt and "```reference" not in txt:
|
108 |
+
return False
|
109 |
+
if "$" not in txt and "\\[" not in txt:
|
110 |
+
return False
|
111 |
+
|
112 |
+
matches = []
|
113 |
+
for pattern, property in mathpatterns.items():
|
114 |
+
flags = re.ASCII | re.DOTALL if property["allow_multi_lines"] else re.ASCII
|
115 |
+
matches.extend(re.findall(pattern, txt, flags))
|
116 |
+
if len(matches) == 0:
|
117 |
+
return False
|
118 |
+
contain_any_eq = False
|
119 |
+
illegal_pattern = re.compile(r"[^\x00-\x7F]|echo")
|
120 |
+
for match in matches:
|
121 |
+
if len(match) != 3:
|
122 |
+
return False
|
123 |
+
eq_canidate = match[1]
|
124 |
+
if illegal_pattern.search(eq_canidate):
|
125 |
+
return False
|
126 |
+
else:
|
127 |
+
contain_any_eq = True
|
128 |
+
return contain_any_eq
|
129 |
+
|
130 |
+
|
131 |
+
def fix_markdown_indent(txt):
|
132 |
+
# fix markdown indent
|
133 |
+
if (" - " not in txt) or (". " not in txt):
|
134 |
+
# do not need to fix, fast escape
|
135 |
+
return txt
|
136 |
+
# walk through the lines and fix non-standard indentation
|
137 |
+
lines = txt.split("\n")
|
138 |
+
pattern = re.compile(r"^\s+-")
|
139 |
+
activated = False
|
140 |
+
for i, line in enumerate(lines):
|
141 |
+
if line.startswith("- ") or line.startswith("1. "):
|
142 |
+
activated = True
|
143 |
+
if activated and pattern.match(line):
|
144 |
+
stripped_string = line.lstrip()
|
145 |
+
num_spaces = len(line) - len(stripped_string)
|
146 |
+
if (num_spaces % 4) == 3:
|
147 |
+
num_spaces_should_be = math.ceil(num_spaces / 4) * 4
|
148 |
+
lines[i] = " " * num_spaces_should_be + stripped_string
|
149 |
+
return "\n".join(lines)
|
150 |
+
|
151 |
+
|
152 |
+
FENCED_BLOCK_RE = re.compile(
|
153 |
+
dedent(
|
154 |
+
r"""
|
155 |
+
(?P<fence>^[ \t]*(?:~{3,}|`{3,}))[ ]* # opening fence
|
156 |
+
((\{(?P<attrs>[^\}\n]*)\})| # (optional {attrs} or
|
157 |
+
(\.?(?P<lang>[\w#.+-]*)[ ]*)? # optional (.)lang
|
158 |
+
(hl_lines=(?P<quot>"|')(?P<hl_lines>.*?)(?P=quot)[ ]*)?) # optional hl_lines)
|
159 |
+
\n # newline (end of opening fence)
|
160 |
+
(?P<code>.*?)(?<=\n) # the code block
|
161 |
+
(?P=fence)[ ]*$ # closing fence
|
162 |
+
"""
|
163 |
+
),
|
164 |
+
re.MULTILINE | re.DOTALL | re.VERBOSE,
|
165 |
+
)
|
166 |
+
|
167 |
+
|
168 |
+
def get_line_range(re_match_obj, txt):
|
169 |
+
start_pos, end_pos = re_match_obj.regs[0]
|
170 |
+
num_newlines_before = txt[: start_pos + 1].count("\n")
|
171 |
+
line_start = num_newlines_before
|
172 |
+
line_end = num_newlines_before + txt[start_pos:end_pos].count("\n") + 1
|
173 |
+
return line_start, line_end
|
174 |
+
|
175 |
+
|
176 |
+
def fix_code_segment_indent(txt):
|
177 |
+
lines = []
|
178 |
+
change_any = False
|
179 |
+
txt_tmp = txt
|
180 |
+
while True:
|
181 |
+
re_match_obj = FENCED_BLOCK_RE.search(txt_tmp)
|
182 |
+
if not re_match_obj:
|
183 |
+
break
|
184 |
+
if len(lines) == 0:
|
185 |
+
lines = txt.split("\n")
|
186 |
+
|
187 |
+
# 清空 txt_tmp 对应的位置方便下次搜索
|
188 |
+
start_pos, end_pos = re_match_obj.regs[0]
|
189 |
+
txt_tmp = txt_tmp[:start_pos] + " " * (end_pos - start_pos) + txt_tmp[end_pos:]
|
190 |
+
line_start, line_end = get_line_range(re_match_obj, txt)
|
191 |
+
|
192 |
+
# 获取公共缩进
|
193 |
+
shared_indent_cnt = 1e5
|
194 |
+
for i in range(line_start, line_end):
|
195 |
+
stripped_string = lines[i].lstrip()
|
196 |
+
num_spaces = len(lines[i]) - len(stripped_string)
|
197 |
+
if num_spaces < shared_indent_cnt:
|
198 |
+
shared_indent_cnt = num_spaces
|
199 |
+
|
200 |
+
# 修复缩进
|
201 |
+
if (shared_indent_cnt < 1e5) and (shared_indent_cnt % 4) == 3:
|
202 |
+
num_spaces_should_be = math.ceil(shared_indent_cnt / 4) * 4
|
203 |
+
for i in range(line_start, line_end):
|
204 |
+
add_n = num_spaces_should_be - shared_indent_cnt
|
205 |
+
lines[i] = " " * add_n + lines[i]
|
206 |
+
if not change_any: # 遇到第一个
|
207 |
+
change_any = True
|
208 |
+
|
209 |
+
if change_any:
|
210 |
+
return "\n".join(lines)
|
211 |
+
else:
|
212 |
+
return txt
|
213 |
+
|
214 |
+
|
215 |
+
def fix_dollar_sticking_bug(txt):
|
216 |
+
"""
|
217 |
+
修复不标准的dollar公式符号的问题
|
218 |
+
"""
|
219 |
+
txt_result = ""
|
220 |
+
single_stack_height = 0
|
221 |
+
double_stack_height = 0
|
222 |
+
while True:
|
223 |
+
while True:
|
224 |
+
index = txt.find('$')
|
225 |
+
|
226 |
+
if index == -1:
|
227 |
+
txt_result += txt
|
228 |
+
return txt_result
|
229 |
+
|
230 |
+
if single_stack_height > 0:
|
231 |
+
if txt[:(index+1)].find('\n') > 0 or txt[:(index+1)].find('<td>') > 0 or txt[:(index+1)].find('</td>') > 0:
|
232 |
+
logger.error('公式之中出现了异常 (Unexpect element in equation)')
|
233 |
+
single_stack_height = 0
|
234 |
+
txt_result += ' $'
|
235 |
+
continue
|
236 |
+
|
237 |
+
if double_stack_height > 0:
|
238 |
+
if txt[:(index+1)].find('\n\n') > 0:
|
239 |
+
logger.error('公式之中出现了异常 (Unexpect element in equation)')
|
240 |
+
double_stack_height = 0
|
241 |
+
txt_result += '$$'
|
242 |
+
continue
|
243 |
+
|
244 |
+
is_double = (txt[index+1] == '$')
|
245 |
+
if is_double:
|
246 |
+
if single_stack_height != 0:
|
247 |
+
# add a padding
|
248 |
+
txt = txt[:(index+1)] + " " + txt[(index+1):]
|
249 |
+
continue
|
250 |
+
if double_stack_height == 0:
|
251 |
+
double_stack_height = 1
|
252 |
+
else:
|
253 |
+
double_stack_height = 0
|
254 |
+
txt_result += txt[:(index+2)]
|
255 |
+
txt = txt[(index+2):]
|
256 |
+
else:
|
257 |
+
if double_stack_height != 0:
|
258 |
+
# logger.info(txt[:(index)])
|
259 |
+
logger.info('发现异常嵌套公式')
|
260 |
+
if single_stack_height == 0:
|
261 |
+
single_stack_height = 1
|
262 |
+
else:
|
263 |
+
single_stack_height = 0
|
264 |
+
# logger.info(txt[:(index)])
|
265 |
+
txt_result += txt[:(index+1)]
|
266 |
+
txt = txt[(index+1):]
|
267 |
+
break
|
268 |
+
|
269 |
+
|
270 |
+
def markdown_convertion_for_file(txt):
|
271 |
+
"""
|
272 |
+
将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
|
273 |
+
"""
|
274 |
+
from themes.theme import advanced_css
|
275 |
+
pre = f"""
|
276 |
+
<!DOCTYPE html><head><meta charset="utf-8"><title>GPT-Academic输出文档</title><style>{advanced_css}</style></head>
|
277 |
+
<body>
|
278 |
+
<div class="test_temp1" style="width:10%; height: 500px; float:left;"></div>
|
279 |
+
<div class="test_temp2" style="width:80%;padding: 40px;float:left;padding-left: 20px;padding-right: 20px;box-shadow: rgba(0, 0, 0, 0.2) 0px 0px 8px 8px;border-radius: 10px;">
|
280 |
+
<div class="markdown-body">
|
281 |
+
"""
|
282 |
+
suf = """
|
283 |
+
</div>
|
284 |
+
</div>
|
285 |
+
<div class="test_temp3" style="width:10%; height: 500px; float:left;"></div>
|
286 |
+
</body>
|
287 |
+
"""
|
288 |
+
|
289 |
+
if txt.startswith(pre) and txt.endswith(suf):
|
290 |
+
# print('警告,输入了已经经过转化的字符串,二次转化可能出问题')
|
291 |
+
return txt # 已经被转化过,不需要再次转化
|
292 |
+
|
293 |
+
find_equation_pattern = r'<script type="math/tex(?:.*?)>(.*?)</script>'
|
294 |
+
txt = fix_markdown_indent(txt)
|
295 |
+
convert_stage_1 = fix_dollar_sticking_bug(txt)
|
296 |
+
# convert everything to html format
|
297 |
+
convert_stage_2 = markdown.markdown(
|
298 |
+
text=convert_stage_1,
|
299 |
+
extensions=[
|
300 |
+
"sane_lists",
|
301 |
+
"tables",
|
302 |
+
"mdx_math",
|
303 |
+
"pymdownx.superfences",
|
304 |
+
"pymdownx.highlight",
|
305 |
+
],
|
306 |
+
extension_configs={**markdown_extension_configs, **code_highlight_configs},
|
307 |
+
)
|
308 |
+
|
309 |
+
|
310 |
+
def repl_fn(match):
|
311 |
+
content = match.group(2)
|
312 |
+
return f'<script type="math/tex">{content}</script>'
|
313 |
+
|
314 |
+
pattern = "|".join([pattern for pattern, property in mathpatterns.items() if not property["allow_multi_lines"]])
|
315 |
+
pattern = re.compile(pattern, flags=re.ASCII)
|
316 |
+
convert_stage_3 = pattern.sub(repl_fn, convert_stage_2)
|
317 |
+
|
318 |
+
convert_stage_4 = markdown_bug_hunt(convert_stage_3)
|
319 |
+
|
320 |
+
# 2. convert to rendered equation
|
321 |
+
convert_stage_5, n = re.subn(
|
322 |
+
find_equation_pattern, replace_math_render, convert_stage_4, flags=re.DOTALL
|
323 |
+
)
|
324 |
+
# cat them together
|
325 |
+
return pre + convert_stage_5 + suf
|
326 |
+
|
327 |
+
@lru_cache(maxsize=128) # 使用 lru缓存 加快转换速度
|
328 |
+
def markdown_convertion(txt):
|
329 |
+
"""
|
330 |
+
将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
|
331 |
+
"""
|
332 |
+
pre = '<div class="markdown-body">'
|
333 |
+
suf = "</div>"
|
334 |
+
if txt.startswith(pre) and txt.endswith(suf):
|
335 |
+
# print('警告,输入了已经经过转化的字符串,二次转化可能出问题')
|
336 |
+
return txt # 已经被转化过,不需要再次转化
|
337 |
+
|
338 |
+
find_equation_pattern = r'<script type="math/tex(?:.*?)>(.*?)</script>'
|
339 |
+
|
340 |
+
txt = fix_markdown_indent(txt)
|
341 |
+
# txt = fix_code_segment_indent(txt)
|
342 |
+
if is_equation(txt): # 有$标识的公式符号,且没有代码段```的标识
|
343 |
+
# convert everything to html format
|
344 |
+
split = markdown.markdown(text="---")
|
345 |
+
convert_stage_1 = markdown.markdown(
|
346 |
+
text=txt,
|
347 |
+
extensions=[
|
348 |
+
"sane_lists",
|
349 |
+
"tables",
|
350 |
+
"mdx_math",
|
351 |
+
"pymdownx.superfences",
|
352 |
+
"pymdownx.highlight",
|
353 |
+
],
|
354 |
+
extension_configs={**markdown_extension_configs, **code_highlight_configs},
|
355 |
+
)
|
356 |
+
convert_stage_1 = markdown_bug_hunt(convert_stage_1)
|
357 |
+
# 1. convert to easy-to-copy tex (do not render math)
|
358 |
+
convert_stage_2_1, n = re.subn(
|
359 |
+
find_equation_pattern,
|
360 |
+
replace_math_no_render,
|
361 |
+
convert_stage_1,
|
362 |
+
flags=re.DOTALL,
|
363 |
+
)
|
364 |
+
# 2. convert to rendered equation
|
365 |
+
convert_stage_2_2, n = re.subn(
|
366 |
+
find_equation_pattern, replace_math_render, convert_stage_1, flags=re.DOTALL
|
367 |
+
)
|
368 |
+
# cat them together
|
369 |
+
return pre + convert_stage_2_1 + f"{split}" + convert_stage_2_2 + suf
|
370 |
+
else:
|
371 |
+
return (
|
372 |
+
pre
|
373 |
+
+ markdown.markdown(
|
374 |
+
txt,
|
375 |
+
extensions=[
|
376 |
+
"sane_lists",
|
377 |
+
"tables",
|
378 |
+
"pymdownx.superfences",
|
379 |
+
"pymdownx.highlight",
|
380 |
+
],
|
381 |
+
extension_configs=code_highlight_configs,
|
382 |
+
)
|
383 |
+
+ suf
|
384 |
+
)
|
385 |
+
|
386 |
+
|
387 |
+
def close_up_code_segment_during_stream(gpt_reply):
|
388 |
+
"""
|
389 |
+
在gpt输出代码的中途(输出了前面的```,但还没输出完后面的```),补上后面的```
|
390 |
+
|
391 |
+
Args:
|
392 |
+
gpt_reply (str): GPT模型返回的回复字符串。
|
393 |
+
|
394 |
+
Returns:
|
395 |
+
str: 返回一个新的字符串,将输出代码片段的“后面的```”补上。
|
396 |
+
|
397 |
+
"""
|
398 |
+
if "```" not in gpt_reply:
|
399 |
+
return gpt_reply
|
400 |
+
if gpt_reply.endswith("```"):
|
401 |
+
return gpt_reply
|
402 |
+
|
403 |
+
# 排除了以上两个情况,我们
|
404 |
+
segments = gpt_reply.split("```")
|
405 |
+
n_mark = len(segments) - 1
|
406 |
+
if n_mark % 2 == 1:
|
407 |
+
return gpt_reply + "\n```" # 输出代码片段中!
|
408 |
+
else:
|
409 |
+
return gpt_reply
|
410 |
+
|
411 |
+
|
412 |
+
def special_render_issues_for_mermaid(text):
|
413 |
+
# 用不太优雅的方式处理一个core_functional.py中出现的mermaid渲染特例:
|
414 |
+
# 我不希望"总结绘制脑图"prompt中的mermaid渲染出来
|
415 |
+
@lru_cache(maxsize=1)
|
416 |
+
def get_special_case():
|
417 |
+
from core_functional import get_core_functions
|
418 |
+
special_case = get_core_functions()["总结绘制脑图"]["Suffix"]
|
419 |
+
return special_case
|
420 |
+
if text.endswith(get_special_case()): text = text.replace("```mermaid", "```")
|
421 |
+
return text
|
422 |
+
|
423 |
+
|
424 |
+
def compat_non_markdown_input(text):
|
425 |
+
"""
|
426 |
+
改善非markdown输入的显示效果,例如将空格转换为 ,将换行符转换为</br>等。
|
427 |
+
"""
|
428 |
+
if "```" in text:
|
429 |
+
# careful input:markdown输入
|
430 |
+
text = special_render_issues_for_mermaid(text) # 处理特殊的渲染问题
|
431 |
+
return text
|
432 |
+
elif "</div>" in text:
|
433 |
+
# careful input:html输入
|
434 |
+
return text
|
435 |
+
else:
|
436 |
+
# whatever input:非markdown输入
|
437 |
+
lines = text.split("\n")
|
438 |
+
for i, line in enumerate(lines):
|
439 |
+
lines[i] = lines[i].replace(" ", " ") # 空格转换为
|
440 |
+
text = "</br>".join(lines) # 换行符转换为</br>
|
441 |
+
return text
|
442 |
+
|
443 |
+
|
444 |
+
@lru_cache(maxsize=128) # 使用lru缓存
|
445 |
+
def simple_markdown_convertion(text):
|
446 |
+
pre = '<div class="markdown-body">'
|
447 |
+
suf = "</div>"
|
448 |
+
if text.startswith(pre) and text.endswith(suf):
|
449 |
+
return text # 已经被转化过,不需要再次转化
|
450 |
+
text = compat_non_markdown_input(text) # 兼容非markdown输入
|
451 |
+
text = markdown.markdown(
|
452 |
+
text,
|
453 |
+
extensions=["pymdownx.superfences", "tables", "pymdownx.highlight"],
|
454 |
+
extension_configs=code_highlight_configs,
|
455 |
+
)
|
456 |
+
return pre + text + suf
|
457 |
+
|
458 |
+
|
459 |
+
def format_io(self, y):
|
460 |
+
"""
|
461 |
+
将输入和输出解析为HTML格式。将y中最后一项的输入部分段落化,并将输出部分的Markdown和数学公式转换为HTML格式。
|
462 |
+
"""
|
463 |
+
if y is None or y == []:
|
464 |
+
return []
|
465 |
+
i_ask, gpt_reply = y[-1]
|
466 |
+
i_ask = apply_gpt_academic_string_mask(i_ask, mode="show_render")
|
467 |
+
gpt_reply = apply_gpt_academic_string_mask(gpt_reply, mode="show_render")
|
468 |
+
# 当代码输出半截的时候,试着补上后个```
|
469 |
+
if gpt_reply is not None:
|
470 |
+
gpt_reply = close_up_code_segment_during_stream(gpt_reply)
|
471 |
+
# 处理提问与输出
|
472 |
+
y[-1] = (
|
473 |
+
# 输入部分
|
474 |
+
None if i_ask is None else simple_markdown_convertion(i_ask),
|
475 |
+
# 输出部分
|
476 |
+
None if gpt_reply is None else markdown_convertion(gpt_reply),
|
477 |
+
)
|
478 |
+
return y
|
docker_as_a_service/shared_utils/char_visual_effect.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def is_full_width_char(ch):
|
2 |
+
"""判断给定的单个字符是否是全角字符"""
|
3 |
+
if '\u4e00' <= ch <= '\u9fff':
|
4 |
+
return True # 中文字符
|
5 |
+
if '\uff01' <= ch <= '\uff5e':
|
6 |
+
return True # 全角符号
|
7 |
+
if '\u3000' <= ch <= '\u303f':
|
8 |
+
return True # CJK标点符号
|
9 |
+
return False
|
10 |
+
|
11 |
+
def scolling_visual_effect(text, scroller_max_len):
|
12 |
+
text = text.\
|
13 |
+
replace('\n', '').replace('`', '.').replace(' ', '.').replace('<br/>', '.....').replace('$', '.')
|
14 |
+
place_take_cnt = 0
|
15 |
+
pointer = len(text) - 1
|
16 |
+
|
17 |
+
if len(text) < scroller_max_len:
|
18 |
+
return text
|
19 |
+
|
20 |
+
while place_take_cnt < scroller_max_len and pointer > 0:
|
21 |
+
if is_full_width_char(text[pointer]): place_take_cnt += 2
|
22 |
+
else: place_take_cnt += 1
|
23 |
+
pointer -= 1
|
24 |
+
|
25 |
+
return text[pointer:]
|
docker_as_a_service/shared_utils/colorful.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import platform
|
2 |
+
from sys import stdout
|
3 |
+
from loguru import logger
|
4 |
+
|
5 |
+
if platform.system()=="Linux":
|
6 |
+
pass
|
7 |
+
else:
|
8 |
+
from colorama import init
|
9 |
+
init()
|
10 |
+
|
11 |
+
# Do you like the elegance of Chinese characters?
|
12 |
+
def print红(*kw,**kargs):
|
13 |
+
print("\033[0;31m",*kw,"\033[0m",**kargs)
|
14 |
+
def print绿(*kw,**kargs):
|
15 |
+
print("\033[0;32m",*kw,"\033[0m",**kargs)
|
16 |
+
def print黄(*kw,**kargs):
|
17 |
+
print("\033[0;33m",*kw,"\033[0m",**kargs)
|
18 |
+
def print蓝(*kw,**kargs):
|
19 |
+
print("\033[0;34m",*kw,"\033[0m",**kargs)
|
20 |
+
def print紫(*kw,**kargs):
|
21 |
+
print("\033[0;35m",*kw,"\033[0m",**kargs)
|
22 |
+
def print靛(*kw,**kargs):
|
23 |
+
print("\033[0;36m",*kw,"\033[0m",**kargs)
|
24 |
+
|
25 |
+
def print亮红(*kw,**kargs):
|
26 |
+
print("\033[1;31m",*kw,"\033[0m",**kargs)
|
27 |
+
def print亮绿(*kw,**kargs):
|
28 |
+
print("\033[1;32m",*kw,"\033[0m",**kargs)
|
29 |
+
def print亮黄(*kw,**kargs):
|
30 |
+
print("\033[1;33m",*kw,"\033[0m",**kargs)
|
31 |
+
def print亮蓝(*kw,**kargs):
|
32 |
+
print("\033[1;34m",*kw,"\033[0m",**kargs)
|
33 |
+
def print亮紫(*kw,**kargs):
|
34 |
+
print("\033[1;35m",*kw,"\033[0m",**kargs)
|
35 |
+
def print亮靛(*kw,**kargs):
|
36 |
+
print("\033[1;36m",*kw,"\033[0m",**kargs)
|
37 |
+
|
38 |
+
# Do you like the elegance of Chinese characters?
|
39 |
+
def sprint红(*kw):
|
40 |
+
return "\033[0;31m"+' '.join(kw)+"\033[0m"
|
41 |
+
def sprint绿(*kw):
|
42 |
+
return "\033[0;32m"+' '.join(kw)+"\033[0m"
|
43 |
+
def sprint黄(*kw):
|
44 |
+
return "\033[0;33m"+' '.join(kw)+"\033[0m"
|
45 |
+
def sprint蓝(*kw):
|
46 |
+
return "\033[0;34m"+' '.join(kw)+"\033[0m"
|
47 |
+
def sprint紫(*kw):
|
48 |
+
return "\033[0;35m"+' '.join(kw)+"\033[0m"
|
49 |
+
def sprint靛(*kw):
|
50 |
+
return "\033[0;36m"+' '.join(kw)+"\033[0m"
|
51 |
+
def sprint亮红(*kw):
|
52 |
+
return "\033[1;31m"+' '.join(kw)+"\033[0m"
|
53 |
+
def sprint亮绿(*kw):
|
54 |
+
return "\033[1;32m"+' '.join(kw)+"\033[0m"
|
55 |
+
def sprint亮黄(*kw):
|
56 |
+
return "\033[1;33m"+' '.join(kw)+"\033[0m"
|
57 |
+
def sprint亮蓝(*kw):
|
58 |
+
return "\033[1;34m"+' '.join(kw)+"\033[0m"
|
59 |
+
def sprint亮紫(*kw):
|
60 |
+
return "\033[1;35m"+' '.join(kw)+"\033[0m"
|
61 |
+
def sprint亮靛(*kw):
|
62 |
+
return "\033[1;36m"+' '.join(kw)+"\033[0m"
|
63 |
+
|
64 |
+
def log红(*kw,**kargs):
|
65 |
+
logger.opt(depth=1).info(sprint红(*kw))
|
66 |
+
def log绿(*kw,**kargs):
|
67 |
+
logger.opt(depth=1).info(sprint绿(*kw))
|
68 |
+
def log黄(*kw,**kargs):
|
69 |
+
logger.opt(depth=1).info(sprint黄(*kw))
|
70 |
+
def log蓝(*kw,**kargs):
|
71 |
+
logger.opt(depth=1).info(sprint蓝(*kw))
|
72 |
+
def log紫(*kw,**kargs):
|
73 |
+
logger.opt(depth=1).info(sprint紫(*kw))
|
74 |
+
def log靛(*kw,**kargs):
|
75 |
+
logger.opt(depth=1).info(sprint靛(*kw))
|
76 |
+
|
77 |
+
def log亮红(*kw,**kargs):
|
78 |
+
logger.opt(depth=1).info(sprint亮红(*kw))
|
79 |
+
def log亮绿(*kw,**kargs):
|
80 |
+
logger.opt(depth=1).info(sprint亮绿(*kw))
|
81 |
+
def log亮黄(*kw,**kargs):
|
82 |
+
logger.opt(depth=1).info(sprint亮黄(*kw))
|
83 |
+
def log亮蓝(*kw,**kargs):
|
84 |
+
logger.opt(depth=1).info(sprint亮蓝(*kw))
|
85 |
+
def log亮紫(*kw,**kargs):
|
86 |
+
logger.opt(depth=1).info(sprint亮紫(*kw))
|
87 |
+
def log亮靛(*kw,**kargs):
|
88 |
+
logger.opt(depth=1).info(sprint亮靛(*kw))
|
docker_as_a_service/shared_utils/config_loader.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import time
|
3 |
+
import os
|
4 |
+
from functools import lru_cache
|
5 |
+
from shared_utils.colorful import log亮红, log亮绿, log亮蓝
|
6 |
+
|
7 |
+
pj = os.path.join
|
8 |
+
default_user_name = 'default_user'
|
9 |
+
|
10 |
+
def read_env_variable(arg, default_value):
|
11 |
+
"""
|
12 |
+
环境变量可以是 `GPT_ACADEMIC_CONFIG`(优先),也可以直接是`CONFIG`
|
13 |
+
例如在windows cmd中,既可以写:
|
14 |
+
set USE_PROXY=True
|
15 |
+
set API_KEY=sk-j7caBpkRoxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
16 |
+
set proxies={"http":"http://127.0.0.1:10085", "https":"http://127.0.0.1:10085",}
|
17 |
+
set AVAIL_LLM_MODELS=["gpt-3.5-turbo", "chatglm"]
|
18 |
+
set AUTHENTICATION=[("username", "password"), ("username2", "password2")]
|
19 |
+
也可以写:
|
20 |
+
set GPT_ACADEMIC_USE_PROXY=True
|
21 |
+
set GPT_ACADEMIC_API_KEY=sk-j7caBpkRoxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
22 |
+
set GPT_ACADEMIC_proxies={"http":"http://127.0.0.1:10085", "https":"http://127.0.0.1:10085",}
|
23 |
+
set GPT_ACADEMIC_AVAIL_LLM_MODELS=["gpt-3.5-turbo", "chatglm"]
|
24 |
+
set GPT_ACADEMIC_AUTHENTICATION=[("username", "password"), ("username2", "password2")]
|
25 |
+
"""
|
26 |
+
arg_with_prefix = "GPT_ACADEMIC_" + arg
|
27 |
+
if arg_with_prefix in os.environ:
|
28 |
+
env_arg = os.environ[arg_with_prefix]
|
29 |
+
elif arg in os.environ:
|
30 |
+
env_arg = os.environ[arg]
|
31 |
+
else:
|
32 |
+
raise KeyError
|
33 |
+
log亮绿(f"[ENV_VAR] 尝试加载{arg},默认值:{default_value} --> 修正值:{env_arg}")
|
34 |
+
try:
|
35 |
+
if isinstance(default_value, bool):
|
36 |
+
env_arg = env_arg.strip()
|
37 |
+
if env_arg == 'True': r = True
|
38 |
+
elif env_arg == 'False': r = False
|
39 |
+
else: log亮红('Expect `True` or `False`, but have:', env_arg); r = default_value
|
40 |
+
elif isinstance(default_value, int):
|
41 |
+
r = int(env_arg)
|
42 |
+
elif isinstance(default_value, float):
|
43 |
+
r = float(env_arg)
|
44 |
+
elif isinstance(default_value, str):
|
45 |
+
r = env_arg.strip()
|
46 |
+
elif isinstance(default_value, dict):
|
47 |
+
r = eval(env_arg)
|
48 |
+
elif isinstance(default_value, list):
|
49 |
+
r = eval(env_arg)
|
50 |
+
elif default_value is None:
|
51 |
+
assert arg == "proxies"
|
52 |
+
r = eval(env_arg)
|
53 |
+
else:
|
54 |
+
log亮红(f"[ENV_VAR] 环境变量{arg}不支持通过环境变量设置! ")
|
55 |
+
raise KeyError
|
56 |
+
except:
|
57 |
+
log亮红(f"[ENV_VAR] 环境变量{arg}加载失败! ")
|
58 |
+
raise KeyError(f"[ENV_VAR] 环境变量{arg}加载失败! ")
|
59 |
+
|
60 |
+
log亮绿(f"[ENV_VAR] 成功读取环境变量{arg}")
|
61 |
+
return r
|
62 |
+
|
63 |
+
|
64 |
+
@lru_cache(maxsize=128)
|
65 |
+
def read_single_conf_with_lru_cache(arg):
|
66 |
+
from shared_utils.key_pattern_manager import is_any_api_key
|
67 |
+
try:
|
68 |
+
# 优先级1. 获取环境变量作为配置
|
69 |
+
default_ref = getattr(importlib.import_module('config'), arg) # 读取默认值作为数据类型转换的参考
|
70 |
+
r = read_env_variable(arg, default_ref)
|
71 |
+
except:
|
72 |
+
try:
|
73 |
+
# 优先级2. 获取config_private中的配置
|
74 |
+
r = getattr(importlib.import_module('config_private'), arg)
|
75 |
+
except:
|
76 |
+
# 优先级3. 获取config中的配置
|
77 |
+
r = getattr(importlib.import_module('config'), arg)
|
78 |
+
|
79 |
+
# 在读取API_KEY时,检查一下是不是忘了改config
|
80 |
+
if arg == 'API_URL_REDIRECT':
|
81 |
+
oai_rd = r.get("https://api.openai.com/v1/chat/completions", None) # API_URL_REDIRECT填写格式是错误的,请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`
|
82 |
+
if oai_rd and not oai_rd.endswith('/completions'):
|
83 |
+
log亮红("\n\n[API_URL_REDIRECT] API_URL_REDIRECT填错了。请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`。如果您确信自己没填错,无视此消息即可。")
|
84 |
+
time.sleep(5)
|
85 |
+
if arg == 'API_KEY':
|
86 |
+
log亮蓝(f"[API_KEY] 本项目现已支持OpenAI和Azure的api-key。也支持同时填写多个api-key,如API_KEY=\"openai-key1,openai-key2,azure-key3\"")
|
87 |
+
log亮蓝(f"[API_KEY] 您既可以在config.py中修改api-key(s),也可以在问题输入区输入临时的api-key(s),然后回车键提交后即可生效。")
|
88 |
+
if is_any_api_key(r):
|
89 |
+
log亮绿(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功")
|
90 |
+
else:
|
91 |
+
log亮红(f"[API_KEY] 您的 API_KEY({r[:15]}***)不满足任何一种已知的密钥格式,请在config文件中修改API密钥之后再运行(详见`https://github.com/binary-husky/gpt_academic/wiki/api_key`)。")
|
92 |
+
if arg == 'proxies':
|
93 |
+
if not read_single_conf_with_lru_cache('USE_PROXY'): r = None # 检查USE_PROXY,防止proxies单独起作用
|
94 |
+
if r is None:
|
95 |
+
log亮红('[PROXY] 网络代理状态:未配置。无代理状态下很可能无法访问OpenAI家族的模型。建议:检查USE_PROXY选项是否修改。')
|
96 |
+
else:
|
97 |
+
log亮绿('[PROXY] 网络代理状态:已配置。配置信息如下:', str(r))
|
98 |
+
assert isinstance(r, dict), 'proxies格式错误,请注意proxies选项的格式,不要遗漏括号。'
|
99 |
+
return r
|
100 |
+
|
101 |
+
|
102 |
+
@lru_cache(maxsize=128)
|
103 |
+
def get_conf(*args):
|
104 |
+
"""
|
105 |
+
本项目的所有配置都集中在config.py中。 修改配置有三种方法,您只需要选择其中一种即可:
|
106 |
+
- 直接修改config.py
|
107 |
+
- 创建并修改config_private.py
|
108 |
+
- 修改环境变量(修改docker-compose.yml等价于修改容器内部的环境变量)
|
109 |
+
|
110 |
+
注意:如果您使用docker-compose部署,请修改docker-compose(等价于修改容器内部的环境变量)
|
111 |
+
"""
|
112 |
+
res = []
|
113 |
+
for arg in args:
|
114 |
+
r = read_single_conf_with_lru_cache(arg)
|
115 |
+
res.append(r)
|
116 |
+
if len(res) == 1: return res[0]
|
117 |
+
return res
|
118 |
+
|
119 |
+
|
120 |
+
def set_conf(key, value):
|
121 |
+
from toolbox import read_single_conf_with_lru_cache
|
122 |
+
read_single_conf_with_lru_cache.cache_clear()
|
123 |
+
get_conf.cache_clear()
|
124 |
+
os.environ[key] = str(value)
|
125 |
+
altered = get_conf(key)
|
126 |
+
return altered
|
127 |
+
|
128 |
+
|
129 |
+
def set_multi_conf(dic):
|
130 |
+
for k, v in dic.items(): set_conf(k, v)
|
131 |
+
return
|
docker_as_a_service/shared_utils/connect_void_terminal.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
"""
|
4 |
+
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
5 |
+
接驳void-terminal:
|
6 |
+
- set_conf: 在运行过程中动态地修改配置
|
7 |
+
- set_multi_conf: 在运行过程中动态地修改多个配置
|
8 |
+
- get_plugin_handle: 获取插件的句柄
|
9 |
+
- get_plugin_default_kwargs: 获取插件的默认参数
|
10 |
+
- get_chat_handle: 获取简单聊天的句柄
|
11 |
+
- get_chat_default_kwargs: 获取简单聊天的默认参数
|
12 |
+
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
13 |
+
"""
|
14 |
+
|
15 |
+
|
16 |
+
def get_plugin_handle(plugin_name):
|
17 |
+
"""
|
18 |
+
e.g. plugin_name = 'crazy_functions.Markdown_Translate->Markdown翻译指定语言'
|
19 |
+
"""
|
20 |
+
import importlib
|
21 |
+
|
22 |
+
assert (
|
23 |
+
"->" in plugin_name
|
24 |
+
), "Example of plugin_name: crazy_functions.Markdown_Translate->Markdown翻译指定语言"
|
25 |
+
module, fn_name = plugin_name.split("->")
|
26 |
+
f_hot_reload = getattr(importlib.import_module(module, fn_name), fn_name)
|
27 |
+
return f_hot_reload
|
28 |
+
|
29 |
+
|
30 |
+
def get_chat_handle():
|
31 |
+
"""
|
32 |
+
Get chat function
|
33 |
+
"""
|
34 |
+
from request_llms.bridge_all import predict_no_ui_long_connection
|
35 |
+
|
36 |
+
return predict_no_ui_long_connection
|
37 |
+
|
38 |
+
|
39 |
+
def get_plugin_default_kwargs():
|
40 |
+
"""
|
41 |
+
Get Plugin Default Arguments
|
42 |
+
"""
|
43 |
+
from toolbox import ChatBotWithCookies, load_chat_cookies
|
44 |
+
|
45 |
+
cookies = load_chat_cookies()
|
46 |
+
llm_kwargs = {
|
47 |
+
"api_key": cookies["api_key"],
|
48 |
+
"llm_model": cookies["llm_model"],
|
49 |
+
"top_p": 1.0,
|
50 |
+
"max_length": None,
|
51 |
+
"temperature": 1.0,
|
52 |
+
}
|
53 |
+
chatbot = ChatBotWithCookies(llm_kwargs)
|
54 |
+
|
55 |
+
# txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request
|
56 |
+
DEFAULT_FN_GROUPS_kwargs = {
|
57 |
+
"main_input": "./README.md",
|
58 |
+
"llm_kwargs": llm_kwargs,
|
59 |
+
"plugin_kwargs": {},
|
60 |
+
"chatbot_with_cookie": chatbot,
|
61 |
+
"history": [],
|
62 |
+
"system_prompt": "You are a good AI.",
|
63 |
+
"user_request": None,
|
64 |
+
}
|
65 |
+
return DEFAULT_FN_GROUPS_kwargs
|
66 |
+
|
67 |
+
|
68 |
+
def get_chat_default_kwargs():
|
69 |
+
"""
|
70 |
+
Get Chat Default Arguments
|
71 |
+
"""
|
72 |
+
from toolbox import load_chat_cookies
|
73 |
+
|
74 |
+
cookies = load_chat_cookies()
|
75 |
+
llm_kwargs = {
|
76 |
+
"api_key": cookies["api_key"],
|
77 |
+
"llm_model": cookies["llm_model"],
|
78 |
+
"top_p": 1.0,
|
79 |
+
"max_length": None,
|
80 |
+
"temperature": 1.0,
|
81 |
+
}
|
82 |
+
default_chat_kwargs = {
|
83 |
+
"inputs": "Hello there, are you ready?",
|
84 |
+
"llm_kwargs": llm_kwargs,
|
85 |
+
"history": [],
|
86 |
+
"sys_prompt": "You are AI assistant",
|
87 |
+
"observe_window": None,
|
88 |
+
"console_slience": False,
|
89 |
+
}
|
90 |
+
|
91 |
+
return default_chat_kwargs
|
docker_as_a_service/shared_utils/cookie_manager.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import base64
|
3 |
+
from typing import Callable
|
4 |
+
|
5 |
+
def load_web_cookie_cache__fn_builder(customize_btns, cookies, predefined_btns)->Callable:
|
6 |
+
def load_web_cookie_cache(persistent_cookie_, cookies_):
|
7 |
+
import gradio as gr
|
8 |
+
from themes.theme import load_dynamic_theme, to_cookie_str, from_cookie_str, assign_user_uuid
|
9 |
+
|
10 |
+
ret = {}
|
11 |
+
for k in customize_btns:
|
12 |
+
ret.update({customize_btns[k]: gr.update(visible=False, value="")})
|
13 |
+
|
14 |
+
try: persistent_cookie_ = from_cookie_str(persistent_cookie_) # persistent cookie to dict
|
15 |
+
except: return ret
|
16 |
+
|
17 |
+
customize_fn_overwrite_ = persistent_cookie_.get("custom_bnt", {})
|
18 |
+
cookies_['customize_fn_overwrite'] = customize_fn_overwrite_
|
19 |
+
ret.update({cookies: cookies_})
|
20 |
+
|
21 |
+
for k,v in persistent_cookie_["custom_bnt"].items():
|
22 |
+
if v['Title'] == "": continue
|
23 |
+
if k in customize_btns: ret.update({customize_btns[k]: gr.update(visible=True, value=v['Title'])})
|
24 |
+
else: ret.update({predefined_btns[k]: gr.update(visible=True, value=v['Title'])})
|
25 |
+
return ret
|
26 |
+
return load_web_cookie_cache
|
27 |
+
|
28 |
+
def assign_btn__fn_builder(customize_btns, predefined_btns, cookies, web_cookie_cache)->Callable:
|
29 |
+
def assign_btn(persistent_cookie_, cookies_, basic_btn_dropdown_, basic_fn_title, basic_fn_prefix, basic_fn_suffix, clean_up=False):
|
30 |
+
import gradio as gr
|
31 |
+
from themes.theme import load_dynamic_theme, to_cookie_str, from_cookie_str, assign_user_uuid
|
32 |
+
ret = {}
|
33 |
+
# 读取之前的自定义按钮
|
34 |
+
customize_fn_overwrite_ = cookies_['customize_fn_overwrite']
|
35 |
+
# 更新新的自定义按钮
|
36 |
+
customize_fn_overwrite_.update({
|
37 |
+
basic_btn_dropdown_:
|
38 |
+
{
|
39 |
+
"Title":basic_fn_title,
|
40 |
+
"Prefix":basic_fn_prefix,
|
41 |
+
"Suffix":basic_fn_suffix,
|
42 |
+
}
|
43 |
+
}
|
44 |
+
)
|
45 |
+
if clean_up:
|
46 |
+
customize_fn_overwrite_ = {}
|
47 |
+
cookies_.update(customize_fn_overwrite_) # 更新cookie
|
48 |
+
visible = (not clean_up) and (basic_fn_title != "")
|
49 |
+
if basic_btn_dropdown_ in customize_btns:
|
50 |
+
# 是自定义按钮,不是预定义按钮
|
51 |
+
ret.update({customize_btns[basic_btn_dropdown_]: gr.update(visible=visible, value=basic_fn_title)})
|
52 |
+
else:
|
53 |
+
# 是预定义按钮
|
54 |
+
ret.update({predefined_btns[basic_btn_dropdown_]: gr.update(visible=visible, value=basic_fn_title)})
|
55 |
+
ret.update({cookies: cookies_})
|
56 |
+
try: persistent_cookie_ = from_cookie_str(persistent_cookie_) # persistent cookie to dict
|
57 |
+
except: persistent_cookie_ = {}
|
58 |
+
persistent_cookie_["custom_bnt"] = customize_fn_overwrite_ # dict update new value
|
59 |
+
persistent_cookie_ = to_cookie_str(persistent_cookie_) # persistent cookie to dict
|
60 |
+
ret.update({web_cookie_cache: persistent_cookie_}) # write persistent cookie
|
61 |
+
return ret
|
62 |
+
return assign_btn
|
63 |
+
|
64 |
+
# cookies, web_cookie_cache = make_cookie_cache()
|
65 |
+
def make_cookie_cache():
|
66 |
+
# 定义 后端state(cookies)、前端(web_cookie_cache)两兄弟
|
67 |
+
import gradio as gr
|
68 |
+
from toolbox import load_chat_cookies
|
69 |
+
# 定义cookies的后端state
|
70 |
+
cookies = gr.State(load_chat_cookies())
|
71 |
+
# 定义cookies的一个孪生的前端存储区(隐藏)
|
72 |
+
web_cookie_cache = gr.Textbox(visible=False, elem_id="web_cookie_cache")
|
73 |
+
return cookies, web_cookie_cache
|
74 |
+
|
75 |
+
# history, history_cache, history_cache_update = make_history_cache()
|
76 |
+
def make_history_cache():
|
77 |
+
# 定义 后端state(history)、前端(history_cache)、后端setter(history_cache_update)三兄弟
|
78 |
+
import gradio as gr
|
79 |
+
# 定义history的后端state
|
80 |
+
history = gr.State([])
|
81 |
+
# 定义history的一个孪生的前端存储区(隐藏)
|
82 |
+
history_cache = gr.Textbox(visible=False, elem_id="history_cache")
|
83 |
+
# 定义history_cache->history的更新方法(隐藏)。在触发这个按钮时,会先执行js代码更新history_cache,然后再执行python代码更新history
|
84 |
+
def process_history_cache(history_cache):
|
85 |
+
return json.loads(history_cache)
|
86 |
+
# 另一种更简单的setter方法
|
87 |
+
history_cache_update = gr.Button("", elem_id="elem_update_history", visible=False).click(
|
88 |
+
process_history_cache, inputs=[history_cache], outputs=[history])
|
89 |
+
return history, history_cache, history_cache_update
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
def create_button_with_javascript_callback(btn_value, elem_id, variant, js_callback, input_list, output_list, function, input_name_list, output_name_list):
|
94 |
+
import gradio as gr
|
95 |
+
middle_ware_component = gr.Textbox(visible=False, elem_id=elem_id+'_buffer')
|
96 |
+
def get_fn_wrap():
|
97 |
+
def fn_wrap(*args):
|
98 |
+
summary_dict = {}
|
99 |
+
for name, value in zip(input_name_list, args):
|
100 |
+
summary_dict.update({name: value})
|
101 |
+
|
102 |
+
res = function(*args)
|
103 |
+
|
104 |
+
for name, value in zip(output_name_list, res):
|
105 |
+
summary_dict.update({name: value})
|
106 |
+
|
107 |
+
summary = base64.b64encode(json.dumps(summary_dict).encode('utf8')).decode("utf-8")
|
108 |
+
return (*res, summary)
|
109 |
+
return fn_wrap
|
110 |
+
|
111 |
+
btn = gr.Button(btn_value, elem_id=elem_id, variant=variant)
|
112 |
+
call_args = ""
|
113 |
+
for name in output_name_list:
|
114 |
+
call_args += f"""Data["{name}"],"""
|
115 |
+
call_args = call_args.rstrip(",")
|
116 |
+
_js_callback = """
|
117 |
+
(base64MiddleString)=>{
|
118 |
+
console.log('hello')
|
119 |
+
const stringData = atob(base64MiddleString);
|
120 |
+
let Data = JSON.parse(stringData);
|
121 |
+
call = JS_CALLBACK_GEN;
|
122 |
+
call(CALL_ARGS);
|
123 |
+
}
|
124 |
+
""".replace("JS_CALLBACK_GEN", js_callback).replace("CALL_ARGS", call_args)
|
125 |
+
|
126 |
+
btn.click(get_fn_wrap(), input_list, output_list+[middle_ware_component]).then(None, [middle_ware_component], None, _js=_js_callback)
|
127 |
+
return btn
|
docker_as_a_service/shared_utils/docker_as_service_api.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import pickle
|
3 |
+
import io
|
4 |
+
import os
|
5 |
+
from pydantic import BaseModel, Field
|
6 |
+
from typing import Optional, Dict
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
class DockerServiceApiComModel(BaseModel):
|
10 |
+
client_command: Optional[str] = Field(default=None, title="Client command", description="The command to be executed on the client side")
|
11 |
+
client_file_attach: Optional[dict] = Field(default=None, title="Client file attach", description="The file to be attached to the client side")
|
12 |
+
server_message: Optional[str] = Field(default=None, title="Server standard error", description="The standard error from the server side")
|
13 |
+
server_std_err: Optional[str] = Field(default=None, title="Server standard error", description="The standard error from the server side")
|
14 |
+
server_std_out: Optional[str] = Field(default=None, title="Server standard output", description="The standard output from the server side")
|
15 |
+
server_file_attach: Optional[dict] = Field(default=None, title="Server file attach", description="The file to be attached to the server side")
|
16 |
+
|
17 |
+
def process_received(received: DockerServiceApiComModel, save_file_dir="./daas_output", output_manifest={}):
|
18 |
+
# Process the received data
|
19 |
+
if received.server_message:
|
20 |
+
output_manifest['server_message'] += received.server_message
|
21 |
+
if received.server_std_err:
|
22 |
+
output_manifest['server_std_err'] += received.server_std_err
|
23 |
+
if received.server_std_out:
|
24 |
+
output_manifest['server_std_out'] += received.server_std_out
|
25 |
+
if received.server_file_attach:
|
26 |
+
# print(f"Recv file attach: {received.server_file_attach}")
|
27 |
+
for file_name, file_content in received.server_file_attach.items():
|
28 |
+
new_fp = os.path.join(save_file_dir, file_name)
|
29 |
+
new_fp_dir = os.path.dirname(new_fp)
|
30 |
+
if not os.path.exists(new_fp_dir):
|
31 |
+
os.makedirs(new_fp_dir, exist_ok=True)
|
32 |
+
with open(new_fp, 'wb') as f:
|
33 |
+
f.write(file_content)
|
34 |
+
output_manifest['server_file_attach'].append(new_fp)
|
35 |
+
return output_manifest
|
36 |
+
|
37 |
+
def stream_daas(docker_service_api_com_model, server_url):
|
38 |
+
# Prepare the file
|
39 |
+
# Pickle the object
|
40 |
+
pickled_data = pickle.dumps(docker_service_api_com_model)
|
41 |
+
|
42 |
+
# Create a file-like object from the pickled data
|
43 |
+
file_obj = io.BytesIO(pickled_data)
|
44 |
+
|
45 |
+
# Prepare the file for sending
|
46 |
+
files = {'file': ('docker_service_api_com_model.pkl', file_obj, 'application/octet-stream')}
|
47 |
+
|
48 |
+
# Send the POST request
|
49 |
+
response = requests.post(server_url, files=files, stream=True)
|
50 |
+
|
51 |
+
max_full_package_size = 1024 * 1024 * 1024 * 1 # 1 GB
|
52 |
+
|
53 |
+
received_output_manifest = {}
|
54 |
+
received_output_manifest['server_message'] = ""
|
55 |
+
received_output_manifest['server_std_err'] = ""
|
56 |
+
received_output_manifest['server_std_out'] = ""
|
57 |
+
received_output_manifest['server_file_attach'] = []
|
58 |
+
|
59 |
+
# Check if the request was successful
|
60 |
+
if response.status_code == 200:
|
61 |
+
# Process the streaming response
|
62 |
+
for chunk in response.iter_content(max_full_package_size):
|
63 |
+
if chunk:
|
64 |
+
received = pickle.loads(chunk)
|
65 |
+
received_output_manifest = process_received(received, received_output_manifest)
|
66 |
+
yield received_output_manifest
|
67 |
+
else:
|
68 |
+
logger.error(f"Error: Received status code {response.status_code}, response.text: {response.text}")
|
69 |
+
|
70 |
+
return received_output_manifest
|
docker_as_a_service/shared_utils/fastapi_server.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Tests:
|
3 |
+
|
4 |
+
- custom_path false / no user auth:
|
5 |
+
-- upload file(yes)
|
6 |
+
-- download file(yes)
|
7 |
+
-- websocket(yes)
|
8 |
+
-- block __pycache__ access(yes)
|
9 |
+
-- rel (yes)
|
10 |
+
-- abs (yes)
|
11 |
+
-- block user access(fail) http://localhost:45013/file=gpt_log/admin/chat_secrets.log
|
12 |
+
-- fix(commit f6bf05048c08f5cd84593f7fdc01e64dec1f584a)-> block successful
|
13 |
+
|
14 |
+
- custom_path yes("/cc/gptac") / no user auth:
|
15 |
+
-- upload file(yes)
|
16 |
+
-- download file(yes)
|
17 |
+
-- websocket(yes)
|
18 |
+
-- block __pycache__ access(yes)
|
19 |
+
-- block user access(yes)
|
20 |
+
|
21 |
+
- custom_path yes("/cc/gptac/") / no user auth:
|
22 |
+
-- upload file(yes)
|
23 |
+
-- download file(yes)
|
24 |
+
-- websocket(yes)
|
25 |
+
-- block user access(yes)
|
26 |
+
|
27 |
+
- custom_path yes("/cc/gptac/") / + user auth:
|
28 |
+
-- upload file(yes)
|
29 |
+
-- download file(yes)
|
30 |
+
-- websocket(yes)
|
31 |
+
-- block user access(yes)
|
32 |
+
-- block user-wise access (yes)
|
33 |
+
|
34 |
+
- custom_path no + user auth:
|
35 |
+
-- upload file(yes)
|
36 |
+
-- download file(yes)
|
37 |
+
-- websocket(yes)
|
38 |
+
-- block user access(yes)
|
39 |
+
-- block user-wise access (yes)
|
40 |
+
|
41 |
+
queue cocurrent effectiveness
|
42 |
+
-- upload file(yes)
|
43 |
+
-- download file(yes)
|
44 |
+
-- websocket(yes)
|
45 |
+
"""
|
46 |
+
|
47 |
+
import os, requests, threading, time
|
48 |
+
import uvicorn
|
49 |
+
|
50 |
+
def validate_path_safety(path_or_url, user):
|
51 |
+
from toolbox import get_conf, default_user_name
|
52 |
+
from toolbox import FriendlyException
|
53 |
+
PATH_PRIVATE_UPLOAD, PATH_LOGGING = get_conf('PATH_PRIVATE_UPLOAD', 'PATH_LOGGING')
|
54 |
+
sensitive_path = None
|
55 |
+
path_or_url = os.path.relpath(path_or_url)
|
56 |
+
if path_or_url.startswith(PATH_LOGGING): # 日志文件(按用户划分)
|
57 |
+
sensitive_path = PATH_LOGGING
|
58 |
+
elif path_or_url.startswith(PATH_PRIVATE_UPLOAD): # 用户的上传目录(按用户划分)
|
59 |
+
sensitive_path = PATH_PRIVATE_UPLOAD
|
60 |
+
elif path_or_url.startswith('tests') or path_or_url.startswith('build'): # 一个常用的测试目录
|
61 |
+
return True
|
62 |
+
else:
|
63 |
+
raise FriendlyException(f"输入文件的路径 ({path_or_url}) 存在,但位置非法。请将文件上传后再执行该任务。") # return False
|
64 |
+
if sensitive_path:
|
65 |
+
allowed_users = [user, 'autogen', 'arxiv_cache', default_user_name] # three user path that can be accessed
|
66 |
+
for user_allowed in allowed_users:
|
67 |
+
if f"{os.sep}".join(path_or_url.split(os.sep)[:2]) == os.path.join(sensitive_path, user_allowed):
|
68 |
+
return True
|
69 |
+
raise FriendlyException(f"输入文件的路径 ({path_or_url}) 存在,但属于其他用户。请将文件上传后再执行该任务。") # return False
|
70 |
+
return True
|
71 |
+
|
72 |
+
def _authorize_user(path_or_url, request, gradio_app):
|
73 |
+
from toolbox import get_conf, default_user_name
|
74 |
+
PATH_PRIVATE_UPLOAD, PATH_LOGGING = get_conf('PATH_PRIVATE_UPLOAD', 'PATH_LOGGING')
|
75 |
+
sensitive_path = None
|
76 |
+
path_or_url = os.path.relpath(path_or_url)
|
77 |
+
if path_or_url.startswith(PATH_LOGGING):
|
78 |
+
sensitive_path = PATH_LOGGING
|
79 |
+
if path_or_url.startswith(PATH_PRIVATE_UPLOAD):
|
80 |
+
sensitive_path = PATH_PRIVATE_UPLOAD
|
81 |
+
if sensitive_path:
|
82 |
+
token = request.cookies.get("access-token") or request.cookies.get("access-token-unsecure")
|
83 |
+
user = gradio_app.tokens.get(token) # get user
|
84 |
+
allowed_users = [user, 'autogen', 'arxiv_cache', default_user_name] # three user path that can be accessed
|
85 |
+
for user_allowed in allowed_users:
|
86 |
+
# exact match
|
87 |
+
if f"{os.sep}".join(path_or_url.split(os.sep)[:2]) == os.path.join(sensitive_path, user_allowed):
|
88 |
+
return True
|
89 |
+
return False # "越权访问!"
|
90 |
+
return True
|
91 |
+
|
92 |
+
|
93 |
+
class Server(uvicorn.Server):
|
94 |
+
# A server that runs in a separate thread
|
95 |
+
def install_signal_handlers(self):
|
96 |
+
pass
|
97 |
+
|
98 |
+
def run_in_thread(self):
|
99 |
+
self.thread = threading.Thread(target=self.run, daemon=True)
|
100 |
+
self.thread.start()
|
101 |
+
while not self.started:
|
102 |
+
time.sleep(5e-2)
|
103 |
+
|
104 |
+
def close(self):
|
105 |
+
self.should_exit = True
|
106 |
+
self.thread.join()
|
107 |
+
|
108 |
+
|
109 |
+
def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SSL_CERTFILE):
|
110 |
+
import uvicorn
|
111 |
+
import fastapi
|
112 |
+
import gradio as gr
|
113 |
+
from fastapi import FastAPI
|
114 |
+
from gradio.routes import App
|
115 |
+
from toolbox import get_conf
|
116 |
+
CUSTOM_PATH, PATH_LOGGING = get_conf('CUSTOM_PATH', 'PATH_LOGGING')
|
117 |
+
|
118 |
+
# --- --- configurate gradio app block --- ---
|
119 |
+
app_block:gr.Blocks
|
120 |
+
app_block.ssl_verify = False
|
121 |
+
app_block.auth_message = '请登录'
|
122 |
+
app_block.favicon_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "docs/logo.png")
|
123 |
+
app_block.auth = AUTHENTICATION if len(AUTHENTICATION) != 0 else None
|
124 |
+
app_block.blocked_paths = ["config.py", "__pycache__", "config_private.py", "docker-compose.yml", "Dockerfile", f"{PATH_LOGGING}/admin"]
|
125 |
+
app_block.dev_mode = False
|
126 |
+
app_block.config = app_block.get_config_file()
|
127 |
+
app_block.enable_queue = True
|
128 |
+
app_block.queue(concurrency_count=CONCURRENT_COUNT)
|
129 |
+
app_block.validate_queue_settings()
|
130 |
+
app_block.show_api = False
|
131 |
+
app_block.config = app_block.get_config_file()
|
132 |
+
max_threads = 40
|
133 |
+
app_block.max_threads = max(
|
134 |
+
app_block._queue.max_thread_count if app_block.enable_queue else 0, max_threads
|
135 |
+
)
|
136 |
+
app_block.is_colab = False
|
137 |
+
app_block.is_kaggle = False
|
138 |
+
app_block.is_sagemaker = False
|
139 |
+
|
140 |
+
gradio_app = App.create_app(app_block)
|
141 |
+
for route in list(gradio_app.router.routes):
|
142 |
+
if route.path == "/proxy={url_path:path}":
|
143 |
+
gradio_app.router.routes.remove(route)
|
144 |
+
# --- --- replace gradio endpoint to forbid access to sensitive files --- ---
|
145 |
+
if len(AUTHENTICATION) > 0:
|
146 |
+
dependencies = []
|
147 |
+
endpoint = None
|
148 |
+
for route in list(gradio_app.router.routes):
|
149 |
+
if route.path == "/file/{path:path}":
|
150 |
+
gradio_app.router.routes.remove(route)
|
151 |
+
if route.path == "/file={path_or_url:path}":
|
152 |
+
dependencies = route.dependencies
|
153 |
+
endpoint = route.endpoint
|
154 |
+
gradio_app.router.routes.remove(route)
|
155 |
+
@gradio_app.get("/file/{path:path}", dependencies=dependencies)
|
156 |
+
@gradio_app.head("/file={path_or_url:path}", dependencies=dependencies)
|
157 |
+
@gradio_app.get("/file={path_or_url:path}", dependencies=dependencies)
|
158 |
+
async def file(path_or_url: str, request: fastapi.Request):
|
159 |
+
if not _authorize_user(path_or_url, request, gradio_app):
|
160 |
+
return "越权访问!"
|
161 |
+
stripped = path_or_url.lstrip().lower()
|
162 |
+
if stripped.startswith("https://") or stripped.startswith("http://"):
|
163 |
+
return "账户密码授权模式下, 禁止链接!"
|
164 |
+
if '../' in stripped:
|
165 |
+
return "非法路径!"
|
166 |
+
return await endpoint(path_or_url, request)
|
167 |
+
|
168 |
+
from fastapi import Request, status
|
169 |
+
from fastapi.responses import FileResponse, RedirectResponse
|
170 |
+
@gradio_app.get("/academic_logout")
|
171 |
+
async def logout():
|
172 |
+
response = RedirectResponse(url=CUSTOM_PATH, status_code=status.HTTP_302_FOUND)
|
173 |
+
response.delete_cookie('access-token')
|
174 |
+
response.delete_cookie('access-token-unsecure')
|
175 |
+
return response
|
176 |
+
else:
|
177 |
+
dependencies = []
|
178 |
+
endpoint = None
|
179 |
+
for route in list(gradio_app.router.routes):
|
180 |
+
if route.path == "/file/{path:path}":
|
181 |
+
gradio_app.router.routes.remove(route)
|
182 |
+
if route.path == "/file={path_or_url:path}":
|
183 |
+
dependencies = route.dependencies
|
184 |
+
endpoint = route.endpoint
|
185 |
+
gradio_app.router.routes.remove(route)
|
186 |
+
@gradio_app.get("/file/{path:path}", dependencies=dependencies)
|
187 |
+
@gradio_app.head("/file={path_or_url:path}", dependencies=dependencies)
|
188 |
+
@gradio_app.get("/file={path_or_url:path}", dependencies=dependencies)
|
189 |
+
async def file(path_or_url: str, request: fastapi.Request):
|
190 |
+
stripped = path_or_url.lstrip().lower()
|
191 |
+
if stripped.startswith("https://") or stripped.startswith("http://"):
|
192 |
+
return "账户密码授权模式下, 禁止链接!"
|
193 |
+
if '../' in stripped:
|
194 |
+
return "非法路径!"
|
195 |
+
return await endpoint(path_or_url, request)
|
196 |
+
|
197 |
+
# --- --- enable TTS (text-to-speech) functionality --- ---
|
198 |
+
TTS_TYPE = get_conf("TTS_TYPE")
|
199 |
+
if TTS_TYPE != "DISABLE":
|
200 |
+
# audio generation functionality
|
201 |
+
import httpx
|
202 |
+
from fastapi import FastAPI, Request, HTTPException
|
203 |
+
from starlette.responses import Response
|
204 |
+
async def forward_request(request: Request, method: str) -> Response:
|
205 |
+
async with httpx.AsyncClient() as client:
|
206 |
+
try:
|
207 |
+
# Forward the request to the target service
|
208 |
+
if TTS_TYPE == "EDGE_TTS":
|
209 |
+
import tempfile
|
210 |
+
import edge_tts
|
211 |
+
import wave
|
212 |
+
import uuid
|
213 |
+
from pydub import AudioSegment
|
214 |
+
json = await request.json()
|
215 |
+
voice = get_conf("EDGE_TTS_VOICE")
|
216 |
+
tts = edge_tts.Communicate(text=json['text'], voice=voice)
|
217 |
+
temp_folder = tempfile.gettempdir()
|
218 |
+
temp_file_name = str(uuid.uuid4().hex)
|
219 |
+
temp_file = os.path.join(temp_folder, f'{temp_file_name}.mp3')
|
220 |
+
await tts.save(temp_file)
|
221 |
+
try:
|
222 |
+
mp3_audio = AudioSegment.from_file(temp_file, format="mp3")
|
223 |
+
mp3_audio.export(temp_file, format="wav")
|
224 |
+
with open(temp_file, 'rb') as wav_file: t = wav_file.read()
|
225 |
+
os.remove(temp_file)
|
226 |
+
return Response(content=t)
|
227 |
+
except:
|
228 |
+
raise RuntimeError("ffmpeg未安装,无法处理EdgeTTS音频��安装方法见`https://github.com/jiaaro/pydub#getting-ffmpeg-set-up`")
|
229 |
+
if TTS_TYPE == "LOCAL_SOVITS_API":
|
230 |
+
# Forward the request to the target service
|
231 |
+
TARGET_URL = get_conf("GPT_SOVITS_URL")
|
232 |
+
body = await request.body()
|
233 |
+
resp = await client.post(TARGET_URL, content=body, timeout=60)
|
234 |
+
# Return the response from the target service
|
235 |
+
return Response(content=resp.content, status_code=resp.status_code, headers=dict(resp.headers))
|
236 |
+
except httpx.RequestError as e:
|
237 |
+
raise HTTPException(status_code=400, detail=f"Request to the target service failed: {str(e)}")
|
238 |
+
@gradio_app.post("/vits")
|
239 |
+
async def forward_post_request(request: Request):
|
240 |
+
return await forward_request(request, "POST")
|
241 |
+
|
242 |
+
# --- --- app_lifespan --- ---
|
243 |
+
from contextlib import asynccontextmanager
|
244 |
+
@asynccontextmanager
|
245 |
+
async def app_lifespan(app):
|
246 |
+
async def startup_gradio_app():
|
247 |
+
if gradio_app.get_blocks().enable_queue:
|
248 |
+
gradio_app.get_blocks().startup_events()
|
249 |
+
async def shutdown_gradio_app():
|
250 |
+
pass
|
251 |
+
await startup_gradio_app() # startup logic here
|
252 |
+
yield # The application will serve requests after this point
|
253 |
+
await shutdown_gradio_app() # cleanup/shutdown logic here
|
254 |
+
|
255 |
+
# --- --- FastAPI --- ---
|
256 |
+
fastapi_app = FastAPI(lifespan=app_lifespan)
|
257 |
+
fastapi_app.mount(CUSTOM_PATH, gradio_app)
|
258 |
+
|
259 |
+
# --- --- favicon and block fastapi api reference routes --- ---
|
260 |
+
from starlette.responses import JSONResponse
|
261 |
+
if CUSTOM_PATH != '/':
|
262 |
+
from fastapi.responses import FileResponse
|
263 |
+
@fastapi_app.get("/favicon.ico")
|
264 |
+
async def favicon():
|
265 |
+
return FileResponse(app_block.favicon_path)
|
266 |
+
|
267 |
+
@fastapi_app.middleware("http")
|
268 |
+
async def middleware(request: Request, call_next):
|
269 |
+
if request.scope['path'] in ["/docs", "/redoc", "/openapi.json"]:
|
270 |
+
return JSONResponse(status_code=404, content={"message": "Not Found"})
|
271 |
+
response = await call_next(request)
|
272 |
+
return response
|
273 |
+
|
274 |
+
|
275 |
+
# --- --- uvicorn.Config --- ---
|
276 |
+
ssl_keyfile = None if SSL_KEYFILE == "" else SSL_KEYFILE
|
277 |
+
ssl_certfile = None if SSL_CERTFILE == "" else SSL_CERTFILE
|
278 |
+
server_name = "0.0.0.0"
|
279 |
+
config = uvicorn.Config(
|
280 |
+
fastapi_app,
|
281 |
+
host=server_name,
|
282 |
+
port=PORT,
|
283 |
+
reload=False,
|
284 |
+
log_level="warning",
|
285 |
+
ssl_keyfile=ssl_keyfile,
|
286 |
+
ssl_certfile=ssl_certfile,
|
287 |
+
)
|
288 |
+
server = Server(config)
|
289 |
+
url_host_name = "localhost" if server_name == "0.0.0.0" else server_name
|
290 |
+
if ssl_keyfile is not None:
|
291 |
+
if ssl_certfile is None:
|
292 |
+
raise ValueError(
|
293 |
+
"ssl_certfile must be provided if ssl_keyfile is provided."
|
294 |
+
)
|
295 |
+
path_to_local_server = f"https://{url_host_name}:{PORT}/"
|
296 |
+
else:
|
297 |
+
path_to_local_server = f"http://{url_host_name}:{PORT}/"
|
298 |
+
if CUSTOM_PATH != '/':
|
299 |
+
path_to_local_server += CUSTOM_PATH.lstrip('/').rstrip('/') + '/'
|
300 |
+
# --- --- begin --- ---
|
301 |
+
server.run_in_thread()
|
302 |
+
|
303 |
+
# --- --- after server launch --- ---
|
304 |
+
app_block.server = server
|
305 |
+
app_block.server_name = server_name
|
306 |
+
app_block.local_url = path_to_local_server
|
307 |
+
app_block.protocol = (
|
308 |
+
"https"
|
309 |
+
if app_block.local_url.startswith("https") or app_block.is_colab
|
310 |
+
else "http"
|
311 |
+
)
|
312 |
+
|
313 |
+
if app_block.enable_queue:
|
314 |
+
app_block._queue.set_url(path_to_local_server)
|
315 |
+
|
316 |
+
forbid_proxies = {
|
317 |
+
"http": "",
|
318 |
+
"https": "",
|
319 |
+
}
|
320 |
+
requests.get(f"{app_block.local_url}startup-events", verify=app_block.ssl_verify, proxies=forbid_proxies)
|
321 |
+
app_block.is_running = True
|
322 |
+
app_block.block_thread()
|
docker_as_a_service/shared_utils/handle_upload.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import time
|
3 |
+
import inspect
|
4 |
+
import re
|
5 |
+
import os
|
6 |
+
import base64
|
7 |
+
import gradio
|
8 |
+
import shutil
|
9 |
+
import glob
|
10 |
+
from shared_utils.config_loader import get_conf
|
11 |
+
from loguru import logger
|
12 |
+
|
13 |
+
def html_local_file(file):
|
14 |
+
base_path = os.path.dirname(__file__) # 项目目录
|
15 |
+
if os.path.exists(str(file)):
|
16 |
+
file = f'file={file.replace(base_path, ".")}'
|
17 |
+
return file
|
18 |
+
|
19 |
+
|
20 |
+
def html_local_img(__file, layout="left", max_width=None, max_height=None, md=True):
|
21 |
+
style = ""
|
22 |
+
if max_width is not None:
|
23 |
+
style += f"max-width: {max_width};"
|
24 |
+
if max_height is not None:
|
25 |
+
style += f"max-height: {max_height};"
|
26 |
+
__file = html_local_file(__file)
|
27 |
+
a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
|
28 |
+
if md:
|
29 |
+
a = f"![{__file}]({__file})"
|
30 |
+
return a
|
31 |
+
|
32 |
+
|
33 |
+
def file_manifest_filter_type(file_list, filter_: list = None):
|
34 |
+
new_list = []
|
35 |
+
if not filter_:
|
36 |
+
filter_ = ["png", "jpg", "jpeg"]
|
37 |
+
for file in file_list:
|
38 |
+
if str(os.path.basename(file)).split(".")[-1] in filter_:
|
39 |
+
new_list.append(html_local_img(file, md=False))
|
40 |
+
else:
|
41 |
+
new_list.append(file)
|
42 |
+
return new_list
|
43 |
+
|
44 |
+
|
45 |
+
def zip_extract_member_new(self, member, targetpath, pwd):
|
46 |
+
# 修复中文乱码的问题
|
47 |
+
"""Extract the ZipInfo object 'member' to a physical
|
48 |
+
file on the path targetpath.
|
49 |
+
"""
|
50 |
+
import zipfile
|
51 |
+
if not isinstance(member, zipfile.ZipInfo):
|
52 |
+
member = self.getinfo(member)
|
53 |
+
|
54 |
+
# build the destination pathname, replacing
|
55 |
+
# forward slashes to platform specific separators.
|
56 |
+
arcname = member.filename.replace('/', os.path.sep)
|
57 |
+
arcname = arcname.encode('cp437', errors='replace').decode('gbk', errors='replace')
|
58 |
+
|
59 |
+
if os.path.altsep:
|
60 |
+
arcname = arcname.replace(os.path.altsep, os.path.sep)
|
61 |
+
# interpret absolute pathname as relative, remove drive letter or
|
62 |
+
# UNC path, redundant separators, "." and ".." components.
|
63 |
+
arcname = os.path.splitdrive(arcname)[1]
|
64 |
+
invalid_path_parts = ('', os.path.curdir, os.path.pardir)
|
65 |
+
arcname = os.path.sep.join(x for x in arcname.split(os.path.sep)
|
66 |
+
if x not in invalid_path_parts)
|
67 |
+
if os.path.sep == '\\':
|
68 |
+
# filter illegal characters on Windows
|
69 |
+
arcname = self._sanitize_windows_name(arcname, os.path.sep)
|
70 |
+
|
71 |
+
targetpath = os.path.join(targetpath, arcname)
|
72 |
+
targetpath = os.path.normpath(targetpath)
|
73 |
+
|
74 |
+
# Create all upper directories if necessary.
|
75 |
+
upperdirs = os.path.dirname(targetpath)
|
76 |
+
if upperdirs and not os.path.exists(upperdirs):
|
77 |
+
os.makedirs(upperdirs)
|
78 |
+
|
79 |
+
if member.is_dir():
|
80 |
+
if not os.path.isdir(targetpath):
|
81 |
+
os.mkdir(targetpath)
|
82 |
+
return targetpath
|
83 |
+
|
84 |
+
with self.open(member, pwd=pwd) as source, \
|
85 |
+
open(targetpath, "wb") as target:
|
86 |
+
shutil.copyfileobj(source, target)
|
87 |
+
|
88 |
+
return targetpath
|
89 |
+
|
90 |
+
|
91 |
+
def extract_archive(file_path, dest_dir):
|
92 |
+
import zipfile
|
93 |
+
import tarfile
|
94 |
+
import os
|
95 |
+
|
96 |
+
# Get the file extension of the input file
|
97 |
+
file_extension = os.path.splitext(file_path)[1]
|
98 |
+
|
99 |
+
# Extract the archive based on its extension
|
100 |
+
if file_extension == ".zip":
|
101 |
+
with zipfile.ZipFile(file_path, "r") as zipobj:
|
102 |
+
zipobj._extract_member = lambda a,b,c: zip_extract_member_new(zipobj, a,b,c) # 修复中文乱码的问题
|
103 |
+
zipobj.extractall(path=dest_dir)
|
104 |
+
logger.info("Successfully extracted zip archive to {}".format(dest_dir))
|
105 |
+
|
106 |
+
elif file_extension in [".tar", ".gz", ".bz2"]:
|
107 |
+
try:
|
108 |
+
with tarfile.open(file_path, "r:*") as tarobj:
|
109 |
+
# 清理提取路径,移除任何不安全的元素
|
110 |
+
for member in tarobj.getmembers():
|
111 |
+
member_path = os.path.normpath(member.name)
|
112 |
+
full_path = os.path.join(dest_dir, member_path)
|
113 |
+
full_path = os.path.abspath(full_path)
|
114 |
+
if not full_path.startswith(os.path.abspath(dest_dir) + os.sep):
|
115 |
+
raise Exception(f"Attempted Path Traversal in {member.name}")
|
116 |
+
|
117 |
+
tarobj.extractall(path=dest_dir)
|
118 |
+
logger.info("Successfully extracted tar archive to {}".format(dest_dir))
|
119 |
+
except tarfile.ReadError as e:
|
120 |
+
if file_extension == ".gz":
|
121 |
+
# 一些特别奇葩的项目,是一个gz文件,里面不是tar,只有一个tex文件
|
122 |
+
import gzip
|
123 |
+
with gzip.open(file_path, 'rb') as f_in:
|
124 |
+
with open(os.path.join(dest_dir, 'main.tex'), 'wb') as f_out:
|
125 |
+
f_out.write(f_in.read())
|
126 |
+
else:
|
127 |
+
raise e
|
128 |
+
|
129 |
+
# 第三方库,需要预先pip install rarfile
|
130 |
+
# 此外,Windows上还需要安装winrar软件,配置其Path环境变量,如"C:\Program Files\WinRAR"才可以
|
131 |
+
elif file_extension == ".rar":
|
132 |
+
try:
|
133 |
+
import rarfile
|
134 |
+
|
135 |
+
with rarfile.RarFile(file_path) as rf:
|
136 |
+
rf.extractall(path=dest_dir)
|
137 |
+
logger.info("Successfully extracted rar archive to {}".format(dest_dir))
|
138 |
+
except:
|
139 |
+
logger.info("Rar format requires additional dependencies to install")
|
140 |
+
return "\n\n解压失败! 需要安装pip install rarfile来解压rar文件。建议:使用zip压缩格式。"
|
141 |
+
|
142 |
+
# 第三方库,需要预先pip install py7zr
|
143 |
+
elif file_extension == ".7z":
|
144 |
+
try:
|
145 |
+
import py7zr
|
146 |
+
|
147 |
+
with py7zr.SevenZipFile(file_path, mode="r") as f:
|
148 |
+
f.extractall(path=dest_dir)
|
149 |
+
logger.info("Successfully extracted 7z archive to {}".format(dest_dir))
|
150 |
+
except:
|
151 |
+
logger.info("7z format requires additional dependencies to install")
|
152 |
+
return "\n\n解压失败! 需要安装pip install py7zr来解压7z文件"
|
153 |
+
else:
|
154 |
+
return ""
|
155 |
+
return ""
|
156 |
+
|
docker_as_a_service/shared_utils/key_pattern_manager.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import os
|
3 |
+
from functools import wraps, lru_cache
|
4 |
+
from shared_utils.advanced_markdown_format import format_io
|
5 |
+
from shared_utils.config_loader import get_conf as get_conf
|
6 |
+
|
7 |
+
|
8 |
+
pj = os.path.join
|
9 |
+
default_user_name = 'default_user'
|
10 |
+
|
11 |
+
# match openai keys
|
12 |
+
openai_regex = re.compile(
|
13 |
+
r"sk-[a-zA-Z0-9_-]{48}$|" +
|
14 |
+
r"sk-[a-zA-Z0-9_-]{92}$|" +
|
15 |
+
r"sk-proj-[a-zA-Z0-9_-]{48}$|"+
|
16 |
+
r"sk-proj-[a-zA-Z0-9_-]{124}$|"+
|
17 |
+
r"sk-proj-[a-zA-Z0-9_-]{156}$|"+ #新版apikey位数不匹配故修改此正则表达式
|
18 |
+
r"sess-[a-zA-Z0-9]{40}$"
|
19 |
+
)
|
20 |
+
def is_openai_api_key(key):
|
21 |
+
CUSTOM_API_KEY_PATTERN = get_conf('CUSTOM_API_KEY_PATTERN')
|
22 |
+
if len(CUSTOM_API_KEY_PATTERN) != 0:
|
23 |
+
API_MATCH_ORIGINAL = re.match(CUSTOM_API_KEY_PATTERN, key)
|
24 |
+
else:
|
25 |
+
API_MATCH_ORIGINAL = openai_regex.match(key)
|
26 |
+
return bool(API_MATCH_ORIGINAL)
|
27 |
+
|
28 |
+
|
29 |
+
def is_azure_api_key(key):
|
30 |
+
API_MATCH_AZURE = re.match(r"[a-zA-Z0-9]{32}$", key)
|
31 |
+
return bool(API_MATCH_AZURE)
|
32 |
+
|
33 |
+
|
34 |
+
def is_api2d_key(key):
|
35 |
+
API_MATCH_API2D = re.match(r"fk[a-zA-Z0-9]{6}-[a-zA-Z0-9]{32}$", key)
|
36 |
+
return bool(API_MATCH_API2D)
|
37 |
+
|
38 |
+
def is_openroute_api_key(key):
|
39 |
+
API_MATCH_OPENROUTE = re.match(r"sk-or-v1-[a-zA-Z0-9]{64}$", key)
|
40 |
+
return bool(API_MATCH_OPENROUTE)
|
41 |
+
|
42 |
+
def is_cohere_api_key(key):
|
43 |
+
API_MATCH_AZURE = re.match(r"[a-zA-Z0-9]{40}$", key)
|
44 |
+
return bool(API_MATCH_AZURE)
|
45 |
+
|
46 |
+
|
47 |
+
def is_any_api_key(key):
|
48 |
+
if ',' in key:
|
49 |
+
keys = key.split(',')
|
50 |
+
for k in keys:
|
51 |
+
if is_any_api_key(k): return True
|
52 |
+
return False
|
53 |
+
else:
|
54 |
+
return is_openai_api_key(key) or is_api2d_key(key) or is_azure_api_key(key) or is_cohere_api_key(key)
|
55 |
+
|
56 |
+
|
57 |
+
def what_keys(keys):
|
58 |
+
avail_key_list = {'OpenAI Key': 0, "Azure Key": 0, "API2D Key": 0}
|
59 |
+
key_list = keys.split(',')
|
60 |
+
|
61 |
+
for k in key_list:
|
62 |
+
if is_openai_api_key(k):
|
63 |
+
avail_key_list['OpenAI Key'] += 1
|
64 |
+
|
65 |
+
for k in key_list:
|
66 |
+
if is_api2d_key(k):
|
67 |
+
avail_key_list['API2D Key'] += 1
|
68 |
+
|
69 |
+
for k in key_list:
|
70 |
+
if is_azure_api_key(k):
|
71 |
+
avail_key_list['Azure Key'] += 1
|
72 |
+
|
73 |
+
return f"检测到: OpenAI Key {avail_key_list['OpenAI Key']} 个, Azure Key {avail_key_list['Azure Key']} 个, API2D Key {avail_key_list['API2D Key']} 个"
|
74 |
+
|
75 |
+
|
76 |
+
def select_api_key(keys, llm_model):
|
77 |
+
import random
|
78 |
+
avail_key_list = []
|
79 |
+
key_list = keys.split(',')
|
80 |
+
|
81 |
+
if llm_model.startswith('gpt-') or llm_model.startswith('one-api-') or llm_model.startswith('o1-'):
|
82 |
+
for k in key_list:
|
83 |
+
if is_openai_api_key(k): avail_key_list.append(k)
|
84 |
+
|
85 |
+
if llm_model.startswith('api2d-'):
|
86 |
+
for k in key_list:
|
87 |
+
if is_api2d_key(k): avail_key_list.append(k)
|
88 |
+
|
89 |
+
if llm_model.startswith('azure-'):
|
90 |
+
for k in key_list:
|
91 |
+
if is_azure_api_key(k): avail_key_list.append(k)
|
92 |
+
|
93 |
+
if llm_model.startswith('cohere-'):
|
94 |
+
for k in key_list:
|
95 |
+
if is_cohere_api_key(k): avail_key_list.append(k)
|
96 |
+
|
97 |
+
if llm_model.startswith('openrouter-'):
|
98 |
+
for k in key_list:
|
99 |
+
if is_openroute_api_key(k): avail_key_list.append(k)
|
100 |
+
|
101 |
+
if len(avail_key_list) == 0:
|
102 |
+
raise RuntimeError(f"您提供的api-key不满足要求,不包含任何可用于{llm_model}的api-key。您可能选择了错误的模型或请求源(左上角更换模型菜单中可切换openai,azure,claude,cohere等请求源)。")
|
103 |
+
|
104 |
+
api_key = random.choice(avail_key_list) # 随机负载均衡
|
105 |
+
return api_key
|
106 |
+
|
107 |
+
|
108 |
+
def select_api_key_for_embed_models(keys, llm_model):
|
109 |
+
import random
|
110 |
+
avail_key_list = []
|
111 |
+
key_list = keys.split(',')
|
112 |
+
|
113 |
+
if llm_model.startswith('text-embedding-'):
|
114 |
+
for k in key_list:
|
115 |
+
if is_openai_api_key(k): avail_key_list.append(k)
|
116 |
+
|
117 |
+
if len(avail_key_list) == 0:
|
118 |
+
raise RuntimeError(f"您提供的api-key不满足要求,不包含任何可用于{llm_model}的api-key。您可能选择了错误的模型或请求源。")
|
119 |
+
|
120 |
+
api_key = random.choice(avail_key_list) # 随机负载均衡
|
121 |
+
return api_key
|
docker_as_a_service/shared_utils/logging.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from loguru import logger
|
2 |
+
import logging
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
|
6 |
+
def chat_log_filter(record):
|
7 |
+
return "chat_msg" in record["extra"]
|
8 |
+
|
9 |
+
def not_chat_log_filter(record):
|
10 |
+
return "chat_msg" not in record["extra"]
|
11 |
+
|
12 |
+
def formatter_with_clip(record):
|
13 |
+
# Note this function returns the string to be formatted, not the actual message to be logged
|
14 |
+
# record["extra"]["serialized"] = "555555"
|
15 |
+
max_len = 12
|
16 |
+
record['function_x'] = record['function'].center(max_len)
|
17 |
+
if len(record['function_x']) > max_len:
|
18 |
+
record['function_x'] = ".." + record['function_x'][-(max_len-2):]
|
19 |
+
record['line_x'] = str(record['line']).ljust(3)
|
20 |
+
return '<green>{time:HH:mm}</green> | <cyan>{function_x}</cyan>:<cyan>{line_x}</cyan> | <level>{message}</level>\n'
|
21 |
+
|
22 |
+
def setup_logging(PATH_LOGGING):
|
23 |
+
|
24 |
+
admin_log_path = os.path.join(PATH_LOGGING, "admin")
|
25 |
+
os.makedirs(admin_log_path, exist_ok=True)
|
26 |
+
sensitive_log_path = os.path.join(admin_log_path, "chat_secrets.log")
|
27 |
+
regular_log_path = os.path.join(admin_log_path, "console_log.log")
|
28 |
+
logger.remove()
|
29 |
+
logger.configure(
|
30 |
+
levels=[dict(name="WARNING", color="<g>")],
|
31 |
+
)
|
32 |
+
|
33 |
+
logger.add(
|
34 |
+
sys.stderr,
|
35 |
+
format=formatter_with_clip,
|
36 |
+
# format='<green>{time:HH:mm}</green> | <cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>',
|
37 |
+
filter=(lambda record: not chat_log_filter(record)),
|
38 |
+
colorize=True,
|
39 |
+
enqueue=True
|
40 |
+
)
|
41 |
+
|
42 |
+
logger.add(
|
43 |
+
sensitive_log_path,
|
44 |
+
format='<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>',
|
45 |
+
rotation="10 MB",
|
46 |
+
filter=chat_log_filter,
|
47 |
+
enqueue=True,
|
48 |
+
)
|
49 |
+
|
50 |
+
logger.add(
|
51 |
+
regular_log_path,
|
52 |
+
format='<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>',
|
53 |
+
rotation="10 MB",
|
54 |
+
filter=not_chat_log_filter,
|
55 |
+
enqueue=True,
|
56 |
+
)
|
57 |
+
|
58 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
59 |
+
|
60 |
+
logger.warning(f"所有对话记录将自动保存在本地目录{sensitive_log_path}, 请注意自我隐私保护哦!")
|
61 |
+
|
62 |
+
|
63 |
+
# logger.bind(chat_msg=True).info("This message is logged to the file!")
|
64 |
+
# logger.debug(f"debug message")
|
65 |
+
# logger.info(f"info message")
|
66 |
+
# logger.success(f"success message")
|
67 |
+
# logger.error(f"error message")
|
68 |
+
# logger.add("special.log", filter=lambda record: "special" in record["extra"])
|
69 |
+
# logger.debug("This message is not logged to the file")
|
docker_as_a_service/shared_utils/map_names.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
mapping_dic = {
|
3 |
+
# "qianfan": "qianfan(文心一言大模型)",
|
4 |
+
# "zhipuai": "zhipuai(智谱GLM4超级模型🔥)",
|
5 |
+
# "gpt-4-1106-preview": "gpt-4-1106-preview(新调优版本GPT-4🔥)",
|
6 |
+
# "gpt-4-vision-preview": "gpt-4-vision-preview(识图模型GPT-4V)",
|
7 |
+
}
|
8 |
+
|
9 |
+
rev_mapping_dic = {}
|
10 |
+
for k, v in mapping_dic.items():
|
11 |
+
rev_mapping_dic[v] = k
|
12 |
+
|
13 |
+
def map_model_to_friendly_names(m):
|
14 |
+
if m in mapping_dic:
|
15 |
+
return mapping_dic[m]
|
16 |
+
return m
|
17 |
+
|
18 |
+
def map_friendly_names_to_model(m):
|
19 |
+
if m in rev_mapping_dic:
|
20 |
+
return rev_mapping_dic[m]
|
21 |
+
return m
|
22 |
+
|
23 |
+
def read_one_api_model_name(model: str):
|
24 |
+
"""return real model name and max_token.
|
25 |
+
"""
|
26 |
+
max_token_pattern = r"\(max_token=(\d+)\)"
|
27 |
+
match = re.search(max_token_pattern, model)
|
28 |
+
if match:
|
29 |
+
max_token_tmp = match.group(1) # 获取 max_token 的值
|
30 |
+
max_token_tmp = int(max_token_tmp)
|
31 |
+
model = re.sub(max_token_pattern, "", model) # 从原字符串中删除 "(max_token=...)"
|
32 |
+
else:
|
33 |
+
max_token_tmp = 4096
|
34 |
+
return model, max_token_tmp
|
docker_as_a_service/shared_utils/text_mask.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from functools import lru_cache
|
3 |
+
|
4 |
+
# 这段代码是使用Python编程语言中的re模块,即正则表达式库,来定义了一个正则表达式模式。
|
5 |
+
# 这个模式被编译成一个正则表达式对象,存储在名为const_extract_exp的变量中,以便于后续快速的匹配和查找操作。
|
6 |
+
# 这里解释一下正则表达式中的几个特殊字符:
|
7 |
+
# - . 表示任意单一字符。
|
8 |
+
# - * 表示前一个字符可以出现0次或多次。
|
9 |
+
# - ? 在这里用作非贪婪匹配,也就是说它会匹配尽可能少的字符。在(.*?)中,它确保我们匹配的任意文本是尽可能短的,也就是说,它会在</show_llm>和</show_render>标签之前停止匹配。
|
10 |
+
# - () 括号在正则表达式中表示捕获组。
|
11 |
+
# - 在这个例子中,(.*?)表示捕获任意长度的文本,直到遇到括号外部最近的限定符,即</show_llm>和</show_render>。
|
12 |
+
|
13 |
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-==-=-=-=/1=-=-=-=-=-=-=-=-=-=-=-=-=-=/2-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
14 |
+
const_extract_re = re.compile(
|
15 |
+
r"<gpt_academic_string_mask><show_llm>(.*?)</show_llm><show_render>(.*?)</show_render></gpt_academic_string_mask>"
|
16 |
+
)
|
17 |
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-==-=-=-=-=-=/1=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-/2-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
18 |
+
const_extract_langbased_re = re.compile(
|
19 |
+
r"<gpt_academic_string_mask><lang_english>(.*?)</lang_english><lang_chinese>(.*?)</lang_chinese></gpt_academic_string_mask>",
|
20 |
+
flags=re.DOTALL,
|
21 |
+
)
|
22 |
+
|
23 |
+
@lru_cache(maxsize=128)
|
24 |
+
def apply_gpt_academic_string_mask(string, mode="show_all"):
|
25 |
+
"""
|
26 |
+
当字符串中有掩码tag时(<gpt_academic_string_mask><show_...>),根据字符串要给谁看(大模型,还是web渲染),对字符串进行处理,返回处理后的字符串
|
27 |
+
示意图:https://mermaid.live/edit#pako:eNqlkUtLw0AUhf9KuOta0iaTplkIPlpduFJwoZEwJGNbzItpita2O6tF8QGKogXFtwu7cSHiq3-mk_oznFR8IYLgrGbuOd9hDrcCpmcR0GDW9ubNPKaBMDauuwI_A9M6YN-3y0bODwxsYos4BdMoBrTg5gwHF-d0mBH6-vqFQe58ed5m9XPW2uteX3Tubrj0ljLYcwxxR3h1zB43WeMs3G19yEM9uapDMe_NG9i2dagKw1Fee4c1D9nGEbtc-5n6HbNtJ8IyHOs8tbs7V2HrlDX2w2Y7XD_5haHEtQiNsOwfMVa_7TzsvrWIuJGo02qTrdwLk9gukQylHv3Afv1ML270s-HZUndrmW1tdA-WfvbM_jMFYuAQ6uCCxVdciTJ1CPLEITpo_GphypeouzXuw6XAmyi7JmgBLZEYlHwLB2S4gHMUO-9DH7tTnvf1CVoFFkBLSOk4QmlRTqpIlaWUHINyNFXjaQWpCYRURUKiWovBYo8X4ymEJFlECQUpqaQkJmuvWygPpg
|
28 |
+
"""
|
29 |
+
if not string:
|
30 |
+
return string
|
31 |
+
if "<gpt_academic_string_mask>" not in string: # No need to process
|
32 |
+
return string
|
33 |
+
|
34 |
+
if mode == "show_all":
|
35 |
+
return string
|
36 |
+
if mode == "show_llm":
|
37 |
+
string = const_extract_re.sub(r"\1", string)
|
38 |
+
elif mode == "show_render":
|
39 |
+
string = const_extract_re.sub(r"\2", string)
|
40 |
+
else:
|
41 |
+
raise ValueError("Invalid mode")
|
42 |
+
return string
|
43 |
+
|
44 |
+
|
45 |
+
@lru_cache(maxsize=128)
|
46 |
+
def build_gpt_academic_masked_string(text_show_llm="", text_show_render=""):
|
47 |
+
"""
|
48 |
+
根据字符串要给谁看(大模型,还是web渲染),生成带掩码tag的字符串
|
49 |
+
"""
|
50 |
+
return f"<gpt_academic_string_mask><show_llm>{text_show_llm}</show_llm><show_render>{text_show_render}</show_render></gpt_academic_string_mask>"
|
51 |
+
|
52 |
+
|
53 |
+
@lru_cache(maxsize=128)
|
54 |
+
def apply_gpt_academic_string_mask_langbased(string, lang_reference):
|
55 |
+
"""
|
56 |
+
当字符串中有掩码tag时(<gpt_academic_string_mask><lang_...>),根据语言,选择提示词,对字符串进行处理,返回处理后的字符串
|
57 |
+
例如,如果lang_reference是英文,那么就只显示英文提示词,中文提示词就不显示了
|
58 |
+
举例:
|
59 |
+
输入1
|
60 |
+
string = "注意,lang_reference这段文字是:<gpt_academic_string_mask><lang_english>英语</lang_english><lang_chinese>中文</lang_chinese></gpt_academic_string_mask>"
|
61 |
+
lang_reference = "hello world"
|
62 |
+
输出1
|
63 |
+
"注意,lang_reference这段文字是:英语"
|
64 |
+
|
65 |
+
输入2
|
66 |
+
string = "注意,lang_reference这段文字是中文" # 注意这里没有掩码tag,所以不会被处理
|
67 |
+
lang_reference = "hello world"
|
68 |
+
输出2
|
69 |
+
"注意,lang_reference这段文字是中文" # 原样返回
|
70 |
+
"""
|
71 |
+
|
72 |
+
if "<gpt_academic_string_mask>" not in string: # No need to process
|
73 |
+
return string
|
74 |
+
|
75 |
+
def contains_chinese(string):
|
76 |
+
chinese_regex = re.compile(u'[\u4e00-\u9fff]+')
|
77 |
+
return chinese_regex.search(string) is not None
|
78 |
+
|
79 |
+
mode = "english" if not contains_chinese(lang_reference) else "chinese"
|
80 |
+
if mode == "english":
|
81 |
+
string = const_extract_langbased_re.sub(r"\1", string)
|
82 |
+
elif mode == "chinese":
|
83 |
+
string = const_extract_langbased_re.sub(r"\2", string)
|
84 |
+
else:
|
85 |
+
raise ValueError("Invalid mode")
|
86 |
+
return string
|
87 |
+
|
88 |
+
|
89 |
+
@lru_cache(maxsize=128)
|
90 |
+
def build_gpt_academic_masked_string_langbased(text_show_english="", text_show_chinese=""):
|
91 |
+
"""
|
92 |
+
根据语言,选择提示词,对字符串进行处理,返回处理后的字符串
|
93 |
+
"""
|
94 |
+
return f"<gpt_academic_string_mask><lang_english>{text_show_english}</lang_english><lang_chinese>{text_show_chinese}</lang_chinese></gpt_academic_string_mask>"
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == "__main__":
|
98 |
+
# Test
|
99 |
+
input_string = (
|
100 |
+
"你好\n"
|
101 |
+
+ build_gpt_academic_masked_string(text_show_llm="mermaid", text_show_render="")
|
102 |
+
+ "你好\n"
|
103 |
+
)
|
104 |
+
print(
|
105 |
+
apply_gpt_academic_string_mask(input_string, "show_llm")
|
106 |
+
) # Should print the strings with 'abc' in place of the academic mask tags
|
107 |
+
print(
|
108 |
+
apply_gpt_academic_string_mask(input_string, "show_render")
|
109 |
+
) # Should print the strings with 'xyz' in place of the academic mask tags
|