Spaces:
Runtime error
Runtime error
File size: 3,637 Bytes
4bee283 9e6cbab 4bee283 9e6cbab 4bee283 9e6cbab 4bee283 9e6cbab 4bee283 9e6cbab 4bee283 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gradio as gr
import gradio.inputs as grinputs
import gradio.outputs as groutputs
import numpy as np
import json
import torch
from torchvision import transforms
import utils
import utils_img
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
np.random.seed(0)
print('Building backbone and normalization layer...')
backbone = utils.build_backbone(path='dino_r50.pth')
normlayer = utils.load_normalization_layer(path='out2048.pth')
model = utils.NormLayerWrapper(backbone, normlayer)
print('Building the hypercone...')
FPR = 1e-6
angle = 1.462771101178447 # value for FPR=1e-6 and D=2048
rho = 1 + np.tan(angle)**2
# angle = utils.pvalue_angle(2048, 1, proba=FPR)
carrier = torch.randn(1, 2048)
carrier /= torch.norm(carrier, dim=1, keepdim=True)
default_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def encode(image, epochs=10, psnr=44, lambda_w=1, lambda_i=1):
img_orig = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
img = img_orig.clone().to(device, non_blocking=True)
img.requires_grad = True
optimizer = torch.optim.Adam([img], lr=1e-2)
for iteration in range(epochs):
x = utils_img.ssim_attenuation(img, img_orig)
x = utils_img.psnr_clip(x, img_orig, psnr)
ft = model(x) # BxCxWxH -> BxD
dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
cosines = torch.abs(dot_product/norm)
log10_pvalue = np.log10(utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B
loss_l2_img = torch.norm(x - img_orig)**2 # CxWxH -> 1
loss = lambda_w*loss_R + lambda_i*loss_l2_img
optimizer.zero_grad()
loss.backward()
optimizer.step()
logs = {
"keyword": "img_optim",
"iteration": iteration,
"loss": loss.item(),
"loss_R": loss_R.item(),
"loss_l2_img": loss_l2_img.item(),
"log10_pvalue": log10_pvalue.item(),
}
print("__log__:%s" % json.dumps(logs))
img = utils_img.ssim_attenuation(img, img_orig)
img = utils_img.psnr_clip(img, img_orig, psnr)
img = utils_img.round_pixel(img)
img = img.squeeze(0).detach().cpu()
img = transforms.ToPILImage()(utils_img.unnormalize_img(img).squeeze(0))
return img
def decode(image):
img = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
ft = model(img) # BxCxWxH -> BxD
dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
cosines = torch.abs(dot_product/norm)
log10_pvalue = np.log10(utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B
text_marked = "marked" if loss_R < 0 else "unmarked"
return 'Image is {s}, with p-value={p}'.format(s=text_marked, p=10**log10_pvalue)
def on_submit(image, mode):
print('{} mode'.format(mode))
if mode=='Encode':
return encode(image), 'Successfully encoded'
else:
return image, decode(image)
iface = gr.Interface(
fn=on_submit,
inputs=[
grinputs.Image(),
grinputs.Radio(['Encode', 'Decode'], label="Encode or Decode mode")],
outputs=[
groutputs.Image(label='Watermarked image'),
groutputs.Textbox(label='Information')],
allow_screenshot=False,
allow_flagging="auto",
)
iface.launch() |