|
import fnmatch |
|
import logging |
|
|
|
from fastapi import FastAPI |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.staticfiles import StaticFiles |
|
|
|
|
|
def is_excluded(path, exclude_patterns): |
|
""" |
|
检查路径是否被排除 |
|
|
|
:param path: 需要检查的路径 |
|
:param exclude_patterns: 包含通配符的排除路径列表 |
|
:return: 如果路径被排除,返回 True;否则返回 False |
|
""" |
|
for pattern in exclude_patterns: |
|
if fnmatch.fnmatch(path, pattern): |
|
print(path, pattern) |
|
return True |
|
return False |
|
|
|
|
|
class APIManager: |
|
def __init__(self, app: FastAPI, exclude_patterns=[]): |
|
self.app = app |
|
self.registered_apis = {} |
|
self.logger = logging.getLogger(__name__) |
|
self.exclude = exclude_patterns |
|
|
|
def is_excluded(self, path): |
|
return is_excluded(path, self.exclude) |
|
|
|
def set_cors( |
|
self, |
|
allow_origins: list = ["*"], |
|
allow_credentials: bool = True, |
|
allow_methods: list = ["*"], |
|
allow_headers: list = ["*"], |
|
): |
|
|
|
self.app.middleware_stack = None |
|
self.app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=allow_origins, |
|
allow_credentials=allow_credentials, |
|
allow_methods=allow_methods, |
|
allow_headers=allow_headers, |
|
) |
|
self.app.build_middleware_stack() |
|
|
|
def setup_playground(self): |
|
app = self.app |
|
app.mount( |
|
"/playground", |
|
StaticFiles(directory="playground", html=True), |
|
name="playground", |
|
) |
|
|
|
def get(self, path: str, **kwargs): |
|
def decorator(func): |
|
if self.is_excluded(path): |
|
return func |
|
|
|
self.app.get(path, **kwargs)(func) |
|
|
|
self.registered_apis[path] = func |
|
self.logger.info(f"Registered API: GET {path}") |
|
|
|
return func |
|
|
|
return decorator |
|
|
|
def post(self, path: str, **kwargs): |
|
def decorator(func): |
|
if self.is_excluded(path): |
|
return func |
|
|
|
self.app.post(path, **kwargs)(func) |
|
|
|
self.registered_apis[path] = func |
|
self.logger.info(f"Registered API: POST {path}") |
|
|
|
return func |
|
|
|
return decorator |
|
|