Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,328 Bytes
daaf1ba aa59806 daaf1ba aa59806 daaf1ba aa59806 daaf1ba aa59806 1dda0cc aa59806 1dda0cc aa59806 daaf1ba aa59806 daaf1ba aa59806 daaf1ba aa59806 daaf1ba aa59806 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import argparse
import gc
import socket
import struct
import torch
import torchaudio
import traceback
from importlib.resources import files
from threading import Thread
from cached_path import cached_path
from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model
from model.backbones.dit import DiT
class TTSStreamingProcessor:
def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
self.device = device or (
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)
# Load the model using the provided checkpoint and vocab files
self.model = load_model(
model_cls=DiT,
model_cfg=dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
ckpt_path=ckpt_file,
mel_spec_type="vocos", # or "bigvgan" depending on vocoder
vocab_file=vocab_file,
ode_method="euler",
use_ema=True,
device=self.device,
).to(self.device, dtype=dtype)
# Load the vocoder
self.vocoder = load_vocoder(is_local=False)
# Set sampling rate for streaming
self.sampling_rate = 24000 # Consistency with client
# Set reference audio and text
self.ref_audio = ref_audio
self.ref_text = ref_text
# Warm up the model
self._warm_up()
def _warm_up(self):
"""Warm up the model with a dummy input to ensure it's ready for real-time processing."""
print("Warming up the model...")
ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
audio, sr = torchaudio.load(ref_audio)
gen_text = "Warm-up text for the model."
# Pass the vocoder as an argument here
infer_batch_process((audio, sr), ref_text, [gen_text], self.model, self.vocoder, device=self.device)
print("Warm-up completed.")
def generate_stream(self, text, play_steps_in_s=0.5):
"""Generate audio in chunks and yield them in real-time."""
# Preprocess the reference audio and text
ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
# Load reference audio
audio, sr = torchaudio.load(ref_audio)
# Run inference for the input text
audio_chunk, final_sample_rate, _ = infer_batch_process(
(audio, sr),
ref_text,
[text],
self.model,
self.vocoder,
device=self.device, # Pass vocoder here
)
# Break the generated audio into chunks and send them
chunk_size = int(final_sample_rate * play_steps_in_s)
if len(audio_chunk) < chunk_size:
packed_audio = struct.pack(f"{len(audio_chunk)}f", *audio_chunk)
yield packed_audio
return
for i in range(0, len(audio_chunk), chunk_size):
chunk = audio_chunk[i : i + chunk_size]
# Check if it's the final chunk
if i + chunk_size >= len(audio_chunk):
chunk = audio_chunk[i:]
# Send the chunk if it is not empty
if len(chunk) > 0:
packed_audio = struct.pack(f"{len(chunk)}f", *chunk)
yield packed_audio
def handle_client(client_socket, processor):
try:
while True:
# Receive data from the client
data = client_socket.recv(1024).decode("utf-8")
if not data:
break
try:
# The client sends the text input
text = data.strip()
# Generate and stream audio chunks
for audio_chunk in processor.generate_stream(text):
client_socket.sendall(audio_chunk)
# Send end-of-audio signal
client_socket.sendall(b"END_OF_AUDIO")
except Exception as inner_e:
print(f"Error during processing: {inner_e}")
traceback.print_exc() # Print the full traceback to diagnose the issue
break
except Exception as e:
print(f"Error handling client: {e}")
traceback.print_exc()
finally:
client_socket.close()
def start_server(host, port, processor):
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind((host, port))
server.listen(5)
print(f"Server listening on {host}:{port}")
while True:
client_socket, addr = server.accept()
print(f"Accepted connection from {addr}")
client_handler = Thread(target=handle_client, args=(client_socket, processor))
client_handler.start()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", default=9998)
parser.add_argument(
"--ckpt_file",
default=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")),
help="Path to the model checkpoint file",
)
parser.add_argument(
"--vocab_file",
default="",
help="Path to the vocab file if customized",
)
parser.add_argument(
"--ref_audio",
default=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
help="Reference audio to provide model with speaker characteristics",
)
parser.add_argument(
"--ref_text",
default="",
help="Reference audio subtitle, leave empty to auto-transcribe",
)
parser.add_argument("--device", default=None, help="Device to run the model on")
parser.add_argument("--dtype", default=torch.float32, help="Data type to use for model inference")
args = parser.parse_args()
try:
# Initialize the processor with the model and vocoder
processor = TTSStreamingProcessor(
ckpt_file=args.ckpt_file,
vocab_file=args.vocab_file,
ref_audio=args.ref_audio,
ref_text=args.ref_text,
device=args.device,
dtype=args.dtype,
)
# Start the server
start_server(args.host, args.port, processor)
except KeyboardInterrupt:
gc.collect()
|