Pierre Fernandez
resolved icon issue
d78a77d
raw
history blame
2.97 kB
import gradio as gr
import gradio.inputs as grinputs
import gradio.outputs as groutputs
import numpy as np
import torch
import torch.nn as nn
from torchvision import models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
np.random.seed(0)
FPR = 1e-6
carrier = np.random.randn(1, 2048)
def build_backbone(path, name='resnet50'):
""" Builds a pretrained ResNet-50 backbone. """
model = getattr(models, name)(pretrained=False)
model.head = nn.Identity()
model.fc = nn.Identity()
checkpoint = torch.load(path, map_location=device)
state_dict = checkpoint
for ckpt_key in ['state_dict', 'model_state_dict', 'teacher']:
if ckpt_key in checkpoint:
state_dict = checkpoint[ckpt_key]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
msg = model.load_state_dict(state_dict, strict=False)
return model
def get_linear_layer(weight, bias):
""" Creates a layer that performs feature whitening or centering """
dim_out, dim_in = weight.shape
layer = nn.Linear(dim_in, dim_out)
layer.weight = nn.Parameter(weight)
layer.bias = nn.Parameter(bias)
return layer
def load_normalization_layer(path):
"""
Loads the normalization layer from a checkpoint and returns the layer.
"""
checkpoint = torch.load(path, map_location=device)
if 'whitening' in path or 'out' in path:
D = checkpoint['weight'].shape[1]
weight = torch.nn.Parameter(D*checkpoint['weight'])
bias = torch.nn.Parameter(D*checkpoint['bias'])
else:
weight = checkpoint['weight']
bias = checkpoint['bias']
return get_linear_layer(weight, bias).to(device, non_blocking=True)
class NormLayerWrapper(nn.Module):
"""
Wraps backbone model and normalization layer
"""
def __init__(self, backbone, head):
super(NormLayerWrapper, self).__init__()
backbone.eval(), head.eval()
self.backbone = backbone
self.head = head
def forward(self, x):
output = self.backbone(x)
return self.head(output)
backbone = build_backbone(path='dino_r50.pth')
normlayer = load_normalization_layer(path='out2048.pth')
model = NormLayerWrapper(backbone, normlayer)
def encode(image):
return image
def decode(image):
return 'decoded'
def on_submit(image, mode):
print('{} mode'.format(mode))
if mode=='Encode':
return encode(image), 'Successfully encoded'
else:
return image, decode(image)
iface = gr.Interface(
fn=on_submit,
inputs=[
grinputs.Image(),
grinputs.Radio(['Encode', 'Decode'], label="Encode or Decode mode")],
outputs=[
groutputs.Image(label='Watermarked image'),
groutputs.Textbox(label='Information')],
allow_screenshot=False,
allow_flagging="auto",
)
iface.launch()