|
import torch |
|
from modules import config |
|
from modules import generate_audio as generate |
|
|
|
from functools import lru_cache |
|
from typing import Callable |
|
|
|
from modules.api.Api import APIManager |
|
|
|
from modules.api.impl import ( |
|
base_api, |
|
tts_api, |
|
ssml_api, |
|
google_api, |
|
openai_api, |
|
refiner_api, |
|
) |
|
|
|
torch._dynamo.config.cache_size_limit = 64 |
|
torch._dynamo.config.suppress_errors = True |
|
torch.set_float32_matmul_precision("high") |
|
|
|
|
|
def create_api(): |
|
api = APIManager() |
|
|
|
base_api.setup(api) |
|
tts_api.setup(api) |
|
ssml_api.setup(api) |
|
google_api.setup(api) |
|
openai_api.setup(api) |
|
refiner_api.setup(api) |
|
|
|
return api |
|
|
|
|
|
def conditional_cache(condition: Callable): |
|
def decorator(func): |
|
@lru_cache(None) |
|
def cached_func(*args, **kwargs): |
|
return func(*args, **kwargs) |
|
|
|
def wrapper(*args, **kwargs): |
|
if condition(*args, **kwargs): |
|
return cached_func(*args, **kwargs) |
|
else: |
|
return func(*args, **kwargs) |
|
|
|
return wrapper |
|
|
|
return decorator |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
import uvicorn |
|
|
|
parser = argparse.ArgumentParser( |
|
description="Start the FastAPI server with command line arguments" |
|
) |
|
parser.add_argument( |
|
"--host", type=str, default="0.0.0.0", help="Host to run the server on" |
|
) |
|
parser.add_argument( |
|
"--port", type=int, default=8000, help="Port to run the server on" |
|
) |
|
parser.add_argument( |
|
"--reload", action="store_true", help="Enable auto-reload for development" |
|
) |
|
parser.add_argument("--compile", action="store_true", help="Enable model compile") |
|
parser.add_argument( |
|
"--lru_size", |
|
type=int, |
|
default=64, |
|
help="Set the size of the request cache pool, set it to 0 will disable lru_cache", |
|
) |
|
parser.add_argument( |
|
"--cors_origin", |
|
type=str, |
|
default="*", |
|
help="Allowed CORS origins. Use '*' to allow all origins.", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
config.args = args |
|
|
|
if args.compile: |
|
print("Model compile is enabled") |
|
config.enable_model_compile = True |
|
|
|
def should_cache(*args, **kwargs): |
|
spk_seed = kwargs.get("spk_seed", -1) |
|
infer_seed = kwargs.get("infer_seed", -1) |
|
return spk_seed != -1 and infer_seed != -1 |
|
|
|
if args.lru_size > 0: |
|
config.lru_size = args.lru_size |
|
generate.generate_audio = conditional_cache(should_cache)( |
|
generate.generate_audio |
|
) |
|
|
|
api = create_api() |
|
config.api = api |
|
|
|
if args.cors_origin: |
|
api.set_cors(allow_origins=[args.cors_origin]) |
|
|
|
uvicorn.run(api.app, host=args.host, port=args.port, reload=args.reload) |
|
|