Spaces:
Runtime error
Runtime error
import gradio as gr | |
import gradio.inputs as grinputs | |
import gradio.outputs as groutputs | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torchvision import models | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
torch.manual_seed(0) | |
np.random.seed(0) | |
FPR = 1e-6 | |
carrier = np.random.randn(1, 2048) | |
def build_backbone(path, name='resnet50'): | |
""" Builds a pretrained ResNet-50 backbone. """ | |
model = getattr(models, name)(pretrained=False) | |
model.head = nn.Identity() | |
model.fc = nn.Identity() | |
checkpoint = torch.load(path, map_location=device) | |
state_dict = checkpoint | |
for ckpt_key in ['state_dict', 'model_state_dict', 'teacher']: | |
if ckpt_key in checkpoint: | |
state_dict = checkpoint[ckpt_key] | |
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} | |
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} | |
msg = model.load_state_dict(state_dict, strict=False) | |
return model | |
def get_linear_layer(weight, bias): | |
""" Creates a layer that performs feature whitening or centering """ | |
dim_out, dim_in = weight.shape | |
layer = nn.Linear(dim_in, dim_out) | |
layer.weight = nn.Parameter(weight) | |
layer.bias = nn.Parameter(bias) | |
return layer | |
def load_normalization_layer(path): | |
""" | |
Loads the normalization layer from a checkpoint and returns the layer. | |
""" | |
checkpoint = torch.load(path, map_location=device) | |
if 'whitening' in path or 'out' in path: | |
D = checkpoint['weight'].shape[1] | |
weight = torch.nn.Parameter(D*checkpoint['weight']) | |
bias = torch.nn.Parameter(D*checkpoint['bias']) | |
else: | |
weight = checkpoint['weight'] | |
bias = checkpoint['bias'] | |
return get_linear_layer(weight, bias).to(device, non_blocking=True) | |
class NormLayerWrapper(nn.Module): | |
""" | |
Wraps backbone model and normalization layer | |
""" | |
def __init__(self, backbone, head): | |
super(NormLayerWrapper, self).__init__() | |
backbone.eval(), head.eval() | |
self.backbone = backbone | |
self.head = head | |
def forward(self, x): | |
output = self.backbone(x) | |
return self.head(output) | |
backbone = build_backbone(path='dino_r50.pth') | |
normlayer = load_normalization_layer(path='out2048.pth') | |
model = NormLayerWrapper(backbone, normlayer) | |
def encode(image): | |
return image | |
def decode(image): | |
return 'decoded' | |
def on_submit(image, mode): | |
print('{} mode'.format(mode)) | |
if mode=='Encode': | |
return encode(image), 'Successfully encoded' | |
else: | |
return image, decode(image) | |
iface = gr.Interface( | |
fn=on_submit, | |
inputs=[ | |
grinputs.Image(), | |
grinputs.Radio(['Encode', 'Decode'], label="Encode or Decode mode")], | |
outputs=[ | |
groutputs.Image(label='Watermarked image'), | |
groutputs.Textbox(label='Information')], | |
allow_screenshot=False, | |
allow_flagging="auto", | |
) | |
iface.launch() |