Yarik commited on
Commit
7da5a9a
·
1 Parent(s): 3ddf776

Add application file

Browse files
Files changed (2) hide show
  1. accentor.py +36 -0
  2. app.py +148 -0
accentor.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ukrainian_word_stress import StressSymbol
2
+ from ukrainian_accentor_transformer import Accentor
3
+
4
+ def stress_replace_and_shift(stressed: str):
5
+ stressed = stressed.replace(
6
+ StressSymbol.CombiningAcuteAccent, "+"
7
+ )
8
+ new_stressed = ""
9
+ start = 0
10
+ last = 0
11
+
12
+ while True:
13
+ plus_position = stressed.find("+", start)
14
+ if plus_position != -1:
15
+ new_stressed += (
16
+ stressed[last : plus_position - 1] + "+" + stressed[plus_position - 1]
17
+ )
18
+ start = plus_position + 1
19
+ last = start
20
+ else:
21
+ new_stressed += stressed[last:]
22
+ break
23
+ return new_stressed
24
+
25
+ accentor_transformer = Accentor()
26
+
27
+ def accentification(sentence: str):
28
+
29
+ sentence = sentence.replace("+", "")
30
+ sentence = sentence.replace(
31
+ StressSymbol.CombiningAcuteAccent, ""
32
+ )
33
+
34
+ accented_sentence = accentor_transformer(sentence)
35
+
36
+ return accented_sentence
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import threading
5
+ import tempfile
6
+ import ctypes
7
+ import gc
8
+ from typing import Optional
9
+ from fastapi import FastAPI, Depends, HTTPException
10
+ from fastapi.responses import FileResponse, JSONResponse
11
+ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
12
+ from pydantic import BaseModel
13
+ from huggingface_hub import hf_hub_download
14
+ from torch import no_grad, package
15
+ from pydub import AudioSegment
16
+ import uvicorn
17
+ from accentor import accentification, stress_replace_and_shift
18
+ import argparse
19
+ from passlib.context import CryptContext
20
+
21
+ app = FastAPI()
22
+
23
+ # Set environment variable for Hugging Face cache directory
24
+ os.environ["HF_HOME"] = "/app/.cache"
25
+
26
+ tts_kwargs = {
27
+ "speaker_name": "uk",
28
+ "language_name": "uk",
29
+ }
30
+
31
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
32
+
33
+ class User(BaseModel):
34
+ username: str
35
+ password: str
36
+
37
+ # Password hashing context
38
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
39
+
40
+ def get_password_hash(password):
41
+ return pwd_context.hash(password)
42
+
43
+ def verify_password(plain_password, hashed_password):
44
+ return pwd_context.verify(plain_password, hashed_password)
45
+
46
+ # Load secret key from Hugging Face secrets
47
+ username = os.getenv("XCHE_SECRET_AUTH")
48
+ password = os.getenv("XCHE_SECRET_PASSWORD")
49
+
50
+
51
+ print(username)
52
+
53
+
54
+
55
+ # In-memory storage for simplicity; in production use a database
56
+ fake_users_db = {
57
+ username: {
58
+ "username": username,
59
+ "password": get_password_hash(password) # Pre-hashed password
60
+ }
61
+ }
62
+
63
+ def get_user(db, username: str):
64
+ if username in db:
65
+ user_dict = db[username]
66
+ return User(**user_dict)
67
+
68
+ def authenticate_user(fake_db, username: str, password: str):
69
+ user = get_user(fake_db, username)
70
+ if not user:
71
+ return False
72
+ if not verify_password(password, user.password):
73
+ return False
74
+ return user
75
+
76
+ @app.post("/token")
77
+ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
78
+ user = authenticate_user(fake_users_db, form_data.username, form_data.password)
79
+ if not user:
80
+ raise HTTPException(
81
+ status_code=400,
82
+ detail="Incorrect username or password",
83
+ headers={"WWW-Authenticate": "Bearer"},
84
+ )
85
+ return {"access_token": user.username, "token_type": "bearer"}
86
+
87
+ def check_api_token(token: str = Depends(oauth2_scheme)):
88
+ user = get_user(fake_users_db, token)
89
+ if not user:
90
+ raise HTTPException(status_code=403, detail="Invalid or missing API Key")
91
+ return user
92
+
93
+ def trim_memory():
94
+ libc = ctypes.CDLL("libc.so.6")
95
+ libc.malloc_trim(0)
96
+ gc.collect()
97
+
98
+ def init_models():
99
+ models = {}
100
+ model_path = hf_hub_download("theodotus/tts-vits-lada-uk", "model.pt")
101
+ importer = package.PackageImporter(model_path)
102
+ models["lada"] = importer.load_pickle("tts_models", "model")
103
+ return models
104
+
105
+ @app.post("/create_audio")
106
+ async def tts(request: str, user: User = Depends(check_api_token)):
107
+
108
+ print(request)
109
+
110
+ accented_text = accentification(request)
111
+ plussed_text = stress_replace_and_shift(accented_text)
112
+
113
+ synt = models["lada"]
114
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as wav_fp:
115
+ with no_grad():
116
+ wav_data = synt.tts(plussed_text, **tts_kwargs)
117
+ synt.save_wav(wav_data, wav_fp)
118
+
119
+ threading.Thread(target=delete_file_after_delay, args=(wav_fp.name, 300)).start()
120
+ return JSONResponse(content={"audio_url": f"https://pro100sata-xche-audio.hf.space/download_audio?audio_path={wav_fp.name}"})
121
+
122
+ @app.get("/download_audio")
123
+ async def download_audio(audio_path: str):
124
+ return FileResponse(audio_path, media_type='audio/wav')
125
+
126
+ models = init_models()
127
+
128
+ def delete_file_after_delay(file_path: str, delay: int):
129
+ time.sleep(delay)
130
+ if os.path.exists(file_path):
131
+ os.remove(file_path)
132
+
133
+ class ArgParser(argparse.ArgumentParser):
134
+ def __init__(self, *args, **kwargs):
135
+ super(ArgParser, self).__init__(*args, **kwargs)
136
+ self.add_argument(
137
+ "-s", "--server", type=str, default="0.0.0.0",
138
+ help="Server IP for HF LLM Chat API",
139
+ )
140
+ self.add_argument(
141
+ "-p", "--port", type=int, default=7860,
142
+ help="Server Port for HF LLM Chat API",
143
+ )
144
+ self.args = self.parse_args(sys.argv[1:])
145
+
146
+ if __name__ == "__main__":
147
+ args = ArgParser().args
148
+ uvicorn.run(app, host=args.server, port=args.port, reload=False)