File size: 5,217 Bytes
c07da0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import librosa
import numpy as np
import torch
import torchaudio
import torchaudio.transforms as T
import matplotlib.pyplot as plt
from PIL import Image
import io
import tempfile
import logging
from audioseal import AudioSeal
import random
from pathlib import Path

logging.basicConfig(level=logging.DEBUG, filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def generate_random_binary_message(length=16):
    return ''.join([str(torch.randint(0, 2, (1,)).item()) for _ in range(length)])

def load_and_resample_audio(audio_file_path, target_sample_rate=16000):
    waveform, sample_rate = torchaudio.load(audio_file_path)
    if sample_rate != target_sample_rate:
        resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)
    return waveform, target_sample_rate

def plot_spectrogram_to_image(waveform, sample_rate, n_fft=400):
    spectrogram_transform = T.Spectrogram(n_fft=n_fft, power=2)
    spectrogram = spectrogram_transform(waveform)
    spectrogram_db = torchaudio.transforms.AmplitudeToDB()(spectrogram)
    plt.figure(figsize=(10, 4))
    plt.imshow(spectrogram_db.detach().numpy(), cmap='hot', aspect='auto')
    plt.axis('off')
    buf = io.BytesIO()
    plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
    buf.seek(0)
    plt.close()
    return Image.open(buf)

def plot_waveform_to_image(waveform, sample_rate):
    plt.figure(figsize=(10, 4))
    plt.plot(waveform.detach().numpy()[0], color='black')
    plt.axis('off')
    buf = io.BytesIO()
    plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
    buf.seek(0)
    plt.close()
    return Image.open(buf)

def watermark_audio(audio_file_path):
    waveform, sample_rate = load_and_resample_audio(audio_file_path, 16000)
    waveform = torch.clamp(waveform, min=-1.0, max=1.0)
    if waveform.ndim == 1:
        waveform = waveform.unsqueeze(0)
    if waveform.ndim == 2:
        waveform = waveform.unsqueeze(0)

    original_waveform_image = plot_waveform_to_image(waveform, sample_rate)
    original_spec_image = plot_spectrogram_to_image(waveform, sample_rate)

    generator = AudioSeal.load_generator("audioseal_wm_16bits")
    message = generate_random_binary_message()
    message_tensor = torch.tensor([int(bit) for bit in message], dtype=torch.int32).unsqueeze(0)
    watermarked_audio = generator(waveform, message=message_tensor)

    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
    torchaudio.save(temp_file.name, watermarked_audio.squeeze(0), sample_rate)

    watermarked_waveform_image = plot_waveform_to_image(watermarked_audio, sample_rate)
    watermarked_spec_image = plot_spectrogram_to_image(watermarked_audio, sample_rate)

    return temp_file.name, message, original_waveform_image, original_spec_image, watermarked_waveform_image, watermarked_spec_image

def detect_watermark(audio_file_path, sample_rate=16000):
    waveform, sample_rate = load_and_resample_audio(audio_file_path, sample_rate)
    detector = AudioSeal.load_detector("audioseal_detector_16bits")
    results, messages = detector.forward(waveform.unsqueeze(0), sample_rate=sample_rate)
    detect_probs = results[:, 1, :]
    result = detect_probs.mean().cpu().item()
    message = f"Detection result: {'Watermarked Audio' if result > 0.5 else 'Not watermarked'}"
    spectrogram_image = plot_spectrogram_to_image(waveform, sample_rate)
    return message, spectrogram_image

style_path = Path("style.css")
style = style_path.read_text()

with gr.Blocks(css=style) as demo:
    with gr.Tab("Watermark Audio"):
        with gr.Column(scale=6):
            audio_input_watermark = gr.Audio(label="Upload Audio File for Watermarking", type="filepath")
            watermark_button = gr.Button("Apply Watermark")
            watermarked_audio_output = gr.Audio(label="Watermarked Audio")
            binary_message_output = gr.Textbox(label="Binary Message")
            original_waveform_output = gr.Image(label="Original Waveform")
            original_spectrogram_output = gr.Image(label="Original Spectrogram")
            watermarked_waveform_output = gr.Image(label="Watermarked Waveform")
            watermarked_spectrogram_output = gr.Image(label="Watermarked Spectrogram")
            watermark_button.click(fn=watermark_audio, inputs=audio_input_watermark, outputs=[watermarked_audio_output, binary_message_output, original_waveform_output, original_spectrogram_output, watermarked_waveform_output, watermarked_spectrogram_output])

    with gr.Tab("Detect Watermark"):
        with gr.Column(scale=6):
            audio_input_detect_watermark = gr.Audio(label="Upload Audio File for Watermark Detection", type="filepath")
            detect_watermark_button = gr.Button("Detect Watermark")
            watermark_detection_output = gr.Textbox(label="Watermark Detection Result")
            spectrogram_image_output = gr.Image(label="Spectrogram")
            detect_watermark_button.click(fn=detect_watermark, inputs=[audio_input_detect_watermark, "16000"], outputs=[watermark_detection_output, spectrogram_image_output])

demo.launch()