Spaces:
Runtime error
Runtime error
File size: 2,969 Bytes
4bee283 d78a77d 4bee283 d78a77d 4bee283 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import gradio as gr
import gradio.inputs as grinputs
import gradio.outputs as groutputs
import numpy as np
import torch
import torch.nn as nn
from torchvision import models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
np.random.seed(0)
FPR = 1e-6
carrier = np.random.randn(1, 2048)
def build_backbone(path, name='resnet50'):
""" Builds a pretrained ResNet-50 backbone. """
model = getattr(models, name)(pretrained=False)
model.head = nn.Identity()
model.fc = nn.Identity()
checkpoint = torch.load(path, map_location=device)
state_dict = checkpoint
for ckpt_key in ['state_dict', 'model_state_dict', 'teacher']:
if ckpt_key in checkpoint:
state_dict = checkpoint[ckpt_key]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
msg = model.load_state_dict(state_dict, strict=False)
return model
def get_linear_layer(weight, bias):
""" Creates a layer that performs feature whitening or centering """
dim_out, dim_in = weight.shape
layer = nn.Linear(dim_in, dim_out)
layer.weight = nn.Parameter(weight)
layer.bias = nn.Parameter(bias)
return layer
def load_normalization_layer(path):
"""
Loads the normalization layer from a checkpoint and returns the layer.
"""
checkpoint = torch.load(path, map_location=device)
if 'whitening' in path or 'out' in path:
D = checkpoint['weight'].shape[1]
weight = torch.nn.Parameter(D*checkpoint['weight'])
bias = torch.nn.Parameter(D*checkpoint['bias'])
else:
weight = checkpoint['weight']
bias = checkpoint['bias']
return get_linear_layer(weight, bias).to(device, non_blocking=True)
class NormLayerWrapper(nn.Module):
"""
Wraps backbone model and normalization layer
"""
def __init__(self, backbone, head):
super(NormLayerWrapper, self).__init__()
backbone.eval(), head.eval()
self.backbone = backbone
self.head = head
def forward(self, x):
output = self.backbone(x)
return self.head(output)
backbone = build_backbone(path='dino_r50.pth')
normlayer = load_normalization_layer(path='out2048.pth')
model = NormLayerWrapper(backbone, normlayer)
def encode(image):
return image
def decode(image):
return 'decoded'
def on_submit(image, mode):
print('{} mode'.format(mode))
if mode=='Encode':
return encode(image), 'Successfully encoded'
else:
return image, decode(image)
iface = gr.Interface(
fn=on_submit,
inputs=[
grinputs.Image(),
grinputs.Radio(['Encode', 'Decode'], label="Encode or Decode mode")],
outputs=[
groutputs.Image(label='Watermarked image'),
groutputs.Textbox(label='Information')],
allow_screenshot=False,
allow_flagging="auto",
)
iface.launch() |