Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import time | |
import threading | |
import tempfile | |
import ctypes | |
import gc | |
from fastapi import FastAPI, Depends, HTTPException, Security | |
from fastapi.responses import FileResponse, JSONResponse | |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | |
from pydantic import BaseModel | |
from huggingface_hub import hf_hub_download | |
from torch import no_grad, package | |
import uvicorn | |
from accentor import accentification, stress_replace_and_shift | |
import argparse | |
from passlib.context import CryptContext | |
app = FastAPI(docs_url=None, redoc_url=None) | |
# Set environment variable for Hugging Face cache directory | |
os.environ["HF_HOME"] = "/app/.cache" | |
tts_kwargs = { | |
"speaker_name": "uk", | |
"language_name": "uk", | |
} | |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
class Auth(BaseModel): | |
api_key: str | |
password: str | |
# Password hashing context | |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
def get_password_hash(password): | |
return pwd_context.hash(password) | |
def verify_password(plain_password, hashed_password): | |
return pwd_context.verify(plain_password, hashed_password) | |
api_key = os.getenv("XCHE_API_KEY") | |
password = os.getenv("XCHE_PASSWORD") | |
fake_data_db = { | |
api_key: { | |
"api_key": api_key, | |
"password": get_password_hash(password) # Pre-hashed password | |
} | |
} | |
def get_api_key(db, api_key: str): | |
if api_key in db: | |
api_dict = db[api_key] | |
return Auth(**api_dict) | |
def authenticate(fake_db, api_key: str, password: str): | |
api_data = get_api_key(fake_db, api_key) | |
if not api_data: | |
return False | |
if not verify_password(password, api_data.password): | |
return False | |
return api_data | |
async def login(form_data: OAuth2PasswordRequestForm = Depends()): | |
api_data = authenticate(fake_data_db, form_data.username, form_data.password) | |
if not api_data: | |
raise HTTPException( | |
status_code=400, | |
detail="Incorrect API KEY or Password", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
return {"access_token": api_data.api_key, "token_type": "bearer"} | |
def check_api_token(token: str = Depends(oauth2_scheme)): | |
api_data = get_api_key(fake_data_db, token) | |
if not api_data: | |
raise HTTPException(status_code=403, detail="Invalid or missing API Key") | |
return api_data | |
def trim_memory(): | |
libc = ctypes.CDLL("libc.so.6") | |
libc.malloc_trim(0) | |
gc.collect() | |
def init_models(): | |
models = {} | |
model_path = hf_hub_download("theodotus/tts-vits-lada-uk", "model.pt") | |
importer = package.PackageImporter(model_path) | |
models["lada"] = importer.load_pickle("tts_models", "model") | |
return models | |
async def tts(request: str, api_data: Auth = Depends(check_api_token)): | |
accented_text = accentification(request, "vocab") | |
plussed_text = stress_replace_and_shift(accented_text) | |
synt = models["lada"] | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as wav_fp: | |
with no_grad(): | |
wav_data = synt.tts(plussed_text, **tts_kwargs) | |
synt.save_wav(wav_data, wav_fp) | |
threading.Thread(target=delete_file_after_delay, args=(wav_fp.name, 300)).start() | |
return JSONResponse(content={"audio_url": f"https://pro100sata-xche-audio.hf.space/download_audio?audio_path={wav_fp.name}"}) | |
async def download_audio(audio_path: str): | |
return FileResponse(audio_path, media_type='audio/wav') | |
models = init_models() | |
def delete_file_after_delay(file_path: str, delay: int): | |
time.sleep(delay) | |
if os.path.exists(file_path): | |
os.remove(file_path) | |
class ArgParser(argparse.ArgumentParser): | |
def __init__(self, *args, **kwargs): | |
super(ArgParser, self).__init__(*args, **kwargs) | |
self.add_argument( | |
"-s", "--server", type=str, default="0.0.0.0", | |
help="Server IP for HF LLM Chat API", | |
) | |
self.add_argument( | |
"-p", "--port", type=int, default=7860, | |
help="Server Port for HF LLM Chat API", | |
) | |
self.args = self.parse_args(sys.argv[1:]) | |
if __name__ == "__main__": | |
args = ArgParser().args | |
uvicorn.run(app, host=args.server, port=args.port, reload=False) | |