import subprocess import threading import argparse import fcntl import select import whisper import ffmpeg import signal import numpy as np import queue import time import webrtcvad import collections import os from transformers import MarianMTModel, MarianTokenizer # Global variables rtmp_url = "" dash_output_path = "" segment_duration = 2 last_activity_time = 0.0 cleanup_threshold = 10 # seconds of inactivity before cleanup start_time = 0.0 # Languages for translation (ISO 639-1 codes) target_languages = ["es", "zh", "ru"] # Example: Spanish, Chinese, Russian # Initialize Whisper model whisper_model = {} # Define Frame class class Frame: def __init__(self, data, timestamp, duration): self.data = data self.timestamp = timestamp self.duration = duration # Audio buffer and caption queues audio_buffer = queue.Queue() caption_queues = {lang: queue.Queue() for lang in target_languages + ["original", "en"]} language_model_names = { "es": "Helsinki-NLP/opus-mt-en-es", "zh": "Helsinki-NLP/opus-mt-en-zh", "ru": "Helsinki-NLP/opus-mt-en-ru", } translation_models = {} tokenizers = {} # Initialize VAD vad = webrtcvad.Vad(3) # Aggressiveness mode 3 (most aggressive) # Event to signal threads to stop stop_event = threading.Event() def transcode_rtmp_to_dash(): ffmpeg_command = [ "/opt/homebrew/bin/ffmpeg", "-i", rtmp_url, "-map", "0:v:0", "-map", "0:a:0", "-c:v", "libx264", "-preset", "slow", "-c:a", "aac", "-b:a", "128k", "-f", "dash", "-seg_duration", str(segment_duration), "-use_timeline", "1", "-use_template", "1", "-init_seg_name", "init_$RepresentationID$.m4s", "-media_seg_name", "chunk_$RepresentationID$_$Number%05d$.m4s", "-adaptation_sets", "id=0,streams=v id=1,streams=a", f"{dash_output_path}/manifest.mpd" ] process = subprocess.Popen(ffmpeg_command) while not stop_event.is_set(): time.sleep(1) process.kill() def capture_audio(): global last_activity_time command = [ '/opt/homebrew/bin/ffmpeg', '-i', rtmp_url, '-acodec', 'pcm_s16le', '-ar', '16000', '-ac', '1', '-f', 's16le', '-' ] sample_rate = 16000 frame_duration_ms = 30 sample_width = 2 # Only 16-bit audio supported process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) # Set stdout to non-blocking mode fd = process.stdout.fileno() fl = fcntl.fcntl(fd, fcntl.F_GETFL) fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) frame_size = int(sample_rate * frame_duration_ms / 1000) * sample_width frame_count = 0 while not stop_event.is_set(): ready, _, _ = select.select([process.stdout], [], [], 0.1) if ready: try: in_bytes = os.read(fd, frame_size) if not in_bytes: break if len(in_bytes) < frame_size: in_bytes += b'\x00' * (frame_size - len(in_bytes)) last_activity_time = time.time() timestamp = frame_count * frame_duration_ms * 0.85 frame = Frame(np.frombuffer(in_bytes, np.int16), timestamp, frame_duration_ms) audio_buffer.put(frame) frame_count += 1 except BlockingIOError: continue else: time.sleep(0.01) process.kill() def frames_to_numpy(frames): all_frames = np.concatenate([f.data for f in frames]) float_samples = all_frames.astype(np.float32) / np.iinfo(np.int16).max return float_samples def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames): num_padding_frames = int(padding_duration_ms / frame_duration_ms) ring_buffer = collections.deque(maxlen=num_padding_frames) triggered = False for frame in frames: if len(frame.data) != int(sample_rate * (frame_duration_ms / 1000.0)): print(f"Skipping frame with incorrect size: {len(frame.data)} samples", flush=True) continue is_speech = vad.is_speech(frame.data.tobytes(), sample_rate) if not triggered: ring_buffer.append((frame, is_speech)) num_voiced = len([f for f, speech in ring_buffer if speech]) if num_voiced > 0.8 * ring_buffer.maxlen: triggered = True for f, s in ring_buffer: yield f ring_buffer.clear() else: yield frame ring_buffer.append((frame, is_speech)) num_unvoiced = len([f for f, speech in ring_buffer if not speech]) if num_unvoiced > 0.8 * ring_buffer.maxlen: triggered = False yield None ring_buffer.clear() for f, s in ring_buffer: yield f ring_buffer.clear() def process_audio(): global last_activity_time frames = [] buffer_duration_ms = 1500 # About 1.5 seconds of audio while not stop_event.is_set(): while not audio_buffer.empty(): frame = audio_buffer.get(timeout=5.0) frames.append(frame) if frames and sum(f.duration for f in frames) >= buffer_duration_ms: vad_frames = list(vad_collector(16000, 30, 300, vad, frames)) if vad_frames: audio_segment = [f for f in vad_frames if f is not None] if audio_segment: # Transcribe the original audio result = whisper_model.transcribe(frames_to_numpy(audio_segment)) if result["text"]: timestamp = audio_segment[0].timestamp caption_queues["original"].put((timestamp, result["text"])) english_translation = whisper_model.transcribe(frames_to_numpy(audio_segment), task="translate") caption_queues["en"].put((timestamp, english_translation["text"])) # Translate to target languages for lang in target_languages: tokenizer = tokenizers[lang] translation_model = translation_models[lang] inputs = tokenizer.encode(english_translation["text"], return_tensors="pt", padding=True, truncation=True) translated_tokens = translation_model.generate(inputs) translated_text = tokenizer.decode(translated_tokens[0], skip_special_tokens=True) caption_queues[lang].put((timestamp, translated_text)) frames = [] time.sleep(0.01) def write_captions(lang): os.makedirs(dash_output_path, exist_ok=True) filename = f"{dash_output_path}/captions_{lang}.vtt" with open(filename, "w", encoding="utf-8") as f: f.write("WEBVTT\n\n") last_end_time = None while not stop_event.is_set(): if not caption_queues[lang].empty(): timestamp, text = caption_queues[lang].get() start_time = format_time(timestamp / 1000) # Convert ms to seconds end_time = format_time((timestamp + 5000) / 1000) # Assume 5-second duration for each caption # Adjust the previous caption's end time if necessary if last_end_time and start_time != last_end_time: adjust_previous_caption(filename, last_end_time, start_time) # Write the new caption with open(filename, "a", encoding="utf-8") as f: f.write(f"{start_time} --> {end_time}\n") f.write(f"{text}\n\n") f.flush() last_end_time = end_time time.sleep(0.1) def adjust_previous_caption(filename, old_end_time, new_end_time): with open(filename, "r", encoding="utf-8") as f: lines = f.readlines() for i in range(len(lines) - 1, -1, -1): if "-->" in lines[i]: parts = lines[i].split("-->") if parts[1].strip() == old_end_time: lines[i] = f"{parts[0].strip()} --> {new_end_time}\n" break with open(filename, "w", encoding="utf-8") as f: f.writelines(lines) def format_time(seconds): hours, remainder = divmod(seconds, 3600) minutes, seconds = divmod(remainder, 60) return f"{int(hours):02d}:{int(minutes):02d}:{seconds:06.3f}" def signal_handler(signum, frame): print(f"Received signal {signum}. Cleaning up and exiting...") # Signal all threads to stop stop_event.set() def cleanup(): global last_activity_time while not stop_event.is_set(): current_time = time.time() if last_activity_time != 0.0 and current_time - last_activity_time > cleanup_threshold: print("No activity detected for 10 seconds. Cleaning up...", flush=True) # Signal all threads to stop stop_event.set() break time.sleep(1) # Check for inactivity every second # Clear caption queues for lang in target_languages + ["original", "en"]: while not caption_queues[lang].empty(): caption_queues[lang].get() # Delete DASH output files for root, dirs, files in os.walk(dash_output_path, topdown=False): for name in files: os.remove(os.path.join(root, name)) for name in dirs: os.rmdir(os.path.join(root, name)) print("Cleanup completed.", flush=True) if __name__ == "__main__": # Get RTMP URL and DASH output path from user input signal.signal(signal.SIGTERM, signal_handler) parser = argparse.ArgumentParser(description="Process audio for translation.") parser.add_argument('--rtmp_url', help='rtmp url') parser.add_argument('--output_directory', help='Dash directory') parser.add_argument('--model', help='Whisper model size: base|small|medium|large|large-v2') start_time = time.time() args = parser.parse_args() rtmp_url = args.rtmp_url dash_output_path = args.output_directory model_size = args.model print(f"RTMP URL: {rtmp_url}") print(f"DASH output path: {dash_output_path}") print(f"Model: {dash_output_path}") print("Downloading models\n") print("Whisper\n") whisper_model = whisper.load_model(model_size, download_root="/tmp/model/") # Adjust model size as necessary for lang, model_name in language_model_names.items(): print(f"Lang: {lang}, model: {model_name}\n") tokenizers[lang] = MarianTokenizer.from_pretrained(model_name) translation_models[lang] = MarianMTModel.from_pretrained(model_name) # Start RTMP to DASH transcoding in a separate thread transcode_thread = threading.Thread(target=transcode_rtmp_to_dash) transcode_thread.start() # Start audio capture in a separate thread audio_capture_thread = threading.Thread(target=capture_audio) audio_capture_thread.start() # Start audio processing in a separate thread audio_processing_thread = threading.Thread(target=process_audio) audio_processing_thread.start() # Start caption writing threads for original and all target languages caption_threads = [] for lang in target_languages + ["original", "en"]: caption_thread = threading.Thread(target=write_captions, args=(lang,)) caption_threads.append(caption_thread) caption_thread.start() # Start the cleanup thread cleanup_thread = threading.Thread(target=cleanup) cleanup_thread.start() # Wait for all threads to complete print("Join transcode", flush=True) if transcode_thread.is_alive(): transcode_thread.join() print("Join sudio capture", flush=True) if audio_capture_thread.is_alive(): audio_capture_thread.join() print("Join audio processing", flush=True) if audio_processing_thread.is_alive(): audio_processing_thread.join() for thread in caption_threads: if thread.is_alive(): thread.join() print("Join clenaup", flush=True) if cleanup_thread.is_alive(): cleanup_thread.join() print("All threads have been stopped and cleaned up.") exit(0)