Spaces:
Sleeping
Sleeping
import torch | |
import torchaudio | |
#Andy commented: torchaudio.set_audio_backend('soundfile') | |
#Andy commented: from audio_diffusion_pytorch import DiffusionAE, UNetV0, VDiffusion, VSampler | |
from audio_encoders_pytorch import MelE1d, TanhBottleneck | |
#Andy commented: from audiodiffusion.audio_encoder import AudioEncoder | |
#Andy commented: from IPython.display import Audio, display | |
#Andy commented: import matplotlib | |
#Andy commented: import matplotlib.pyplot as plt | |
import pandas as pd | |
#Andy commented: from archisound import ArchiSound | |
print(torch.cuda.is_available(), torch.cuda.device_count()) | |
#Andy removed: import wandb | |
#Andy removed: wandb.init(project="audio_encoder_attack") | |
#Andy commented: from tqdm import tqdm | |
import auraloss | |
from transformers import EncodecModel, AutoProcessor | |
import cdpam | |
import audio_diffusion_attacks_forhf.src.losses as losses | |
#Andy edited: from audiotools import AudioSignal | |
#Andy edited step 2: from audiotools.audiotools.core.audio_signal.py import AudioSignal | |
from audiotools import AudioSignal | |
from audio_diffusion_attacks_forhf.src.balancer import Balancer | |
#Andy commented: from gradnorm_pytorch import ( | |
#Andy commented: GradNormLossWeighter, | |
#Andy commented: MockNetworkWithMultipleLosses | |
#Andy commented: ) | |
'''Andy commented: | |
from audiocraft.losses import ( | |
MelSpectrogramL1Loss, | |
MultiScaleMelSpectrogramLoss, | |
MRSTFTLoss, | |
SISNR, | |
STFTLoss, | |
) | |
''' | |
from audio_diffusion_attacks_forhf.src.music_gen import MusicGenEval | |
from audio_diffusion_attacks_forhf.src.speech_inference import XTTS_Eval | |
# From https://pytorch.org/tutorials/beginner/audio_preprocessing_tutorial.html#loading-audio-data-into-tensor | |
def print_stats(waveform, sample_rate=None, src=None): | |
if src: | |
print("-" * 10) | |
print("Source:", src) | |
print("-" * 10) | |
if sample_rate:exit() | |
print("Sample Rate:", sample_rate) | |
print("Shape:", tuple(waveform.shape)) | |
print("Dtype:", waveform.dtype) | |
print(f" - Max: {waveform.max().item():6.3f}") | |
print(f" - Min: {waveform.min().item():6.3f}") | |
print(f" - Mean: {waveform.mean().item():6.3f}") | |
print(f" - Std Dev: {waveform.std().item():6.3f}") | |
print() | |
print(waveform) | |
print() | |
def si_snr(estimate, reference, epsilon=1e-8): | |
estimate = estimate - estimate.mean() | |
reference = reference - reference.mean() | |
reference_pow = reference.pow(2).mean(axis=1, keepdim=True) | |
mix_pow = (estimate * reference).mean(axis=1, keepdim=True) | |
scale = mix_pow / (reference_pow + epsilon) | |
reference = scale * reference | |
error = estimate - reference | |
reference_pow = reference.pow(2) | |
error_pow = error.pow(2) | |
reference_pow = reference_pow.mean(axis=1) | |
error_pow = error_pow.mean(axis=1) | |
si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow) | |
return si_snr.item() | |
# Train autoencoder with audio samples | |
#waveform = torch.randn(2, 2**10) # [batch, in_channels, length] | |
# loss.backward() | |
#andy edited: def poison_audio(audio_folder, encoders, audio_difference_weights=[1], method='encoder', weight=1, modality="music"): | |
def poison_audio(waveform, sample_rate, encoders, audio_difference_weights=[1], method='encoder', weight=1, modality="music"): | |
''' | |
Protect a folder of audio. | |
audio_folder: string, path to folder of audio files. Protected audio files will be saved in that folder. | |
encoders: encoders to protect against. See initialization at end of file. | |
''' | |
for encoder in encoders: | |
#Andy removed: encoder.to(device='cuda') | |
encoder.eval() | |
for p in encoder.parameters(): | |
p.requires_grad = False | |
audio_len=1000000 | |
#Andy removed: waveform, sample_rate = torchaudio.load(f"test_audio/Texas Sun.mp3") | |
if modality=="music": | |
music_gen_eval=MusicGenEval(sample_rate, audio_len) | |
elif modality=="speech": | |
music_gen_eval=XTTS_Eval(sample_rate) | |
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz") | |
#Andy edited: loss_fn = cdpam.CDPAM(dev='cuda:0') | |
my_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
loss_fn = cdpam.CDPAM(dev=my_device) | |
for p in loss_fn.model.parameters(): | |
p.requires_grad = False | |
#Andy removed: for audio_file in tqdm(os.listdir(audio_folder)): | |
for diff_weight in audio_difference_weights: | |
#Andy edited: waveform, sample_rate = torchaudio.load(os.path.join(audio_folder, audio_file)) | |
# convert mono to stereo | |
if waveform.shape[0]==1: | |
stereo_waveform=torch.zeros((2, waveform.shape[1])) | |
stereo_waveform[:,:]=waveform | |
waveform=stereo_waveform | |
waveform=waveform[:, :audio_len] | |
inputs = processor(raw_audio=waveform, sampling_rate=processor.sampling_rate, return_tensors="pt") | |
waveform=inputs['input_values'][0] | |
#Andy removed: wandb.log({f"unperturbed {audio_name}": wandb.Audio(waveform[0].detach().numpy().flatten(), sample_rate=sample_rate)}, step=0) | |
waveform=torch.reshape(waveform, (1, waveform.shape[0], waveform.shape[1])) | |
#Andy removed: waveform=waveform.to(device='cuda') | |
#Andy edited: inputs["padding_mask"]=inputs["padding_mask"].to(device='cuda') | |
inputs["padding_mask"]=inputs["padding_mask"] | |
if method=="encoder": | |
unperturbed_waveform=waveform.clone().detach() | |
unperturbed_latents=[] | |
for encoder in encoders: | |
unperturbed_latent=encoder(waveform, inputs["padding_mask"]).audio_values.detach() | |
unperturbed_latents.append(unperturbed_latent) | |
if method=="style_transfer": | |
style_waveform, style_sample_rate = torchaudio.load(f"test_audio/Il Sogno Del Marinaio - Nanos' Waltz.mp3") | |
style_waveform=style_waveform[:, :audio_len] | |
style_inputs = processor(raw_audio=style_waveform, sampling_rate=processor.sampling_rate, return_tensors="pt") | |
style_waveform=style_inputs['input_values'][0] | |
#Andy removed: wandb.log({f"transfer style": wandb.Audio(style_waveform[0].detach().numpy().flatten(), sample_rate=sample_rate)}, step=0) | |
style_waveform=torch.reshape(style_waveform, (1, style_waveform.shape[0], style_waveform.shape[1])) | |
#Andy edited: style_waveform=style_waveform.to(device='cuda') | |
style_waveform=style_waveform | |
#Andy edited: style_inputs["padding_mask"]=style_inputs["padding_mask"].to(device='cuda') | |
style_inputs["padding_mask"]=style_inputs["padding_mask"] | |
# unperturbed_latent=encoder(waveform, inputs["padding_mask"]).audio_values.detach() | |
unperturbed_waveform=style_waveform.clone().detach() | |
unperturbed_latents=[] | |
for encoder in encoders: | |
unperturbed_latent=encoder(style_waveform, style_inputs["padding_mask"]).audio_values.detach() | |
unperturbed_latents.append(unperturbed_latent) | |
noise=torch.normal(torch.zeros(waveform.shape), 0.0) | |
#Andy removed: noise=noise.to(device='cuda') | |
noise.requires_grad=True | |
# waveform=torch.nn.parameter.Parameter(waveform) | |
weights = {'waveform_diff': weight, 'latent_diff': 1} | |
balancer = Balancer(weights) | |
l1loss = torch.nn.L1Loss() | |
# for p in mel_loss.parameters(): | |
# p.requires_grad = False | |
optim = torch.optim.AdamW([noise], lr=0.002, weight_decay=0.005) | |
#optim_diff = torch.optim.Adam([waveform], lr=0.02) | |
# loss_weighter = GradNormLossWeighter( | |
# num_losses = 2, | |
# learning_rate = 0.00002, | |
# restoring_force_alpha = 0., # 0. is perfectly balanced losses, while anything greater than 1 would account for the relative training rates of each loss. in the paper, they go as high as 3. | |
# grad_norm_parameters = waveform | |
# ) | |
downsample = torchaudio.transforms.Resample(sample_rate, 22050) | |
#Andy removed: downsample=downsample.to(device='cuda') | |
cos = torch.nn.CosineSimilarity() | |
mrstft = auraloss.perceptual.FIRFilter()#auraloss.time.SISDRLoss()#torch.nn.functional.l1_loss | |
#Andy removed: mrstft.to(device='cuda') | |
waveform_loss = losses.L1Loss() | |
stft_loss = losses.MultiScaleSTFTLoss() | |
mel_loss = losses.MelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320], | |
window_lengths=[32, 64, 128, 256, 512, 1024, 2048], | |
mel_fmin=[0, 0, 0, 0, 0, 0, 0], | |
pow=1.0, | |
clamp_eps=1.0e-5, | |
mag_weight=0.0) | |
past_10_latent_losses=[] | |
latent_weight=1000 | |
latent_diff=0 | |
#Andy edited for testing purposes: number_steps=500 | |
number_steps=5 | |
if diff_weight>-1: | |
for step in range(number_steps): | |
latent_diff=0 | |
perturned_waveform=noise+waveform | |
for encoder_ind in range(len(encoders)): | |
perturbed_latent = encoders[encoder_ind](perturned_waveform, inputs["padding_mask"]).audio_values | |
latent_diff+=cos(torch.reshape(perturbed_latent, (1,-1)), torch.reshape(unperturbed_latents[encoder_ind], (1, -1))) | |
#latent_diff+=1-torch.mean(torch.abs((torch.reshape(perturbed_latent, (1,-1))-torch.reshape(unperturbed_latents[encoder_ind], (1, -1))))) | |
#latent_diff=-l1loss(perturbed_latent,unperturbed_latents[0]) | |
latent_diff=latent_diff/len(encoders) | |
#waveform_diff=mrstft(waveform, unperturbed_waveform) | |
#waveform_diff=mrstft(torch.reshape(waveform, (1,-1)), torch.reshape(unperturbed_waveform, (1,-1))) | |
#waveform_diff=si_snr(waveform, unperturbed_waveform) | |
a_waveform=AudioSignal(perturned_waveform, sample_rate) | |
a_uwaveform=AudioSignal(unperturbed_waveform, sample_rate) | |
c_waveform_loss=waveform_loss(a_waveform, a_uwaveform)*100 | |
c_stft_loss=stft_loss(a_waveform, a_uwaveform)/6.0 | |
c_mel_loss=mel_loss(a_waveform, a_uwaveform) | |
l1_loss=torch.mean(torch.abs(perturned_waveform-unperturbed_waveform)) | |
waveform_diff=c_mel_loss#(c_waveform_loss+c_stft_loss+c_mel_loss)/3.0 | |
#loss=100*latent_diff+waveform_diff | |
# loss=latent_weight*latent_diff+waveform_diff | |
# past_10_latent_losses.append(latent_diff.detach().cpu().numpy().item()) | |
# if len(past_10_latent_losses)>10: | |
# mean=sum(past_10_latent_losses)/len(past_10_latent_losses) | |
# if mean<latent_diff: | |
# latent_weight=latent_weight*1.1 | |
# elif mean>latent_diff*1.01: | |
# latent_weight=latent_weight/1.1 | |
# past_10_latent_losses=past_10_latent_losses[1:] | |
# print('latent_weight', latent_weight) | |
# if latent_diff>0.85: | |
# loss=1500*latent_diff+waveform_diff | |
#loss=1000*latent_diff+waveform_diff | |
# if method=='encoder': | |
# if latent_diff>0.75: | |
# loss=1000*latent_diff+waveform_diff | |
# print('latent') | |
# else: | |
loss=waveform_diff+latent_diff | |
# print('waveform_diff') | |
# elif method=='style_transfer': | |
# loss=latent_diff | |
'''Andy removed: | |
if step%10==0 or step==number_steps-1: | |
wandb.log({"loss": loss, "latent_diff": latent_diff, 'waveform_diff': waveform_diff}, step=step) | |
if step%100==0 or step==number_steps-1: | |
audio_save=torch.reshape((noise+waveform), (2, waveform.shape[2]))[0, :audio_len].detach().cpu().numpy().flatten() | |
wandb.log({f"perturbed cos_dist_{latent_diff}_diff_weight_{diff_weight}_{audio_name}": wandb.Audio(audio_save, sample_rate=sample_rate)}, step=step) | |
if step%100==0 or step==number_steps-1: | |
music_gen_eval_dict, unprotected_gen, protected_gen=music_gen_eval.eval(waveform, noise+waveform) | |
audio_save=torch.reshape(unprotected_gen, (2, unprotected_gen.shape[1]))[0].detach().cpu().numpy().flatten() | |
wandb.log({f"unprotected_gen_{latent_diff}_diff_weight_{diff_weight}": wandb.Audio(audio_save, sample_rate=sample_rate)}, step=step) | |
audio_save=torch.reshape(protected_gen, (2, protected_gen.shape[1]))[0].detach().cpu().numpy().flatten() | |
wandb.log({f"protected_gen_{latent_diff}_diff_weight_{diff_weight}": wandb.Audio(audio_save, sample_rate=sample_rate)}, step=step) | |
wandb.log(music_gen_eval_dict, step=step) | |
''' | |
# if c_mel_loss>0.5: | |
# loss=waveform_diff | |
# else: | |
# loss=latent_diff | |
# noise=noise*0.99 | |
loss_dict = {} | |
loss_dict['waveform_diff'] = waveform_diff | |
loss_dict['latent_diff'] = latent_diff[0] | |
effective_loss = balancer.backward(loss_dict, noise) | |
# loss=latent_diff | |
# loss.backward() | |
#loss_weighter.backward([latent_diff, c_mel_loss]) | |
# torch.nn.utils.clip_grad_norm_(waveform, 10e-8) | |
optim.step() | |
optim.zero_grad() | |
# if latent_diff>0.5: | |
# latent_diff.backward() | |
# optim_diff.step() | |
# optim_diff.zero_grad() | |
# else: | |
# loss=waveform_diff | |
# loss.backward() | |
# optim_diff.step() | |
# optim_diff.zero_grad() | |
encoder.zero_grad() | |
mel_loss.zero_grad() | |
# with torch.no_grad(): | |
# noise_clip=0.25 | |
# noise.clamp_(-noise_clip, noise_clip) | |
# print('noise max', torch.max(noise)) | |
print('step', step, 'loss', loss.item(), 'latent loss', latent_diff.item(), 'audio loss', waveform_diff.item(), 'c_waveform_loss', c_waveform_loss.item(), 'c_stft_loss', c_stft_loss.item(), 'l1_loss', l1_loss.item()) | |
latent_diff=latent_diff.detach().item() | |
#Andy removed: audio_save=torch.reshape((noise+waveform), (2, waveform.shape[2]))[0, :audio_len].detach().cpu().numpy().flatten() | |
#Andy removed: wandb.log({f"perturbed cos_dist_{latent_diff}_diff_weight_{diff_weight}_{audio_name}": wandb.Audio(audio_save, sample_rate=sample_rate)}, step=step) | |
#Andy moved from inside the loop: | |
music_gen_eval_dict, unprotected_gen, protected_gen=music_gen_eval.eval(waveform, noise+waveform) | |
#Andy edited: torchaudio.save(os.path.join(audio_folder, f"protected_{audio_name}_{audio_len}_mel_{latent_diff}_diff_weight_{waveform_diff}"), torch.reshape((noise+waveform).detach().cpu(), (2, waveform.shape[2])), sample_rate) | |
return (torch.reshape((noise+waveform).detach().cpu(), (2, waveform.shape[2]))), music_gen_eval_dict, unprotected_gen, protected_gen | |
# encoders = [ArchiSound.from_pretrained('autoencoder1d-AT-v1'), | |
# ArchiSound.from_pretrained("dmae1d-ATC64-v2"), | |
# ArchiSound.from_pretrained("dmae1d-ATC32-v3"), | |
# AudioEncoder.from_pretrained("teticio/audio-encoder"), | |
encoders = [EncodecModel.from_pretrained("facebook/encodec_48khz")] | |
audio_difference_weights=[1] | |
#Andy commented out: poison_audio(<audio_folder>, encoders, [1], method="encoder", weight=weight) | |
#Andy removed: wandb.finish() | |