chattts / launch.py
zhzluke96
update
01e655b
raw
history blame
2.76 kB
import torch
from modules import config
from modules import generate_audio as generate
from functools import lru_cache
from typing import Callable
from modules.api.Api import APIManager
from modules.api.impl import (
base_api,
tts_api,
ssml_api,
google_api,
openai_api,
refiner_api,
)
torch._dynamo.config.cache_size_limit = 64
torch._dynamo.config.suppress_errors = True
torch.set_float32_matmul_precision("high")
def create_api():
api = APIManager()
base_api.setup(api)
tts_api.setup(api)
ssml_api.setup(api)
google_api.setup(api)
openai_api.setup(api)
refiner_api.setup(api)
return api
def conditional_cache(condition: Callable):
def decorator(func):
@lru_cache(None)
def cached_func(*args, **kwargs):
return func(*args, **kwargs)
def wrapper(*args, **kwargs):
if condition(*args, **kwargs):
return cached_func(*args, **kwargs)
else:
return func(*args, **kwargs)
return wrapper
return decorator
if __name__ == "__main__":
import argparse
import uvicorn
parser = argparse.ArgumentParser(
description="Start the FastAPI server with command line arguments"
)
parser.add_argument(
"--host", type=str, default="0.0.0.0", help="Host to run the server on"
)
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
parser.add_argument(
"--reload", action="store_true", help="Enable auto-reload for development"
)
parser.add_argument("--compile", action="store_true", help="Enable model compile")
parser.add_argument(
"--lru_size",
type=int,
default=64,
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
)
parser.add_argument(
"--cors_origin",
type=str,
default="*",
help="Allowed CORS origins. Use '*' to allow all origins.",
)
args = parser.parse_args()
config.args = args
if args.compile:
print("Model compile is enabled")
config.enable_model_compile = True
def should_cache(*args, **kwargs):
spk_seed = kwargs.get("spk_seed", -1)
infer_seed = kwargs.get("infer_seed", -1)
return spk_seed != -1 and infer_seed != -1
if args.lru_size > 0:
config.lru_size = args.lru_size
generate.generate_audio = conditional_cache(should_cache)(
generate.generate_audio
)
api = create_api()
config.api = api
if args.cors_origin:
api.set_cors(allow_origins=[args.cors_origin])
uvicorn.run(api.app, host=args.host, port=args.port, reload=args.reload)