Spaces:
Sleeping
Sleeping
Yarik
commited on
Commit
·
7da5a9a
1
Parent(s):
3ddf776
Add application file
Browse files- accentor.py +36 -0
- app.py +148 -0
accentor.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ukrainian_word_stress import StressSymbol
|
2 |
+
from ukrainian_accentor_transformer import Accentor
|
3 |
+
|
4 |
+
def stress_replace_and_shift(stressed: str):
|
5 |
+
stressed = stressed.replace(
|
6 |
+
StressSymbol.CombiningAcuteAccent, "+"
|
7 |
+
)
|
8 |
+
new_stressed = ""
|
9 |
+
start = 0
|
10 |
+
last = 0
|
11 |
+
|
12 |
+
while True:
|
13 |
+
plus_position = stressed.find("+", start)
|
14 |
+
if plus_position != -1:
|
15 |
+
new_stressed += (
|
16 |
+
stressed[last : plus_position - 1] + "+" + stressed[plus_position - 1]
|
17 |
+
)
|
18 |
+
start = plus_position + 1
|
19 |
+
last = start
|
20 |
+
else:
|
21 |
+
new_stressed += stressed[last:]
|
22 |
+
break
|
23 |
+
return new_stressed
|
24 |
+
|
25 |
+
accentor_transformer = Accentor()
|
26 |
+
|
27 |
+
def accentification(sentence: str):
|
28 |
+
|
29 |
+
sentence = sentence.replace("+", "")
|
30 |
+
sentence = sentence.replace(
|
31 |
+
StressSymbol.CombiningAcuteAccent, ""
|
32 |
+
)
|
33 |
+
|
34 |
+
accented_sentence = accentor_transformer(sentence)
|
35 |
+
|
36 |
+
return accented_sentence
|
app.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
import threading
|
5 |
+
import tempfile
|
6 |
+
import ctypes
|
7 |
+
import gc
|
8 |
+
from typing import Optional
|
9 |
+
from fastapi import FastAPI, Depends, HTTPException
|
10 |
+
from fastapi.responses import FileResponse, JSONResponse
|
11 |
+
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
12 |
+
from pydantic import BaseModel
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
+
from torch import no_grad, package
|
15 |
+
from pydub import AudioSegment
|
16 |
+
import uvicorn
|
17 |
+
from accentor import accentification, stress_replace_and_shift
|
18 |
+
import argparse
|
19 |
+
from passlib.context import CryptContext
|
20 |
+
|
21 |
+
app = FastAPI()
|
22 |
+
|
23 |
+
# Set environment variable for Hugging Face cache directory
|
24 |
+
os.environ["HF_HOME"] = "/app/.cache"
|
25 |
+
|
26 |
+
tts_kwargs = {
|
27 |
+
"speaker_name": "uk",
|
28 |
+
"language_name": "uk",
|
29 |
+
}
|
30 |
+
|
31 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
32 |
+
|
33 |
+
class User(BaseModel):
|
34 |
+
username: str
|
35 |
+
password: str
|
36 |
+
|
37 |
+
# Password hashing context
|
38 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
39 |
+
|
40 |
+
def get_password_hash(password):
|
41 |
+
return pwd_context.hash(password)
|
42 |
+
|
43 |
+
def verify_password(plain_password, hashed_password):
|
44 |
+
return pwd_context.verify(plain_password, hashed_password)
|
45 |
+
|
46 |
+
# Load secret key from Hugging Face secrets
|
47 |
+
username = os.getenv("XCHE_SECRET_AUTH")
|
48 |
+
password = os.getenv("XCHE_SECRET_PASSWORD")
|
49 |
+
|
50 |
+
|
51 |
+
print(username)
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
# In-memory storage for simplicity; in production use a database
|
56 |
+
fake_users_db = {
|
57 |
+
username: {
|
58 |
+
"username": username,
|
59 |
+
"password": get_password_hash(password) # Pre-hashed password
|
60 |
+
}
|
61 |
+
}
|
62 |
+
|
63 |
+
def get_user(db, username: str):
|
64 |
+
if username in db:
|
65 |
+
user_dict = db[username]
|
66 |
+
return User(**user_dict)
|
67 |
+
|
68 |
+
def authenticate_user(fake_db, username: str, password: str):
|
69 |
+
user = get_user(fake_db, username)
|
70 |
+
if not user:
|
71 |
+
return False
|
72 |
+
if not verify_password(password, user.password):
|
73 |
+
return False
|
74 |
+
return user
|
75 |
+
|
76 |
+
@app.post("/token")
|
77 |
+
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
78 |
+
user = authenticate_user(fake_users_db, form_data.username, form_data.password)
|
79 |
+
if not user:
|
80 |
+
raise HTTPException(
|
81 |
+
status_code=400,
|
82 |
+
detail="Incorrect username or password",
|
83 |
+
headers={"WWW-Authenticate": "Bearer"},
|
84 |
+
)
|
85 |
+
return {"access_token": user.username, "token_type": "bearer"}
|
86 |
+
|
87 |
+
def check_api_token(token: str = Depends(oauth2_scheme)):
|
88 |
+
user = get_user(fake_users_db, token)
|
89 |
+
if not user:
|
90 |
+
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
91 |
+
return user
|
92 |
+
|
93 |
+
def trim_memory():
|
94 |
+
libc = ctypes.CDLL("libc.so.6")
|
95 |
+
libc.malloc_trim(0)
|
96 |
+
gc.collect()
|
97 |
+
|
98 |
+
def init_models():
|
99 |
+
models = {}
|
100 |
+
model_path = hf_hub_download("theodotus/tts-vits-lada-uk", "model.pt")
|
101 |
+
importer = package.PackageImporter(model_path)
|
102 |
+
models["lada"] = importer.load_pickle("tts_models", "model")
|
103 |
+
return models
|
104 |
+
|
105 |
+
@app.post("/create_audio")
|
106 |
+
async def tts(request: str, user: User = Depends(check_api_token)):
|
107 |
+
|
108 |
+
print(request)
|
109 |
+
|
110 |
+
accented_text = accentification(request)
|
111 |
+
plussed_text = stress_replace_and_shift(accented_text)
|
112 |
+
|
113 |
+
synt = models["lada"]
|
114 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as wav_fp:
|
115 |
+
with no_grad():
|
116 |
+
wav_data = synt.tts(plussed_text, **tts_kwargs)
|
117 |
+
synt.save_wav(wav_data, wav_fp)
|
118 |
+
|
119 |
+
threading.Thread(target=delete_file_after_delay, args=(wav_fp.name, 300)).start()
|
120 |
+
return JSONResponse(content={"audio_url": f"https://pro100sata-xche-audio.hf.space/download_audio?audio_path={wav_fp.name}"})
|
121 |
+
|
122 |
+
@app.get("/download_audio")
|
123 |
+
async def download_audio(audio_path: str):
|
124 |
+
return FileResponse(audio_path, media_type='audio/wav')
|
125 |
+
|
126 |
+
models = init_models()
|
127 |
+
|
128 |
+
def delete_file_after_delay(file_path: str, delay: int):
|
129 |
+
time.sleep(delay)
|
130 |
+
if os.path.exists(file_path):
|
131 |
+
os.remove(file_path)
|
132 |
+
|
133 |
+
class ArgParser(argparse.ArgumentParser):
|
134 |
+
def __init__(self, *args, **kwargs):
|
135 |
+
super(ArgParser, self).__init__(*args, **kwargs)
|
136 |
+
self.add_argument(
|
137 |
+
"-s", "--server", type=str, default="0.0.0.0",
|
138 |
+
help="Server IP for HF LLM Chat API",
|
139 |
+
)
|
140 |
+
self.add_argument(
|
141 |
+
"-p", "--port", type=int, default=7860,
|
142 |
+
help="Server Port for HF LLM Chat API",
|
143 |
+
)
|
144 |
+
self.args = self.parse_args(sys.argv[1:])
|
145 |
+
|
146 |
+
if __name__ == "__main__":
|
147 |
+
args = ArgParser().args
|
148 |
+
uvicorn.run(app, host=args.server, port=args.port, reload=False)
|