File size: 3,536 Bytes
b2458f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
from io import BytesIO
from typing import Dict, List
import torch
from av import open as avopen
from flask import Flask, request, Response
from scipy.io import wavfile
import utils
from config import config
from infer import infer, get_net_g, latest_version
# Flask Init
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
def replace_punctuation(text, i=2):
punctuation = ",。?!"
for char in punctuation:
text = text.replace(char, char * i)
return text
def wav2(i, o, format):
inp = avopen(i, "rb")
out = avopen(o, "wb", format=format)
if format == "ogg":
format = "libvorbis"
ostream = out.add_stream(format)
for frame in inp.decode(audio=0):
for p in ostream.encode(frame):
out.mux(p)
for p in ostream.encode(None):
out.mux(p)
out.close()
inp.close()
net_g_List = []
hps_List = []
# 模型角色字典
# 使用方法 chr_name = chrsMap[model_id][chr_id]
chrsMap: List[Dict[int, str]] = list()
# 加载模型
models = config.server_config.models
for model in models:
hps_List.append(utils.get_hparams_from_file(model["config"]))
# 添加角色字典
chrsMap.append(dict())
for name, cid in hps_List[-1].data.spk2id.items():
chrsMap[-1][cid] = name
version = (hps_List[-1].version if hasattr(hps_List[-1], "version") else latest_version)
net_g_List.append(get_net_g(model_path=model["model"], device=model["device"], hps=hps_List[-1], ))
@app.route("/")
def main():
try:
model = int(request.args.get("model"))
speaker = request.args.get("speaker", "") # 指定人物名
speaker_id = request.args.get("speaker_id", None) # 直接指定id
text = request.args.get("text").replace("/n", "")
sdp_ratio = float(request.args.get("sdp_ratio", 0.2))
noise = float(request.args.get("noise", 0.5))
noisew = float(request.args.get("noisew", 0.6))
length = float(request.args.get("length", 1.2))
language = request.args.get("language")
if length >= 2:
return "Too big length"
if len(text) >= 250:
return "Too long text"
fmt = request.args.get("format", "wav")
if None in (speaker, text):
return "Missing Parameter"
if fmt not in ("mp3", "wav", "ogg"):
return "Invalid Format"
if language not in ("SH", "ZH"):
return "Invalid language"
except:
return "Invalid Parameter"
if speaker_id is not None:
if speaker_id.isdigit():
speaker = chrsMap[model][int(speaker_id)]
with torch.no_grad():
audio = infer(
text=text,
sdp_ratio=sdp_ratio,
noise_scale=noise,
noise_scale_w=noisew,
length_scale=length,
sid=speaker,
language=models[model]["language"],
hps=hps_List[model],
net_g=net_g_List[model],
device=models[model]["device"],
)
with BytesIO() as wav:
wavfile.write(wav, hps_List[model].data.sampling_rate, audio)
torch.cuda.empty_cache()
if fmt == "wav":
return Response(wav.getvalue(), mimetype="audio/wav")
wav.seek(0, 0)
with BytesIO() as ofp:
wav2(wav, ofp, fmt)
return Response(
ofp.getvalue(), mimetype="audio/mpeg" if fmt == "mp3" else "audio/ogg"
)
if __name__ == "__main__":
app.run(port=config.server_config.port)
|