File size: 3,637 Bytes
4bee283
 
 
 
 
9e6cbab
4bee283
 
9e6cbab
 
 
 
4bee283
 
 
 
 
 
9e6cbab
 
 
 
 
 
4bee283
9e6cbab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bee283
 
9e6cbab
 
 
 
 
 
 
 
 
 
 
 
 
4bee283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import gradio.inputs as grinputs
import gradio.outputs as groutputs

import numpy as np
import json

import torch
from torchvision import transforms

import utils 
import utils_img

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(0)
np.random.seed(0)

print('Building backbone and normalization layer...')
backbone = utils.build_backbone(path='dino_r50.pth')
normlayer = utils.load_normalization_layer(path='out2048.pth')
model = utils.NormLayerWrapper(backbone, normlayer)

print('Building the hypercone...')
FPR = 1e-6
angle = 1.462771101178447 # value for FPR=1e-6 and D=2048
rho = 1 + np.tan(angle)**2
# angle = utils.pvalue_angle(2048, 1, proba=FPR)
carrier = torch.randn(1, 2048)
carrier /= torch.norm(carrier, dim=1, keepdim=True)

default_transform = transforms.Compose([
        transforms.ToTensor(), 
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

def encode(image, epochs=10, psnr=44, lambda_w=1, lambda_i=1):
    img_orig = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
    img = img_orig.clone().to(device, non_blocking=True) 
    img.requires_grad = True
    optimizer = torch.optim.Adam([img], lr=1e-2)

    for iteration in range(epochs):
        x = utils_img.ssim_attenuation(img, img_orig)
        x = utils_img.psnr_clip(x, img_orig, psnr)

        ft = model(x) # BxCxWxH -> BxD

        dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
        norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
        cosines = torch.abs(dot_product/norm)
        log10_pvalue = np.log10(utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
        loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B

        loss_l2_img = torch.norm(x - img_orig)**2 # CxWxH -> 1
        loss = lambda_w*loss_R + lambda_i*loss_l2_img
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        logs = {
            "keyword": "img_optim",
            "iteration": iteration,
            "loss": loss.item(),
            "loss_R": loss_R.item(),
            "loss_l2_img": loss_l2_img.item(),
            "log10_pvalue": log10_pvalue.item(),
        }
        print("__log__:%s" % json.dumps(logs))

    img = utils_img.ssim_attenuation(img, img_orig)
    img = utils_img.psnr_clip(img, img_orig, psnr)
    img = utils_img.round_pixel(img)
    img = img.squeeze(0).detach().cpu()
    img = transforms.ToPILImage()(utils_img.unnormalize_img(img).squeeze(0))

    return img

def decode(image):
    img = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
    ft = model(img) # BxCxWxH -> BxD

    dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
    norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
    cosines = torch.abs(dot_product/norm)
    log10_pvalue = np.log10(utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
    loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B

    text_marked = "marked" if loss_R < 0 else "unmarked"
    return 'Image is {s}, with p-value={p}'.format(s=text_marked, p=10**log10_pvalue)



def on_submit(image, mode):
    print('{} mode'.format(mode))
    if mode=='Encode':
        return encode(image), 'Successfully encoded'
    else:
        return image, decode(image)

iface = gr.Interface(
    fn=on_submit, 
    inputs=[
        grinputs.Image(), 
        grinputs.Radio(['Encode', 'Decode'], label="Encode or Decode mode")], 
    outputs=[
        groutputs.Image(label='Watermarked image'), 
        groutputs.Textbox(label='Information')],
    allow_screenshot=False, 
    allow_flagging="auto",
    )
iface.launch()