|
import gradio as gr |
|
from huggingface_hub import snapshot_download |
|
from threading import Thread |
|
import time |
|
import base64 |
|
import numpy as np |
|
import requests |
|
import traceback |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
import io |
|
import wave |
|
import tempfile |
|
import librosa |
|
from utils.vad import get_speech_timestamps, collect_chunks, VadOptions |
|
|
|
|
|
from server import serve |
|
|
|
repo_id = "gpt-omni/mini-omni" |
|
snapshot_download(repo_id, local_dir="./checkpoint", revision="main") |
|
|
|
IP='0.0.0.0' |
|
PORT=60808 |
|
|
|
thread = Thread(target=serve, daemon=True) |
|
thread.start() |
|
|
|
API_URL = "http://0.0.0.0:60808/chat" |
|
|
|
|
|
IN_CHANNELS = 1 |
|
IN_RATE = 24000 |
|
IN_CHUNK = 1024 |
|
IN_SAMPLE_WIDTH = 2 |
|
VAD_STRIDE = 0.5 |
|
|
|
|
|
OUT_CHANNELS = 1 |
|
OUT_RATE = 24000 |
|
OUT_SAMPLE_WIDTH = 2 |
|
OUT_CHUNK = 5760 |
|
|
|
|
|
|
|
OUT_CHUNK = 4096 |
|
OUT_RATE = 24000 |
|
OUT_CHANNELS = 1 |
|
|
|
def run_vad(ori_audio, sr): |
|
_st = time.time() |
|
try: |
|
audio = np.frombuffer(ori_audio, dtype=np.int16) |
|
audio = audio.astype(np.float32) / 32768.0 |
|
sampling_rate = 16000 |
|
if sr != sampling_rate: |
|
audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate) |
|
|
|
vad_parameters = {} |
|
vad_parameters = VadOptions(**vad_parameters) |
|
speech_chunks = get_speech_timestamps(audio, vad_parameters) |
|
audio = collect_chunks(audio, speech_chunks) |
|
duration_after_vad = audio.shape[0] / sampling_rate |
|
|
|
if sr != sampling_rate: |
|
|
|
vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr) |
|
else: |
|
vad_audio = audio |
|
vad_audio = np.round(vad_audio * 32768.0).astype(np.int16) |
|
vad_audio_bytes = vad_audio.tobytes() |
|
|
|
return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4) |
|
except Exception as e: |
|
msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}" |
|
print(msg) |
|
return -1, ori_audio, round(time.time() - _st, 4) |
|
|
|
|
|
def warm_up(): |
|
frames = b"\x00\x00" * 1024 * 2 |
|
dur, frames, tcost = run_vad(frames, 16000) |
|
print(f"warm up done, time_cost: {tcost:.3f} s") |
|
|
|
warm_up() |
|
|
|
|
|
def determine_pause(stream: bytes, start_talking: bool) -> tuple[bool, bool]: |
|
"""Take in the stream, determine if a pause happened""" |
|
|
|
temp_audio = stream |
|
|
|
if len(temp_audio) > IN_SAMPLE_WIDTH * IN_RATE * IN_CHANNELS * VAD_STRIDE: |
|
dur_vad, _, time_vad = run_vad(temp_audio, IN_RATE) |
|
|
|
print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s") |
|
|
|
if dur_vad > 0.2 and not start_talking: |
|
start_talking = True |
|
pause = False |
|
return pause, start_talking |
|
if dur_vad < 0.1 and start_talking: |
|
print("pause detected") |
|
return True, start_talking |
|
return False, start_talking |
|
return False, start_talking |
|
|
|
|
|
def speaking(total_frames: bytes): |
|
|
|
audio_buffer = io.BytesIO() |
|
wf = wave.open(audio_buffer, "wb") |
|
wf.setnchannels(IN_CHANNELS) |
|
wf.setsampwidth(IN_SAMPLE_WIDTH) |
|
wf.setframerate(IN_RATE) |
|
|
|
dur = len(total_frames) / (IN_RATE * IN_CHANNELS * IN_SAMPLE_WIDTH) |
|
print(f"Speaking... recorded audio duration: {dur:.3f} s") |
|
|
|
wf.writeframes(total_frames) |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: |
|
with open(tmpfile.name, "wb") as f: |
|
f.write(audio_buffer.getvalue()) |
|
|
|
audio_bytes = audio_buffer.getvalue() |
|
|
|
base64_encoded = str(base64.b64encode(audio_bytes), encoding="utf-8") |
|
files = {"audio": base64_encoded} |
|
with requests.post(API_URL, json=files, stream=True) as response: |
|
try: |
|
for chunk in response.iter_content(chunk_size=OUT_CHUNK): |
|
if chunk: |
|
yield chunk |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
raise gr.Error(f"Error during audio streaming: {e}") |
|
|
|
|
|
wf.close() |
|
|
|
|
|
@dataclass |
|
class AppState: |
|
start_talking: bool = False |
|
stream: bytes = b"" |
|
pause_detected: bool = False |
|
|
|
|
|
|
|
def process_audio(audio: str, state: AppState): |
|
state.stream += Path(audio).read_bytes() |
|
|
|
pause_detected, start_talking = determine_pause(state.stream, state.pause_detected) |
|
state.pause_detected = pause_detected |
|
state.start_talking = start_talking |
|
|
|
if not state.pause_detected: |
|
yield None, state |
|
|
|
for out_bytes in speaking(state.stream): |
|
yield out_bytes, state |
|
|
|
state = AppState() |
|
yield None, state |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
input_audio = gr.Audio(label="Input Audio") |
|
with gr.Row(): |
|
output_audio = gr.Audio(label="Output Audio") |
|
state = gr.State(value=AppState()) |
|
|
|
input_audio.stream(process_audio, [input_audio, state], [output_audio, state], |
|
stream_every=0.5, time_limit=30) |
|
|
|
|
|
demo.launch() |
|
|