qingxu98 commited on
Commit
312898d
·
1 Parent(s): 20ed30c
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pyc
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM fuqingxu/bbdown
2
+
3
+ RUN apt update && apt-get install -y python3 python3-dev python3-pip
4
+
5
+ RUN python3 -m pip install fastapi pydantic loguru --break-system-packages
6
+ RUN python3 -m pip install requests python-multipart --break-system-packages
7
+ RUN python3 -m pip install uvicorn --break-system-packages
8
+
9
+ COPY ./docker_as_a_service /docker_as_a_service
10
+
11
+ WORKDIR /docker_as_a_service
12
+
13
+ # ENTRYPOINT [ "python3", "docker_as_a_service.py" ]
14
+ ENTRYPOINT ["/bin/bash", "-c"]
15
+ CMD ["python3 docker_as_a_service.py"]
16
+ # CMD ["python3", "docker_as_a_service.py"]
README.md CHANGED
@@ -1,10 +1,15 @@
1
  ---
2
  title: Bbdown
3
- emoji: 🌍
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: docker
7
  pinned: false
 
8
  ---
9
 
10
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
  ---
2
  title: Bbdown
3
+ emoji: 🐳
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: docker
7
  pinned: false
8
+ app_port: 7860
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
12
+
13
+
14
+ 1. create space
15
+ 2.
docker_as_a_service/docker_as_a_service.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DaaS (Docker as a Service) is a service
3
+ that allows users to run docker commands on the server side.
4
+ """
5
+
6
+ from fastapi import FastAPI
7
+ from fastapi.responses import StreamingResponse
8
+ from fastapi import FastAPI, File, UploadFile, HTTPException
9
+ from pydantic import BaseModel, Field
10
+ from typing import Optional, Dict
11
+ import time
12
+ import os
13
+ import asyncio
14
+ import subprocess
15
+ import uuid
16
+ import threading
17
+ import queue
18
+ from shared_utils.docker_as_service_api import DockerServiceApiComModel
19
+
20
+ app = FastAPI()
21
+
22
+ def python_obj_to_pickle_file_bytes(obj):
23
+ import pickle
24
+ import io
25
+ with io.BytesIO() as f:
26
+ pickle.dump(obj, f)
27
+ return f.getvalue()
28
+
29
+ def yield_message(message):
30
+ dsacm = DockerServiceApiComModel(server_message=message)
31
+ return python_obj_to_pickle_file_bytes(dsacm)
32
+
33
+ def read_output(stream, output_queue):
34
+ while True:
35
+ line_stdout = stream.readline()
36
+ # print('recv')
37
+ if line_stdout:
38
+ output_queue.put(line_stdout)
39
+ else:
40
+ break
41
+
42
+
43
+ async def stream_generator(request_obj):
44
+ import tempfile
45
+ # Create a temporary directory
46
+ with tempfile.TemporaryDirectory() as temp_dir:
47
+
48
+ # Construct the docker command
49
+ download_folder = temp_dir
50
+
51
+ # Get list of existing files before download
52
+ existing_file_before_download = []
53
+
54
+ video_id = request_obj.client_command
55
+ cmd = [
56
+ '/root/.dotnet/tools/BBDown',
57
+ video_id,
58
+ '--use-app-api',
59
+ '--work-dir',
60
+ f'{os.path.abspath(temp_dir)}'
61
+ ]
62
+ cmd = ' '.join(cmd)
63
+ yield yield_message(cmd)
64
+ process = subprocess.Popen(cmd,
65
+ stdout=subprocess.PIPE,
66
+ stderr=subprocess.PIPE,
67
+ shell=True,
68
+ text=True)
69
+
70
+ stdout_queue = queue.Queue()
71
+ thread = threading.Thread(target=read_output, args=(process.stdout, stdout_queue))
72
+ thread.daemon = True
73
+ thread.start()
74
+ stderr_queue = queue.Queue()
75
+ thread = threading.Thread(target=read_output, args=(process.stderr, stderr_queue))
76
+ thread.daemon = True
77
+ thread.start()
78
+
79
+ while True:
80
+ print("looping")
81
+ # Check if there is any output in the queue
82
+ try:
83
+ output_stdout = stdout_queue.get_nowait() # Non-blocking get
84
+ if output_stdout:
85
+ print(output_stdout)
86
+ yield yield_message(output_stdout)
87
+
88
+ output_stderr = stderr_queue.get_nowait() # Non-blocking get
89
+ if output_stderr:
90
+ print(output_stdout)
91
+ yield yield_message(output_stderr)
92
+ except queue.Empty:
93
+ pass # No output available
94
+
95
+ # Break the loop if the process has finished
96
+ if process.poll() is not None:
97
+ break
98
+
99
+ await asyncio.sleep(0.25)
100
+
101
+ # Get the return code
102
+ return_code = process.returncode
103
+ yield yield_message("(return code:) " + str(return_code))
104
+
105
+ # print(f"Successfully downloaded video {video_id}")
106
+ existing_file_after_download = list(os.listdir(download_folder))
107
+ # get the difference
108
+ downloaded_files = [
109
+ f for f in existing_file_after_download if f not in existing_file_before_download
110
+ ]
111
+ downloaded_files_path = [
112
+ os.path.join(download_folder, f) for f in existing_file_after_download if f not in existing_file_before_download
113
+ ]
114
+ # read file
115
+ server_file_attach = {}
116
+ for fp, fn in zip(downloaded_files_path, downloaded_files):
117
+ with open(fp, "rb") as f:
118
+ file_bytes = f.read()
119
+ server_file_attach[fn] = file_bytes
120
+
121
+ dsacm = DockerServiceApiComModel(
122
+ server_message="complete",
123
+ server_file_attach=server_file_attach,
124
+ )
125
+ yield python_obj_to_pickle_file_bytes(dsacm)
126
+
127
+
128
+ @app.post("/stream")
129
+ async def stream_response(file: UploadFile = File(...)):
130
+ # read the file in memory, treat it as pickle file, and unpickle it
131
+ import pickle
132
+ import io
133
+ content = await file.read()
134
+ with io.BytesIO(content) as f:
135
+ request_obj = pickle.load(f)
136
+ # process the request_obj
137
+ return StreamingResponse(stream_generator(request_obj), media_type="application/octet-stream")
138
+
139
+ @app.get("/")
140
+ async def hi():
141
+ return "Hello, this is Docker as a Service (DaaS)!"
142
+
143
+ if __name__ == "__main__":
144
+ import uvicorn
145
+ uvicorn.run(app, host="0.0.0.0", port=49000)
docker_as_a_service/shared_utils/advanced_markdown_format.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import markdown
2
+ import re
3
+ import os
4
+ import math
5
+
6
+ from loguru import logger
7
+ from textwrap import dedent
8
+ from functools import lru_cache
9
+ from pymdownx.superfences import fence_code_format
10
+ from latex2mathml.converter import convert as tex2mathml
11
+ from shared_utils.config_loader import get_conf as get_conf
12
+ from shared_utils.text_mask import apply_gpt_academic_string_mask
13
+
14
+ markdown_extension_configs = {
15
+ "mdx_math": {
16
+ "enable_dollar_delimiter": True,
17
+ "use_gitlab_delimiters": False,
18
+ },
19
+ }
20
+
21
+ code_highlight_configs = {
22
+ "pymdownx.superfences": {
23
+ "css_class": "codehilite",
24
+ "custom_fences": [
25
+ {"name": "mermaid", "class": "mermaid", "format": fence_code_format}
26
+ ],
27
+ },
28
+ "pymdownx.highlight": {
29
+ "css_class": "codehilite",
30
+ "guess_lang": True,
31
+ # 'auto_title': True,
32
+ # 'linenums': True
33
+ },
34
+ }
35
+
36
+ code_highlight_configs_block_mermaid = {
37
+ "pymdownx.superfences": {
38
+ "css_class": "codehilite",
39
+ # "custom_fences": [
40
+ # {"name": "mermaid", "class": "mermaid", "format": fence_code_format}
41
+ # ],
42
+ },
43
+ "pymdownx.highlight": {
44
+ "css_class": "codehilite",
45
+ "guess_lang": True,
46
+ # 'auto_title': True,
47
+ # 'linenums': True
48
+ },
49
+ }
50
+
51
+
52
+ mathpatterns = {
53
+ r"(?<!\\|\$)(\$)([^\$]+)(\$)": {"allow_multi_lines": False}, #  $...$
54
+ r"(?<!\\)(\$\$)([^\$]+)(\$\$)": {"allow_multi_lines": True}, # $$...$$
55
+ r"(?<!\\)(\\\[)(.+?)(\\\])": {"allow_multi_lines": False}, # \[...\]
56
+ r'(?<!\\)(\\\()(.+?)(\\\))': {'allow_multi_lines': False}, # \(...\)
57
+ # r'(?<!\\)(\\begin{([a-z]+?\*?)})(.+?)(\\end{\2})': {'allow_multi_lines': True}, # \begin...\end
58
+ # r'(?<!\\)(\$`)([^`]+)(`\$)': {'allow_multi_lines': False}, # $`...`$
59
+ }
60
+
61
+ def tex2mathml_catch_exception(content, *args, **kwargs):
62
+ try:
63
+ content = tex2mathml(content, *args, **kwargs)
64
+ except:
65
+ content = content
66
+ return content
67
+
68
+
69
+ def replace_math_no_render(match):
70
+ content = match.group(1)
71
+ if "mode=display" in match.group(0):
72
+ content = content.replace("\n", "</br>")
73
+ return f'<font color="#00FF00">$$</font><font color="#FF00FF">{content}</font><font color="#00FF00">$$</font>'
74
+ else:
75
+ return f'<font color="#00FF00">$</font><font color="#FF00FF">{content}</font><font color="#00FF00">$</font>'
76
+
77
+
78
+ def replace_math_render(match):
79
+ content = match.group(1)
80
+ if "mode=display" in match.group(0):
81
+ if "\\begin{aligned}" in content:
82
+ content = content.replace("\\begin{aligned}", "\\begin{array}")
83
+ content = content.replace("\\end{aligned}", "\\end{array}")
84
+ content = content.replace("&", " ")
85
+ content = tex2mathml_catch_exception(content, display="block")
86
+ return content
87
+ else:
88
+ return tex2mathml_catch_exception(content)
89
+
90
+
91
+ def markdown_bug_hunt(content):
92
+ """
93
+ 解决一个mdx_math的bug(单$包裹begin命令时多余<script>)
94
+ """
95
+ content = content.replace(
96
+ '<script type="math/tex">\n<script type="math/tex; mode=display">',
97
+ '<script type="math/tex; mode=display">',
98
+ )
99
+ content = content.replace("</script>\n</script>", "</script>")
100
+ return content
101
+
102
+
103
+ def is_equation(txt):
104
+ """
105
+ 判定是否为公式 | 测试1 写出洛伦兹定律,使用tex格式公式 测试2 给出柯西不等式,使用latex格式 测试3 写出麦克斯韦方程组
106
+ """
107
+ if "```" in txt and "```reference" not in txt:
108
+ return False
109
+ if "$" not in txt and "\\[" not in txt:
110
+ return False
111
+
112
+ matches = []
113
+ for pattern, property in mathpatterns.items():
114
+ flags = re.ASCII | re.DOTALL if property["allow_multi_lines"] else re.ASCII
115
+ matches.extend(re.findall(pattern, txt, flags))
116
+ if len(matches) == 0:
117
+ return False
118
+ contain_any_eq = False
119
+ illegal_pattern = re.compile(r"[^\x00-\x7F]|echo")
120
+ for match in matches:
121
+ if len(match) != 3:
122
+ return False
123
+ eq_canidate = match[1]
124
+ if illegal_pattern.search(eq_canidate):
125
+ return False
126
+ else:
127
+ contain_any_eq = True
128
+ return contain_any_eq
129
+
130
+
131
+ def fix_markdown_indent(txt):
132
+ # fix markdown indent
133
+ if (" - " not in txt) or (". " not in txt):
134
+ # do not need to fix, fast escape
135
+ return txt
136
+ # walk through the lines and fix non-standard indentation
137
+ lines = txt.split("\n")
138
+ pattern = re.compile(r"^\s+-")
139
+ activated = False
140
+ for i, line in enumerate(lines):
141
+ if line.startswith("- ") or line.startswith("1. "):
142
+ activated = True
143
+ if activated and pattern.match(line):
144
+ stripped_string = line.lstrip()
145
+ num_spaces = len(line) - len(stripped_string)
146
+ if (num_spaces % 4) == 3:
147
+ num_spaces_should_be = math.ceil(num_spaces / 4) * 4
148
+ lines[i] = " " * num_spaces_should_be + stripped_string
149
+ return "\n".join(lines)
150
+
151
+
152
+ FENCED_BLOCK_RE = re.compile(
153
+ dedent(
154
+ r"""
155
+ (?P<fence>^[ \t]*(?:~{3,}|`{3,}))[ ]* # opening fence
156
+ ((\{(?P<attrs>[^\}\n]*)\})| # (optional {attrs} or
157
+ (\.?(?P<lang>[\w#.+-]*)[ ]*)? # optional (.)lang
158
+ (hl_lines=(?P<quot>"|')(?P<hl_lines>.*?)(?P=quot)[ ]*)?) # optional hl_lines)
159
+ \n # newline (end of opening fence)
160
+ (?P<code>.*?)(?<=\n) # the code block
161
+ (?P=fence)[ ]*$ # closing fence
162
+ """
163
+ ),
164
+ re.MULTILINE | re.DOTALL | re.VERBOSE,
165
+ )
166
+
167
+
168
+ def get_line_range(re_match_obj, txt):
169
+ start_pos, end_pos = re_match_obj.regs[0]
170
+ num_newlines_before = txt[: start_pos + 1].count("\n")
171
+ line_start = num_newlines_before
172
+ line_end = num_newlines_before + txt[start_pos:end_pos].count("\n") + 1
173
+ return line_start, line_end
174
+
175
+
176
+ def fix_code_segment_indent(txt):
177
+ lines = []
178
+ change_any = False
179
+ txt_tmp = txt
180
+ while True:
181
+ re_match_obj = FENCED_BLOCK_RE.search(txt_tmp)
182
+ if not re_match_obj:
183
+ break
184
+ if len(lines) == 0:
185
+ lines = txt.split("\n")
186
+
187
+ # 清空 txt_tmp 对应的位置方便下次搜索
188
+ start_pos, end_pos = re_match_obj.regs[0]
189
+ txt_tmp = txt_tmp[:start_pos] + " " * (end_pos - start_pos) + txt_tmp[end_pos:]
190
+ line_start, line_end = get_line_range(re_match_obj, txt)
191
+
192
+ # 获取公共缩进
193
+ shared_indent_cnt = 1e5
194
+ for i in range(line_start, line_end):
195
+ stripped_string = lines[i].lstrip()
196
+ num_spaces = len(lines[i]) - len(stripped_string)
197
+ if num_spaces < shared_indent_cnt:
198
+ shared_indent_cnt = num_spaces
199
+
200
+ # 修复缩进
201
+ if (shared_indent_cnt < 1e5) and (shared_indent_cnt % 4) == 3:
202
+ num_spaces_should_be = math.ceil(shared_indent_cnt / 4) * 4
203
+ for i in range(line_start, line_end):
204
+ add_n = num_spaces_should_be - shared_indent_cnt
205
+ lines[i] = " " * add_n + lines[i]
206
+ if not change_any: # 遇到第一个
207
+ change_any = True
208
+
209
+ if change_any:
210
+ return "\n".join(lines)
211
+ else:
212
+ return txt
213
+
214
+
215
+ def fix_dollar_sticking_bug(txt):
216
+ """
217
+ 修复不标准的dollar公式符号的问题
218
+ """
219
+ txt_result = ""
220
+ single_stack_height = 0
221
+ double_stack_height = 0
222
+ while True:
223
+ while True:
224
+ index = txt.find('$')
225
+
226
+ if index == -1:
227
+ txt_result += txt
228
+ return txt_result
229
+
230
+ if single_stack_height > 0:
231
+ if txt[:(index+1)].find('\n') > 0 or txt[:(index+1)].find('<td>') > 0 or txt[:(index+1)].find('</td>') > 0:
232
+ logger.error('公式之中出现了异常 (Unexpect element in equation)')
233
+ single_stack_height = 0
234
+ txt_result += ' $'
235
+ continue
236
+
237
+ if double_stack_height > 0:
238
+ if txt[:(index+1)].find('\n\n') > 0:
239
+ logger.error('公式之中出现了异常 (Unexpect element in equation)')
240
+ double_stack_height = 0
241
+ txt_result += '$$'
242
+ continue
243
+
244
+ is_double = (txt[index+1] == '$')
245
+ if is_double:
246
+ if single_stack_height != 0:
247
+ # add a padding
248
+ txt = txt[:(index+1)] + " " + txt[(index+1):]
249
+ continue
250
+ if double_stack_height == 0:
251
+ double_stack_height = 1
252
+ else:
253
+ double_stack_height = 0
254
+ txt_result += txt[:(index+2)]
255
+ txt = txt[(index+2):]
256
+ else:
257
+ if double_stack_height != 0:
258
+ # logger.info(txt[:(index)])
259
+ logger.info('发现异常嵌套公式')
260
+ if single_stack_height == 0:
261
+ single_stack_height = 1
262
+ else:
263
+ single_stack_height = 0
264
+ # logger.info(txt[:(index)])
265
+ txt_result += txt[:(index+1)]
266
+ txt = txt[(index+1):]
267
+ break
268
+
269
+
270
+ def markdown_convertion_for_file(txt):
271
+ """
272
+ 将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
273
+ """
274
+ from themes.theme import advanced_css
275
+ pre = f"""
276
+ <!DOCTYPE html><head><meta charset="utf-8"><title>GPT-Academic输出文档</title><style>{advanced_css}</style></head>
277
+ <body>
278
+ <div class="test_temp1" style="width:10%; height: 500px; float:left;"></div>
279
+ <div class="test_temp2" style="width:80%;padding: 40px;float:left;padding-left: 20px;padding-right: 20px;box-shadow: rgba(0, 0, 0, 0.2) 0px 0px 8px 8px;border-radius: 10px;">
280
+ <div class="markdown-body">
281
+ """
282
+ suf = """
283
+ </div>
284
+ </div>
285
+ <div class="test_temp3" style="width:10%; height: 500px; float:left;"></div>
286
+ </body>
287
+ """
288
+
289
+ if txt.startswith(pre) and txt.endswith(suf):
290
+ # print('警告,输入了已经经过转化的字符串,二次转化可能出问题')
291
+ return txt # 已经被转化过,不需要再次转化
292
+
293
+ find_equation_pattern = r'<script type="math/tex(?:.*?)>(.*?)</script>'
294
+ txt = fix_markdown_indent(txt)
295
+ convert_stage_1 = fix_dollar_sticking_bug(txt)
296
+ # convert everything to html format
297
+ convert_stage_2 = markdown.markdown(
298
+ text=convert_stage_1,
299
+ extensions=[
300
+ "sane_lists",
301
+ "tables",
302
+ "mdx_math",
303
+ "pymdownx.superfences",
304
+ "pymdownx.highlight",
305
+ ],
306
+ extension_configs={**markdown_extension_configs, **code_highlight_configs},
307
+ )
308
+
309
+
310
+ def repl_fn(match):
311
+ content = match.group(2)
312
+ return f'<script type="math/tex">{content}</script>'
313
+
314
+ pattern = "|".join([pattern for pattern, property in mathpatterns.items() if not property["allow_multi_lines"]])
315
+ pattern = re.compile(pattern, flags=re.ASCII)
316
+ convert_stage_3 = pattern.sub(repl_fn, convert_stage_2)
317
+
318
+ convert_stage_4 = markdown_bug_hunt(convert_stage_3)
319
+
320
+ # 2. convert to rendered equation
321
+ convert_stage_5, n = re.subn(
322
+ find_equation_pattern, replace_math_render, convert_stage_4, flags=re.DOTALL
323
+ )
324
+ # cat them together
325
+ return pre + convert_stage_5 + suf
326
+
327
+ @lru_cache(maxsize=128) # 使用 lru缓存 加快转换速度
328
+ def markdown_convertion(txt):
329
+ """
330
+ 将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
331
+ """
332
+ pre = '<div class="markdown-body">'
333
+ suf = "</div>"
334
+ if txt.startswith(pre) and txt.endswith(suf):
335
+ # print('警告,输入了已经经过转化的字符串,二次转化可能出问题')
336
+ return txt # 已经被转化过,不需要再次转化
337
+
338
+ find_equation_pattern = r'<script type="math/tex(?:.*?)>(.*?)</script>'
339
+
340
+ txt = fix_markdown_indent(txt)
341
+ # txt = fix_code_segment_indent(txt)
342
+ if is_equation(txt): # 有$标识的公式符号,且没有代码段```的标识
343
+ # convert everything to html format
344
+ split = markdown.markdown(text="---")
345
+ convert_stage_1 = markdown.markdown(
346
+ text=txt,
347
+ extensions=[
348
+ "sane_lists",
349
+ "tables",
350
+ "mdx_math",
351
+ "pymdownx.superfences",
352
+ "pymdownx.highlight",
353
+ ],
354
+ extension_configs={**markdown_extension_configs, **code_highlight_configs},
355
+ )
356
+ convert_stage_1 = markdown_bug_hunt(convert_stage_1)
357
+ # 1. convert to easy-to-copy tex (do not render math)
358
+ convert_stage_2_1, n = re.subn(
359
+ find_equation_pattern,
360
+ replace_math_no_render,
361
+ convert_stage_1,
362
+ flags=re.DOTALL,
363
+ )
364
+ # 2. convert to rendered equation
365
+ convert_stage_2_2, n = re.subn(
366
+ find_equation_pattern, replace_math_render, convert_stage_1, flags=re.DOTALL
367
+ )
368
+ # cat them together
369
+ return pre + convert_stage_2_1 + f"{split}" + convert_stage_2_2 + suf
370
+ else:
371
+ return (
372
+ pre
373
+ + markdown.markdown(
374
+ txt,
375
+ extensions=[
376
+ "sane_lists",
377
+ "tables",
378
+ "pymdownx.superfences",
379
+ "pymdownx.highlight",
380
+ ],
381
+ extension_configs=code_highlight_configs,
382
+ )
383
+ + suf
384
+ )
385
+
386
+
387
+ def close_up_code_segment_during_stream(gpt_reply):
388
+ """
389
+ 在gpt输出代码的中途(输出了前面的```,但还没输出完后面的```),补上后面的```
390
+
391
+ Args:
392
+ gpt_reply (str): GPT模型返回的回复字符串。
393
+
394
+ Returns:
395
+ str: 返回一个新的字符串,将输出代码片段的“后面的```”补上。
396
+
397
+ """
398
+ if "```" not in gpt_reply:
399
+ return gpt_reply
400
+ if gpt_reply.endswith("```"):
401
+ return gpt_reply
402
+
403
+ # 排除了以上两个情况,我们
404
+ segments = gpt_reply.split("```")
405
+ n_mark = len(segments) - 1
406
+ if n_mark % 2 == 1:
407
+ return gpt_reply + "\n```" # 输出代码片段中!
408
+ else:
409
+ return gpt_reply
410
+
411
+
412
+ def special_render_issues_for_mermaid(text):
413
+ # 用不太优雅的方式处理一个core_functional.py中出现的mermaid渲染特例:
414
+ # 我不希望"总结绘制脑图"prompt中的mermaid渲染出来
415
+ @lru_cache(maxsize=1)
416
+ def get_special_case():
417
+ from core_functional import get_core_functions
418
+ special_case = get_core_functions()["总结绘制脑图"]["Suffix"]
419
+ return special_case
420
+ if text.endswith(get_special_case()): text = text.replace("```mermaid", "```")
421
+ return text
422
+
423
+
424
+ def compat_non_markdown_input(text):
425
+ """
426
+ 改善非markdown输入的显示效果,例如将空格转换为&nbsp;,将换行符转换为</br>等。
427
+ """
428
+ if "```" in text:
429
+ # careful input:markdown输入
430
+ text = special_render_issues_for_mermaid(text) # 处理特殊的渲染问题
431
+ return text
432
+ elif "</div>" in text:
433
+ # careful input:html输入
434
+ return text
435
+ else:
436
+ # whatever input:非markdown输入
437
+ lines = text.split("\n")
438
+ for i, line in enumerate(lines):
439
+ lines[i] = lines[i].replace(" ", "&nbsp;") # 空格转换为&nbsp;
440
+ text = "</br>".join(lines) # 换行符转换为</br>
441
+ return text
442
+
443
+
444
+ @lru_cache(maxsize=128) # 使用lru缓存
445
+ def simple_markdown_convertion(text):
446
+ pre = '<div class="markdown-body">'
447
+ suf = "</div>"
448
+ if text.startswith(pre) and text.endswith(suf):
449
+ return text # 已经被转化过,不需要再次转化
450
+ text = compat_non_markdown_input(text) # 兼容非markdown输入
451
+ text = markdown.markdown(
452
+ text,
453
+ extensions=["pymdownx.superfences", "tables", "pymdownx.highlight"],
454
+ extension_configs=code_highlight_configs,
455
+ )
456
+ return pre + text + suf
457
+
458
+
459
+ def format_io(self, y):
460
+ """
461
+ 将输入和输出解析为HTML格式。将y中最后一项的输入部分段落化,并将输出部分的Markdown和数学公式转换为HTML格式。
462
+ """
463
+ if y is None or y == []:
464
+ return []
465
+ i_ask, gpt_reply = y[-1]
466
+ i_ask = apply_gpt_academic_string_mask(i_ask, mode="show_render")
467
+ gpt_reply = apply_gpt_academic_string_mask(gpt_reply, mode="show_render")
468
+ # 当代码输出半截的时候,试着补上后个```
469
+ if gpt_reply is not None:
470
+ gpt_reply = close_up_code_segment_during_stream(gpt_reply)
471
+ # 处理提问与输出
472
+ y[-1] = (
473
+ # 输入部分
474
+ None if i_ask is None else simple_markdown_convertion(i_ask),
475
+ # 输出部分
476
+ None if gpt_reply is None else markdown_convertion(gpt_reply),
477
+ )
478
+ return y
docker_as_a_service/shared_utils/char_visual_effect.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def is_full_width_char(ch):
2
+ """判断给定的单个字符是否是全角字符"""
3
+ if '\u4e00' <= ch <= '\u9fff':
4
+ return True # 中文字符
5
+ if '\uff01' <= ch <= '\uff5e':
6
+ return True # 全角符号
7
+ if '\u3000' <= ch <= '\u303f':
8
+ return True # CJK标点符号
9
+ return False
10
+
11
+ def scolling_visual_effect(text, scroller_max_len):
12
+ text = text.\
13
+ replace('\n', '').replace('`', '.').replace(' ', '.').replace('<br/>', '.....').replace('$', '.')
14
+ place_take_cnt = 0
15
+ pointer = len(text) - 1
16
+
17
+ if len(text) < scroller_max_len:
18
+ return text
19
+
20
+ while place_take_cnt < scroller_max_len and pointer > 0:
21
+ if is_full_width_char(text[pointer]): place_take_cnt += 2
22
+ else: place_take_cnt += 1
23
+ pointer -= 1
24
+
25
+ return text[pointer:]
docker_as_a_service/shared_utils/colorful.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ from sys import stdout
3
+ from loguru import logger
4
+
5
+ if platform.system()=="Linux":
6
+ pass
7
+ else:
8
+ from colorama import init
9
+ init()
10
+
11
+ # Do you like the elegance of Chinese characters?
12
+ def print红(*kw,**kargs):
13
+ print("\033[0;31m",*kw,"\033[0m",**kargs)
14
+ def print绿(*kw,**kargs):
15
+ print("\033[0;32m",*kw,"\033[0m",**kargs)
16
+ def print黄(*kw,**kargs):
17
+ print("\033[0;33m",*kw,"\033[0m",**kargs)
18
+ def print蓝(*kw,**kargs):
19
+ print("\033[0;34m",*kw,"\033[0m",**kargs)
20
+ def print紫(*kw,**kargs):
21
+ print("\033[0;35m",*kw,"\033[0m",**kargs)
22
+ def print靛(*kw,**kargs):
23
+ print("\033[0;36m",*kw,"\033[0m",**kargs)
24
+
25
+ def print亮红(*kw,**kargs):
26
+ print("\033[1;31m",*kw,"\033[0m",**kargs)
27
+ def print亮绿(*kw,**kargs):
28
+ print("\033[1;32m",*kw,"\033[0m",**kargs)
29
+ def print亮黄(*kw,**kargs):
30
+ print("\033[1;33m",*kw,"\033[0m",**kargs)
31
+ def print亮蓝(*kw,**kargs):
32
+ print("\033[1;34m",*kw,"\033[0m",**kargs)
33
+ def print亮紫(*kw,**kargs):
34
+ print("\033[1;35m",*kw,"\033[0m",**kargs)
35
+ def print亮靛(*kw,**kargs):
36
+ print("\033[1;36m",*kw,"\033[0m",**kargs)
37
+
38
+ # Do you like the elegance of Chinese characters?
39
+ def sprint红(*kw):
40
+ return "\033[0;31m"+' '.join(kw)+"\033[0m"
41
+ def sprint绿(*kw):
42
+ return "\033[0;32m"+' '.join(kw)+"\033[0m"
43
+ def sprint黄(*kw):
44
+ return "\033[0;33m"+' '.join(kw)+"\033[0m"
45
+ def sprint蓝(*kw):
46
+ return "\033[0;34m"+' '.join(kw)+"\033[0m"
47
+ def sprint紫(*kw):
48
+ return "\033[0;35m"+' '.join(kw)+"\033[0m"
49
+ def sprint靛(*kw):
50
+ return "\033[0;36m"+' '.join(kw)+"\033[0m"
51
+ def sprint亮红(*kw):
52
+ return "\033[1;31m"+' '.join(kw)+"\033[0m"
53
+ def sprint亮绿(*kw):
54
+ return "\033[1;32m"+' '.join(kw)+"\033[0m"
55
+ def sprint亮黄(*kw):
56
+ return "\033[1;33m"+' '.join(kw)+"\033[0m"
57
+ def sprint亮蓝(*kw):
58
+ return "\033[1;34m"+' '.join(kw)+"\033[0m"
59
+ def sprint亮紫(*kw):
60
+ return "\033[1;35m"+' '.join(kw)+"\033[0m"
61
+ def sprint亮靛(*kw):
62
+ return "\033[1;36m"+' '.join(kw)+"\033[0m"
63
+
64
+ def log红(*kw,**kargs):
65
+ logger.opt(depth=1).info(sprint红(*kw))
66
+ def log绿(*kw,**kargs):
67
+ logger.opt(depth=1).info(sprint绿(*kw))
68
+ def log黄(*kw,**kargs):
69
+ logger.opt(depth=1).info(sprint黄(*kw))
70
+ def log蓝(*kw,**kargs):
71
+ logger.opt(depth=1).info(sprint蓝(*kw))
72
+ def log紫(*kw,**kargs):
73
+ logger.opt(depth=1).info(sprint紫(*kw))
74
+ def log靛(*kw,**kargs):
75
+ logger.opt(depth=1).info(sprint靛(*kw))
76
+
77
+ def log亮红(*kw,**kargs):
78
+ logger.opt(depth=1).info(sprint亮红(*kw))
79
+ def log亮绿(*kw,**kargs):
80
+ logger.opt(depth=1).info(sprint亮绿(*kw))
81
+ def log亮黄(*kw,**kargs):
82
+ logger.opt(depth=1).info(sprint亮黄(*kw))
83
+ def log亮蓝(*kw,**kargs):
84
+ logger.opt(depth=1).info(sprint亮蓝(*kw))
85
+ def log亮紫(*kw,**kargs):
86
+ logger.opt(depth=1).info(sprint亮紫(*kw))
87
+ def log亮靛(*kw,**kargs):
88
+ logger.opt(depth=1).info(sprint亮靛(*kw))
docker_as_a_service/shared_utils/config_loader.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import time
3
+ import os
4
+ from functools import lru_cache
5
+ from shared_utils.colorful import log亮红, log亮绿, log亮蓝
6
+
7
+ pj = os.path.join
8
+ default_user_name = 'default_user'
9
+
10
+ def read_env_variable(arg, default_value):
11
+ """
12
+ 环境变量可以是 `GPT_ACADEMIC_CONFIG`(优先),也可以直接是`CONFIG`
13
+ 例如在windows cmd中,既可以写:
14
+ set USE_PROXY=True
15
+ set API_KEY=sk-j7caBpkRoxxxxxxxxxxxxxxxxxxxxxxxxxxxx
16
+ set proxies={"http":"http://127.0.0.1:10085", "https":"http://127.0.0.1:10085",}
17
+ set AVAIL_LLM_MODELS=["gpt-3.5-turbo", "chatglm"]
18
+ set AUTHENTICATION=[("username", "password"), ("username2", "password2")]
19
+ 也可以写:
20
+ set GPT_ACADEMIC_USE_PROXY=True
21
+ set GPT_ACADEMIC_API_KEY=sk-j7caBpkRoxxxxxxxxxxxxxxxxxxxxxxxxxxxx
22
+ set GPT_ACADEMIC_proxies={"http":"http://127.0.0.1:10085", "https":"http://127.0.0.1:10085",}
23
+ set GPT_ACADEMIC_AVAIL_LLM_MODELS=["gpt-3.5-turbo", "chatglm"]
24
+ set GPT_ACADEMIC_AUTHENTICATION=[("username", "password"), ("username2", "password2")]
25
+ """
26
+ arg_with_prefix = "GPT_ACADEMIC_" + arg
27
+ if arg_with_prefix in os.environ:
28
+ env_arg = os.environ[arg_with_prefix]
29
+ elif arg in os.environ:
30
+ env_arg = os.environ[arg]
31
+ else:
32
+ raise KeyError
33
+ log亮绿(f"[ENV_VAR] 尝试加载{arg},默认值:{default_value} --> 修正值:{env_arg}")
34
+ try:
35
+ if isinstance(default_value, bool):
36
+ env_arg = env_arg.strip()
37
+ if env_arg == 'True': r = True
38
+ elif env_arg == 'False': r = False
39
+ else: log亮红('Expect `True` or `False`, but have:', env_arg); r = default_value
40
+ elif isinstance(default_value, int):
41
+ r = int(env_arg)
42
+ elif isinstance(default_value, float):
43
+ r = float(env_arg)
44
+ elif isinstance(default_value, str):
45
+ r = env_arg.strip()
46
+ elif isinstance(default_value, dict):
47
+ r = eval(env_arg)
48
+ elif isinstance(default_value, list):
49
+ r = eval(env_arg)
50
+ elif default_value is None:
51
+ assert arg == "proxies"
52
+ r = eval(env_arg)
53
+ else:
54
+ log亮红(f"[ENV_VAR] 环境变量{arg}不支持通过环境变量设置! ")
55
+ raise KeyError
56
+ except:
57
+ log亮红(f"[ENV_VAR] 环境变量{arg}加载失败! ")
58
+ raise KeyError(f"[ENV_VAR] 环境变量{arg}加载失败! ")
59
+
60
+ log亮绿(f"[ENV_VAR] 成功读取环境变量{arg}")
61
+ return r
62
+
63
+
64
+ @lru_cache(maxsize=128)
65
+ def read_single_conf_with_lru_cache(arg):
66
+ from shared_utils.key_pattern_manager import is_any_api_key
67
+ try:
68
+ # 优先级1. 获取环境变量作为配置
69
+ default_ref = getattr(importlib.import_module('config'), arg) # 读取默认值作为数据类型转换的参考
70
+ r = read_env_variable(arg, default_ref)
71
+ except:
72
+ try:
73
+ # 优先级2. 获取config_private中的配置
74
+ r = getattr(importlib.import_module('config_private'), arg)
75
+ except:
76
+ # 优先级3. 获取config中的配置
77
+ r = getattr(importlib.import_module('config'), arg)
78
+
79
+ # 在读取API_KEY时,检查一下是不是忘了改config
80
+ if arg == 'API_URL_REDIRECT':
81
+ oai_rd = r.get("https://api.openai.com/v1/chat/completions", None) # API_URL_REDIRECT填写格式是错误的,请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`
82
+ if oai_rd and not oai_rd.endswith('/completions'):
83
+ log亮红("\n\n[API_URL_REDIRECT] API_URL_REDIRECT填错了。请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`。如果您确信自己没填错,无视此消息即可。")
84
+ time.sleep(5)
85
+ if arg == 'API_KEY':
86
+ log亮蓝(f"[API_KEY] 本项目现已支持OpenAI和Azure的api-key。也支持同时填写多个api-key,如API_KEY=\"openai-key1,openai-key2,azure-key3\"")
87
+ log亮蓝(f"[API_KEY] 您既可以在config.py中修改api-key(s),也可以在问题输入区输入临时的api-key(s),然后回车键提交后即可生效。")
88
+ if is_any_api_key(r):
89
+ log亮绿(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功")
90
+ else:
91
+ log亮红(f"[API_KEY] 您的 API_KEY({r[:15]}***)不满足任何一种已知的密钥格式,请在config文件中修改API密钥之后再运行(详见`https://github.com/binary-husky/gpt_academic/wiki/api_key`)。")
92
+ if arg == 'proxies':
93
+ if not read_single_conf_with_lru_cache('USE_PROXY'): r = None # 检查USE_PROXY,防止proxies单独起作用
94
+ if r is None:
95
+ log亮红('[PROXY] 网络代理状态:未配置。无代理状态下很可能无法访问OpenAI家族的模型。建议:检查USE_PROXY选项是否修改。')
96
+ else:
97
+ log亮绿('[PROXY] 网络代理状态:已配置。配置信息如下:', str(r))
98
+ assert isinstance(r, dict), 'proxies格式错误,请注意proxies选项的格式,不要遗漏括号。'
99
+ return r
100
+
101
+
102
+ @lru_cache(maxsize=128)
103
+ def get_conf(*args):
104
+ """
105
+ 本项目的所有配置都集中在config.py中。 修改配置有三种方法,您只需要选择其中一种即可:
106
+ - 直接修改config.py
107
+ - 创建并修改config_private.py
108
+ - 修改环境变量(修改docker-compose.yml等价于修改容器内部的环境变量)
109
+
110
+ 注意:如果您使用docker-compose部署,请修改docker-compose(等价于修改容器内部的环境变量)
111
+ """
112
+ res = []
113
+ for arg in args:
114
+ r = read_single_conf_with_lru_cache(arg)
115
+ res.append(r)
116
+ if len(res) == 1: return res[0]
117
+ return res
118
+
119
+
120
+ def set_conf(key, value):
121
+ from toolbox import read_single_conf_with_lru_cache
122
+ read_single_conf_with_lru_cache.cache_clear()
123
+ get_conf.cache_clear()
124
+ os.environ[key] = str(value)
125
+ altered = get_conf(key)
126
+ return altered
127
+
128
+
129
+ def set_multi_conf(dic):
130
+ for k, v in dic.items(): set_conf(k, v)
131
+ return
docker_as_a_service/shared_utils/connect_void_terminal.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ """
4
+ =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
5
+ 接驳void-terminal:
6
+ - set_conf: 在运行过程中动态地修改配置
7
+ - set_multi_conf: 在运行过程中动态地修改多个配置
8
+ - get_plugin_handle: 获取插件的句柄
9
+ - get_plugin_default_kwargs: 获取插件的默认参数
10
+ - get_chat_handle: 获取简单聊天的句柄
11
+ - get_chat_default_kwargs: 获取简单聊天的默认参数
12
+ =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
13
+ """
14
+
15
+
16
+ def get_plugin_handle(plugin_name):
17
+ """
18
+ e.g. plugin_name = 'crazy_functions.Markdown_Translate->Markdown翻译指定语言'
19
+ """
20
+ import importlib
21
+
22
+ assert (
23
+ "->" in plugin_name
24
+ ), "Example of plugin_name: crazy_functions.Markdown_Translate->Markdown翻译指定语言"
25
+ module, fn_name = plugin_name.split("->")
26
+ f_hot_reload = getattr(importlib.import_module(module, fn_name), fn_name)
27
+ return f_hot_reload
28
+
29
+
30
+ def get_chat_handle():
31
+ """
32
+ Get chat function
33
+ """
34
+ from request_llms.bridge_all import predict_no_ui_long_connection
35
+
36
+ return predict_no_ui_long_connection
37
+
38
+
39
+ def get_plugin_default_kwargs():
40
+ """
41
+ Get Plugin Default Arguments
42
+ """
43
+ from toolbox import ChatBotWithCookies, load_chat_cookies
44
+
45
+ cookies = load_chat_cookies()
46
+ llm_kwargs = {
47
+ "api_key": cookies["api_key"],
48
+ "llm_model": cookies["llm_model"],
49
+ "top_p": 1.0,
50
+ "max_length": None,
51
+ "temperature": 1.0,
52
+ }
53
+ chatbot = ChatBotWithCookies(llm_kwargs)
54
+
55
+ # txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request
56
+ DEFAULT_FN_GROUPS_kwargs = {
57
+ "main_input": "./README.md",
58
+ "llm_kwargs": llm_kwargs,
59
+ "plugin_kwargs": {},
60
+ "chatbot_with_cookie": chatbot,
61
+ "history": [],
62
+ "system_prompt": "You are a good AI.",
63
+ "user_request": None,
64
+ }
65
+ return DEFAULT_FN_GROUPS_kwargs
66
+
67
+
68
+ def get_chat_default_kwargs():
69
+ """
70
+ Get Chat Default Arguments
71
+ """
72
+ from toolbox import load_chat_cookies
73
+
74
+ cookies = load_chat_cookies()
75
+ llm_kwargs = {
76
+ "api_key": cookies["api_key"],
77
+ "llm_model": cookies["llm_model"],
78
+ "top_p": 1.0,
79
+ "max_length": None,
80
+ "temperature": 1.0,
81
+ }
82
+ default_chat_kwargs = {
83
+ "inputs": "Hello there, are you ready?",
84
+ "llm_kwargs": llm_kwargs,
85
+ "history": [],
86
+ "sys_prompt": "You are AI assistant",
87
+ "observe_window": None,
88
+ "console_slience": False,
89
+ }
90
+
91
+ return default_chat_kwargs
docker_as_a_service/shared_utils/cookie_manager.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import base64
3
+ from typing import Callable
4
+
5
+ def load_web_cookie_cache__fn_builder(customize_btns, cookies, predefined_btns)->Callable:
6
+ def load_web_cookie_cache(persistent_cookie_, cookies_):
7
+ import gradio as gr
8
+ from themes.theme import load_dynamic_theme, to_cookie_str, from_cookie_str, assign_user_uuid
9
+
10
+ ret = {}
11
+ for k in customize_btns:
12
+ ret.update({customize_btns[k]: gr.update(visible=False, value="")})
13
+
14
+ try: persistent_cookie_ = from_cookie_str(persistent_cookie_) # persistent cookie to dict
15
+ except: return ret
16
+
17
+ customize_fn_overwrite_ = persistent_cookie_.get("custom_bnt", {})
18
+ cookies_['customize_fn_overwrite'] = customize_fn_overwrite_
19
+ ret.update({cookies: cookies_})
20
+
21
+ for k,v in persistent_cookie_["custom_bnt"].items():
22
+ if v['Title'] == "": continue
23
+ if k in customize_btns: ret.update({customize_btns[k]: gr.update(visible=True, value=v['Title'])})
24
+ else: ret.update({predefined_btns[k]: gr.update(visible=True, value=v['Title'])})
25
+ return ret
26
+ return load_web_cookie_cache
27
+
28
+ def assign_btn__fn_builder(customize_btns, predefined_btns, cookies, web_cookie_cache)->Callable:
29
+ def assign_btn(persistent_cookie_, cookies_, basic_btn_dropdown_, basic_fn_title, basic_fn_prefix, basic_fn_suffix, clean_up=False):
30
+ import gradio as gr
31
+ from themes.theme import load_dynamic_theme, to_cookie_str, from_cookie_str, assign_user_uuid
32
+ ret = {}
33
+ # 读取之前的自定义按钮
34
+ customize_fn_overwrite_ = cookies_['customize_fn_overwrite']
35
+ # 更新新的自定义按钮
36
+ customize_fn_overwrite_.update({
37
+ basic_btn_dropdown_:
38
+ {
39
+ "Title":basic_fn_title,
40
+ "Prefix":basic_fn_prefix,
41
+ "Suffix":basic_fn_suffix,
42
+ }
43
+ }
44
+ )
45
+ if clean_up:
46
+ customize_fn_overwrite_ = {}
47
+ cookies_.update(customize_fn_overwrite_) # 更新cookie
48
+ visible = (not clean_up) and (basic_fn_title != "")
49
+ if basic_btn_dropdown_ in customize_btns:
50
+ # 是自定义按钮,不是预定义按钮
51
+ ret.update({customize_btns[basic_btn_dropdown_]: gr.update(visible=visible, value=basic_fn_title)})
52
+ else:
53
+ # 是预定义按钮
54
+ ret.update({predefined_btns[basic_btn_dropdown_]: gr.update(visible=visible, value=basic_fn_title)})
55
+ ret.update({cookies: cookies_})
56
+ try: persistent_cookie_ = from_cookie_str(persistent_cookie_) # persistent cookie to dict
57
+ except: persistent_cookie_ = {}
58
+ persistent_cookie_["custom_bnt"] = customize_fn_overwrite_ # dict update new value
59
+ persistent_cookie_ = to_cookie_str(persistent_cookie_) # persistent cookie to dict
60
+ ret.update({web_cookie_cache: persistent_cookie_}) # write persistent cookie
61
+ return ret
62
+ return assign_btn
63
+
64
+ # cookies, web_cookie_cache = make_cookie_cache()
65
+ def make_cookie_cache():
66
+ # 定义 后端state(cookies)、前端(web_cookie_cache)两兄弟
67
+ import gradio as gr
68
+ from toolbox import load_chat_cookies
69
+ # 定义cookies的后端state
70
+ cookies = gr.State(load_chat_cookies())
71
+ # 定义cookies的一个孪生的前端存储区(隐藏)
72
+ web_cookie_cache = gr.Textbox(visible=False, elem_id="web_cookie_cache")
73
+ return cookies, web_cookie_cache
74
+
75
+ # history, history_cache, history_cache_update = make_history_cache()
76
+ def make_history_cache():
77
+ # 定义 后端state(history)、前端(history_cache)、后端setter(history_cache_update)三兄弟
78
+ import gradio as gr
79
+ # 定义history的后端state
80
+ history = gr.State([])
81
+ # 定义history的一个孪生的前端存储区(隐藏)
82
+ history_cache = gr.Textbox(visible=False, elem_id="history_cache")
83
+ # 定义history_cache->history的更新方法(隐藏)。在触发这个按钮时,会先执行js代码更新history_cache,然后再执行python代码更新history
84
+ def process_history_cache(history_cache):
85
+ return json.loads(history_cache)
86
+ # 另一种更简单的setter方法
87
+ history_cache_update = gr.Button("", elem_id="elem_update_history", visible=False).click(
88
+ process_history_cache, inputs=[history_cache], outputs=[history])
89
+ return history, history_cache, history_cache_update
90
+
91
+
92
+
93
+ def create_button_with_javascript_callback(btn_value, elem_id, variant, js_callback, input_list, output_list, function, input_name_list, output_name_list):
94
+ import gradio as gr
95
+ middle_ware_component = gr.Textbox(visible=False, elem_id=elem_id+'_buffer')
96
+ def get_fn_wrap():
97
+ def fn_wrap(*args):
98
+ summary_dict = {}
99
+ for name, value in zip(input_name_list, args):
100
+ summary_dict.update({name: value})
101
+
102
+ res = function(*args)
103
+
104
+ for name, value in zip(output_name_list, res):
105
+ summary_dict.update({name: value})
106
+
107
+ summary = base64.b64encode(json.dumps(summary_dict).encode('utf8')).decode("utf-8")
108
+ return (*res, summary)
109
+ return fn_wrap
110
+
111
+ btn = gr.Button(btn_value, elem_id=elem_id, variant=variant)
112
+ call_args = ""
113
+ for name in output_name_list:
114
+ call_args += f"""Data["{name}"],"""
115
+ call_args = call_args.rstrip(",")
116
+ _js_callback = """
117
+ (base64MiddleString)=>{
118
+ console.log('hello')
119
+ const stringData = atob(base64MiddleString);
120
+ let Data = JSON.parse(stringData);
121
+ call = JS_CALLBACK_GEN;
122
+ call(CALL_ARGS);
123
+ }
124
+ """.replace("JS_CALLBACK_GEN", js_callback).replace("CALL_ARGS", call_args)
125
+
126
+ btn.click(get_fn_wrap(), input_list, output_list+[middle_ware_component]).then(None, [middle_ware_component], None, _js=_js_callback)
127
+ return btn
docker_as_a_service/shared_utils/docker_as_service_api.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import pickle
3
+ import io
4
+ import os
5
+ from pydantic import BaseModel, Field
6
+ from typing import Optional, Dict
7
+ from loguru import logger
8
+
9
+ class DockerServiceApiComModel(BaseModel):
10
+ client_command: Optional[str] = Field(default=None, title="Client command", description="The command to be executed on the client side")
11
+ client_file_attach: Optional[dict] = Field(default=None, title="Client file attach", description="The file to be attached to the client side")
12
+ server_message: Optional[str] = Field(default=None, title="Server standard error", description="The standard error from the server side")
13
+ server_std_err: Optional[str] = Field(default=None, title="Server standard error", description="The standard error from the server side")
14
+ server_std_out: Optional[str] = Field(default=None, title="Server standard output", description="The standard output from the server side")
15
+ server_file_attach: Optional[dict] = Field(default=None, title="Server file attach", description="The file to be attached to the server side")
16
+
17
+ def process_received(received: DockerServiceApiComModel, save_file_dir="./daas_output", output_manifest={}):
18
+ # Process the received data
19
+ if received.server_message:
20
+ output_manifest['server_message'] += received.server_message
21
+ if received.server_std_err:
22
+ output_manifest['server_std_err'] += received.server_std_err
23
+ if received.server_std_out:
24
+ output_manifest['server_std_out'] += received.server_std_out
25
+ if received.server_file_attach:
26
+ # print(f"Recv file attach: {received.server_file_attach}")
27
+ for file_name, file_content in received.server_file_attach.items():
28
+ new_fp = os.path.join(save_file_dir, file_name)
29
+ new_fp_dir = os.path.dirname(new_fp)
30
+ if not os.path.exists(new_fp_dir):
31
+ os.makedirs(new_fp_dir, exist_ok=True)
32
+ with open(new_fp, 'wb') as f:
33
+ f.write(file_content)
34
+ output_manifest['server_file_attach'].append(new_fp)
35
+ return output_manifest
36
+
37
+ def stream_daas(docker_service_api_com_model, server_url):
38
+ # Prepare the file
39
+ # Pickle the object
40
+ pickled_data = pickle.dumps(docker_service_api_com_model)
41
+
42
+ # Create a file-like object from the pickled data
43
+ file_obj = io.BytesIO(pickled_data)
44
+
45
+ # Prepare the file for sending
46
+ files = {'file': ('docker_service_api_com_model.pkl', file_obj, 'application/octet-stream')}
47
+
48
+ # Send the POST request
49
+ response = requests.post(server_url, files=files, stream=True)
50
+
51
+ max_full_package_size = 1024 * 1024 * 1024 * 1 # 1 GB
52
+
53
+ received_output_manifest = {}
54
+ received_output_manifest['server_message'] = ""
55
+ received_output_manifest['server_std_err'] = ""
56
+ received_output_manifest['server_std_out'] = ""
57
+ received_output_manifest['server_file_attach'] = []
58
+
59
+ # Check if the request was successful
60
+ if response.status_code == 200:
61
+ # Process the streaming response
62
+ for chunk in response.iter_content(max_full_package_size):
63
+ if chunk:
64
+ received = pickle.loads(chunk)
65
+ received_output_manifest = process_received(received, received_output_manifest)
66
+ yield received_output_manifest
67
+ else:
68
+ logger.error(f"Error: Received status code {response.status_code}, response.text: {response.text}")
69
+
70
+ return received_output_manifest
docker_as_a_service/shared_utils/fastapi_server.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests:
3
+
4
+ - custom_path false / no user auth:
5
+ -- upload file(yes)
6
+ -- download file(yes)
7
+ -- websocket(yes)
8
+ -- block __pycache__ access(yes)
9
+ -- rel (yes)
10
+ -- abs (yes)
11
+ -- block user access(fail) http://localhost:45013/file=gpt_log/admin/chat_secrets.log
12
+ -- fix(commit f6bf05048c08f5cd84593f7fdc01e64dec1f584a)-> block successful
13
+
14
+ - custom_path yes("/cc/gptac") / no user auth:
15
+ -- upload file(yes)
16
+ -- download file(yes)
17
+ -- websocket(yes)
18
+ -- block __pycache__ access(yes)
19
+ -- block user access(yes)
20
+
21
+ - custom_path yes("/cc/gptac/") / no user auth:
22
+ -- upload file(yes)
23
+ -- download file(yes)
24
+ -- websocket(yes)
25
+ -- block user access(yes)
26
+
27
+ - custom_path yes("/cc/gptac/") / + user auth:
28
+ -- upload file(yes)
29
+ -- download file(yes)
30
+ -- websocket(yes)
31
+ -- block user access(yes)
32
+ -- block user-wise access (yes)
33
+
34
+ - custom_path no + user auth:
35
+ -- upload file(yes)
36
+ -- download file(yes)
37
+ -- websocket(yes)
38
+ -- block user access(yes)
39
+ -- block user-wise access (yes)
40
+
41
+ queue cocurrent effectiveness
42
+ -- upload file(yes)
43
+ -- download file(yes)
44
+ -- websocket(yes)
45
+ """
46
+
47
+ import os, requests, threading, time
48
+ import uvicorn
49
+
50
+ def validate_path_safety(path_or_url, user):
51
+ from toolbox import get_conf, default_user_name
52
+ from toolbox import FriendlyException
53
+ PATH_PRIVATE_UPLOAD, PATH_LOGGING = get_conf('PATH_PRIVATE_UPLOAD', 'PATH_LOGGING')
54
+ sensitive_path = None
55
+ path_or_url = os.path.relpath(path_or_url)
56
+ if path_or_url.startswith(PATH_LOGGING): # 日志文件(按用户划分)
57
+ sensitive_path = PATH_LOGGING
58
+ elif path_or_url.startswith(PATH_PRIVATE_UPLOAD): # 用户的上传目录(按用户划分)
59
+ sensitive_path = PATH_PRIVATE_UPLOAD
60
+ elif path_or_url.startswith('tests') or path_or_url.startswith('build'): # 一个常用的测试目录
61
+ return True
62
+ else:
63
+ raise FriendlyException(f"输入文件的路径 ({path_or_url}) 存在,但位置非法。请将文件上传后再执行该任务。") # return False
64
+ if sensitive_path:
65
+ allowed_users = [user, 'autogen', 'arxiv_cache', default_user_name] # three user path that can be accessed
66
+ for user_allowed in allowed_users:
67
+ if f"{os.sep}".join(path_or_url.split(os.sep)[:2]) == os.path.join(sensitive_path, user_allowed):
68
+ return True
69
+ raise FriendlyException(f"输入文件的路径 ({path_or_url}) 存在,但属于其他用户。请将文件上传后再执行该任务。") # return False
70
+ return True
71
+
72
+ def _authorize_user(path_or_url, request, gradio_app):
73
+ from toolbox import get_conf, default_user_name
74
+ PATH_PRIVATE_UPLOAD, PATH_LOGGING = get_conf('PATH_PRIVATE_UPLOAD', 'PATH_LOGGING')
75
+ sensitive_path = None
76
+ path_or_url = os.path.relpath(path_or_url)
77
+ if path_or_url.startswith(PATH_LOGGING):
78
+ sensitive_path = PATH_LOGGING
79
+ if path_or_url.startswith(PATH_PRIVATE_UPLOAD):
80
+ sensitive_path = PATH_PRIVATE_UPLOAD
81
+ if sensitive_path:
82
+ token = request.cookies.get("access-token") or request.cookies.get("access-token-unsecure")
83
+ user = gradio_app.tokens.get(token) # get user
84
+ allowed_users = [user, 'autogen', 'arxiv_cache', default_user_name] # three user path that can be accessed
85
+ for user_allowed in allowed_users:
86
+ # exact match
87
+ if f"{os.sep}".join(path_or_url.split(os.sep)[:2]) == os.path.join(sensitive_path, user_allowed):
88
+ return True
89
+ return False # "越权访问!"
90
+ return True
91
+
92
+
93
+ class Server(uvicorn.Server):
94
+ # A server that runs in a separate thread
95
+ def install_signal_handlers(self):
96
+ pass
97
+
98
+ def run_in_thread(self):
99
+ self.thread = threading.Thread(target=self.run, daemon=True)
100
+ self.thread.start()
101
+ while not self.started:
102
+ time.sleep(5e-2)
103
+
104
+ def close(self):
105
+ self.should_exit = True
106
+ self.thread.join()
107
+
108
+
109
+ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SSL_CERTFILE):
110
+ import uvicorn
111
+ import fastapi
112
+ import gradio as gr
113
+ from fastapi import FastAPI
114
+ from gradio.routes import App
115
+ from toolbox import get_conf
116
+ CUSTOM_PATH, PATH_LOGGING = get_conf('CUSTOM_PATH', 'PATH_LOGGING')
117
+
118
+ # --- --- configurate gradio app block --- ---
119
+ app_block:gr.Blocks
120
+ app_block.ssl_verify = False
121
+ app_block.auth_message = '请登录'
122
+ app_block.favicon_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "docs/logo.png")
123
+ app_block.auth = AUTHENTICATION if len(AUTHENTICATION) != 0 else None
124
+ app_block.blocked_paths = ["config.py", "__pycache__", "config_private.py", "docker-compose.yml", "Dockerfile", f"{PATH_LOGGING}/admin"]
125
+ app_block.dev_mode = False
126
+ app_block.config = app_block.get_config_file()
127
+ app_block.enable_queue = True
128
+ app_block.queue(concurrency_count=CONCURRENT_COUNT)
129
+ app_block.validate_queue_settings()
130
+ app_block.show_api = False
131
+ app_block.config = app_block.get_config_file()
132
+ max_threads = 40
133
+ app_block.max_threads = max(
134
+ app_block._queue.max_thread_count if app_block.enable_queue else 0, max_threads
135
+ )
136
+ app_block.is_colab = False
137
+ app_block.is_kaggle = False
138
+ app_block.is_sagemaker = False
139
+
140
+ gradio_app = App.create_app(app_block)
141
+ for route in list(gradio_app.router.routes):
142
+ if route.path == "/proxy={url_path:path}":
143
+ gradio_app.router.routes.remove(route)
144
+ # --- --- replace gradio endpoint to forbid access to sensitive files --- ---
145
+ if len(AUTHENTICATION) > 0:
146
+ dependencies = []
147
+ endpoint = None
148
+ for route in list(gradio_app.router.routes):
149
+ if route.path == "/file/{path:path}":
150
+ gradio_app.router.routes.remove(route)
151
+ if route.path == "/file={path_or_url:path}":
152
+ dependencies = route.dependencies
153
+ endpoint = route.endpoint
154
+ gradio_app.router.routes.remove(route)
155
+ @gradio_app.get("/file/{path:path}", dependencies=dependencies)
156
+ @gradio_app.head("/file={path_or_url:path}", dependencies=dependencies)
157
+ @gradio_app.get("/file={path_or_url:path}", dependencies=dependencies)
158
+ async def file(path_or_url: str, request: fastapi.Request):
159
+ if not _authorize_user(path_or_url, request, gradio_app):
160
+ return "越权访问!"
161
+ stripped = path_or_url.lstrip().lower()
162
+ if stripped.startswith("https://") or stripped.startswith("http://"):
163
+ return "账户密码授权模式下, 禁止链接!"
164
+ if '../' in stripped:
165
+ return "非法路径!"
166
+ return await endpoint(path_or_url, request)
167
+
168
+ from fastapi import Request, status
169
+ from fastapi.responses import FileResponse, RedirectResponse
170
+ @gradio_app.get("/academic_logout")
171
+ async def logout():
172
+ response = RedirectResponse(url=CUSTOM_PATH, status_code=status.HTTP_302_FOUND)
173
+ response.delete_cookie('access-token')
174
+ response.delete_cookie('access-token-unsecure')
175
+ return response
176
+ else:
177
+ dependencies = []
178
+ endpoint = None
179
+ for route in list(gradio_app.router.routes):
180
+ if route.path == "/file/{path:path}":
181
+ gradio_app.router.routes.remove(route)
182
+ if route.path == "/file={path_or_url:path}":
183
+ dependencies = route.dependencies
184
+ endpoint = route.endpoint
185
+ gradio_app.router.routes.remove(route)
186
+ @gradio_app.get("/file/{path:path}", dependencies=dependencies)
187
+ @gradio_app.head("/file={path_or_url:path}", dependencies=dependencies)
188
+ @gradio_app.get("/file={path_or_url:path}", dependencies=dependencies)
189
+ async def file(path_or_url: str, request: fastapi.Request):
190
+ stripped = path_or_url.lstrip().lower()
191
+ if stripped.startswith("https://") or stripped.startswith("http://"):
192
+ return "账户密码授权模式下, 禁止链接!"
193
+ if '../' in stripped:
194
+ return "非法路径!"
195
+ return await endpoint(path_or_url, request)
196
+
197
+ # --- --- enable TTS (text-to-speech) functionality --- ---
198
+ TTS_TYPE = get_conf("TTS_TYPE")
199
+ if TTS_TYPE != "DISABLE":
200
+ # audio generation functionality
201
+ import httpx
202
+ from fastapi import FastAPI, Request, HTTPException
203
+ from starlette.responses import Response
204
+ async def forward_request(request: Request, method: str) -> Response:
205
+ async with httpx.AsyncClient() as client:
206
+ try:
207
+ # Forward the request to the target service
208
+ if TTS_TYPE == "EDGE_TTS":
209
+ import tempfile
210
+ import edge_tts
211
+ import wave
212
+ import uuid
213
+ from pydub import AudioSegment
214
+ json = await request.json()
215
+ voice = get_conf("EDGE_TTS_VOICE")
216
+ tts = edge_tts.Communicate(text=json['text'], voice=voice)
217
+ temp_folder = tempfile.gettempdir()
218
+ temp_file_name = str(uuid.uuid4().hex)
219
+ temp_file = os.path.join(temp_folder, f'{temp_file_name}.mp3')
220
+ await tts.save(temp_file)
221
+ try:
222
+ mp3_audio = AudioSegment.from_file(temp_file, format="mp3")
223
+ mp3_audio.export(temp_file, format="wav")
224
+ with open(temp_file, 'rb') as wav_file: t = wav_file.read()
225
+ os.remove(temp_file)
226
+ return Response(content=t)
227
+ except:
228
+ raise RuntimeError("ffmpeg未安装,无法处理EdgeTTS音频��安装方法见`https://github.com/jiaaro/pydub#getting-ffmpeg-set-up`")
229
+ if TTS_TYPE == "LOCAL_SOVITS_API":
230
+ # Forward the request to the target service
231
+ TARGET_URL = get_conf("GPT_SOVITS_URL")
232
+ body = await request.body()
233
+ resp = await client.post(TARGET_URL, content=body, timeout=60)
234
+ # Return the response from the target service
235
+ return Response(content=resp.content, status_code=resp.status_code, headers=dict(resp.headers))
236
+ except httpx.RequestError as e:
237
+ raise HTTPException(status_code=400, detail=f"Request to the target service failed: {str(e)}")
238
+ @gradio_app.post("/vits")
239
+ async def forward_post_request(request: Request):
240
+ return await forward_request(request, "POST")
241
+
242
+ # --- --- app_lifespan --- ---
243
+ from contextlib import asynccontextmanager
244
+ @asynccontextmanager
245
+ async def app_lifespan(app):
246
+ async def startup_gradio_app():
247
+ if gradio_app.get_blocks().enable_queue:
248
+ gradio_app.get_blocks().startup_events()
249
+ async def shutdown_gradio_app():
250
+ pass
251
+ await startup_gradio_app() # startup logic here
252
+ yield # The application will serve requests after this point
253
+ await shutdown_gradio_app() # cleanup/shutdown logic here
254
+
255
+ # --- --- FastAPI --- ---
256
+ fastapi_app = FastAPI(lifespan=app_lifespan)
257
+ fastapi_app.mount(CUSTOM_PATH, gradio_app)
258
+
259
+ # --- --- favicon and block fastapi api reference routes --- ---
260
+ from starlette.responses import JSONResponse
261
+ if CUSTOM_PATH != '/':
262
+ from fastapi.responses import FileResponse
263
+ @fastapi_app.get("/favicon.ico")
264
+ async def favicon():
265
+ return FileResponse(app_block.favicon_path)
266
+
267
+ @fastapi_app.middleware("http")
268
+ async def middleware(request: Request, call_next):
269
+ if request.scope['path'] in ["/docs", "/redoc", "/openapi.json"]:
270
+ return JSONResponse(status_code=404, content={"message": "Not Found"})
271
+ response = await call_next(request)
272
+ return response
273
+
274
+
275
+ # --- --- uvicorn.Config --- ---
276
+ ssl_keyfile = None if SSL_KEYFILE == "" else SSL_KEYFILE
277
+ ssl_certfile = None if SSL_CERTFILE == "" else SSL_CERTFILE
278
+ server_name = "0.0.0.0"
279
+ config = uvicorn.Config(
280
+ fastapi_app,
281
+ host=server_name,
282
+ port=PORT,
283
+ reload=False,
284
+ log_level="warning",
285
+ ssl_keyfile=ssl_keyfile,
286
+ ssl_certfile=ssl_certfile,
287
+ )
288
+ server = Server(config)
289
+ url_host_name = "localhost" if server_name == "0.0.0.0" else server_name
290
+ if ssl_keyfile is not None:
291
+ if ssl_certfile is None:
292
+ raise ValueError(
293
+ "ssl_certfile must be provided if ssl_keyfile is provided."
294
+ )
295
+ path_to_local_server = f"https://{url_host_name}:{PORT}/"
296
+ else:
297
+ path_to_local_server = f"http://{url_host_name}:{PORT}/"
298
+ if CUSTOM_PATH != '/':
299
+ path_to_local_server += CUSTOM_PATH.lstrip('/').rstrip('/') + '/'
300
+ # --- --- begin --- ---
301
+ server.run_in_thread()
302
+
303
+ # --- --- after server launch --- ---
304
+ app_block.server = server
305
+ app_block.server_name = server_name
306
+ app_block.local_url = path_to_local_server
307
+ app_block.protocol = (
308
+ "https"
309
+ if app_block.local_url.startswith("https") or app_block.is_colab
310
+ else "http"
311
+ )
312
+
313
+ if app_block.enable_queue:
314
+ app_block._queue.set_url(path_to_local_server)
315
+
316
+ forbid_proxies = {
317
+ "http": "",
318
+ "https": "",
319
+ }
320
+ requests.get(f"{app_block.local_url}startup-events", verify=app_block.ssl_verify, proxies=forbid_proxies)
321
+ app_block.is_running = True
322
+ app_block.block_thread()
docker_as_a_service/shared_utils/handle_upload.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import time
3
+ import inspect
4
+ import re
5
+ import os
6
+ import base64
7
+ import gradio
8
+ import shutil
9
+ import glob
10
+ from shared_utils.config_loader import get_conf
11
+ from loguru import logger
12
+
13
+ def html_local_file(file):
14
+ base_path = os.path.dirname(__file__) # 项目目录
15
+ if os.path.exists(str(file)):
16
+ file = f'file={file.replace(base_path, ".")}'
17
+ return file
18
+
19
+
20
+ def html_local_img(__file, layout="left", max_width=None, max_height=None, md=True):
21
+ style = ""
22
+ if max_width is not None:
23
+ style += f"max-width: {max_width};"
24
+ if max_height is not None:
25
+ style += f"max-height: {max_height};"
26
+ __file = html_local_file(__file)
27
+ a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
28
+ if md:
29
+ a = f"![{__file}]({__file})"
30
+ return a
31
+
32
+
33
+ def file_manifest_filter_type(file_list, filter_: list = None):
34
+ new_list = []
35
+ if not filter_:
36
+ filter_ = ["png", "jpg", "jpeg"]
37
+ for file in file_list:
38
+ if str(os.path.basename(file)).split(".")[-1] in filter_:
39
+ new_list.append(html_local_img(file, md=False))
40
+ else:
41
+ new_list.append(file)
42
+ return new_list
43
+
44
+
45
+ def zip_extract_member_new(self, member, targetpath, pwd):
46
+ # 修复中文乱码的问题
47
+ """Extract the ZipInfo object 'member' to a physical
48
+ file on the path targetpath.
49
+ """
50
+ import zipfile
51
+ if not isinstance(member, zipfile.ZipInfo):
52
+ member = self.getinfo(member)
53
+
54
+ # build the destination pathname, replacing
55
+ # forward slashes to platform specific separators.
56
+ arcname = member.filename.replace('/', os.path.sep)
57
+ arcname = arcname.encode('cp437', errors='replace').decode('gbk', errors='replace')
58
+
59
+ if os.path.altsep:
60
+ arcname = arcname.replace(os.path.altsep, os.path.sep)
61
+ # interpret absolute pathname as relative, remove drive letter or
62
+ # UNC path, redundant separators, "." and ".." components.
63
+ arcname = os.path.splitdrive(arcname)[1]
64
+ invalid_path_parts = ('', os.path.curdir, os.path.pardir)
65
+ arcname = os.path.sep.join(x for x in arcname.split(os.path.sep)
66
+ if x not in invalid_path_parts)
67
+ if os.path.sep == '\\':
68
+ # filter illegal characters on Windows
69
+ arcname = self._sanitize_windows_name(arcname, os.path.sep)
70
+
71
+ targetpath = os.path.join(targetpath, arcname)
72
+ targetpath = os.path.normpath(targetpath)
73
+
74
+ # Create all upper directories if necessary.
75
+ upperdirs = os.path.dirname(targetpath)
76
+ if upperdirs and not os.path.exists(upperdirs):
77
+ os.makedirs(upperdirs)
78
+
79
+ if member.is_dir():
80
+ if not os.path.isdir(targetpath):
81
+ os.mkdir(targetpath)
82
+ return targetpath
83
+
84
+ with self.open(member, pwd=pwd) as source, \
85
+ open(targetpath, "wb") as target:
86
+ shutil.copyfileobj(source, target)
87
+
88
+ return targetpath
89
+
90
+
91
+ def extract_archive(file_path, dest_dir):
92
+ import zipfile
93
+ import tarfile
94
+ import os
95
+
96
+ # Get the file extension of the input file
97
+ file_extension = os.path.splitext(file_path)[1]
98
+
99
+ # Extract the archive based on its extension
100
+ if file_extension == ".zip":
101
+ with zipfile.ZipFile(file_path, "r") as zipobj:
102
+ zipobj._extract_member = lambda a,b,c: zip_extract_member_new(zipobj, a,b,c) # 修复中文乱码的问题
103
+ zipobj.extractall(path=dest_dir)
104
+ logger.info("Successfully extracted zip archive to {}".format(dest_dir))
105
+
106
+ elif file_extension in [".tar", ".gz", ".bz2"]:
107
+ try:
108
+ with tarfile.open(file_path, "r:*") as tarobj:
109
+ # 清理提取路径,移除任何不安全的元素
110
+ for member in tarobj.getmembers():
111
+ member_path = os.path.normpath(member.name)
112
+ full_path = os.path.join(dest_dir, member_path)
113
+ full_path = os.path.abspath(full_path)
114
+ if not full_path.startswith(os.path.abspath(dest_dir) + os.sep):
115
+ raise Exception(f"Attempted Path Traversal in {member.name}")
116
+
117
+ tarobj.extractall(path=dest_dir)
118
+ logger.info("Successfully extracted tar archive to {}".format(dest_dir))
119
+ except tarfile.ReadError as e:
120
+ if file_extension == ".gz":
121
+ # 一些特别奇葩的项目,是一个gz文件,里面不是tar,只有一个tex文件
122
+ import gzip
123
+ with gzip.open(file_path, 'rb') as f_in:
124
+ with open(os.path.join(dest_dir, 'main.tex'), 'wb') as f_out:
125
+ f_out.write(f_in.read())
126
+ else:
127
+ raise e
128
+
129
+ # 第三方库,需要预先pip install rarfile
130
+ # 此外,Windows上还需要安装winrar软件,配置其Path环境变量,如"C:\Program Files\WinRAR"才可以
131
+ elif file_extension == ".rar":
132
+ try:
133
+ import rarfile
134
+
135
+ with rarfile.RarFile(file_path) as rf:
136
+ rf.extractall(path=dest_dir)
137
+ logger.info("Successfully extracted rar archive to {}".format(dest_dir))
138
+ except:
139
+ logger.info("Rar format requires additional dependencies to install")
140
+ return "\n\n解压失败! 需要安装pip install rarfile来解压rar文件。建议:使用zip压缩格式。"
141
+
142
+ # 第三方库,需要预先pip install py7zr
143
+ elif file_extension == ".7z":
144
+ try:
145
+ import py7zr
146
+
147
+ with py7zr.SevenZipFile(file_path, mode="r") as f:
148
+ f.extractall(path=dest_dir)
149
+ logger.info("Successfully extracted 7z archive to {}".format(dest_dir))
150
+ except:
151
+ logger.info("7z format requires additional dependencies to install")
152
+ return "\n\n解压失败! 需要安装pip install py7zr来解压7z文件"
153
+ else:
154
+ return ""
155
+ return ""
156
+
docker_as_a_service/shared_utils/key_pattern_manager.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ from functools import wraps, lru_cache
4
+ from shared_utils.advanced_markdown_format import format_io
5
+ from shared_utils.config_loader import get_conf as get_conf
6
+
7
+
8
+ pj = os.path.join
9
+ default_user_name = 'default_user'
10
+
11
+ # match openai keys
12
+ openai_regex = re.compile(
13
+ r"sk-[a-zA-Z0-9_-]{48}$|" +
14
+ r"sk-[a-zA-Z0-9_-]{92}$|" +
15
+ r"sk-proj-[a-zA-Z0-9_-]{48}$|"+
16
+ r"sk-proj-[a-zA-Z0-9_-]{124}$|"+
17
+ r"sk-proj-[a-zA-Z0-9_-]{156}$|"+ #新版apikey位数不匹配故修改此正则表达式
18
+ r"sess-[a-zA-Z0-9]{40}$"
19
+ )
20
+ def is_openai_api_key(key):
21
+ CUSTOM_API_KEY_PATTERN = get_conf('CUSTOM_API_KEY_PATTERN')
22
+ if len(CUSTOM_API_KEY_PATTERN) != 0:
23
+ API_MATCH_ORIGINAL = re.match(CUSTOM_API_KEY_PATTERN, key)
24
+ else:
25
+ API_MATCH_ORIGINAL = openai_regex.match(key)
26
+ return bool(API_MATCH_ORIGINAL)
27
+
28
+
29
+ def is_azure_api_key(key):
30
+ API_MATCH_AZURE = re.match(r"[a-zA-Z0-9]{32}$", key)
31
+ return bool(API_MATCH_AZURE)
32
+
33
+
34
+ def is_api2d_key(key):
35
+ API_MATCH_API2D = re.match(r"fk[a-zA-Z0-9]{6}-[a-zA-Z0-9]{32}$", key)
36
+ return bool(API_MATCH_API2D)
37
+
38
+ def is_openroute_api_key(key):
39
+ API_MATCH_OPENROUTE = re.match(r"sk-or-v1-[a-zA-Z0-9]{64}$", key)
40
+ return bool(API_MATCH_OPENROUTE)
41
+
42
+ def is_cohere_api_key(key):
43
+ API_MATCH_AZURE = re.match(r"[a-zA-Z0-9]{40}$", key)
44
+ return bool(API_MATCH_AZURE)
45
+
46
+
47
+ def is_any_api_key(key):
48
+ if ',' in key:
49
+ keys = key.split(',')
50
+ for k in keys:
51
+ if is_any_api_key(k): return True
52
+ return False
53
+ else:
54
+ return is_openai_api_key(key) or is_api2d_key(key) or is_azure_api_key(key) or is_cohere_api_key(key)
55
+
56
+
57
+ def what_keys(keys):
58
+ avail_key_list = {'OpenAI Key': 0, "Azure Key": 0, "API2D Key": 0}
59
+ key_list = keys.split(',')
60
+
61
+ for k in key_list:
62
+ if is_openai_api_key(k):
63
+ avail_key_list['OpenAI Key'] += 1
64
+
65
+ for k in key_list:
66
+ if is_api2d_key(k):
67
+ avail_key_list['API2D Key'] += 1
68
+
69
+ for k in key_list:
70
+ if is_azure_api_key(k):
71
+ avail_key_list['Azure Key'] += 1
72
+
73
+ return f"检测到: OpenAI Key {avail_key_list['OpenAI Key']} 个, Azure Key {avail_key_list['Azure Key']} 个, API2D Key {avail_key_list['API2D Key']} 个"
74
+
75
+
76
+ def select_api_key(keys, llm_model):
77
+ import random
78
+ avail_key_list = []
79
+ key_list = keys.split(',')
80
+
81
+ if llm_model.startswith('gpt-') or llm_model.startswith('one-api-') or llm_model.startswith('o1-'):
82
+ for k in key_list:
83
+ if is_openai_api_key(k): avail_key_list.append(k)
84
+
85
+ if llm_model.startswith('api2d-'):
86
+ for k in key_list:
87
+ if is_api2d_key(k): avail_key_list.append(k)
88
+
89
+ if llm_model.startswith('azure-'):
90
+ for k in key_list:
91
+ if is_azure_api_key(k): avail_key_list.append(k)
92
+
93
+ if llm_model.startswith('cohere-'):
94
+ for k in key_list:
95
+ if is_cohere_api_key(k): avail_key_list.append(k)
96
+
97
+ if llm_model.startswith('openrouter-'):
98
+ for k in key_list:
99
+ if is_openroute_api_key(k): avail_key_list.append(k)
100
+
101
+ if len(avail_key_list) == 0:
102
+ raise RuntimeError(f"您提供的api-key不满足要求,不包含任何可用于{llm_model}的api-key。您可能选择了错误的模型或请求源(左上角更换模型菜单中可切换openai,azure,claude,cohere等请求源)。")
103
+
104
+ api_key = random.choice(avail_key_list) # 随机负载均衡
105
+ return api_key
106
+
107
+
108
+ def select_api_key_for_embed_models(keys, llm_model):
109
+ import random
110
+ avail_key_list = []
111
+ key_list = keys.split(',')
112
+
113
+ if llm_model.startswith('text-embedding-'):
114
+ for k in key_list:
115
+ if is_openai_api_key(k): avail_key_list.append(k)
116
+
117
+ if len(avail_key_list) == 0:
118
+ raise RuntimeError(f"您提供的api-key不满足要求,不包含任何可用于{llm_model}的api-key。您可能选择了错误的模型或请求源。")
119
+
120
+ api_key = random.choice(avail_key_list) # 随机负载均衡
121
+ return api_key
docker_as_a_service/shared_utils/logging.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from loguru import logger
2
+ import logging
3
+ import sys
4
+ import os
5
+
6
+ def chat_log_filter(record):
7
+ return "chat_msg" in record["extra"]
8
+
9
+ def not_chat_log_filter(record):
10
+ return "chat_msg" not in record["extra"]
11
+
12
+ def formatter_with_clip(record):
13
+ # Note this function returns the string to be formatted, not the actual message to be logged
14
+ # record["extra"]["serialized"] = "555555"
15
+ max_len = 12
16
+ record['function_x'] = record['function'].center(max_len)
17
+ if len(record['function_x']) > max_len:
18
+ record['function_x'] = ".." + record['function_x'][-(max_len-2):]
19
+ record['line_x'] = str(record['line']).ljust(3)
20
+ return '<green>{time:HH:mm}</green> | <cyan>{function_x}</cyan>:<cyan>{line_x}</cyan> | <level>{message}</level>\n'
21
+
22
+ def setup_logging(PATH_LOGGING):
23
+
24
+ admin_log_path = os.path.join(PATH_LOGGING, "admin")
25
+ os.makedirs(admin_log_path, exist_ok=True)
26
+ sensitive_log_path = os.path.join(admin_log_path, "chat_secrets.log")
27
+ regular_log_path = os.path.join(admin_log_path, "console_log.log")
28
+ logger.remove()
29
+ logger.configure(
30
+ levels=[dict(name="WARNING", color="<g>")],
31
+ )
32
+
33
+ logger.add(
34
+ sys.stderr,
35
+ format=formatter_with_clip,
36
+ # format='<green>{time:HH:mm}</green> | <cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>',
37
+ filter=(lambda record: not chat_log_filter(record)),
38
+ colorize=True,
39
+ enqueue=True
40
+ )
41
+
42
+ logger.add(
43
+ sensitive_log_path,
44
+ format='<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>',
45
+ rotation="10 MB",
46
+ filter=chat_log_filter,
47
+ enqueue=True,
48
+ )
49
+
50
+ logger.add(
51
+ regular_log_path,
52
+ format='<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>',
53
+ rotation="10 MB",
54
+ filter=not_chat_log_filter,
55
+ enqueue=True,
56
+ )
57
+
58
+ logging.getLogger("httpx").setLevel(logging.WARNING)
59
+
60
+ logger.warning(f"所有对话记录将自动保存在本地目录{sensitive_log_path}, 请注意自我隐私保护哦!")
61
+
62
+
63
+ # logger.bind(chat_msg=True).info("This message is logged to the file!")
64
+ # logger.debug(f"debug message")
65
+ # logger.info(f"info message")
66
+ # logger.success(f"success message")
67
+ # logger.error(f"error message")
68
+ # logger.add("special.log", filter=lambda record: "special" in record["extra"])
69
+ # logger.debug("This message is not logged to the file")
docker_as_a_service/shared_utils/map_names.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ mapping_dic = {
3
+ # "qianfan": "qianfan(文心一言大模型)",
4
+ # "zhipuai": "zhipuai(智谱GLM4超级模型🔥)",
5
+ # "gpt-4-1106-preview": "gpt-4-1106-preview(新调优版本GPT-4🔥)",
6
+ # "gpt-4-vision-preview": "gpt-4-vision-preview(识图模型GPT-4V)",
7
+ }
8
+
9
+ rev_mapping_dic = {}
10
+ for k, v in mapping_dic.items():
11
+ rev_mapping_dic[v] = k
12
+
13
+ def map_model_to_friendly_names(m):
14
+ if m in mapping_dic:
15
+ return mapping_dic[m]
16
+ return m
17
+
18
+ def map_friendly_names_to_model(m):
19
+ if m in rev_mapping_dic:
20
+ return rev_mapping_dic[m]
21
+ return m
22
+
23
+ def read_one_api_model_name(model: str):
24
+ """return real model name and max_token.
25
+ """
26
+ max_token_pattern = r"\(max_token=(\d+)\)"
27
+ match = re.search(max_token_pattern, model)
28
+ if match:
29
+ max_token_tmp = match.group(1) # 获取 max_token 的值
30
+ max_token_tmp = int(max_token_tmp)
31
+ model = re.sub(max_token_pattern, "", model) # 从原字符串中删除 "(max_token=...)"
32
+ else:
33
+ max_token_tmp = 4096
34
+ return model, max_token_tmp
docker_as_a_service/shared_utils/text_mask.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from functools import lru_cache
3
+
4
+ # 这段代码是使用Python编程语言中的re模块,即正则表达式库,来定义了一个正则表达式模式。
5
+ # 这个模式被编译成一个正则表达式对象,存储在名为const_extract_exp的变量中,以便于后续快速的匹配和查找操作。
6
+ # 这里解释一下正则表达式中的几个特殊字符:
7
+ # - . 表示任意单一字符。
8
+ # - * 表示前一个字符可以出现0次或多次。
9
+ # - ? 在这里用作非贪婪匹配,也就是说它会匹配尽可能少的字符。在(.*?)中,它确保我们匹配的任意文本是尽可能短的,也就是说,它会在</show_llm>和</show_render>标签之前停止匹配。
10
+ # - () 括号在正则表达式中表示捕获组。
11
+ # - 在这个例子中,(.*?)表示捕获任意长度的文本,直到遇到括号外部最近的限定符,即</show_llm>和</show_render>。
12
+
13
+ # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-==-=-=-=/1=-=-=-=-=-=-=-=-=-=-=-=-=-=/2-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
14
+ const_extract_re = re.compile(
15
+ r"<gpt_academic_string_mask><show_llm>(.*?)</show_llm><show_render>(.*?)</show_render></gpt_academic_string_mask>"
16
+ )
17
+ # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-==-=-=-=-=-=/1=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-/2-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
18
+ const_extract_langbased_re = re.compile(
19
+ r"<gpt_academic_string_mask><lang_english>(.*?)</lang_english><lang_chinese>(.*?)</lang_chinese></gpt_academic_string_mask>",
20
+ flags=re.DOTALL,
21
+ )
22
+
23
+ @lru_cache(maxsize=128)
24
+ def apply_gpt_academic_string_mask(string, mode="show_all"):
25
+ """
26
+ 当字符串中有掩码tag时(<gpt_academic_string_mask><show_...>),根据字符串要给谁看(大模型,还是web渲染),对字符串进行处理,返回处理后的字符串
27
+ 示意图:https://mermaid.live/edit#pako:eNqlkUtLw0AUhf9KuOta0iaTplkIPlpduFJwoZEwJGNbzItpita2O6tF8QGKogXFtwu7cSHiq3-mk_oznFR8IYLgrGbuOd9hDrcCpmcR0GDW9ubNPKaBMDauuwI_A9M6YN-3y0bODwxsYos4BdMoBrTg5gwHF-d0mBH6-vqFQe58ed5m9XPW2uteX3Tubrj0ljLYcwxxR3h1zB43WeMs3G19yEM9uapDMe_NG9i2dagKw1Fee4c1D9nGEbtc-5n6HbNtJ8IyHOs8tbs7V2HrlDX2w2Y7XD_5haHEtQiNsOwfMVa_7TzsvrWIuJGo02qTrdwLk9gukQylHv3Afv1ML270s-HZUndrmW1tdA-WfvbM_jMFYuAQ6uCCxVdciTJ1CPLEITpo_GphypeouzXuw6XAmyi7JmgBLZEYlHwLB2S4gHMUO-9DH7tTnvf1CVoFFkBLSOk4QmlRTqpIlaWUHINyNFXjaQWpCYRURUKiWovBYo8X4ymEJFlECQUpqaQkJmuvWygPpg
28
+ """
29
+ if not string:
30
+ return string
31
+ if "<gpt_academic_string_mask>" not in string: # No need to process
32
+ return string
33
+
34
+ if mode == "show_all":
35
+ return string
36
+ if mode == "show_llm":
37
+ string = const_extract_re.sub(r"\1", string)
38
+ elif mode == "show_render":
39
+ string = const_extract_re.sub(r"\2", string)
40
+ else:
41
+ raise ValueError("Invalid mode")
42
+ return string
43
+
44
+
45
+ @lru_cache(maxsize=128)
46
+ def build_gpt_academic_masked_string(text_show_llm="", text_show_render=""):
47
+ """
48
+ 根据字符串要给谁看(大模型,还是web渲染),生成带掩码tag的字符串
49
+ """
50
+ return f"<gpt_academic_string_mask><show_llm>{text_show_llm}</show_llm><show_render>{text_show_render}</show_render></gpt_academic_string_mask>"
51
+
52
+
53
+ @lru_cache(maxsize=128)
54
+ def apply_gpt_academic_string_mask_langbased(string, lang_reference):
55
+ """
56
+ 当字符串中有掩码tag时(<gpt_academic_string_mask><lang_...>),根据语言,选择提示词,对字符串进行处理,返回处理后的字符串
57
+ 例如,如果lang_reference是英文,那么就只显示英文提示词,中文提示词就不显示了
58
+ 举例:
59
+ 输入1
60
+ string = "注意,lang_reference这段文字是:<gpt_academic_string_mask><lang_english>英语</lang_english><lang_chinese>中文</lang_chinese></gpt_academic_string_mask>"
61
+ lang_reference = "hello world"
62
+ 输出1
63
+ "注意,lang_reference这段文字是:英语"
64
+
65
+ 输入2
66
+ string = "注意,lang_reference这段文字是中文" # 注意这里没有掩码tag,所以不会被处理
67
+ lang_reference = "hello world"
68
+ 输出2
69
+ "注意,lang_reference这段文字是中文" # 原样返回
70
+ """
71
+
72
+ if "<gpt_academic_string_mask>" not in string: # No need to process
73
+ return string
74
+
75
+ def contains_chinese(string):
76
+ chinese_regex = re.compile(u'[\u4e00-\u9fff]+')
77
+ return chinese_regex.search(string) is not None
78
+
79
+ mode = "english" if not contains_chinese(lang_reference) else "chinese"
80
+ if mode == "english":
81
+ string = const_extract_langbased_re.sub(r"\1", string)
82
+ elif mode == "chinese":
83
+ string = const_extract_langbased_re.sub(r"\2", string)
84
+ else:
85
+ raise ValueError("Invalid mode")
86
+ return string
87
+
88
+
89
+ @lru_cache(maxsize=128)
90
+ def build_gpt_academic_masked_string_langbased(text_show_english="", text_show_chinese=""):
91
+ """
92
+ 根据语言,选择提示词,对字符串进行处理,返回处理后的字符串
93
+ """
94
+ return f"<gpt_academic_string_mask><lang_english>{text_show_english}</lang_english><lang_chinese>{text_show_chinese}</lang_chinese></gpt_academic_string_mask>"
95
+
96
+
97
+ if __name__ == "__main__":
98
+ # Test
99
+ input_string = (
100
+ "你好\n"
101
+ + build_gpt_academic_masked_string(text_show_llm="mermaid", text_show_render="")
102
+ + "你好\n"
103
+ )
104
+ print(
105
+ apply_gpt_academic_string_mask(input_string, "show_llm")
106
+ ) # Should print the strings with 'abc' in place of the academic mask tags
107
+ print(
108
+ apply_gpt_academic_string_mask(input_string, "show_render")
109
+ ) # Should print the strings with 'xyz' in place of the academic mask tags