File size: 6,593 Bytes
38ae436 e5563d8 38ae436 b57d37a 38ae436 b57d37a 38ae436 b57d37a eb2c3bb b57d37a 38ae436 eb2c3bb 38ae436 eb2c3bb 38ae436 b57d37a 38ae436 eb2c3bb b57d37a 38ae436 e5563d8 38ae436 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import spaces
import os
import random
import argparse
import torch
import gradio as gr
import numpy as np
import ChatTTS
import se_extractor
from api import BaseSpeakerTTS, ToneColorConverter
print("loading ChatTTS model...")
chat = ChatTTS.Chat()
chat.load_models()
def generate_seed():
new_seed = random.randint(1, 100000000)
return {
"__type__": "update",
"value": new_seed
}
@spaces.GPU
def chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, output_path=None):
torch.manual_seed(audio_seed_input)
rand_spk = torch.randn(768)
params_infer_code = {
'spk_emb': rand_spk,
'temperature': temperature,
'top_P': top_P,
'top_K': top_K,
}
params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'}
torch.manual_seed(text_seed_input)
if refine_text_flag:
if refine_text_input:
params_refine_text['prompt'] = refine_text_input
text = chat.infer(text,
skip_refine_text=False,
refine_text_only=True,
params_refine_text=params_refine_text,
params_infer_code=params_infer_code
)
wav = chat.infer(text,
skip_refine_text=True,
params_refine_text=params_refine_text,
params_infer_code=params_infer_code
)
audio_data = np.array(wav[0]).flatten()
sample_rate = 24000
text_data = text[0] if isinstance(text, list) else text
if output_path is None:
return [(sample_rate, audio_data), text_data]
else:
soundfile.write(output_path, audio_data, sample_rate)
# OpenVoice
ckpt_base_en = 'checkpoints/base_speakers/EN'
ckpt_converter_en = 'checkpoints/converter'
device = 'cuda:0'
#device = "cpu"
base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base_en}/config.json', device=device)
base_speaker_tts.load_ckpt(f'{ckpt_base_en}/checkpoint.pth')
tone_color_converter = ToneColorConverter(f'{ckpt_converter_en}/config.json', device=device)
tone_color_converter.load_ckpt(f'{ckpt_converter_en}/checkpoint.pth')
def generate_audio(text, audio_ref, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input):
if style_mode=="default":
source_se = torch.load(f'{ckpt_base_en}/en_default_se.pth').to(device)
reference_speaker = audio_ref
target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
save_path = "output.wav"
# Run the base speaker tts
src_path = "tmp.wav"
chat_tts(text, text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, output_path=None, src_path)
# Run the tone color converter
encode_message = "@MyShell"
tone_color_converter.convert(
audio_src_path=src_path,
src_se=source_se,
tgt_se=target_se,
output_path=save_path,
message=encode_message)
else:
source_se = torch.load(f'{ckpt_base_en}/en_style_se.pth').to(device)
reference_speaker = audio_ref
target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
save_path = "output.wav"
# Run the base speaker tts
src_path = "tmp.wav"
base_speaker_tts.tts(text, src_path, speaker=style_mode, language='English', speed=0.9)
# Run the tone color converter
encode_message = "@MyShell"
tone_color_converter.convert(
audio_src_path=src_path,
src_se=source_se,
tgt_se=target_se,
output_path=save_path,
message=encode_message)
return "output.wav"
with gr.Blocks() as demo:
gr.Markdown("# Enjoy chatting with your ai friends on website, telegram and so on! (https://linkin.love)")
default_text = "Today a man knocked on my door and asked for a small donation toward the local swimming pool. I gave him a glass of water."
text_input = gr.Textbox(label="Input Text", lines=4, placeholder="Please Input Text...", value=default_text)
default_refine_text = "[oral_2][laugh_0][break_6]"
refine_text_checkbox = gr.Checkbox(label="Refine text:'oral' means add filler words, 'laugh' means add laughter, and 'break' means add a pause. (0-10) ", value=True)
refine_text_input = gr.Textbox(label="Refine Prompt", lines=1, placeholder="Please Refine Prompt...", value=default_refine_text)
with gr.Column():
clone_voice = gr.Audio(label="请上传您喜欢的语音文件", type="filepath")
with gr.Row():
temperature_slider = gr.Slider(minimum=0.00001, maximum=1.0, step=0.00001, value=0.3, label="Audio temperature")
top_p_slider = gr.Slider(minimum=0.1, maximum=0.9, step=0.05, value=0.7, label="top_P")
top_k_slider = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_K")
with gr.Row():
audio_seed_input = gr.Number(value=42, label="Speaker Seed")
generate_audio_seed = gr.Button("\U0001F3B2")
text_seed_input = gr.Number(value=42, label="Text Seed")
generate_text_seed = gr.Button("\U0001F3B2")
generate_button = gr.Button("Generate")
#text_output = gr.Textbox(label="Refined Text", interactive=False)
audio_output = gr.Audio(label="Output Audio")
generate_audio_seed.click(generate_seed,
inputs=[],
outputs=audio_seed_input)
generate_text_seed.click(generate_seed,
inputs=[],
outputs=text_seed_input)
generate_button.click(generate_audio,
inputs=[text_input, clone_voice, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox, refine_text_input],
outputs=audio_output)
parser = argparse.ArgumentParser(description='ChatTTS demo Launch')
parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
parser.add_argument('--server_port', type=int, default=8080, help='Server port')
args = parser.parse_args()
# demo.launch(server_name=args.server_name, server_port=args.server_port, inbrowser=True)
if __name__ == '__main__':
demo.launch() |