ALeLacheur's picture
Update audio_diffusion_attacks_forhf/src/test_encoder_attack.py
3ad523c verified
import torch
import torchaudio
#Andy commented: torchaudio.set_audio_backend('soundfile')
#Andy commented: from audio_diffusion_pytorch import DiffusionAE, UNetV0, VDiffusion, VSampler
from audio_encoders_pytorch import MelE1d, TanhBottleneck
#Andy commented: from audiodiffusion.audio_encoder import AudioEncoder
#Andy commented: from IPython.display import Audio, display
#Andy commented: import matplotlib
#Andy commented: import matplotlib.pyplot as plt
import pandas as pd
#Andy commented: from archisound import ArchiSound
print(torch.cuda.is_available(), torch.cuda.device_count())
#Andy removed: import wandb
#Andy removed: wandb.init(project="audio_encoder_attack")
#Andy commented: from tqdm import tqdm
import auraloss
from transformers import EncodecModel, AutoProcessor
import cdpam
import audio_diffusion_attacks_forhf.src.losses as losses
#Andy edited: from audiotools import AudioSignal
#Andy edited step 2: from audiotools.audiotools.core.audio_signal.py import AudioSignal
from audiotools import AudioSignal
from audio_diffusion_attacks_forhf.src.balancer import Balancer
#Andy commented: from gradnorm_pytorch import (
#Andy commented: GradNormLossWeighter,
#Andy commented: MockNetworkWithMultipleLosses
#Andy commented: )
'''Andy commented:
from audiocraft.losses import (
MelSpectrogramL1Loss,
MultiScaleMelSpectrogramLoss,
MRSTFTLoss,
SISNR,
STFTLoss,
)
'''
from audio_diffusion_attacks_forhf.src.music_gen import MusicGenEval
from audio_diffusion_attacks_forhf.src.speech_inference import XTTS_Eval
# From https://pytorch.org/tutorials/beginner/audio_preprocessing_tutorial.html#loading-audio-data-into-tensor
def print_stats(waveform, sample_rate=None, src=None):
if src:
print("-" * 10)
print("Source:", src)
print("-" * 10)
if sample_rate:exit()
print("Sample Rate:", sample_rate)
print("Shape:", tuple(waveform.shape))
print("Dtype:", waveform.dtype)
print(f" - Max: {waveform.max().item():6.3f}")
print(f" - Min: {waveform.min().item():6.3f}")
print(f" - Mean: {waveform.mean().item():6.3f}")
print(f" - Std Dev: {waveform.std().item():6.3f}")
print()
print(waveform)
print()
def si_snr(estimate, reference, epsilon=1e-8):
estimate = estimate - estimate.mean()
reference = reference - reference.mean()
reference_pow = reference.pow(2).mean(axis=1, keepdim=True)
mix_pow = (estimate * reference).mean(axis=1, keepdim=True)
scale = mix_pow / (reference_pow + epsilon)
reference = scale * reference
error = estimate - reference
reference_pow = reference.pow(2)
error_pow = error.pow(2)
reference_pow = reference_pow.mean(axis=1)
error_pow = error_pow.mean(axis=1)
si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
return si_snr.item()
# Train autoencoder with audio samples
#waveform = torch.randn(2, 2**10) # [batch, in_channels, length]
# loss.backward()
#andy edited: def poison_audio(audio_folder, encoders, audio_difference_weights=[1], method='encoder', weight=1, modality="music"):
def poison_audio(waveform, sample_rate, encoders, audio_difference_weights=[1], method='encoder', weight=1, modality="music"):
'''
Protect a folder of audio.
audio_folder: string, path to folder of audio files. Protected audio files will be saved in that folder.
encoders: encoders to protect against. See initialization at end of file.
'''
for encoder in encoders:
#Andy removed: encoder.to(device='cuda')
encoder.eval()
for p in encoder.parameters():
p.requires_grad = False
audio_len=1000000
#Andy removed: waveform, sample_rate = torchaudio.load(f"test_audio/Texas Sun.mp3")
if modality=="music":
music_gen_eval=MusicGenEval(sample_rate, audio_len)
elif modality=="speech":
music_gen_eval=XTTS_Eval(sample_rate)
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
#Andy edited: loss_fn = cdpam.CDPAM(dev='cuda:0')
my_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_fn = cdpam.CDPAM(dev=my_device)
for p in loss_fn.model.parameters():
p.requires_grad = False
#Andy removed: for audio_file in tqdm(os.listdir(audio_folder)):
for diff_weight in audio_difference_weights:
#Andy edited: waveform, sample_rate = torchaudio.load(os.path.join(audio_folder, audio_file))
# convert mono to stereo
if waveform.shape[0]==1:
stereo_waveform=torch.zeros((2, waveform.shape[1]))
stereo_waveform[:,:]=waveform
waveform=stereo_waveform
waveform=waveform[:, :audio_len]
inputs = processor(raw_audio=waveform, sampling_rate=processor.sampling_rate, return_tensors="pt")
waveform=inputs['input_values'][0]
#Andy removed: wandb.log({f"unperturbed {audio_name}": wandb.Audio(waveform[0].detach().numpy().flatten(), sample_rate=sample_rate)}, step=0)
waveform=torch.reshape(waveform, (1, waveform.shape[0], waveform.shape[1]))
#Andy removed: waveform=waveform.to(device='cuda')
#Andy edited: inputs["padding_mask"]=inputs["padding_mask"].to(device='cuda')
inputs["padding_mask"]=inputs["padding_mask"]
if method=="encoder":
unperturbed_waveform=waveform.clone().detach()
unperturbed_latents=[]
for encoder in encoders:
unperturbed_latent=encoder(waveform, inputs["padding_mask"]).audio_values.detach()
unperturbed_latents.append(unperturbed_latent)
if method=="style_transfer":
style_waveform, style_sample_rate = torchaudio.load(f"test_audio/Il Sogno Del Marinaio - Nanos' Waltz.mp3")
style_waveform=style_waveform[:, :audio_len]
style_inputs = processor(raw_audio=style_waveform, sampling_rate=processor.sampling_rate, return_tensors="pt")
style_waveform=style_inputs['input_values'][0]
#Andy removed: wandb.log({f"transfer style": wandb.Audio(style_waveform[0].detach().numpy().flatten(), sample_rate=sample_rate)}, step=0)
style_waveform=torch.reshape(style_waveform, (1, style_waveform.shape[0], style_waveform.shape[1]))
#Andy edited: style_waveform=style_waveform.to(device='cuda')
style_waveform=style_waveform
#Andy edited: style_inputs["padding_mask"]=style_inputs["padding_mask"].to(device='cuda')
style_inputs["padding_mask"]=style_inputs["padding_mask"]
# unperturbed_latent=encoder(waveform, inputs["padding_mask"]).audio_values.detach()
unperturbed_waveform=style_waveform.clone().detach()
unperturbed_latents=[]
for encoder in encoders:
unperturbed_latent=encoder(style_waveform, style_inputs["padding_mask"]).audio_values.detach()
unperturbed_latents.append(unperturbed_latent)
noise=torch.normal(torch.zeros(waveform.shape), 0.0)
#Andy removed: noise=noise.to(device='cuda')
noise.requires_grad=True
# waveform=torch.nn.parameter.Parameter(waveform)
weights = {'waveform_diff': weight, 'latent_diff': 1}
balancer = Balancer(weights)
l1loss = torch.nn.L1Loss()
# for p in mel_loss.parameters():
# p.requires_grad = False
optim = torch.optim.AdamW([noise], lr=0.002, weight_decay=0.005)
#optim_diff = torch.optim.Adam([waveform], lr=0.02)
# loss_weighter = GradNormLossWeighter(
# num_losses = 2,
# learning_rate = 0.00002,
# restoring_force_alpha = 0., # 0. is perfectly balanced losses, while anything greater than 1 would account for the relative training rates of each loss. in the paper, they go as high as 3.
# grad_norm_parameters = waveform
# )
downsample = torchaudio.transforms.Resample(sample_rate, 22050)
#Andy removed: downsample=downsample.to(device='cuda')
cos = torch.nn.CosineSimilarity()
mrstft = auraloss.perceptual.FIRFilter()#auraloss.time.SISDRLoss()#torch.nn.functional.l1_loss
#Andy removed: mrstft.to(device='cuda')
waveform_loss = losses.L1Loss()
stft_loss = losses.MultiScaleSTFTLoss()
mel_loss = losses.MelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320],
window_lengths=[32, 64, 128, 256, 512, 1024, 2048],
mel_fmin=[0, 0, 0, 0, 0, 0, 0],
pow=1.0,
clamp_eps=1.0e-5,
mag_weight=0.0)
past_10_latent_losses=[]
latent_weight=1000
latent_diff=0
#Andy edited for testing purposes: number_steps=500
number_steps=5
if diff_weight>-1:
for step in range(number_steps):
latent_diff=0
perturned_waveform=noise+waveform
for encoder_ind in range(len(encoders)):
perturbed_latent = encoders[encoder_ind](perturned_waveform, inputs["padding_mask"]).audio_values
latent_diff+=cos(torch.reshape(perturbed_latent, (1,-1)), torch.reshape(unperturbed_latents[encoder_ind], (1, -1)))
#latent_diff+=1-torch.mean(torch.abs((torch.reshape(perturbed_latent, (1,-1))-torch.reshape(unperturbed_latents[encoder_ind], (1, -1)))))
#latent_diff=-l1loss(perturbed_latent,unperturbed_latents[0])
latent_diff=latent_diff/len(encoders)
#waveform_diff=mrstft(waveform, unperturbed_waveform)
#waveform_diff=mrstft(torch.reshape(waveform, (1,-1)), torch.reshape(unperturbed_waveform, (1,-1)))
#waveform_diff=si_snr(waveform, unperturbed_waveform)
a_waveform=AudioSignal(perturned_waveform, sample_rate)
a_uwaveform=AudioSignal(unperturbed_waveform, sample_rate)
c_waveform_loss=waveform_loss(a_waveform, a_uwaveform)*100
c_stft_loss=stft_loss(a_waveform, a_uwaveform)/6.0
c_mel_loss=mel_loss(a_waveform, a_uwaveform)
l1_loss=torch.mean(torch.abs(perturned_waveform-unperturbed_waveform))
waveform_diff=c_mel_loss#(c_waveform_loss+c_stft_loss+c_mel_loss)/3.0
#loss=100*latent_diff+waveform_diff
# loss=latent_weight*latent_diff+waveform_diff
# past_10_latent_losses.append(latent_diff.detach().cpu().numpy().item())
# if len(past_10_latent_losses)>10:
# mean=sum(past_10_latent_losses)/len(past_10_latent_losses)
# if mean<latent_diff:
# latent_weight=latent_weight*1.1
# elif mean>latent_diff*1.01:
# latent_weight=latent_weight/1.1
# past_10_latent_losses=past_10_latent_losses[1:]
# print('latent_weight', latent_weight)
# if latent_diff>0.85:
# loss=1500*latent_diff+waveform_diff
#loss=1000*latent_diff+waveform_diff
# if method=='encoder':
# if latent_diff>0.75:
# loss=1000*latent_diff+waveform_diff
# print('latent')
# else:
loss=waveform_diff+latent_diff
# print('waveform_diff')
# elif method=='style_transfer':
# loss=latent_diff
'''Andy removed:
if step%10==0 or step==number_steps-1:
wandb.log({"loss": loss, "latent_diff": latent_diff, 'waveform_diff': waveform_diff}, step=step)
if step%100==0 or step==number_steps-1:
audio_save=torch.reshape((noise+waveform), (2, waveform.shape[2]))[0, :audio_len].detach().cpu().numpy().flatten()
wandb.log({f"perturbed cos_dist_{latent_diff}_diff_weight_{diff_weight}_{audio_name}": wandb.Audio(audio_save, sample_rate=sample_rate)}, step=step)
if step%100==0 or step==number_steps-1:
music_gen_eval_dict, unprotected_gen, protected_gen=music_gen_eval.eval(waveform, noise+waveform)
audio_save=torch.reshape(unprotected_gen, (2, unprotected_gen.shape[1]))[0].detach().cpu().numpy().flatten()
wandb.log({f"unprotected_gen_{latent_diff}_diff_weight_{diff_weight}": wandb.Audio(audio_save, sample_rate=sample_rate)}, step=step)
audio_save=torch.reshape(protected_gen, (2, protected_gen.shape[1]))[0].detach().cpu().numpy().flatten()
wandb.log({f"protected_gen_{latent_diff}_diff_weight_{diff_weight}": wandb.Audio(audio_save, sample_rate=sample_rate)}, step=step)
wandb.log(music_gen_eval_dict, step=step)
'''
# if c_mel_loss>0.5:
# loss=waveform_diff
# else:
# loss=latent_diff
# noise=noise*0.99
loss_dict = {}
loss_dict['waveform_diff'] = waveform_diff
loss_dict['latent_diff'] = latent_diff[0]
effective_loss = balancer.backward(loss_dict, noise)
# loss=latent_diff
# loss.backward()
#loss_weighter.backward([latent_diff, c_mel_loss])
# torch.nn.utils.clip_grad_norm_(waveform, 10e-8)
optim.step()
optim.zero_grad()
# if latent_diff>0.5:
# latent_diff.backward()
# optim_diff.step()
# optim_diff.zero_grad()
# else:
# loss=waveform_diff
# loss.backward()
# optim_diff.step()
# optim_diff.zero_grad()
encoder.zero_grad()
mel_loss.zero_grad()
# with torch.no_grad():
# noise_clip=0.25
# noise.clamp_(-noise_clip, noise_clip)
# print('noise max', torch.max(noise))
print('step', step, 'loss', loss.item(), 'latent loss', latent_diff.item(), 'audio loss', waveform_diff.item(), 'c_waveform_loss', c_waveform_loss.item(), 'c_stft_loss', c_stft_loss.item(), 'l1_loss', l1_loss.item())
latent_diff=latent_diff.detach().item()
#Andy removed: audio_save=torch.reshape((noise+waveform), (2, waveform.shape[2]))[0, :audio_len].detach().cpu().numpy().flatten()
#Andy removed: wandb.log({f"perturbed cos_dist_{latent_diff}_diff_weight_{diff_weight}_{audio_name}": wandb.Audio(audio_save, sample_rate=sample_rate)}, step=step)
#Andy moved from inside the loop:
music_gen_eval_dict, unprotected_gen, protected_gen=music_gen_eval.eval(waveform, noise+waveform)
#Andy edited: torchaudio.save(os.path.join(audio_folder, f"protected_{audio_name}_{audio_len}_mel_{latent_diff}_diff_weight_{waveform_diff}"), torch.reshape((noise+waveform).detach().cpu(), (2, waveform.shape[2])), sample_rate)
return (torch.reshape((noise+waveform).detach().cpu(), (2, waveform.shape[2]))), music_gen_eval_dict, unprotected_gen, protected_gen
# encoders = [ArchiSound.from_pretrained('autoencoder1d-AT-v1'),
# ArchiSound.from_pretrained("dmae1d-ATC64-v2"),
# ArchiSound.from_pretrained("dmae1d-ATC32-v3"),
# AudioEncoder.from_pretrained("teticio/audio-encoder"),
encoders = [EncodecModel.from_pretrained("facebook/encodec_48khz")]
audio_difference_weights=[1]
#Andy commented out: poison_audio(<audio_folder>, encoders, [1], method="encoder", weight=weight)
#Andy removed: wandb.finish()