Kabatubare's picture
Create app.py
c07da0e verified
raw
history blame
5.22 kB
import gradio as gr
import librosa
import numpy as np
import torch
import torchaudio
import torchaudio.transforms as T
import matplotlib.pyplot as plt
from PIL import Image
import io
import tempfile
import logging
from audioseal import AudioSeal
import random
from pathlib import Path
logging.basicConfig(level=logging.DEBUG, filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def generate_random_binary_message(length=16):
return ''.join([str(torch.randint(0, 2, (1,)).item()) for _ in range(length)])
def load_and_resample_audio(audio_file_path, target_sample_rate=16000):
waveform, sample_rate = torchaudio.load(audio_file_path)
if sample_rate != target_sample_rate:
resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
waveform = resampler(waveform)
return waveform, target_sample_rate
def plot_spectrogram_to_image(waveform, sample_rate, n_fft=400):
spectrogram_transform = T.Spectrogram(n_fft=n_fft, power=2)
spectrogram = spectrogram_transform(waveform)
spectrogram_db = torchaudio.transforms.AmplitudeToDB()(spectrogram)
plt.figure(figsize=(10, 4))
plt.imshow(spectrogram_db.detach().numpy(), cmap='hot', aspect='auto')
plt.axis('off')
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
plt.close()
return Image.open(buf)
def plot_waveform_to_image(waveform, sample_rate):
plt.figure(figsize=(10, 4))
plt.plot(waveform.detach().numpy()[0], color='black')
plt.axis('off')
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
plt.close()
return Image.open(buf)
def watermark_audio(audio_file_path):
waveform, sample_rate = load_and_resample_audio(audio_file_path, 16000)
waveform = torch.clamp(waveform, min=-1.0, max=1.0)
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0)
if waveform.ndim == 2:
waveform = waveform.unsqueeze(0)
original_waveform_image = plot_waveform_to_image(waveform, sample_rate)
original_spec_image = plot_spectrogram_to_image(waveform, sample_rate)
generator = AudioSeal.load_generator("audioseal_wm_16bits")
message = generate_random_binary_message()
message_tensor = torch.tensor([int(bit) for bit in message], dtype=torch.int32).unsqueeze(0)
watermarked_audio = generator(waveform, message=message_tensor)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
torchaudio.save(temp_file.name, watermarked_audio.squeeze(0), sample_rate)
watermarked_waveform_image = plot_waveform_to_image(watermarked_audio, sample_rate)
watermarked_spec_image = plot_spectrogram_to_image(watermarked_audio, sample_rate)
return temp_file.name, message, original_waveform_image, original_spec_image, watermarked_waveform_image, watermarked_spec_image
def detect_watermark(audio_file_path, sample_rate=16000):
waveform, sample_rate = load_and_resample_audio(audio_file_path, sample_rate)
detector = AudioSeal.load_detector("audioseal_detector_16bits")
results, messages = detector.forward(waveform.unsqueeze(0), sample_rate=sample_rate)
detect_probs = results[:, 1, :]
result = detect_probs.mean().cpu().item()
message = f"Detection result: {'Watermarked Audio' if result > 0.5 else 'Not watermarked'}"
spectrogram_image = plot_spectrogram_to_image(waveform, sample_rate)
return message, spectrogram_image
style_path = Path("style.css")
style = style_path.read_text()
with gr.Blocks(css=style) as demo:
with gr.Tab("Watermark Audio"):
with gr.Column(scale=6):
audio_input_watermark = gr.Audio(label="Upload Audio File for Watermarking", type="filepath")
watermark_button = gr.Button("Apply Watermark")
watermarked_audio_output = gr.Audio(label="Watermarked Audio")
binary_message_output = gr.Textbox(label="Binary Message")
original_waveform_output = gr.Image(label="Original Waveform")
original_spectrogram_output = gr.Image(label="Original Spectrogram")
watermarked_waveform_output = gr.Image(label="Watermarked Waveform")
watermarked_spectrogram_output = gr.Image(label="Watermarked Spectrogram")
watermark_button.click(fn=watermark_audio, inputs=audio_input_watermark, outputs=[watermarked_audio_output, binary_message_output, original_waveform_output, original_spectrogram_output, watermarked_waveform_output, watermarked_spectrogram_output])
with gr.Tab("Detect Watermark"):
with gr.Column(scale=6):
audio_input_detect_watermark = gr.Audio(label="Upload Audio File for Watermark Detection", type="filepath")
detect_watermark_button = gr.Button("Detect Watermark")
watermark_detection_output = gr.Textbox(label="Watermark Detection Result")
spectrogram_image_output = gr.Image(label="Spectrogram")
detect_watermark_button.click(fn=detect_watermark, inputs=[audio_input_detect_watermark, "16000"], outputs=[watermark_detection_output, spectrogram_image_output])
demo.launch()