ChatVC / app.py
Hilley's picture
Update app.py
869769a verified
raw
history blame
10 kB
import spaces
import os
import random
import argparse
import torch
import gradio as gr
import numpy as np
import ChatTTS
import se_extractor
from api import BaseSpeakerTTS, ToneColorConverter
import soundfile
from tts_voice import tts_order_voice
import edge_tts
import tempfile
import anyio
print("loading ChatTTS model...")
chat = ChatTTS.Chat()
chat.load_models()
def generate_seed():
new_seed = random.randint(1, 100000000)
return {
"__type__": "update",
"value": new_seed
}
@spaces.GPU
def chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, output_path=None):
torch.manual_seed(audio_seed_input)
rand_spk = torch.randn(768)
params_infer_code = {
'spk_emb': rand_spk,
'temperature': temperature,
'top_P': top_P,
'top_K': top_K,
}
params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'}
torch.manual_seed(text_seed_input)
if refine_text_flag:
if refine_text_input:
params_refine_text['prompt'] = refine_text_input
text = chat.infer(text,
skip_refine_text=False,
refine_text_only=True,
params_refine_text=params_refine_text,
params_infer_code=params_infer_code
)
print("Text has been refined!")
wav = chat.infer(text,
skip_refine_text=True,
params_refine_text=params_refine_text,
params_infer_code=params_infer_code
)
audio_data = np.array(wav[0]).flatten()
sample_rate = 24000
text_data = text[0] if isinstance(text, list) else text
if output_path is None:
return [(sample_rate, audio_data), text_data]
else:
soundfile.write(output_path, audio_data, sample_rate)
# OpenVoice
ckpt_base_en = 'checkpoints/base_speakers/EN'
ckpt_converter_en = 'checkpoints/converter'
device = 'cuda:0'
#device = "cpu"
base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base_en}/config.json', device=device)
base_speaker_tts.load_ckpt(f'{ckpt_base_en}/checkpoint.pth')
tone_color_converter = ToneColorConverter(f'{ckpt_converter_en}/config.json', device=device)
tone_color_converter.load_ckpt(f'{ckpt_converter_en}/checkpoint.pth')
def generate_audio(text, audio_ref, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input):
source_se = torch.load(f'{ckpt_base_en}/en_default_se.pth').to(device)
reference_speaker = audio_ref
target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
save_path = "output.wav"
# Run the base speaker tts
src_path = "tmp.wav"
chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, src_path)
print("Ready for voice cloning!")
source_se, audio_name = se_extractor.get_se(src_path, tone_color_converter, target_dir='processed', vad=True)
print("Get source segment!")
# Run the tone color converter
encode_message = "@Hilley"
# convert from file
tone_color_converter.convert(
audio_src_path=src_path,
src_se=source_se,
tgt_se=target_se,
output_path=save_path,
message=encode_message)
'''
# convert from data
src_path = None
sample_rate, audio = chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, src_path)[0]
print("Ready for voice cloning!")
tone_color_converter.convert_data(
audio=audio,
sample_rate=sample_rate,
src_se=source_se,
tgt_se=target_se,
output_path=save_path,
message=encode_message)
'''
print("Finished!")
return "output.wav"
def vc_en(text, audio_ref, style_mode):
if style_mode=="default":
source_se = torch.load(f'{ckpt_base_en}/en_default_se.pth').to(device)
reference_speaker = audio_ref
target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
save_path = "output.wav"
# Run the base speaker tts
src_path = "tmp.wav"
base_speaker_tts.tts(text, src_path, speaker='default', language='English', speed=1.0)
# Run the tone color converter
encode_message = "@MyShell"
tone_color_converter.convert(
audio_src_path=src_path,
src_se=source_se,
tgt_se=target_se,
output_path=save_path,
message=encode_message)
else:
source_se = torch.load(f'{ckpt_base_en}/en_style_se.pth').to(device)
reference_speaker = audio_ref
target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
save_path = "output.wav"
# Run the base speaker tts
src_path = "tmp.wav"
base_speaker_tts.tts(text, src_path, speaker=style_mode, language='English', speed=0.9)
# Run the tone color converter
encode_message = "@MyShell"
tone_color_converter.convert(
audio_src_path=src_path,
src_se=source_se,
tgt_se=target_se,
output_path=save_path,
message=encode_message)
return "output.wav"
language_dict = tts_order_voice
base_speaker = "base_audio.mp3"
source_se, audio_name = se_extractor.get_se(base_speaker, tone_color_converter, vad=True)
async def text_to_speech_edge(text, audio_ref, language_code):
voice = language_dict[language_code]
communicate = edge_tts.Communicate(text, voice)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
tmp_path = tmp_file.name
await communicate.save(tmp_path)
reference_speaker = audio_ref
target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
save_path = "output.wav"
# Run the tone color converter
encode_message = "@MyShell"
tone_color_converter.convert(
audio_src_path=tmp_path,
src_se=source_se,
tgt_se=target_se,
output_path=save_path,
message=encode_message)
return "output.wav"
with gr.Blocks() as demo:
# gr.Markdown("# ❣️❣️")
default_text = "Today a man knocked on my door and asked for a small donation toward the local swimming pool. I gave him a glass of water."
text_input = gr.Textbox(label="Input Text", lines=4, placeholder="Please Input Text...", value=default_text)
voice_ref = gr.Audio(label="Reference Audio", type="filepath", value="base_audio.mp3")
with gr.Tab("💕Super Natural"):
default_refine_text = "[oral_2][laugh_0][break_6]"
refine_text_checkbox = gr.Checkbox(label="Refine text", info="'oral' means add filler words, 'laugh' means add laughter, and 'break' means add a pause. (0-10) ", value=True)
refine_text_input = gr.Textbox(label="Refine Prompt", lines=1, placeholder="Please Refine Prompt...", value=default_refine_text)
with gr.Row():
temperature_slider = gr.Slider(minimum=0.00001, maximum=1.0, step=0.00001, value=0.3, label="Audio temperature")
top_p_slider = gr.Slider(minimum=0.1, maximum=0.9, step=0.05, value=0.7, label="top_P")
top_k_slider = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_K")
with gr.Row():
audio_seed_input = gr.Number(value=42, label="Speaker Seed")
generate_audio_seed = gr.Button("\U0001F3B2")
text_seed_input = gr.Number(value=42, label="Text Seed")
generate_text_seed = gr.Button("\U0001F3B2")
generate_button = gr.Button("Generate!")
#text_output = gr.Textbox(label="Refined Text", interactive=False)
audio_output = gr.Audio(label="Output Audio")
generate_audio_seed.click(generate_seed,
inputs=[],
outputs=audio_seed_input)
generate_text_seed.click(generate_seed,
inputs=[],
outputs=text_seed_input)
generate_button.click(generate_audio,
inputs=[text_input, voice_ref, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox, refine_text_input],
outputs=audio_output)
with gr.Tab("💕Emotion Control"):
emo_pick = gr.Dropdown(label="Emotion", info="🙂default😊friendly🤫whispering😄cheerful😱terrified😡angry😢sad", choices=["default", "friendly", "whispering", "cheerful", "terrified", "angry", "sad"], value="default")
generate_button_emo = gr.Button("Generate!", variant="primary")
audio_emo = gr.Audio(label="Output Audio", type="filepath")
generate_button_emo.click(vc_en, [text_input, voice_ref, emo_pick], audio_emo)
with gr.Tab("💕multilingual"):
language = gr.Dropdown(choices=list(language_dict.keys()), value=list(language_dict.keys())[15], label="Language")
generate_button_ml = gr.Button("Generate!", variant="primary")
audio_ml = gr.Audio(label="Output Audio", type="filepath")
generate_button_ml.click(text_to_speech_edge, [text_input, voice_ref, language], audio_ml)
parser = argparse.ArgumentParser(description='ChatVC demo Launch')
parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
parser.add_argument('--server_port', type=int, default=8080, help='Server port')
args = parser.parse_args()
# demo.launch(server_name=args.server_name, server_port=args.server_port, inbrowser=True)
if __name__ == '__main__':
demo.launch()