|
""" |
|
Tests: |
|
|
|
- custom_path false / no user auth: |
|
-- upload file(yes) |
|
-- download file(yes) |
|
-- websocket(yes) |
|
-- block __pycache__ access(yes) |
|
-- rel (yes) |
|
-- abs (yes) |
|
-- block user access(fail) http://localhost:45013/file=gpt_log/admin/chat_secrets.log |
|
-- fix(commit f6bf05048c08f5cd84593f7fdc01e64dec1f584a)-> block successful |
|
|
|
- custom_path yes("/cc/gptac") / no user auth: |
|
-- upload file(yes) |
|
-- download file(yes) |
|
-- websocket(yes) |
|
-- block __pycache__ access(yes) |
|
-- block user access(yes) |
|
|
|
- custom_path yes("/cc/gptac/") / no user auth: |
|
-- upload file(yes) |
|
-- download file(yes) |
|
-- websocket(yes) |
|
-- block user access(yes) |
|
|
|
- custom_path yes("/cc/gptac/") / + user auth: |
|
-- upload file(yes) |
|
-- download file(yes) |
|
-- websocket(yes) |
|
-- block user access(yes) |
|
-- block user-wise access (yes) |
|
|
|
- custom_path no + user auth: |
|
-- upload file(yes) |
|
-- download file(yes) |
|
-- websocket(yes) |
|
-- block user access(yes) |
|
-- block user-wise access (yes) |
|
|
|
queue cocurrent effectiveness |
|
-- upload file(yes) |
|
-- download file(yes) |
|
-- websocket(yes) |
|
""" |
|
|
|
import os, requests, threading, time |
|
import uvicorn |
|
|
|
def validate_path_safety(path_or_url, user): |
|
from toolbox import get_conf, default_user_name |
|
from toolbox import FriendlyException |
|
PATH_PRIVATE_UPLOAD, PATH_LOGGING = get_conf('PATH_PRIVATE_UPLOAD', 'PATH_LOGGING') |
|
sensitive_path = None |
|
path_or_url = os.path.relpath(path_or_url) |
|
if path_or_url.startswith(PATH_LOGGING): |
|
sensitive_path = PATH_LOGGING |
|
elif path_or_url.startswith(PATH_PRIVATE_UPLOAD): |
|
sensitive_path = PATH_PRIVATE_UPLOAD |
|
elif path_or_url.startswith('tests') or path_or_url.startswith('build'): |
|
return True |
|
else: |
|
raise FriendlyException(f"输入文件的路径 ({path_or_url}) 存在,但位置非法。请将文件上传后再执行该任务。") |
|
if sensitive_path: |
|
allowed_users = [user, 'autogen', 'arxiv_cache', default_user_name] |
|
for user_allowed in allowed_users: |
|
if f"{os.sep}".join(path_or_url.split(os.sep)[:2]) == os.path.join(sensitive_path, user_allowed): |
|
return True |
|
raise FriendlyException(f"输入文件的路径 ({path_or_url}) 存在,但属于其他用户。请将文件上传后再执行该任务。") |
|
return True |
|
|
|
def _authorize_user(path_or_url, request, gradio_app): |
|
from toolbox import get_conf, default_user_name |
|
PATH_PRIVATE_UPLOAD, PATH_LOGGING = get_conf('PATH_PRIVATE_UPLOAD', 'PATH_LOGGING') |
|
sensitive_path = None |
|
path_or_url = os.path.relpath(path_or_url) |
|
if path_or_url.startswith(PATH_LOGGING): |
|
sensitive_path = PATH_LOGGING |
|
if path_or_url.startswith(PATH_PRIVATE_UPLOAD): |
|
sensitive_path = PATH_PRIVATE_UPLOAD |
|
if sensitive_path: |
|
token = request.cookies.get("access-token") or request.cookies.get("access-token-unsecure") |
|
user = gradio_app.tokens.get(token) |
|
allowed_users = [user, 'autogen', 'arxiv_cache', default_user_name] |
|
for user_allowed in allowed_users: |
|
|
|
if f"{os.sep}".join(path_or_url.split(os.sep)[:2]) == os.path.join(sensitive_path, user_allowed): |
|
return True |
|
return False |
|
return True |
|
|
|
|
|
class Server(uvicorn.Server): |
|
|
|
def install_signal_handlers(self): |
|
pass |
|
|
|
def run_in_thread(self): |
|
self.thread = threading.Thread(target=self.run, daemon=True) |
|
self.thread.start() |
|
while not self.started: |
|
time.sleep(5e-2) |
|
|
|
def close(self): |
|
self.should_exit = True |
|
self.thread.join() |
|
|
|
|
|
def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SSL_CERTFILE): |
|
import uvicorn |
|
import fastapi |
|
import gradio as gr |
|
from fastapi import FastAPI |
|
from gradio.routes import App |
|
from toolbox import get_conf |
|
CUSTOM_PATH, PATH_LOGGING = get_conf('CUSTOM_PATH', 'PATH_LOGGING') |
|
|
|
|
|
app_block:gr.Blocks |
|
app_block.ssl_verify = False |
|
app_block.auth_message = '请登录' |
|
app_block.favicon_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "docs/logo.png") |
|
app_block.auth = AUTHENTICATION if len(AUTHENTICATION) != 0 else None |
|
app_block.blocked_paths = ["config.py", "__pycache__", "config_private.py", "docker-compose.yml", "Dockerfile", f"{PATH_LOGGING}/admin"] |
|
app_block.dev_mode = False |
|
app_block.config = app_block.get_config_file() |
|
app_block.enable_queue = True |
|
app_block.queue(concurrency_count=CONCURRENT_COUNT) |
|
app_block.validate_queue_settings() |
|
app_block.show_api = False |
|
app_block.config = app_block.get_config_file() |
|
max_threads = 40 |
|
app_block.max_threads = max( |
|
app_block._queue.max_thread_count if app_block.enable_queue else 0, max_threads |
|
) |
|
app_block.is_colab = False |
|
app_block.is_kaggle = False |
|
app_block.is_sagemaker = False |
|
|
|
gradio_app = App.create_app(app_block) |
|
for route in list(gradio_app.router.routes): |
|
if route.path == "/proxy={url_path:path}": |
|
gradio_app.router.routes.remove(route) |
|
|
|
if len(AUTHENTICATION) > 0: |
|
dependencies = [] |
|
endpoint = None |
|
for route in list(gradio_app.router.routes): |
|
if route.path == "/file/{path:path}": |
|
gradio_app.router.routes.remove(route) |
|
if route.path == "/file={path_or_url:path}": |
|
dependencies = route.dependencies |
|
endpoint = route.endpoint |
|
gradio_app.router.routes.remove(route) |
|
@gradio_app.get("/file/{path:path}", dependencies=dependencies) |
|
@gradio_app.head("/file={path_or_url:path}", dependencies=dependencies) |
|
@gradio_app.get("/file={path_or_url:path}", dependencies=dependencies) |
|
async def file(path_or_url: str, request: fastapi.Request): |
|
if not _authorize_user(path_or_url, request, gradio_app): |
|
return "越权访问!" |
|
stripped = path_or_url.lstrip().lower() |
|
if stripped.startswith("https://") or stripped.startswith("http://"): |
|
return "账户密码授权模式下, 禁止链接!" |
|
if '../' in stripped: |
|
return "非法路径!" |
|
return await endpoint(path_or_url, request) |
|
|
|
from fastapi import Request, status |
|
from fastapi.responses import FileResponse, RedirectResponse |
|
@gradio_app.get("/academic_logout") |
|
async def logout(): |
|
response = RedirectResponse(url=CUSTOM_PATH, status_code=status.HTTP_302_FOUND) |
|
response.delete_cookie('access-token') |
|
response.delete_cookie('access-token-unsecure') |
|
return response |
|
else: |
|
dependencies = [] |
|
endpoint = None |
|
for route in list(gradio_app.router.routes): |
|
if route.path == "/file/{path:path}": |
|
gradio_app.router.routes.remove(route) |
|
if route.path == "/file={path_or_url:path}": |
|
dependencies = route.dependencies |
|
endpoint = route.endpoint |
|
gradio_app.router.routes.remove(route) |
|
@gradio_app.get("/file/{path:path}", dependencies=dependencies) |
|
@gradio_app.head("/file={path_or_url:path}", dependencies=dependencies) |
|
@gradio_app.get("/file={path_or_url:path}", dependencies=dependencies) |
|
async def file(path_or_url: str, request: fastapi.Request): |
|
stripped = path_or_url.lstrip().lower() |
|
if stripped.startswith("https://") or stripped.startswith("http://"): |
|
return "账户密码授权模式下, 禁止链接!" |
|
if '../' in stripped: |
|
return "非法路径!" |
|
return await endpoint(path_or_url, request) |
|
|
|
|
|
TTS_TYPE = get_conf("TTS_TYPE") |
|
if TTS_TYPE != "DISABLE": |
|
|
|
import httpx |
|
from fastapi import FastAPI, Request, HTTPException |
|
from starlette.responses import Response |
|
async def forward_request(request: Request, method: str) -> Response: |
|
async with httpx.AsyncClient() as client: |
|
try: |
|
|
|
if TTS_TYPE == "EDGE_TTS": |
|
import tempfile |
|
import edge_tts |
|
import wave |
|
import uuid |
|
from pydub import AudioSegment |
|
json = await request.json() |
|
voice = get_conf("EDGE_TTS_VOICE") |
|
tts = edge_tts.Communicate(text=json['text'], voice=voice) |
|
temp_folder = tempfile.gettempdir() |
|
temp_file_name = str(uuid.uuid4().hex) |
|
temp_file = os.path.join(temp_folder, f'{temp_file_name}.mp3') |
|
await tts.save(temp_file) |
|
try: |
|
mp3_audio = AudioSegment.from_file(temp_file, format="mp3") |
|
mp3_audio.export(temp_file, format="wav") |
|
with open(temp_file, 'rb') as wav_file: t = wav_file.read() |
|
os.remove(temp_file) |
|
return Response(content=t) |
|
except: |
|
raise RuntimeError("ffmpeg未安装,无法处理EdgeTTS音频。安装方法见`https://github.com/jiaaro/pydub#getting-ffmpeg-set-up`") |
|
if TTS_TYPE == "LOCAL_SOVITS_API": |
|
|
|
TARGET_URL = get_conf("GPT_SOVITS_URL") |
|
body = await request.body() |
|
resp = await client.post(TARGET_URL, content=body, timeout=60) |
|
|
|
return Response(content=resp.content, status_code=resp.status_code, headers=dict(resp.headers)) |
|
except httpx.RequestError as e: |
|
raise HTTPException(status_code=400, detail=f"Request to the target service failed: {str(e)}") |
|
@gradio_app.post("/vits") |
|
async def forward_post_request(request: Request): |
|
return await forward_request(request, "POST") |
|
|
|
|
|
from contextlib import asynccontextmanager |
|
@asynccontextmanager |
|
async def app_lifespan(app): |
|
async def startup_gradio_app(): |
|
if gradio_app.get_blocks().enable_queue: |
|
gradio_app.get_blocks().startup_events() |
|
async def shutdown_gradio_app(): |
|
pass |
|
await startup_gradio_app() |
|
yield |
|
await shutdown_gradio_app() |
|
|
|
|
|
fastapi_app = FastAPI(lifespan=app_lifespan) |
|
fastapi_app.mount(CUSTOM_PATH, gradio_app) |
|
|
|
|
|
from starlette.responses import JSONResponse |
|
if CUSTOM_PATH != '/': |
|
from fastapi.responses import FileResponse |
|
@fastapi_app.get("/favicon.ico") |
|
async def favicon(): |
|
return FileResponse(app_block.favicon_path) |
|
|
|
@fastapi_app.middleware("http") |
|
async def middleware(request: Request, call_next): |
|
if request.scope['path'] in ["/docs", "/redoc", "/openapi.json"]: |
|
return JSONResponse(status_code=404, content={"message": "Not Found"}) |
|
response = await call_next(request) |
|
return response |
|
|
|
|
|
|
|
ssl_keyfile = None if SSL_KEYFILE == "" else SSL_KEYFILE |
|
ssl_certfile = None if SSL_CERTFILE == "" else SSL_CERTFILE |
|
server_name = "0.0.0.0" |
|
config = uvicorn.Config( |
|
fastapi_app, |
|
host=server_name, |
|
port=PORT, |
|
reload=False, |
|
log_level="warning", |
|
ssl_keyfile=ssl_keyfile, |
|
ssl_certfile=ssl_certfile, |
|
) |
|
server = Server(config) |
|
url_host_name = "localhost" if server_name == "0.0.0.0" else server_name |
|
if ssl_keyfile is not None: |
|
if ssl_certfile is None: |
|
raise ValueError( |
|
"ssl_certfile must be provided if ssl_keyfile is provided." |
|
) |
|
path_to_local_server = f"https://{url_host_name}:{PORT}/" |
|
else: |
|
path_to_local_server = f"http://{url_host_name}:{PORT}/" |
|
if CUSTOM_PATH != '/': |
|
path_to_local_server += CUSTOM_PATH.lstrip('/').rstrip('/') + '/' |
|
|
|
server.run_in_thread() |
|
|
|
|
|
app_block.server = server |
|
app_block.server_name = server_name |
|
app_block.local_url = path_to_local_server |
|
app_block.protocol = ( |
|
"https" |
|
if app_block.local_url.startswith("https") or app_block.is_colab |
|
else "http" |
|
) |
|
|
|
if app_block.enable_queue: |
|
app_block._queue.set_url(path_to_local_server) |
|
|
|
forbid_proxies = { |
|
"http": "", |
|
"https": "", |
|
} |
|
requests.get(f"{app_block.local_url}startup-events", verify=app_block.ssl_verify, proxies=forbid_proxies) |
|
app_block.is_running = True |
|
app_block.block_thread() |