import gradio as gr import gradio.inputs as grinputs import gradio.outputs as groutputs import numpy as np import torch import torch.nn as nn from torchvision import models device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.manual_seed(0) np.random.seed(0) FPR = 1e-6 carrier = np.random.randn(1, 2048) def build_backbone(path, name='resnet50'): """ Builds a pretrained ResNet-50 backbone. """ model = getattr(models, name)(pretrained=False) model.head = nn.Identity() model.fc = nn.Identity() checkpoint = torch.load(path, map_location=device) state_dict = checkpoint for ckpt_key in ['state_dict', 'model_state_dict', 'teacher']: if ckpt_key in checkpoint: state_dict = checkpoint[ckpt_key] state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} msg = model.load_state_dict(state_dict, strict=False) return model def get_linear_layer(weight, bias): """ Creates a layer that performs feature whitening or centering """ dim_out, dim_in = weight.shape layer = nn.Linear(dim_in, dim_out) layer.weight = nn.Parameter(weight) layer.bias = nn.Parameter(bias) return layer def load_normalization_layer(path): """ Loads the normalization layer from a checkpoint and returns the layer. """ checkpoint = torch.load(path, map_location=device) if 'whitening' in path or 'out' in path: D = checkpoint['weight'].shape[1] weight = torch.nn.Parameter(D*checkpoint['weight']) bias = torch.nn.Parameter(D*checkpoint['bias']) else: weight = checkpoint['weight'] bias = checkpoint['bias'] return get_linear_layer(weight, bias).to(device, non_blocking=True) class NormLayerWrapper(nn.Module): """ Wraps backbone model and normalization layer """ def __init__(self, backbone, head): super(NormLayerWrapper, self).__init__() backbone.eval(), head.eval() self.backbone = backbone self.head = head def forward(self, x): output = self.backbone(x) return self.head(output) backbone = build_backbone(path='dino_r50.pth') normlayer = load_normalization_layer(path='out2048.pth') model = NormLayerWrapper(backbone, normlayer) def encode(image): return image def decode(image): return 'decoded' def on_submit(image, mode): print('{} mode'.format(mode)) if mode=='Encode': return encode(image), 'Successfully encoded' else: return image, decode(image) iface = gr.Interface( fn=on_submit, inputs=[ grinputs.Image(), grinputs.Radio(['Encode', 'Decode'], label="Encode or Decode mode")], outputs=[ groutputs.Image(label='Watermarked image'), groutputs.Textbox(label='Information')], allow_screenshot=False, allow_flagging="auto", ) iface.launch()