Pierre Fernandez commited on
Commit
4bee283
β€’
1 Parent(s): a55c404

Create first draft

Browse files
Files changed (4) hide show
  1. README.md +3 -3
  2. app.py +99 -0
  3. dino_r50.pth +3 -0
  4. out2048.pth +3 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Ssl_watermarking
3
- emoji: πŸ“‰
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
@@ -8,4 +8,4 @@ app_file: app.py
8
  pinned: false
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
1
  ---
2
+ title: Watermarking in SSL latent spaces
3
+ emoji: :lock:
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
 
8
  pinned: false
9
  ---
10
 
11
+ Watermark an image using *Watermarking Images in Self-Supervised Latent Spaces*.
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import gradio.inputs as grinputs
3
+ import gradio.outputs as groutputs
4
+
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torchvision import models
10
+
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ torch.manual_seed(0)
14
+ np.random.seed(0)
15
+
16
+ FPR = 1e-6
17
+ carrier = np.random.randn(size=(1, 2048))
18
+
19
+
20
+ def build_backbone(path, name='resnet50'):
21
+ """ Builds a pretrained ResNet-50 backbone. """
22
+ model = getattr(models, name)(pretrained=True)
23
+ model.head = nn.Identity()
24
+ model.fc = nn.Identity()
25
+ checkpoint = torch.load(path, map_location=device)
26
+ state_dict = checkpoint
27
+ for ckpt_key in ['state_dict', 'model_state_dict', 'teacher']:
28
+ if ckpt_key in checkpoint:
29
+ state_dict = checkpoint[ckpt_key]
30
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
31
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
32
+ msg = model.load_state_dict(state_dict, strict=False)
33
+ return model
34
+
35
+ def get_linear_layer(weight, bias):
36
+ """ Creates a layer that performs feature whitening or centering """
37
+ dim_out, dim_in = weight.shape
38
+ layer = nn.Linear(dim_in, dim_out)
39
+ layer.weight = nn.Parameter(weight)
40
+ layer.bias = nn.Parameter(bias)
41
+ return layer
42
+
43
+ def load_normalization_layer(path):
44
+ """
45
+ Loads the normalization layer from a checkpoint and returns the layer.
46
+ """
47
+ checkpoint = torch.load(path, map_location=device)
48
+ if 'whitening' in path or 'out' in path:
49
+ D = checkpoint['weight'].shape[1]
50
+ weight = torch.nn.Parameter(D*checkpoint['weight'])
51
+ bias = torch.nn.Parameter(D*checkpoint['bias'])
52
+ else:
53
+ weight = checkpoint['weight']
54
+ bias = checkpoint['bias']
55
+ return get_linear_layer(weight, bias).to(device, non_blocking=True)
56
+
57
+ class NormLayerWrapper(nn.Module):
58
+ """
59
+ Wraps backbone model and normalization layer
60
+ """
61
+ def __init__(self, backbone, head):
62
+ super(NormLayerWrapper, self).__init__()
63
+ backbone.eval(), head.eval()
64
+ self.backbone = backbone
65
+ self.head = head
66
+
67
+ def forward(self, x):
68
+ output = self.backbone(x)
69
+ return self.head(output)
70
+
71
+ backbone = build_backbone(path='dino_r50.pth')
72
+ normlayer = load_normalization_layer(path='out2048.pth')
73
+ model = NormLayerWrapper(backbone, normlayer)
74
+
75
+ def encode(image):
76
+ return image
77
+
78
+ def decode(image):
79
+ return 'decoded'
80
+
81
+ def on_submit(image, mode):
82
+ print('{} mode'.format(mode))
83
+ if mode=='Encode':
84
+ return encode(image), 'Successfully encoded'
85
+ else:
86
+ return image, decode(image)
87
+
88
+ iface = gr.Interface(
89
+ fn=on_submit,
90
+ inputs=[
91
+ grinputs.Image(),
92
+ grinputs.Radio(['Encode', 'Decode'], label="Encode or Decode mode")],
93
+ outputs=[
94
+ groutputs.Image(label='Watermarked image'),
95
+ groutputs.Textbox(label='Information')],
96
+ allow_screenshot=False,
97
+ allow_flagging="auto",
98
+ )
99
+ iface.launch()
dino_r50.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab26d85d00cb1be8e757cf8820cf0fd8aa729ea7e21b1cf6c44875952ba8eb0f
3
+ size 788803344
out2048.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b256188454d8f7cf440de048df398e2a3209136a52cd7cdac834f5792f526a3
3
+ size 16786561