Spaces:
Sleeping
Sleeping
File size: 4,301 Bytes
7da5a9a c0deafa 7da5a9a ca466ef 7da5a9a a6f4e74 7da5a9a a6f4e74 dca572a 7da5a9a a6f4e74 7da5a9a a6f4e74 7da5a9a a6f4e74 7da5a9a a6f4e74 7da5a9a a6f4e74 7da5a9a 9a18fdd a6f4e74 7da5a9a a6f4e74 7da5a9a a6f4e74 7da5a9a a6f4e74 15a7075 a6f4e74 7da5a9a a6f4e74 58373da 7da5a9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
import sys
import time
import threading
import tempfile
import ctypes
import gc
from fastapi import FastAPI, Depends, HTTPException, Security
from fastapi.responses import FileResponse, JSONResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel
from huggingface_hub import hf_hub_download
from torch import no_grad, package
import uvicorn
from accentor import accentification, stress_replace_and_shift
import argparse
from passlib.context import CryptContext
app = FastAPI(docs_url=None, redoc_url=None)
# Set environment variable for Hugging Face cache directory
os.environ["HF_HOME"] = "/app/.cache"
tts_kwargs = {
"speaker_name": "uk",
"language_name": "uk",
}
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
class Auth(BaseModel):
api_key: str
password: str
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_password_hash(password):
return pwd_context.hash(password)
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
api_key = os.getenv("XCHE_API_KEY")
password = os.getenv("XCHE_PASSWORD")
fake_data_db = {
api_key: {
"api_key": api_key,
"password": get_password_hash(password) # Pre-hashed password
}
}
def get_api_key(db, api_key: str):
if api_key in db:
api_dict = db[api_key]
return Auth(**api_dict)
def authenticate(fake_db, api_key: str, password: str):
api_data = get_api_key(fake_db, api_key)
if not api_data:
return False
if not verify_password(password, api_data.password):
return False
return api_data
@app.post("/token")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
api_data = authenticate(fake_data_db, form_data.username, form_data.password)
if not api_data:
raise HTTPException(
status_code=400,
detail="Incorrect API KEY or Password",
headers={"WWW-Authenticate": "Bearer"},
)
return {"access_token": api_data.api_key, "token_type": "bearer"}
def check_api_token(token: str = Depends(oauth2_scheme)):
api_data = get_api_key(fake_data_db, token)
if not api_data:
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
return api_data
def trim_memory():
libc = ctypes.CDLL("libc.so.6")
libc.malloc_trim(0)
gc.collect()
def init_models():
models = {}
model_path = hf_hub_download("theodotus/tts-vits-lada-uk", "model.pt")
importer = package.PackageImporter(model_path)
models["lada"] = importer.load_pickle("tts_models", "model")
return models
@app.post("/create_audio")
async def tts(request: str, api_data: Auth = Depends(check_api_token)):
accented_text = accentification(request, "vocab")
plussed_text = stress_replace_and_shift(accented_text)
synt = models["lada"]
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as wav_fp:
with no_grad():
wav_data = synt.tts(plussed_text, **tts_kwargs)
synt.save_wav(wav_data, wav_fp)
threading.Thread(target=delete_file_after_delay, args=(wav_fp.name, 300)).start()
return JSONResponse(content={"audio_url": f"https://pro100sata-xche-audio.hf.space/download_audio?audio_path={wav_fp.name}"})
@app.get("/download_audio")
async def download_audio(audio_path: str):
return FileResponse(audio_path, media_type='audio/wav')
models = init_models()
def delete_file_after_delay(file_path: str, delay: int):
time.sleep(delay)
if os.path.exists(file_path):
os.remove(file_path)
class ArgParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
super(ArgParser, self).__init__(*args, **kwargs)
self.add_argument(
"-s", "--server", type=str, default="0.0.0.0",
help="Server IP for HF LLM Chat API",
)
self.add_argument(
"-p", "--port", type=int, default=7860,
help="Server Port for HF LLM Chat API",
)
self.args = self.parse_args(sys.argv[1:])
if __name__ == "__main__":
args = ArgParser().args
uvicorn.run(app, host=args.server, port=args.port, reload=False)
|