import gradio as gr import librosa import numpy as np import torch import torchaudio import torchaudio.transforms as T import matplotlib.pyplot as plt from PIL import Image import io import tempfile import logging from audioseal import AudioSeal import random from pathlib import Path logging.basicConfig(level=logging.DEBUG, filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def generate_random_binary_message(length=16): return ''.join([str(torch.randint(0, 2, (1,)).item()) for _ in range(length)]) def load_and_resample_audio(audio_file_path, target_sample_rate=16000): waveform, sample_rate = torchaudio.load(audio_file_path) if sample_rate != target_sample_rate: resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate) waveform = resampler(waveform) return waveform, target_sample_rate def plot_spectrogram_to_image(waveform, sample_rate, n_fft=400): spectrogram_transform = T.Spectrogram(n_fft=n_fft, power=2) spectrogram = spectrogram_transform(waveform) spectrogram_db = torchaudio.transforms.AmplitudeToDB()(spectrogram) plt.figure(figsize=(10, 4)) # Adjusted to remove extra dimensions plt.imshow(spectrogram_db.squeeze().detach().numpy(), cmap='hot', aspect='auto') plt.axis('off') buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) buf.seek(0) plt.close() return Image.open(buf) def plot_waveform_to_image(waveform, sample_rate): plt.figure(figsize=(10, 4)) if waveform.dim() == 3: waveform = waveform.squeeze(0) plt.plot(waveform.detach().numpy()[0], color='black') plt.axis('off') buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) buf.seek(0) plt.close() return Image.open(buf) def watermark_audio(audio_file_path): waveform, sample_rate = load_and_resample_audio(audio_file_path, 16000) waveform = torch.clamp(waveform, min=-1.0, max=1.0) if waveform.ndim == 1: waveform = waveform.unsqueeze(0) if waveform.ndim == 2: waveform = waveform.unsqueeze(0) original_waveform_image = plot_waveform_to_image(waveform, sample_rate) original_spec_image = plot_spectrogram_to_image(waveform, sample_rate) generator = AudioSeal.load_generator("audioseal_wm_16bits") message = generate_random_binary_message() message_tensor = torch.tensor([int(bit) for bit in message], dtype=torch.int32).unsqueeze(0) watermarked_audio = generator(waveform, message=message_tensor) temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') torchaudio.save(temp_file.name, watermarked_audio.squeeze(0), sample_rate) watermarked_waveform_image = plot_waveform_to_image(watermarked_audio, sample_rate) watermarked_spec_image = plot_spectrogram_to_image(watermarked_audio, sample_rate) return temp_file.name, message, original_waveform_image, original_spec_image, watermarked_waveform_image, watermarked_spec_image def detect_watermark(audio_file_path): sample_rate = 16000 # Set the sample rate directly inside the function waveform, _ = load_and_resample_audio(audio_file_path, sample_rate) detector = AudioSeal.load_detector("audioseal_detector_16bits") results, _ = detector.forward(waveform.unsqueeze(0), sample_rate=sample_rate) detect_probs = results[:, 1, :] result = detect_probs.mean().cpu().item() message = f"Detection result: {'Watermarked Audio' if result > 0.5 else 'Not watermarked'}" spectrogram_image = plot_spectrogram_to_image(waveform, sample_rate) return message, spectrogram_image style_path = Path("style.css") style = style_path.read_text() with gr.Blocks(css=style) as demo: with gr.Tab("Watermark Audio"): with gr.Column(scale=6): audio_input_watermark = gr.Audio(label="Upload Audio File for Watermarking", type="filepath") watermark_button = gr.Button("Apply Watermark") watermarked_audio_output = gr.Audio(label="Watermarked Audio") binary_message_output = gr.Textbox(label="Binary Message") original_waveform_output = gr.Image(label="Original Waveform") original_spectrogram_output = gr.Image(label="Original Spectrogram") watermarked_waveform_output = gr.Image(label="Watermarked Waveform") watermarked_spectrogram_output = gr.Image(label="Watermarked Spectrogram") watermark_button.click(fn=watermark_audio, inputs=audio_input_watermark, outputs=[watermarked_audio_output, binary_message_output, original_waveform_output, original_spectrogram_output, watermarked_waveform_output, watermarked_spectrogram_output]) with gr.Tab("Detect Watermark"): with gr.Column(scale=6): audio_input_detect_watermark = gr.Audio(label="Upload Audio File for Watermark Detection", type="filepath") detect_watermark_button = gr.Button("Detect Watermark") watermark_detection_output = gr.Textbox(label="Watermark Detection Result") spectrogram_image_output = gr.Image(label="Spectrogram") detect_watermark_button.click(fn=detect_watermark, inputs=audio_input_detect_watermark, outputs=[watermark_detection_output, spectrogram_image_output]) demo.launch()