Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- annotator/annotator_path.py +22 -0
- annotator/binary/__init__.py +14 -0
- annotator/canny/__init__.py +5 -0
- annotator/clipvision/__init__.py +123 -0
- annotator/color/__init__.py +20 -0
- annotator/hed/__init__.py +98 -0
- annotator/hed/network-bsds500.pth +3 -0
- annotator/keypose/__init__.py +212 -0
- annotator/keypose/faster_rcnn_r50_fpn_coco.py +182 -0
- annotator/keypose/hrnet_w48_coco_256x192.py +169 -0
- annotator/lama/__init__.py +58 -0
- annotator/lama/config.yaml +157 -0
- annotator/lama/saicinpainting/__init__.py +0 -0
- annotator/lama/saicinpainting/training/__init__.py +0 -0
- annotator/lama/saicinpainting/training/data/__init__.py +0 -0
- annotator/lama/saicinpainting/training/data/masks.py +332 -0
- annotator/lama/saicinpainting/training/losses/__init__.py +0 -0
- annotator/lama/saicinpainting/training/losses/adversarial.py +177 -0
- annotator/lama/saicinpainting/training/losses/constants.py +152 -0
- annotator/lama/saicinpainting/training/losses/distance_weighting.py +126 -0
- annotator/lama/saicinpainting/training/losses/feature_matching.py +33 -0
- annotator/lama/saicinpainting/training/losses/perceptual.py +113 -0
- annotator/lama/saicinpainting/training/losses/segmentation.py +43 -0
- annotator/lama/saicinpainting/training/losses/style_loss.py +155 -0
- annotator/lama/saicinpainting/training/modules/__init__.py +31 -0
- annotator/lama/saicinpainting/training/modules/base.py +80 -0
- annotator/lama/saicinpainting/training/modules/depthwise_sep_conv.py +17 -0
- annotator/lama/saicinpainting/training/modules/fake_fakes.py +47 -0
- annotator/lama/saicinpainting/training/modules/ffc.py +485 -0
- annotator/lama/saicinpainting/training/modules/multidilated_conv.py +98 -0
- annotator/lama/saicinpainting/training/modules/multiscale.py +244 -0
- annotator/lama/saicinpainting/training/modules/pix2pixhd.py +669 -0
- annotator/lama/saicinpainting/training/modules/spatial_transform.py +49 -0
- annotator/lama/saicinpainting/training/modules/squeeze_excitation.py +20 -0
- annotator/lama/saicinpainting/training/trainers/__init__.py +29 -0
- annotator/lama/saicinpainting/training/trainers/base.py +293 -0
- annotator/lama/saicinpainting/training/trainers/default.py +175 -0
- annotator/lama/saicinpainting/training/visualizers/__init__.py +15 -0
- annotator/lama/saicinpainting/training/visualizers/base.py +73 -0
- annotator/lama/saicinpainting/training/visualizers/colors.py +76 -0
- annotator/lama/saicinpainting/training/visualizers/directory.py +36 -0
- annotator/lama/saicinpainting/training/visualizers/noop.py +9 -0
- annotator/lama/saicinpainting/utils.py +174 -0
- annotator/leres/__init__.py +113 -0
- annotator/leres/leres/LICENSE +23 -0
- annotator/leres/leres/Resnet.py +199 -0
- annotator/leres/leres/Resnext_torch.py +237 -0
- annotator/leres/leres/depthmap.py +546 -0
- annotator/leres/leres/multi_depth_model_woauxi.py +34 -0
- annotator/leres/leres/net_tools.py +54 -0
annotator/annotator_path.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from modules import shared
|
3 |
+
|
4 |
+
models_path = shared.opts.data.get('control_net_modules_path', None)
|
5 |
+
if not models_path:
|
6 |
+
models_path = getattr(shared.cmd_opts, 'controlnet_annotator_models_path', None)
|
7 |
+
if not models_path:
|
8 |
+
models_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'downloads')
|
9 |
+
|
10 |
+
if not os.path.isabs(models_path):
|
11 |
+
models_path = os.path.join(shared.data_path, models_path)
|
12 |
+
|
13 |
+
clip_vision_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'clip_vision')
|
14 |
+
# clip vision is always inside controlnet "extensions\sd-webui-controlnet"
|
15 |
+
# and any problem can be solved by removing controlnet and reinstall
|
16 |
+
|
17 |
+
models_path = os.path.realpath(models_path)
|
18 |
+
os.makedirs(models_path, exist_ok=True)
|
19 |
+
print(f'ControlNet preprocessor location: {models_path}')
|
20 |
+
# Make sure that the default location is inside controlnet "extensions\sd-webui-controlnet"
|
21 |
+
# so that any problem can be solved by removing controlnet and reinstall
|
22 |
+
# if users do not change configs on their own (otherwise users will know what is wrong)
|
annotator/binary/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
|
3 |
+
|
4 |
+
def apply_binary(img, bin_threshold):
|
5 |
+
img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
6 |
+
|
7 |
+
if bin_threshold == 0 or bin_threshold == 255:
|
8 |
+
# Otsu's threshold
|
9 |
+
otsu_threshold, img_bin = cv2.threshold(img_gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
10 |
+
print("Otsu threshold:", otsu_threshold)
|
11 |
+
else:
|
12 |
+
_, img_bin = cv2.threshold(img_gray, bin_threshold, 255, cv2.THRESH_BINARY_INV)
|
13 |
+
|
14 |
+
return cv2.cvtColor(img_bin, cv2.COLOR_GRAY2RGB)
|
annotator/canny/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
|
3 |
+
|
4 |
+
def apply_canny(img, low_threshold, high_threshold):
|
5 |
+
return cv2.Canny(img, low_threshold, high_threshold)
|
annotator/clipvision/__init__.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from modules import devices
|
5 |
+
from modules.modelloader import load_file_from_url
|
6 |
+
from annotator.annotator_path import models_path
|
7 |
+
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor, modeling_utils
|
8 |
+
|
9 |
+
|
10 |
+
config_clip_g = {
|
11 |
+
"attention_dropout": 0.0,
|
12 |
+
"dropout": 0.0,
|
13 |
+
"hidden_act": "gelu",
|
14 |
+
"hidden_size": 1664,
|
15 |
+
"image_size": 224,
|
16 |
+
"initializer_factor": 1.0,
|
17 |
+
"initializer_range": 0.02,
|
18 |
+
"intermediate_size": 8192,
|
19 |
+
"layer_norm_eps": 1e-05,
|
20 |
+
"model_type": "clip_vision_model",
|
21 |
+
"num_attention_heads": 16,
|
22 |
+
"num_channels": 3,
|
23 |
+
"num_hidden_layers": 48,
|
24 |
+
"patch_size": 14,
|
25 |
+
"projection_dim": 1280,
|
26 |
+
"torch_dtype": "float32"
|
27 |
+
}
|
28 |
+
|
29 |
+
config_clip_h = {
|
30 |
+
"attention_dropout": 0.0,
|
31 |
+
"dropout": 0.0,
|
32 |
+
"hidden_act": "gelu",
|
33 |
+
"hidden_size": 1280,
|
34 |
+
"image_size": 224,
|
35 |
+
"initializer_factor": 1.0,
|
36 |
+
"initializer_range": 0.02,
|
37 |
+
"intermediate_size": 5120,
|
38 |
+
"layer_norm_eps": 1e-05,
|
39 |
+
"model_type": "clip_vision_model",
|
40 |
+
"num_attention_heads": 16,
|
41 |
+
"num_channels": 3,
|
42 |
+
"num_hidden_layers": 32,
|
43 |
+
"patch_size": 14,
|
44 |
+
"projection_dim": 1024,
|
45 |
+
"torch_dtype": "float32"
|
46 |
+
}
|
47 |
+
|
48 |
+
config_clip_vitl = {
|
49 |
+
"attention_dropout": 0.0,
|
50 |
+
"dropout": 0.0,
|
51 |
+
"hidden_act": "quick_gelu",
|
52 |
+
"hidden_size": 1024,
|
53 |
+
"image_size": 224,
|
54 |
+
"initializer_factor": 1.0,
|
55 |
+
"initializer_range": 0.02,
|
56 |
+
"intermediate_size": 4096,
|
57 |
+
"layer_norm_eps": 1e-05,
|
58 |
+
"model_type": "clip_vision_model",
|
59 |
+
"num_attention_heads": 16,
|
60 |
+
"num_channels": 3,
|
61 |
+
"num_hidden_layers": 24,
|
62 |
+
"patch_size": 14,
|
63 |
+
"projection_dim": 768,
|
64 |
+
"torch_dtype": "float32"
|
65 |
+
}
|
66 |
+
|
67 |
+
configs = {
|
68 |
+
'clip_g': config_clip_g,
|
69 |
+
'clip_h': config_clip_h,
|
70 |
+
'clip_vitl': config_clip_vitl,
|
71 |
+
}
|
72 |
+
|
73 |
+
downloads = {
|
74 |
+
'clip_vitl': 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin',
|
75 |
+
'clip_g': 'https://huggingface.co/lllyasviel/Annotators/resolve/main/clip_g.pth',
|
76 |
+
'clip_h': 'https://huggingface.co/h94/IP-Adapter/resolve/main/models/image_encoder/pytorch_model.bin'
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
class ClipVisionDetector:
|
81 |
+
def __init__(self, config):
|
82 |
+
assert config in downloads
|
83 |
+
self.download_link = downloads[config]
|
84 |
+
self.model_path = os.path.join(models_path, 'clip_vision')
|
85 |
+
self.file_name = config + '.pth'
|
86 |
+
self.config = configs[config]
|
87 |
+
self.device = devices.get_device_for("controlnet")
|
88 |
+
os.makedirs(self.model_path, exist_ok=True)
|
89 |
+
file_path = os.path.join(self.model_path, self.file_name)
|
90 |
+
if not os.path.exists(file_path):
|
91 |
+
load_file_from_url(url=self.download_link, model_dir=self.model_path, file_name=self.file_name)
|
92 |
+
config = CLIPVisionConfig(**self.config)
|
93 |
+
self.model = CLIPVisionModelWithProjection(config)
|
94 |
+
self.processor = CLIPImageProcessor(crop_size=224,
|
95 |
+
do_center_crop=True,
|
96 |
+
do_convert_rgb=True,
|
97 |
+
do_normalize=True,
|
98 |
+
do_resize=True,
|
99 |
+
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
100 |
+
image_std=[0.26862954, 0.26130258, 0.27577711],
|
101 |
+
resample=3,
|
102 |
+
size=224)
|
103 |
+
|
104 |
+
sd = torch.load(file_path, map_location=torch.device('cpu'))
|
105 |
+
self.model.load_state_dict(sd, strict=False)
|
106 |
+
del sd
|
107 |
+
|
108 |
+
self.model.eval()
|
109 |
+
self.model.cpu()
|
110 |
+
|
111 |
+
def unload_model(self):
|
112 |
+
if self.model is not None:
|
113 |
+
self.model.to('meta')
|
114 |
+
|
115 |
+
def __call__(self, input_image):
|
116 |
+
with torch.no_grad():
|
117 |
+
clip_vision_model = self.model.cpu()
|
118 |
+
feat = self.processor(images=input_image, return_tensors="pt")
|
119 |
+
feat['pixel_values'] = feat['pixel_values'].cpu()
|
120 |
+
result = clip_vision_model(**feat, output_hidden_states=True)
|
121 |
+
result['hidden_states'] = [v.to(devices.get_device_for("controlnet")) for v in result['hidden_states']]
|
122 |
+
result = {k: v.to(devices.get_device_for("controlnet")) if isinstance(v, torch.Tensor) else v for k, v in result.items()}
|
123 |
+
return result
|
annotator/color/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
|
3 |
+
def cv2_resize_shortest_edge(image, size):
|
4 |
+
h, w = image.shape[:2]
|
5 |
+
if h < w:
|
6 |
+
new_h = size
|
7 |
+
new_w = int(round(w / h * size))
|
8 |
+
else:
|
9 |
+
new_w = size
|
10 |
+
new_h = int(round(h / w * size))
|
11 |
+
resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
12 |
+
return resized_image
|
13 |
+
|
14 |
+
def apply_color(img, res=512):
|
15 |
+
img = cv2_resize_shortest_edge(img, res)
|
16 |
+
h, w = img.shape[:2]
|
17 |
+
|
18 |
+
input_img_color = cv2.resize(img, (w//64, h//64), interpolation=cv2.INTER_CUBIC)
|
19 |
+
input_img_color = cv2.resize(input_img_color, (w, h), interpolation=cv2.INTER_NEAREST)
|
20 |
+
return input_img_color
|
annotator/hed/__init__.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is an improved version and model of HED edge detection with Apache License, Version 2.0.
|
2 |
+
# Please use this implementation in your products
|
3 |
+
# This implementation may produce slightly different results from Saining Xie's official implementations,
|
4 |
+
# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
|
5 |
+
# Different from official models and other implementations, this is an RGB-input model (rather than BGR)
|
6 |
+
# and in this way it works better for gradio's RGB protocol
|
7 |
+
|
8 |
+
import os
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from einops import rearrange
|
14 |
+
import os
|
15 |
+
from modules import devices
|
16 |
+
from annotator.annotator_path import models_path
|
17 |
+
from annotator.util import safe_step, nms
|
18 |
+
|
19 |
+
|
20 |
+
class DoubleConvBlock(torch.nn.Module):
|
21 |
+
def __init__(self, input_channel, output_channel, layer_number):
|
22 |
+
super().__init__()
|
23 |
+
self.convs = torch.nn.Sequential()
|
24 |
+
self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
25 |
+
for i in range(1, layer_number):
|
26 |
+
self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
27 |
+
self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
|
28 |
+
|
29 |
+
def __call__(self, x, down_sampling=False):
|
30 |
+
h = x
|
31 |
+
if down_sampling:
|
32 |
+
h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
|
33 |
+
for conv in self.convs:
|
34 |
+
h = conv(h)
|
35 |
+
h = torch.nn.functional.relu(h)
|
36 |
+
return h, self.projection(h)
|
37 |
+
|
38 |
+
|
39 |
+
class ControlNetHED_Apache2(torch.nn.Module):
|
40 |
+
def __init__(self):
|
41 |
+
super().__init__()
|
42 |
+
self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
|
43 |
+
self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
|
44 |
+
self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
|
45 |
+
self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
|
46 |
+
self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
|
47 |
+
self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
|
48 |
+
|
49 |
+
def __call__(self, x):
|
50 |
+
h = x - self.norm
|
51 |
+
h, projection1 = self.block1(h)
|
52 |
+
h, projection2 = self.block2(h, down_sampling=True)
|
53 |
+
h, projection3 = self.block3(h, down_sampling=True)
|
54 |
+
h, projection4 = self.block4(h, down_sampling=True)
|
55 |
+
h, projection5 = self.block5(h, down_sampling=True)
|
56 |
+
return projection1, projection2, projection3, projection4, projection5
|
57 |
+
|
58 |
+
|
59 |
+
netNetwork = None
|
60 |
+
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth"
|
61 |
+
modeldir = os.path.join(models_path, "hed")
|
62 |
+
old_modeldir = os.path.dirname(os.path.realpath(__file__))
|
63 |
+
|
64 |
+
|
65 |
+
def apply_hed(input_image, is_safe=False):
|
66 |
+
global netNetwork
|
67 |
+
if netNetwork is None:
|
68 |
+
modelpath = os.path.join(modeldir, "ControlNetHED.pth")
|
69 |
+
old_modelpath = os.path.join(old_modeldir, "ControlNetHED.pth")
|
70 |
+
if os.path.exists(old_modelpath):
|
71 |
+
modelpath = old_modelpath
|
72 |
+
elif not os.path.exists(modelpath):
|
73 |
+
from basicsr.utils.download_util import load_file_from_url
|
74 |
+
load_file_from_url(remote_model_path, model_dir=modeldir)
|
75 |
+
netNetwork = ControlNetHED_Apache2().to(devices.get_device_for("controlnet"))
|
76 |
+
netNetwork.load_state_dict(torch.load(modelpath, map_location='cpu'))
|
77 |
+
netNetwork.to(devices.get_device_for("controlnet")).float().eval()
|
78 |
+
|
79 |
+
assert input_image.ndim == 3
|
80 |
+
H, W, C = input_image.shape
|
81 |
+
with torch.no_grad():
|
82 |
+
image_hed = torch.from_numpy(input_image.copy()).float().to(devices.get_device_for("controlnet"))
|
83 |
+
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
84 |
+
edges = netNetwork(image_hed)
|
85 |
+
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
|
86 |
+
edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
|
87 |
+
edges = np.stack(edges, axis=2)
|
88 |
+
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
|
89 |
+
if is_safe:
|
90 |
+
edge = safe_step(edge)
|
91 |
+
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
|
92 |
+
return edge
|
93 |
+
|
94 |
+
|
95 |
+
def unload_hed_model():
|
96 |
+
global netNetwork
|
97 |
+
if netNetwork is not None:
|
98 |
+
netNetwork.cpu()
|
annotator/hed/network-bsds500.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:58a858782f5fa3e0ca3dc92e7a1a609add93987d77be3dfa54f8f8419d881a94
|
3 |
+
size 58871680
|
annotator/keypose/__init__.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import os
|
6 |
+
from modules import devices
|
7 |
+
from annotator.annotator_path import models_path
|
8 |
+
|
9 |
+
import mmcv
|
10 |
+
from mmdet.apis import inference_detector, init_detector
|
11 |
+
from mmpose.apis import inference_top_down_pose_model
|
12 |
+
from mmpose.apis import init_pose_model, process_mmdet_results, vis_pose_result
|
13 |
+
|
14 |
+
|
15 |
+
def preprocessing(image, device):
|
16 |
+
# Resize
|
17 |
+
scale = 640 / max(image.shape[:2])
|
18 |
+
image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
|
19 |
+
raw_image = image.astype(np.uint8)
|
20 |
+
|
21 |
+
# Subtract mean values
|
22 |
+
image = image.astype(np.float32)
|
23 |
+
image -= np.array(
|
24 |
+
[
|
25 |
+
float(104.008),
|
26 |
+
float(116.669),
|
27 |
+
float(122.675),
|
28 |
+
]
|
29 |
+
)
|
30 |
+
|
31 |
+
# Convert to torch.Tensor and add "batch" axis
|
32 |
+
image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
|
33 |
+
image = image.to(device)
|
34 |
+
|
35 |
+
return image, raw_image
|
36 |
+
|
37 |
+
|
38 |
+
def imshow_keypoints(img,
|
39 |
+
pose_result,
|
40 |
+
skeleton=None,
|
41 |
+
kpt_score_thr=0.1,
|
42 |
+
pose_kpt_color=None,
|
43 |
+
pose_link_color=None,
|
44 |
+
radius=4,
|
45 |
+
thickness=1):
|
46 |
+
"""Draw keypoints and links on an image.
|
47 |
+
Args:
|
48 |
+
img (ndarry): The image to draw poses on.
|
49 |
+
pose_result (list[kpts]): The poses to draw. Each element kpts is
|
50 |
+
a set of K keypoints as an Kx3 numpy.ndarray, where each
|
51 |
+
keypoint is represented as x, y, score.
|
52 |
+
kpt_score_thr (float, optional): Minimum score of keypoints
|
53 |
+
to be shown. Default: 0.3.
|
54 |
+
pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
|
55 |
+
the keypoint will not be drawn.
|
56 |
+
pose_link_color (np.array[Mx3]): Color of M links. If None, the
|
57 |
+
links will not be drawn.
|
58 |
+
thickness (int): Thickness of lines.
|
59 |
+
"""
|
60 |
+
|
61 |
+
img_h, img_w, _ = img.shape
|
62 |
+
img = np.zeros(img.shape)
|
63 |
+
|
64 |
+
for idx, kpts in enumerate(pose_result):
|
65 |
+
if idx > 1:
|
66 |
+
continue
|
67 |
+
kpts = kpts['keypoints']
|
68 |
+
# print(kpts)
|
69 |
+
kpts = np.array(kpts, copy=False)
|
70 |
+
|
71 |
+
# draw each point on image
|
72 |
+
if pose_kpt_color is not None:
|
73 |
+
assert len(pose_kpt_color) == len(kpts)
|
74 |
+
|
75 |
+
for kid, kpt in enumerate(kpts):
|
76 |
+
x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
|
77 |
+
|
78 |
+
if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None:
|
79 |
+
# skip the point that should not be drawn
|
80 |
+
continue
|
81 |
+
|
82 |
+
color = tuple(int(c) for c in pose_kpt_color[kid])
|
83 |
+
cv2.circle(img, (int(x_coord), int(y_coord)),
|
84 |
+
radius, color, -1)
|
85 |
+
|
86 |
+
# draw links
|
87 |
+
if skeleton is not None and pose_link_color is not None:
|
88 |
+
assert len(pose_link_color) == len(skeleton)
|
89 |
+
|
90 |
+
for sk_id, sk in enumerate(skeleton):
|
91 |
+
pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
|
92 |
+
pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
|
93 |
+
|
94 |
+
if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0
|
95 |
+
or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr
|
96 |
+
or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None):
|
97 |
+
# skip the link that should not be drawn
|
98 |
+
continue
|
99 |
+
color = tuple(int(c) for c in pose_link_color[sk_id])
|
100 |
+
cv2.line(img, pos1, pos2, color, thickness=thickness)
|
101 |
+
|
102 |
+
return img
|
103 |
+
|
104 |
+
|
105 |
+
human_det, pose_model = None, None
|
106 |
+
det_model_path = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
|
107 |
+
pose_model_path = "https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth"
|
108 |
+
|
109 |
+
modeldir = os.path.join(models_path, "keypose")
|
110 |
+
old_modeldir = os.path.dirname(os.path.realpath(__file__))
|
111 |
+
|
112 |
+
det_config = 'faster_rcnn_r50_fpn_coco.py'
|
113 |
+
pose_config = 'hrnet_w48_coco_256x192.py'
|
114 |
+
|
115 |
+
det_checkpoint = 'faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
|
116 |
+
pose_checkpoint = 'hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
|
117 |
+
det_cat_id = 1
|
118 |
+
bbox_thr = 0.2
|
119 |
+
|
120 |
+
skeleton = [
|
121 |
+
[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8],
|
122 |
+
[7, 9], [8, 10],
|
123 |
+
[1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]
|
124 |
+
]
|
125 |
+
|
126 |
+
pose_kpt_color = [
|
127 |
+
[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
|
128 |
+
[0, 255, 0],
|
129 |
+
[255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0],
|
130 |
+
[255, 128, 0],
|
131 |
+
[0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]
|
132 |
+
]
|
133 |
+
|
134 |
+
pose_link_color = [
|
135 |
+
[0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0],
|
136 |
+
[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0],
|
137 |
+
[255, 128, 0],
|
138 |
+
[0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255],
|
139 |
+
[51, 153, 255],
|
140 |
+
[51, 153, 255], [51, 153, 255], [51, 153, 255]
|
141 |
+
]
|
142 |
+
|
143 |
+
def find_download_model(checkpoint, remote_path):
|
144 |
+
modelpath = os.path.join(modeldir, checkpoint)
|
145 |
+
old_modelpath = os.path.join(old_modeldir, checkpoint)
|
146 |
+
|
147 |
+
if os.path.exists(old_modelpath):
|
148 |
+
modelpath = old_modelpath
|
149 |
+
elif not os.path.exists(modelpath):
|
150 |
+
from basicsr.utils.download_util import load_file_from_url
|
151 |
+
load_file_from_url(remote_path, model_dir=modeldir)
|
152 |
+
|
153 |
+
return modelpath
|
154 |
+
|
155 |
+
def apply_keypose(input_image):
|
156 |
+
global human_det, pose_model
|
157 |
+
if netNetwork is None:
|
158 |
+
det_model_local = find_download_model(det_checkpoint, det_model_path)
|
159 |
+
hrnet_model_local = find_download_model(pose_checkpoint, pose_model_path)
|
160 |
+
det_config_mmcv = mmcv.Config.fromfile(det_config)
|
161 |
+
pose_config_mmcv = mmcv.Config.fromfile(pose_config)
|
162 |
+
human_det = init_detector(det_config_mmcv, det_model_local, device=devices.get_device_for("controlnet"))
|
163 |
+
pose_model = init_pose_model(pose_config_mmcv, hrnet_model_local, device=devices.get_device_for("controlnet"))
|
164 |
+
|
165 |
+
assert input_image.ndim == 3
|
166 |
+
input_image = input_image.copy()
|
167 |
+
with torch.no_grad():
|
168 |
+
image = torch.from_numpy(input_image).float().to(devices.get_device_for("controlnet"))
|
169 |
+
image = image / 255.0
|
170 |
+
mmdet_results = inference_detector(human_det, image)
|
171 |
+
|
172 |
+
# keep the person class bounding boxes.
|
173 |
+
person_results = process_mmdet_results(mmdet_results, det_cat_id)
|
174 |
+
|
175 |
+
return_heatmap = False
|
176 |
+
dataset = pose_model.cfg.data['test']['type']
|
177 |
+
|
178 |
+
# e.g. use ('backbone', ) to return backbone feature
|
179 |
+
output_layer_names = None
|
180 |
+
pose_results, _ = inference_top_down_pose_model(
|
181 |
+
pose_model,
|
182 |
+
image,
|
183 |
+
person_results,
|
184 |
+
bbox_thr=bbox_thr,
|
185 |
+
format='xyxy',
|
186 |
+
dataset=dataset,
|
187 |
+
dataset_info=None,
|
188 |
+
return_heatmap=return_heatmap,
|
189 |
+
outputs=output_layer_names
|
190 |
+
)
|
191 |
+
|
192 |
+
im_keypose_out = imshow_keypoints(
|
193 |
+
image,
|
194 |
+
pose_results,
|
195 |
+
skeleton=skeleton,
|
196 |
+
pose_kpt_color=pose_kpt_color,
|
197 |
+
pose_link_color=pose_link_color,
|
198 |
+
radius=2,
|
199 |
+
thickness=2
|
200 |
+
)
|
201 |
+
im_keypose_out = im_keypose_out.astype(np.uint8)
|
202 |
+
|
203 |
+
# image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
204 |
+
# edge = netNetwork(image_hed)[0]
|
205 |
+
# edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
|
206 |
+
return im_keypose_out
|
207 |
+
|
208 |
+
|
209 |
+
def unload_hed_model():
|
210 |
+
global netNetwork
|
211 |
+
if netNetwork is not None:
|
212 |
+
netNetwork.cpu()
|
annotator/keypose/faster_rcnn_r50_fpn_coco.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
checkpoint_config = dict(interval=1)
|
2 |
+
# yapf:disable
|
3 |
+
log_config = dict(
|
4 |
+
interval=50,
|
5 |
+
hooks=[
|
6 |
+
dict(type='TextLoggerHook'),
|
7 |
+
# dict(type='TensorboardLoggerHook')
|
8 |
+
])
|
9 |
+
# yapf:enable
|
10 |
+
dist_params = dict(backend='nccl')
|
11 |
+
log_level = 'INFO'
|
12 |
+
load_from = None
|
13 |
+
resume_from = None
|
14 |
+
workflow = [('train', 1)]
|
15 |
+
# optimizer
|
16 |
+
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
|
17 |
+
optimizer_config = dict(grad_clip=None)
|
18 |
+
# learning policy
|
19 |
+
lr_config = dict(
|
20 |
+
policy='step',
|
21 |
+
warmup='linear',
|
22 |
+
warmup_iters=500,
|
23 |
+
warmup_ratio=0.001,
|
24 |
+
step=[8, 11])
|
25 |
+
total_epochs = 12
|
26 |
+
|
27 |
+
model = dict(
|
28 |
+
type='FasterRCNN',
|
29 |
+
pretrained='torchvision://resnet50',
|
30 |
+
backbone=dict(
|
31 |
+
type='ResNet',
|
32 |
+
depth=50,
|
33 |
+
num_stages=4,
|
34 |
+
out_indices=(0, 1, 2, 3),
|
35 |
+
frozen_stages=1,
|
36 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
37 |
+
norm_eval=True,
|
38 |
+
style='pytorch'),
|
39 |
+
neck=dict(
|
40 |
+
type='FPN',
|
41 |
+
in_channels=[256, 512, 1024, 2048],
|
42 |
+
out_channels=256,
|
43 |
+
num_outs=5),
|
44 |
+
rpn_head=dict(
|
45 |
+
type='RPNHead',
|
46 |
+
in_channels=256,
|
47 |
+
feat_channels=256,
|
48 |
+
anchor_generator=dict(
|
49 |
+
type='AnchorGenerator',
|
50 |
+
scales=[8],
|
51 |
+
ratios=[0.5, 1.0, 2.0],
|
52 |
+
strides=[4, 8, 16, 32, 64]),
|
53 |
+
bbox_coder=dict(
|
54 |
+
type='DeltaXYWHBBoxCoder',
|
55 |
+
target_means=[.0, .0, .0, .0],
|
56 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
57 |
+
loss_cls=dict(
|
58 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
59 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
60 |
+
roi_head=dict(
|
61 |
+
type='StandardRoIHead',
|
62 |
+
bbox_roi_extractor=dict(
|
63 |
+
type='SingleRoIExtractor',
|
64 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
65 |
+
out_channels=256,
|
66 |
+
featmap_strides=[4, 8, 16, 32]),
|
67 |
+
bbox_head=dict(
|
68 |
+
type='Shared2FCBBoxHead',
|
69 |
+
in_channels=256,
|
70 |
+
fc_out_channels=1024,
|
71 |
+
roi_feat_size=7,
|
72 |
+
num_classes=80,
|
73 |
+
bbox_coder=dict(
|
74 |
+
type='DeltaXYWHBBoxCoder',
|
75 |
+
target_means=[0., 0., 0., 0.],
|
76 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
77 |
+
reg_class_agnostic=False,
|
78 |
+
loss_cls=dict(
|
79 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
80 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
|
81 |
+
# model training and testing settings
|
82 |
+
train_cfg=dict(
|
83 |
+
rpn=dict(
|
84 |
+
assigner=dict(
|
85 |
+
type='MaxIoUAssigner',
|
86 |
+
pos_iou_thr=0.7,
|
87 |
+
neg_iou_thr=0.3,
|
88 |
+
min_pos_iou=0.3,
|
89 |
+
match_low_quality=True,
|
90 |
+
ignore_iof_thr=-1),
|
91 |
+
sampler=dict(
|
92 |
+
type='RandomSampler',
|
93 |
+
num=256,
|
94 |
+
pos_fraction=0.5,
|
95 |
+
neg_pos_ub=-1,
|
96 |
+
add_gt_as_proposals=False),
|
97 |
+
allowed_border=-1,
|
98 |
+
pos_weight=-1,
|
99 |
+
debug=False),
|
100 |
+
rpn_proposal=dict(
|
101 |
+
nms_pre=2000,
|
102 |
+
max_per_img=1000,
|
103 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
104 |
+
min_bbox_size=0),
|
105 |
+
rcnn=dict(
|
106 |
+
assigner=dict(
|
107 |
+
type='MaxIoUAssigner',
|
108 |
+
pos_iou_thr=0.5,
|
109 |
+
neg_iou_thr=0.5,
|
110 |
+
min_pos_iou=0.5,
|
111 |
+
match_low_quality=False,
|
112 |
+
ignore_iof_thr=-1),
|
113 |
+
sampler=dict(
|
114 |
+
type='RandomSampler',
|
115 |
+
num=512,
|
116 |
+
pos_fraction=0.25,
|
117 |
+
neg_pos_ub=-1,
|
118 |
+
add_gt_as_proposals=True),
|
119 |
+
pos_weight=-1,
|
120 |
+
debug=False)),
|
121 |
+
test_cfg=dict(
|
122 |
+
rpn=dict(
|
123 |
+
nms_pre=1000,
|
124 |
+
max_per_img=1000,
|
125 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
126 |
+
min_bbox_size=0),
|
127 |
+
rcnn=dict(
|
128 |
+
score_thr=0.05,
|
129 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
130 |
+
max_per_img=100)
|
131 |
+
# soft-nms is also supported for rcnn testing
|
132 |
+
# e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
|
133 |
+
))
|
134 |
+
|
135 |
+
dataset_type = 'CocoDataset'
|
136 |
+
data_root = 'data/coco'
|
137 |
+
img_norm_cfg = dict(
|
138 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
139 |
+
train_pipeline = [
|
140 |
+
dict(type='LoadImageFromFile'),
|
141 |
+
dict(type='LoadAnnotations', with_bbox=True),
|
142 |
+
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
|
143 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
144 |
+
dict(type='Normalize', **img_norm_cfg),
|
145 |
+
dict(type='Pad', size_divisor=32),
|
146 |
+
dict(type='DefaultFormatBundle'),
|
147 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
|
148 |
+
]
|
149 |
+
test_pipeline = [
|
150 |
+
dict(type='LoadImageFromFile'),
|
151 |
+
dict(
|
152 |
+
type='MultiScaleFlipAug',
|
153 |
+
img_scale=(1333, 800),
|
154 |
+
flip=False,
|
155 |
+
transforms=[
|
156 |
+
dict(type='Resize', keep_ratio=True),
|
157 |
+
dict(type='RandomFlip'),
|
158 |
+
dict(type='Normalize', **img_norm_cfg),
|
159 |
+
dict(type='Pad', size_divisor=32),
|
160 |
+
dict(type='DefaultFormatBundle'),
|
161 |
+
dict(type='Collect', keys=['img']),
|
162 |
+
])
|
163 |
+
]
|
164 |
+
data = dict(
|
165 |
+
samples_per_gpu=2,
|
166 |
+
workers_per_gpu=2,
|
167 |
+
train=dict(
|
168 |
+
type=dataset_type,
|
169 |
+
ann_file=f'{data_root}/annotations/instances_train2017.json',
|
170 |
+
img_prefix=f'{data_root}/train2017/',
|
171 |
+
pipeline=train_pipeline),
|
172 |
+
val=dict(
|
173 |
+
type=dataset_type,
|
174 |
+
ann_file=f'{data_root}/annotations/instances_val2017.json',
|
175 |
+
img_prefix=f'{data_root}/val2017/',
|
176 |
+
pipeline=test_pipeline),
|
177 |
+
test=dict(
|
178 |
+
type=dataset_type,
|
179 |
+
ann_file=f'{data_root}/annotations/instances_val2017.json',
|
180 |
+
img_prefix=f'{data_root}/val2017/',
|
181 |
+
pipeline=test_pipeline))
|
182 |
+
evaluation = dict(interval=1, metric='bbox')
|
annotator/keypose/hrnet_w48_coco_256x192.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# _base_ = [
|
2 |
+
# '../../../../_base_/default_runtime.py',
|
3 |
+
# '../../../../_base_/datasets/coco.py'
|
4 |
+
# ]
|
5 |
+
evaluation = dict(interval=10, metric='mAP', save_best='AP')
|
6 |
+
|
7 |
+
optimizer = dict(
|
8 |
+
type='Adam',
|
9 |
+
lr=5e-4,
|
10 |
+
)
|
11 |
+
optimizer_config = dict(grad_clip=None)
|
12 |
+
# learning policy
|
13 |
+
lr_config = dict(
|
14 |
+
policy='step',
|
15 |
+
warmup='linear',
|
16 |
+
warmup_iters=500,
|
17 |
+
warmup_ratio=0.001,
|
18 |
+
step=[170, 200])
|
19 |
+
total_epochs = 210
|
20 |
+
channel_cfg = dict(
|
21 |
+
num_output_channels=17,
|
22 |
+
dataset_joints=17,
|
23 |
+
dataset_channel=[
|
24 |
+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
25 |
+
],
|
26 |
+
inference_channel=[
|
27 |
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
|
28 |
+
])
|
29 |
+
|
30 |
+
# model settings
|
31 |
+
model = dict(
|
32 |
+
type='TopDown',
|
33 |
+
pretrained='https://download.openmmlab.com/mmpose/'
|
34 |
+
'pretrain_models/hrnet_w48-8ef0771d.pth',
|
35 |
+
backbone=dict(
|
36 |
+
type='HRNet',
|
37 |
+
in_channels=3,
|
38 |
+
extra=dict(
|
39 |
+
stage1=dict(
|
40 |
+
num_modules=1,
|
41 |
+
num_branches=1,
|
42 |
+
block='BOTTLENECK',
|
43 |
+
num_blocks=(4, ),
|
44 |
+
num_channels=(64, )),
|
45 |
+
stage2=dict(
|
46 |
+
num_modules=1,
|
47 |
+
num_branches=2,
|
48 |
+
block='BASIC',
|
49 |
+
num_blocks=(4, 4),
|
50 |
+
num_channels=(48, 96)),
|
51 |
+
stage3=dict(
|
52 |
+
num_modules=4,
|
53 |
+
num_branches=3,
|
54 |
+
block='BASIC',
|
55 |
+
num_blocks=(4, 4, 4),
|
56 |
+
num_channels=(48, 96, 192)),
|
57 |
+
stage4=dict(
|
58 |
+
num_modules=3,
|
59 |
+
num_branches=4,
|
60 |
+
block='BASIC',
|
61 |
+
num_blocks=(4, 4, 4, 4),
|
62 |
+
num_channels=(48, 96, 192, 384))),
|
63 |
+
),
|
64 |
+
keypoint_head=dict(
|
65 |
+
type='TopdownHeatmapSimpleHead',
|
66 |
+
in_channels=48,
|
67 |
+
out_channels=channel_cfg['num_output_channels'],
|
68 |
+
num_deconv_layers=0,
|
69 |
+
extra=dict(final_conv_kernel=1, ),
|
70 |
+
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
|
71 |
+
train_cfg=dict(),
|
72 |
+
test_cfg=dict(
|
73 |
+
flip_test=True,
|
74 |
+
post_process='default',
|
75 |
+
shift_heatmap=True,
|
76 |
+
modulate_kernel=11))
|
77 |
+
|
78 |
+
data_cfg = dict(
|
79 |
+
image_size=[192, 256],
|
80 |
+
heatmap_size=[48, 64],
|
81 |
+
num_output_channels=channel_cfg['num_output_channels'],
|
82 |
+
num_joints=channel_cfg['dataset_joints'],
|
83 |
+
dataset_channel=channel_cfg['dataset_channel'],
|
84 |
+
inference_channel=channel_cfg['inference_channel'],
|
85 |
+
soft_nms=False,
|
86 |
+
nms_thr=1.0,
|
87 |
+
oks_thr=0.9,
|
88 |
+
vis_thr=0.2,
|
89 |
+
use_gt_bbox=False,
|
90 |
+
det_bbox_thr=0.0,
|
91 |
+
bbox_file='data/coco/person_detection_results/'
|
92 |
+
'COCO_val2017_detections_AP_H_56_person.json',
|
93 |
+
)
|
94 |
+
|
95 |
+
train_pipeline = [
|
96 |
+
dict(type='LoadImageFromFile'),
|
97 |
+
dict(type='TopDownGetBboxCenterScale', padding=1.25),
|
98 |
+
dict(type='TopDownRandomShiftBboxCenter', shift_factor=0.16, prob=0.3),
|
99 |
+
dict(type='TopDownRandomFlip', flip_prob=0.5),
|
100 |
+
dict(
|
101 |
+
type='TopDownHalfBodyTransform',
|
102 |
+
num_joints_half_body=8,
|
103 |
+
prob_half_body=0.3),
|
104 |
+
dict(
|
105 |
+
type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
|
106 |
+
dict(type='TopDownAffine'),
|
107 |
+
dict(type='ToTensor'),
|
108 |
+
dict(
|
109 |
+
type='NormalizeTensor',
|
110 |
+
mean=[0.485, 0.456, 0.406],
|
111 |
+
std=[0.229, 0.224, 0.225]),
|
112 |
+
dict(type='TopDownGenerateTarget', sigma=2),
|
113 |
+
dict(
|
114 |
+
type='Collect',
|
115 |
+
keys=['img', 'target', 'target_weight'],
|
116 |
+
meta_keys=[
|
117 |
+
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
|
118 |
+
'rotation', 'bbox_score', 'flip_pairs'
|
119 |
+
]),
|
120 |
+
]
|
121 |
+
|
122 |
+
val_pipeline = [
|
123 |
+
dict(type='LoadImageFromFile'),
|
124 |
+
dict(type='TopDownGetBboxCenterScale', padding=1.25),
|
125 |
+
dict(type='TopDownAffine'),
|
126 |
+
dict(type='ToTensor'),
|
127 |
+
dict(
|
128 |
+
type='NormalizeTensor',
|
129 |
+
mean=[0.485, 0.456, 0.406],
|
130 |
+
std=[0.229, 0.224, 0.225]),
|
131 |
+
dict(
|
132 |
+
type='Collect',
|
133 |
+
keys=['img'],
|
134 |
+
meta_keys=[
|
135 |
+
'image_file', 'center', 'scale', 'rotation', 'bbox_score',
|
136 |
+
'flip_pairs'
|
137 |
+
]),
|
138 |
+
]
|
139 |
+
|
140 |
+
test_pipeline = val_pipeline
|
141 |
+
|
142 |
+
data_root = 'data/coco'
|
143 |
+
data = dict(
|
144 |
+
samples_per_gpu=32,
|
145 |
+
workers_per_gpu=2,
|
146 |
+
val_dataloader=dict(samples_per_gpu=32),
|
147 |
+
test_dataloader=dict(samples_per_gpu=32),
|
148 |
+
train=dict(
|
149 |
+
type='TopDownCocoDataset',
|
150 |
+
ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
|
151 |
+
img_prefix=f'{data_root}/train2017/',
|
152 |
+
data_cfg=data_cfg,
|
153 |
+
pipeline=train_pipeline,
|
154 |
+
dataset_info={{_base_.dataset_info}}),
|
155 |
+
val=dict(
|
156 |
+
type='TopDownCocoDataset',
|
157 |
+
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
|
158 |
+
img_prefix=f'{data_root}/val2017/',
|
159 |
+
data_cfg=data_cfg,
|
160 |
+
pipeline=val_pipeline,
|
161 |
+
dataset_info={{_base_.dataset_info}}),
|
162 |
+
test=dict(
|
163 |
+
type='TopDownCocoDataset',
|
164 |
+
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
|
165 |
+
img_prefix=f'{data_root}/val2017/',
|
166 |
+
data_cfg=data_cfg,
|
167 |
+
pipeline=test_pipeline,
|
168 |
+
dataset_info={{_base_.dataset_info}}),
|
169 |
+
)
|
annotator/lama/__init__.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/advimman/lama
|
2 |
+
|
3 |
+
import yaml
|
4 |
+
import torch
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from einops import rearrange
|
9 |
+
import os
|
10 |
+
from modules import devices
|
11 |
+
from annotator.annotator_path import models_path
|
12 |
+
from annotator.lama.saicinpainting.training.trainers import load_checkpoint
|
13 |
+
|
14 |
+
|
15 |
+
class LamaInpainting:
|
16 |
+
model_dir = os.path.join(models_path, "lama")
|
17 |
+
|
18 |
+
def __init__(self):
|
19 |
+
self.model = None
|
20 |
+
self.device = devices.get_device_for("controlnet")
|
21 |
+
|
22 |
+
def load_model(self):
|
23 |
+
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetLama.pth"
|
24 |
+
modelpath = os.path.join(self.model_dir, "ControlNetLama.pth")
|
25 |
+
if not os.path.exists(modelpath):
|
26 |
+
from basicsr.utils.download_util import load_file_from_url
|
27 |
+
load_file_from_url(remote_model_path, model_dir=self.model_dir)
|
28 |
+
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.yaml')
|
29 |
+
cfg = yaml.safe_load(open(config_path, 'rt'))
|
30 |
+
cfg = OmegaConf.create(cfg)
|
31 |
+
cfg.training_model.predict_only = True
|
32 |
+
cfg.visualizer.kind = 'noop'
|
33 |
+
self.model = load_checkpoint(cfg, os.path.abspath(modelpath), strict=False, map_location='cpu')
|
34 |
+
self.model = self.model.to(self.device)
|
35 |
+
self.model.eval()
|
36 |
+
|
37 |
+
def unload_model(self):
|
38 |
+
if self.model is not None:
|
39 |
+
self.model.cpu()
|
40 |
+
|
41 |
+
def __call__(self, input_image):
|
42 |
+
if self.model is None:
|
43 |
+
self.load_model()
|
44 |
+
self.model.to(self.device)
|
45 |
+
color = np.ascontiguousarray(input_image[:, :, 0:3]).astype(np.float32) / 255.0
|
46 |
+
mask = np.ascontiguousarray(input_image[:, :, 3:4]).astype(np.float32) / 255.0
|
47 |
+
with torch.no_grad():
|
48 |
+
color = torch.from_numpy(color).float().to(self.device)
|
49 |
+
mask = torch.from_numpy(mask).float().to(self.device)
|
50 |
+
mask = (mask > 0.5).float()
|
51 |
+
color = color * (1 - mask)
|
52 |
+
image_feed = torch.cat([color, mask], dim=2)
|
53 |
+
image_feed = rearrange(image_feed, 'h w c -> 1 c h w')
|
54 |
+
result = self.model(image_feed)[0]
|
55 |
+
result = rearrange(result, 'c h w -> h w c')
|
56 |
+
result = result * mask + color * (1 - mask)
|
57 |
+
result *= 255.0
|
58 |
+
return result.detach().cpu().numpy().clip(0, 255).astype(np.uint8)
|
annotator/lama/config.yaml
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
run_title: b18_ffc075_batch8x15
|
2 |
+
training_model:
|
3 |
+
kind: default
|
4 |
+
visualize_each_iters: 1000
|
5 |
+
concat_mask: true
|
6 |
+
store_discr_outputs_for_vis: true
|
7 |
+
losses:
|
8 |
+
l1:
|
9 |
+
weight_missing: 0
|
10 |
+
weight_known: 10
|
11 |
+
perceptual:
|
12 |
+
weight: 0
|
13 |
+
adversarial:
|
14 |
+
kind: r1
|
15 |
+
weight: 10
|
16 |
+
gp_coef: 0.001
|
17 |
+
mask_as_fake_target: true
|
18 |
+
allow_scale_mask: true
|
19 |
+
feature_matching:
|
20 |
+
weight: 100
|
21 |
+
resnet_pl:
|
22 |
+
weight: 30
|
23 |
+
weights_path: ${env:TORCH_HOME}
|
24 |
+
|
25 |
+
optimizers:
|
26 |
+
generator:
|
27 |
+
kind: adam
|
28 |
+
lr: 0.001
|
29 |
+
discriminator:
|
30 |
+
kind: adam
|
31 |
+
lr: 0.0001
|
32 |
+
visualizer:
|
33 |
+
key_order:
|
34 |
+
- image
|
35 |
+
- predicted_image
|
36 |
+
- discr_output_fake
|
37 |
+
- discr_output_real
|
38 |
+
- inpainted
|
39 |
+
rescale_keys:
|
40 |
+
- discr_output_fake
|
41 |
+
- discr_output_real
|
42 |
+
kind: directory
|
43 |
+
outdir: /group-volume/User-Driven-Content-Generation/r.suvorov/inpainting/experiments/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/samples
|
44 |
+
location:
|
45 |
+
data_root_dir: /group-volume/User-Driven-Content-Generation/datasets/inpainting_data_root_large
|
46 |
+
out_root_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/experiments
|
47 |
+
tb_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/tb_logs
|
48 |
+
data:
|
49 |
+
batch_size: 15
|
50 |
+
val_batch_size: 2
|
51 |
+
num_workers: 3
|
52 |
+
train:
|
53 |
+
indir: ${location.data_root_dir}/train
|
54 |
+
out_size: 256
|
55 |
+
mask_gen_kwargs:
|
56 |
+
irregular_proba: 1
|
57 |
+
irregular_kwargs:
|
58 |
+
max_angle: 4
|
59 |
+
max_len: 200
|
60 |
+
max_width: 100
|
61 |
+
max_times: 5
|
62 |
+
min_times: 1
|
63 |
+
box_proba: 1
|
64 |
+
box_kwargs:
|
65 |
+
margin: 10
|
66 |
+
bbox_min_size: 30
|
67 |
+
bbox_max_size: 150
|
68 |
+
max_times: 3
|
69 |
+
min_times: 1
|
70 |
+
segm_proba: 0
|
71 |
+
segm_kwargs:
|
72 |
+
confidence_threshold: 0.5
|
73 |
+
max_object_area: 0.5
|
74 |
+
min_mask_area: 0.07
|
75 |
+
downsample_levels: 6
|
76 |
+
num_variants_per_mask: 1
|
77 |
+
rigidness_mode: 1
|
78 |
+
max_foreground_coverage: 0.3
|
79 |
+
max_foreground_intersection: 0.7
|
80 |
+
max_mask_intersection: 0.1
|
81 |
+
max_hidden_area: 0.1
|
82 |
+
max_scale_change: 0.25
|
83 |
+
horizontal_flip: true
|
84 |
+
max_vertical_shift: 0.2
|
85 |
+
position_shuffle: true
|
86 |
+
transform_variant: distortions
|
87 |
+
dataloader_kwargs:
|
88 |
+
batch_size: ${data.batch_size}
|
89 |
+
shuffle: true
|
90 |
+
num_workers: ${data.num_workers}
|
91 |
+
val:
|
92 |
+
indir: ${location.data_root_dir}/val
|
93 |
+
img_suffix: .png
|
94 |
+
dataloader_kwargs:
|
95 |
+
batch_size: ${data.val_batch_size}
|
96 |
+
shuffle: false
|
97 |
+
num_workers: ${data.num_workers}
|
98 |
+
visual_test:
|
99 |
+
indir: ${location.data_root_dir}/korean_test
|
100 |
+
img_suffix: _input.png
|
101 |
+
pad_out_to_modulo: 32
|
102 |
+
dataloader_kwargs:
|
103 |
+
batch_size: 1
|
104 |
+
shuffle: false
|
105 |
+
num_workers: ${data.num_workers}
|
106 |
+
generator:
|
107 |
+
kind: ffc_resnet
|
108 |
+
input_nc: 4
|
109 |
+
output_nc: 3
|
110 |
+
ngf: 64
|
111 |
+
n_downsampling: 3
|
112 |
+
n_blocks: 18
|
113 |
+
add_out_act: sigmoid
|
114 |
+
init_conv_kwargs:
|
115 |
+
ratio_gin: 0
|
116 |
+
ratio_gout: 0
|
117 |
+
enable_lfu: false
|
118 |
+
downsample_conv_kwargs:
|
119 |
+
ratio_gin: ${generator.init_conv_kwargs.ratio_gout}
|
120 |
+
ratio_gout: ${generator.downsample_conv_kwargs.ratio_gin}
|
121 |
+
enable_lfu: false
|
122 |
+
resnet_conv_kwargs:
|
123 |
+
ratio_gin: 0.75
|
124 |
+
ratio_gout: ${generator.resnet_conv_kwargs.ratio_gin}
|
125 |
+
enable_lfu: false
|
126 |
+
discriminator:
|
127 |
+
kind: pix2pixhd_nlayer
|
128 |
+
input_nc: 3
|
129 |
+
ndf: 64
|
130 |
+
n_layers: 4
|
131 |
+
evaluator:
|
132 |
+
kind: default
|
133 |
+
inpainted_key: inpainted
|
134 |
+
integral_kind: ssim_fid100_f1
|
135 |
+
trainer:
|
136 |
+
kwargs:
|
137 |
+
gpus: -1
|
138 |
+
accelerator: ddp
|
139 |
+
max_epochs: 200
|
140 |
+
gradient_clip_val: 1
|
141 |
+
log_gpu_memory: None
|
142 |
+
limit_train_batches: 25000
|
143 |
+
val_check_interval: ${trainer.kwargs.limit_train_batches}
|
144 |
+
log_every_n_steps: 1000
|
145 |
+
precision: 32
|
146 |
+
terminate_on_nan: false
|
147 |
+
check_val_every_n_epoch: 1
|
148 |
+
num_sanity_val_steps: 8
|
149 |
+
limit_val_batches: 1000
|
150 |
+
replace_sampler_ddp: false
|
151 |
+
checkpoint_kwargs:
|
152 |
+
verbose: true
|
153 |
+
save_top_k: 5
|
154 |
+
save_last: true
|
155 |
+
period: 1
|
156 |
+
monitor: val_ssim_fid100_f1_total_mean
|
157 |
+
mode: max
|
annotator/lama/saicinpainting/__init__.py
ADDED
File without changes
|
annotator/lama/saicinpainting/training/__init__.py
ADDED
File without changes
|
annotator/lama/saicinpainting/training/data/__init__.py
ADDED
File without changes
|
annotator/lama/saicinpainting/training/data/masks.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import hashlib
|
4 |
+
import logging
|
5 |
+
from enum import Enum
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
# from annotator.lama.saicinpainting.evaluation.masks.mask import SegmentationMask
|
11 |
+
from annotator.lama.saicinpainting.utils import LinearRamp
|
12 |
+
|
13 |
+
LOGGER = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
class DrawMethod(Enum):
|
17 |
+
LINE = 'line'
|
18 |
+
CIRCLE = 'circle'
|
19 |
+
SQUARE = 'square'
|
20 |
+
|
21 |
+
|
22 |
+
def make_random_irregular_mask(shape, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10,
|
23 |
+
draw_method=DrawMethod.LINE):
|
24 |
+
draw_method = DrawMethod(draw_method)
|
25 |
+
|
26 |
+
height, width = shape
|
27 |
+
mask = np.zeros((height, width), np.float32)
|
28 |
+
times = np.random.randint(min_times, max_times + 1)
|
29 |
+
for i in range(times):
|
30 |
+
start_x = np.random.randint(width)
|
31 |
+
start_y = np.random.randint(height)
|
32 |
+
for j in range(1 + np.random.randint(5)):
|
33 |
+
angle = 0.01 + np.random.randint(max_angle)
|
34 |
+
if i % 2 == 0:
|
35 |
+
angle = 2 * 3.1415926 - angle
|
36 |
+
length = 10 + np.random.randint(max_len)
|
37 |
+
brush_w = 5 + np.random.randint(max_width)
|
38 |
+
end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width)
|
39 |
+
end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height)
|
40 |
+
if draw_method == DrawMethod.LINE:
|
41 |
+
cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w)
|
42 |
+
elif draw_method == DrawMethod.CIRCLE:
|
43 |
+
cv2.circle(mask, (start_x, start_y), radius=brush_w, color=1., thickness=-1)
|
44 |
+
elif draw_method == DrawMethod.SQUARE:
|
45 |
+
radius = brush_w // 2
|
46 |
+
mask[start_y - radius:start_y + radius, start_x - radius:start_x + radius] = 1
|
47 |
+
start_x, start_y = end_x, end_y
|
48 |
+
return mask[None, ...]
|
49 |
+
|
50 |
+
|
51 |
+
class RandomIrregularMaskGenerator:
|
52 |
+
def __init__(self, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10, ramp_kwargs=None,
|
53 |
+
draw_method=DrawMethod.LINE):
|
54 |
+
self.max_angle = max_angle
|
55 |
+
self.max_len = max_len
|
56 |
+
self.max_width = max_width
|
57 |
+
self.min_times = min_times
|
58 |
+
self.max_times = max_times
|
59 |
+
self.draw_method = draw_method
|
60 |
+
self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
|
61 |
+
|
62 |
+
def __call__(self, img, iter_i=None, raw_image=None):
|
63 |
+
coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
|
64 |
+
cur_max_len = int(max(1, self.max_len * coef))
|
65 |
+
cur_max_width = int(max(1, self.max_width * coef))
|
66 |
+
cur_max_times = int(self.min_times + 1 + (self.max_times - self.min_times) * coef)
|
67 |
+
return make_random_irregular_mask(img.shape[1:], max_angle=self.max_angle, max_len=cur_max_len,
|
68 |
+
max_width=cur_max_width, min_times=self.min_times, max_times=cur_max_times,
|
69 |
+
draw_method=self.draw_method)
|
70 |
+
|
71 |
+
|
72 |
+
def make_random_rectangle_mask(shape, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3):
|
73 |
+
height, width = shape
|
74 |
+
mask = np.zeros((height, width), np.float32)
|
75 |
+
bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2)
|
76 |
+
times = np.random.randint(min_times, max_times + 1)
|
77 |
+
for i in range(times):
|
78 |
+
box_width = np.random.randint(bbox_min_size, bbox_max_size)
|
79 |
+
box_height = np.random.randint(bbox_min_size, bbox_max_size)
|
80 |
+
start_x = np.random.randint(margin, width - margin - box_width + 1)
|
81 |
+
start_y = np.random.randint(margin, height - margin - box_height + 1)
|
82 |
+
mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1
|
83 |
+
return mask[None, ...]
|
84 |
+
|
85 |
+
|
86 |
+
class RandomRectangleMaskGenerator:
|
87 |
+
def __init__(self, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3, ramp_kwargs=None):
|
88 |
+
self.margin = margin
|
89 |
+
self.bbox_min_size = bbox_min_size
|
90 |
+
self.bbox_max_size = bbox_max_size
|
91 |
+
self.min_times = min_times
|
92 |
+
self.max_times = max_times
|
93 |
+
self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
|
94 |
+
|
95 |
+
def __call__(self, img, iter_i=None, raw_image=None):
|
96 |
+
coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
|
97 |
+
cur_bbox_max_size = int(self.bbox_min_size + 1 + (self.bbox_max_size - self.bbox_min_size) * coef)
|
98 |
+
cur_max_times = int(self.min_times + (self.max_times - self.min_times) * coef)
|
99 |
+
return make_random_rectangle_mask(img.shape[1:], margin=self.margin, bbox_min_size=self.bbox_min_size,
|
100 |
+
bbox_max_size=cur_bbox_max_size, min_times=self.min_times,
|
101 |
+
max_times=cur_max_times)
|
102 |
+
|
103 |
+
|
104 |
+
class RandomSegmentationMaskGenerator:
|
105 |
+
def __init__(self, **kwargs):
|
106 |
+
self.impl = None # will be instantiated in first call (effectively in subprocess)
|
107 |
+
self.kwargs = kwargs
|
108 |
+
|
109 |
+
def __call__(self, img, iter_i=None, raw_image=None):
|
110 |
+
if self.impl is None:
|
111 |
+
self.impl = SegmentationMask(**self.kwargs)
|
112 |
+
|
113 |
+
masks = self.impl.get_masks(np.transpose(img, (1, 2, 0)))
|
114 |
+
masks = [m for m in masks if len(np.unique(m)) > 1]
|
115 |
+
return np.random.choice(masks)
|
116 |
+
|
117 |
+
|
118 |
+
def make_random_superres_mask(shape, min_step=2, max_step=4, min_width=1, max_width=3):
|
119 |
+
height, width = shape
|
120 |
+
mask = np.zeros((height, width), np.float32)
|
121 |
+
step_x = np.random.randint(min_step, max_step + 1)
|
122 |
+
width_x = np.random.randint(min_width, min(step_x, max_width + 1))
|
123 |
+
offset_x = np.random.randint(0, step_x)
|
124 |
+
|
125 |
+
step_y = np.random.randint(min_step, max_step + 1)
|
126 |
+
width_y = np.random.randint(min_width, min(step_y, max_width + 1))
|
127 |
+
offset_y = np.random.randint(0, step_y)
|
128 |
+
|
129 |
+
for dy in range(width_y):
|
130 |
+
mask[offset_y + dy::step_y] = 1
|
131 |
+
for dx in range(width_x):
|
132 |
+
mask[:, offset_x + dx::step_x] = 1
|
133 |
+
return mask[None, ...]
|
134 |
+
|
135 |
+
|
136 |
+
class RandomSuperresMaskGenerator:
|
137 |
+
def __init__(self, **kwargs):
|
138 |
+
self.kwargs = kwargs
|
139 |
+
|
140 |
+
def __call__(self, img, iter_i=None):
|
141 |
+
return make_random_superres_mask(img.shape[1:], **self.kwargs)
|
142 |
+
|
143 |
+
|
144 |
+
class DumbAreaMaskGenerator:
|
145 |
+
min_ratio = 0.1
|
146 |
+
max_ratio = 0.35
|
147 |
+
default_ratio = 0.225
|
148 |
+
|
149 |
+
def __init__(self, is_training):
|
150 |
+
#Parameters:
|
151 |
+
# is_training(bool): If true - random rectangular mask, if false - central square mask
|
152 |
+
self.is_training = is_training
|
153 |
+
|
154 |
+
def _random_vector(self, dimension):
|
155 |
+
if self.is_training:
|
156 |
+
lower_limit = math.sqrt(self.min_ratio)
|
157 |
+
upper_limit = math.sqrt(self.max_ratio)
|
158 |
+
mask_side = round((random.random() * (upper_limit - lower_limit) + lower_limit) * dimension)
|
159 |
+
u = random.randint(0, dimension-mask_side-1)
|
160 |
+
v = u+mask_side
|
161 |
+
else:
|
162 |
+
margin = (math.sqrt(self.default_ratio) / 2) * dimension
|
163 |
+
u = round(dimension/2 - margin)
|
164 |
+
v = round(dimension/2 + margin)
|
165 |
+
return u, v
|
166 |
+
|
167 |
+
def __call__(self, img, iter_i=None, raw_image=None):
|
168 |
+
c, height, width = img.shape
|
169 |
+
mask = np.zeros((height, width), np.float32)
|
170 |
+
x1, x2 = self._random_vector(width)
|
171 |
+
y1, y2 = self._random_vector(height)
|
172 |
+
mask[x1:x2, y1:y2] = 1
|
173 |
+
return mask[None, ...]
|
174 |
+
|
175 |
+
|
176 |
+
class OutpaintingMaskGenerator:
|
177 |
+
def __init__(self, min_padding_percent:float=0.04, max_padding_percent:int=0.25, left_padding_prob:float=0.5, top_padding_prob:float=0.5,
|
178 |
+
right_padding_prob:float=0.5, bottom_padding_prob:float=0.5, is_fixed_randomness:bool=False):
|
179 |
+
"""
|
180 |
+
is_fixed_randomness - get identical paddings for the same image if args are the same
|
181 |
+
"""
|
182 |
+
self.min_padding_percent = min_padding_percent
|
183 |
+
self.max_padding_percent = max_padding_percent
|
184 |
+
self.probs = [left_padding_prob, top_padding_prob, right_padding_prob, bottom_padding_prob]
|
185 |
+
self.is_fixed_randomness = is_fixed_randomness
|
186 |
+
|
187 |
+
assert self.min_padding_percent <= self.max_padding_percent
|
188 |
+
assert self.max_padding_percent > 0
|
189 |
+
assert len([x for x in [self.min_padding_percent, self.max_padding_percent] if (x>=0 and x<=1)]) == 2, f"Padding percentage should be in [0,1]"
|
190 |
+
assert sum(self.probs) > 0, f"At least one of the padding probs should be greater than 0 - {self.probs}"
|
191 |
+
assert len([x for x in self.probs if (x >= 0) and (x <= 1)]) == 4, f"At least one of padding probs is not in [0,1] - {self.probs}"
|
192 |
+
if len([x for x in self.probs if x > 0]) == 1:
|
193 |
+
LOGGER.warning(f"Only one padding prob is greater than zero - {self.probs}. That means that the outpainting masks will be always on the same side")
|
194 |
+
|
195 |
+
def apply_padding(self, mask, coord):
|
196 |
+
mask[int(coord[0][0]*self.img_h):int(coord[1][0]*self.img_h),
|
197 |
+
int(coord[0][1]*self.img_w):int(coord[1][1]*self.img_w)] = 1
|
198 |
+
return mask
|
199 |
+
|
200 |
+
def get_padding(self, size):
|
201 |
+
n1 = int(self.min_padding_percent*size)
|
202 |
+
n2 = int(self.max_padding_percent*size)
|
203 |
+
return self.rnd.randint(n1, n2) / size
|
204 |
+
|
205 |
+
@staticmethod
|
206 |
+
def _img2rs(img):
|
207 |
+
arr = np.ascontiguousarray(img.astype(np.uint8))
|
208 |
+
str_hash = hashlib.sha1(arr).hexdigest()
|
209 |
+
res = hash(str_hash)%(2**32)
|
210 |
+
return res
|
211 |
+
|
212 |
+
def __call__(self, img, iter_i=None, raw_image=None):
|
213 |
+
c, self.img_h, self.img_w = img.shape
|
214 |
+
mask = np.zeros((self.img_h, self.img_w), np.float32)
|
215 |
+
at_least_one_mask_applied = False
|
216 |
+
|
217 |
+
if self.is_fixed_randomness:
|
218 |
+
assert raw_image is not None, f"Cant calculate hash on raw_image=None"
|
219 |
+
rs = self._img2rs(raw_image)
|
220 |
+
self.rnd = np.random.RandomState(rs)
|
221 |
+
else:
|
222 |
+
self.rnd = np.random
|
223 |
+
|
224 |
+
coords = [[
|
225 |
+
(0,0),
|
226 |
+
(1,self.get_padding(size=self.img_h))
|
227 |
+
],
|
228 |
+
[
|
229 |
+
(0,0),
|
230 |
+
(self.get_padding(size=self.img_w),1)
|
231 |
+
],
|
232 |
+
[
|
233 |
+
(0,1-self.get_padding(size=self.img_h)),
|
234 |
+
(1,1)
|
235 |
+
],
|
236 |
+
[
|
237 |
+
(1-self.get_padding(size=self.img_w),0),
|
238 |
+
(1,1)
|
239 |
+
]]
|
240 |
+
|
241 |
+
for pp, coord in zip(self.probs, coords):
|
242 |
+
if self.rnd.random() < pp:
|
243 |
+
at_least_one_mask_applied = True
|
244 |
+
mask = self.apply_padding(mask=mask, coord=coord)
|
245 |
+
|
246 |
+
if not at_least_one_mask_applied:
|
247 |
+
idx = self.rnd.choice(range(len(coords)), p=np.array(self.probs)/sum(self.probs))
|
248 |
+
mask = self.apply_padding(mask=mask, coord=coords[idx])
|
249 |
+
return mask[None, ...]
|
250 |
+
|
251 |
+
|
252 |
+
class MixedMaskGenerator:
|
253 |
+
def __init__(self, irregular_proba=1/3, irregular_kwargs=None,
|
254 |
+
box_proba=1/3, box_kwargs=None,
|
255 |
+
segm_proba=1/3, segm_kwargs=None,
|
256 |
+
squares_proba=0, squares_kwargs=None,
|
257 |
+
superres_proba=0, superres_kwargs=None,
|
258 |
+
outpainting_proba=0, outpainting_kwargs=None,
|
259 |
+
invert_proba=0):
|
260 |
+
self.probas = []
|
261 |
+
self.gens = []
|
262 |
+
|
263 |
+
if irregular_proba > 0:
|
264 |
+
self.probas.append(irregular_proba)
|
265 |
+
if irregular_kwargs is None:
|
266 |
+
irregular_kwargs = {}
|
267 |
+
else:
|
268 |
+
irregular_kwargs = dict(irregular_kwargs)
|
269 |
+
irregular_kwargs['draw_method'] = DrawMethod.LINE
|
270 |
+
self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs))
|
271 |
+
|
272 |
+
if box_proba > 0:
|
273 |
+
self.probas.append(box_proba)
|
274 |
+
if box_kwargs is None:
|
275 |
+
box_kwargs = {}
|
276 |
+
self.gens.append(RandomRectangleMaskGenerator(**box_kwargs))
|
277 |
+
|
278 |
+
if segm_proba > 0:
|
279 |
+
self.probas.append(segm_proba)
|
280 |
+
if segm_kwargs is None:
|
281 |
+
segm_kwargs = {}
|
282 |
+
self.gens.append(RandomSegmentationMaskGenerator(**segm_kwargs))
|
283 |
+
|
284 |
+
if squares_proba > 0:
|
285 |
+
self.probas.append(squares_proba)
|
286 |
+
if squares_kwargs is None:
|
287 |
+
squares_kwargs = {}
|
288 |
+
else:
|
289 |
+
squares_kwargs = dict(squares_kwargs)
|
290 |
+
squares_kwargs['draw_method'] = DrawMethod.SQUARE
|
291 |
+
self.gens.append(RandomIrregularMaskGenerator(**squares_kwargs))
|
292 |
+
|
293 |
+
if superres_proba > 0:
|
294 |
+
self.probas.append(superres_proba)
|
295 |
+
if superres_kwargs is None:
|
296 |
+
superres_kwargs = {}
|
297 |
+
self.gens.append(RandomSuperresMaskGenerator(**superres_kwargs))
|
298 |
+
|
299 |
+
if outpainting_proba > 0:
|
300 |
+
self.probas.append(outpainting_proba)
|
301 |
+
if outpainting_kwargs is None:
|
302 |
+
outpainting_kwargs = {}
|
303 |
+
self.gens.append(OutpaintingMaskGenerator(**outpainting_kwargs))
|
304 |
+
|
305 |
+
self.probas = np.array(self.probas, dtype='float32')
|
306 |
+
self.probas /= self.probas.sum()
|
307 |
+
self.invert_proba = invert_proba
|
308 |
+
|
309 |
+
def __call__(self, img, iter_i=None, raw_image=None):
|
310 |
+
kind = np.random.choice(len(self.probas), p=self.probas)
|
311 |
+
gen = self.gens[kind]
|
312 |
+
result = gen(img, iter_i=iter_i, raw_image=raw_image)
|
313 |
+
if self.invert_proba > 0 and random.random() < self.invert_proba:
|
314 |
+
result = 1 - result
|
315 |
+
return result
|
316 |
+
|
317 |
+
|
318 |
+
def get_mask_generator(kind, kwargs):
|
319 |
+
if kind is None:
|
320 |
+
kind = "mixed"
|
321 |
+
if kwargs is None:
|
322 |
+
kwargs = {}
|
323 |
+
|
324 |
+
if kind == "mixed":
|
325 |
+
cl = MixedMaskGenerator
|
326 |
+
elif kind == "outpainting":
|
327 |
+
cl = OutpaintingMaskGenerator
|
328 |
+
elif kind == "dumb":
|
329 |
+
cl = DumbAreaMaskGenerator
|
330 |
+
else:
|
331 |
+
raise NotImplementedError(f"No such generator kind = {kind}")
|
332 |
+
return cl(**kwargs)
|
annotator/lama/saicinpainting/training/losses/__init__.py
ADDED
File without changes
|
annotator/lama/saicinpainting/training/losses/adversarial.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Dict, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class BaseAdversarialLoss:
|
9 |
+
def pre_generator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
|
10 |
+
generator: nn.Module, discriminator: nn.Module):
|
11 |
+
"""
|
12 |
+
Prepare for generator step
|
13 |
+
:param real_batch: Tensor, a batch of real samples
|
14 |
+
:param fake_batch: Tensor, a batch of samples produced by generator
|
15 |
+
:param generator:
|
16 |
+
:param discriminator:
|
17 |
+
:return: None
|
18 |
+
"""
|
19 |
+
|
20 |
+
def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
|
21 |
+
generator: nn.Module, discriminator: nn.Module):
|
22 |
+
"""
|
23 |
+
Prepare for discriminator step
|
24 |
+
:param real_batch: Tensor, a batch of real samples
|
25 |
+
:param fake_batch: Tensor, a batch of samples produced by generator
|
26 |
+
:param generator:
|
27 |
+
:param discriminator:
|
28 |
+
:return: None
|
29 |
+
"""
|
30 |
+
|
31 |
+
def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
|
32 |
+
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
|
33 |
+
mask: Optional[torch.Tensor] = None) \
|
34 |
+
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
35 |
+
"""
|
36 |
+
Calculate generator loss
|
37 |
+
:param real_batch: Tensor, a batch of real samples
|
38 |
+
:param fake_batch: Tensor, a batch of samples produced by generator
|
39 |
+
:param discr_real_pred: Tensor, discriminator output for real_batch
|
40 |
+
:param discr_fake_pred: Tensor, discriminator output for fake_batch
|
41 |
+
:param mask: Tensor, actual mask, which was at input of generator when making fake_batch
|
42 |
+
:return: total generator loss along with some values that might be interesting to log
|
43 |
+
"""
|
44 |
+
raise NotImplemented()
|
45 |
+
|
46 |
+
def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
|
47 |
+
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
|
48 |
+
mask: Optional[torch.Tensor] = None) \
|
49 |
+
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
50 |
+
"""
|
51 |
+
Calculate discriminator loss and call .backward() on it
|
52 |
+
:param real_batch: Tensor, a batch of real samples
|
53 |
+
:param fake_batch: Tensor, a batch of samples produced by generator
|
54 |
+
:param discr_real_pred: Tensor, discriminator output for real_batch
|
55 |
+
:param discr_fake_pred: Tensor, discriminator output for fake_batch
|
56 |
+
:param mask: Tensor, actual mask, which was at input of generator when making fake_batch
|
57 |
+
:return: total discriminator loss along with some values that might be interesting to log
|
58 |
+
"""
|
59 |
+
raise NotImplemented()
|
60 |
+
|
61 |
+
def interpolate_mask(self, mask, shape):
|
62 |
+
assert mask is not None
|
63 |
+
assert self.allow_scale_mask or shape == mask.shape[-2:]
|
64 |
+
if shape != mask.shape[-2:] and self.allow_scale_mask:
|
65 |
+
if self.mask_scale_mode == 'maxpool':
|
66 |
+
mask = F.adaptive_max_pool2d(mask, shape)
|
67 |
+
else:
|
68 |
+
mask = F.interpolate(mask, size=shape, mode=self.mask_scale_mode)
|
69 |
+
return mask
|
70 |
+
|
71 |
+
def make_r1_gp(discr_real_pred, real_batch):
|
72 |
+
if torch.is_grad_enabled():
|
73 |
+
grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(), inputs=real_batch, create_graph=True)[0]
|
74 |
+
grad_penalty = (grad_real.view(grad_real.shape[0], -1).norm(2, dim=1) ** 2).mean()
|
75 |
+
else:
|
76 |
+
grad_penalty = 0
|
77 |
+
real_batch.requires_grad = False
|
78 |
+
|
79 |
+
return grad_penalty
|
80 |
+
|
81 |
+
class NonSaturatingWithR1(BaseAdversarialLoss):
|
82 |
+
def __init__(self, gp_coef=5, weight=1, mask_as_fake_target=False, allow_scale_mask=False,
|
83 |
+
mask_scale_mode='nearest', extra_mask_weight_for_gen=0,
|
84 |
+
use_unmasked_for_gen=True, use_unmasked_for_discr=True):
|
85 |
+
self.gp_coef = gp_coef
|
86 |
+
self.weight = weight
|
87 |
+
# use for discr => use for gen;
|
88 |
+
# otherwise we teach only the discr to pay attention to very small difference
|
89 |
+
assert use_unmasked_for_gen or (not use_unmasked_for_discr)
|
90 |
+
# mask as target => use unmasked for discr:
|
91 |
+
# if we don't care about unmasked regions at all
|
92 |
+
# then it doesn't matter if the value of mask_as_fake_target is true or false
|
93 |
+
assert use_unmasked_for_discr or (not mask_as_fake_target)
|
94 |
+
self.use_unmasked_for_gen = use_unmasked_for_gen
|
95 |
+
self.use_unmasked_for_discr = use_unmasked_for_discr
|
96 |
+
self.mask_as_fake_target = mask_as_fake_target
|
97 |
+
self.allow_scale_mask = allow_scale_mask
|
98 |
+
self.mask_scale_mode = mask_scale_mode
|
99 |
+
self.extra_mask_weight_for_gen = extra_mask_weight_for_gen
|
100 |
+
|
101 |
+
def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
|
102 |
+
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
|
103 |
+
mask=None) \
|
104 |
+
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
105 |
+
fake_loss = F.softplus(-discr_fake_pred)
|
106 |
+
if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \
|
107 |
+
not self.use_unmasked_for_gen: # == if masked region should be treated differently
|
108 |
+
mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:])
|
109 |
+
if not self.use_unmasked_for_gen:
|
110 |
+
fake_loss = fake_loss * mask
|
111 |
+
else:
|
112 |
+
pixel_weights = 1 + mask * self.extra_mask_weight_for_gen
|
113 |
+
fake_loss = fake_loss * pixel_weights
|
114 |
+
|
115 |
+
return fake_loss.mean() * self.weight, dict()
|
116 |
+
|
117 |
+
def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
|
118 |
+
generator: nn.Module, discriminator: nn.Module):
|
119 |
+
real_batch.requires_grad = True
|
120 |
+
|
121 |
+
def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
|
122 |
+
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
|
123 |
+
mask=None) \
|
124 |
+
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
125 |
+
|
126 |
+
real_loss = F.softplus(-discr_real_pred)
|
127 |
+
grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef
|
128 |
+
fake_loss = F.softplus(discr_fake_pred)
|
129 |
+
|
130 |
+
if not self.use_unmasked_for_discr or self.mask_as_fake_target:
|
131 |
+
# == if masked region should be treated differently
|
132 |
+
mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:])
|
133 |
+
# use_unmasked_for_discr=False only makes sense for fakes;
|
134 |
+
# for reals there is no difference beetween two regions
|
135 |
+
fake_loss = fake_loss * mask
|
136 |
+
if self.mask_as_fake_target:
|
137 |
+
fake_loss = fake_loss + (1 - mask) * F.softplus(-discr_fake_pred)
|
138 |
+
|
139 |
+
sum_discr_loss = real_loss + grad_penalty + fake_loss
|
140 |
+
metrics = dict(discr_real_out=discr_real_pred.mean(),
|
141 |
+
discr_fake_out=discr_fake_pred.mean(),
|
142 |
+
discr_real_gp=grad_penalty)
|
143 |
+
return sum_discr_loss.mean(), metrics
|
144 |
+
|
145 |
+
class BCELoss(BaseAdversarialLoss):
|
146 |
+
def __init__(self, weight):
|
147 |
+
self.weight = weight
|
148 |
+
self.bce_loss = nn.BCEWithLogitsLoss()
|
149 |
+
|
150 |
+
def generator_loss(self, discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
151 |
+
real_mask_gt = torch.zeros(discr_fake_pred.shape).to(discr_fake_pred.device)
|
152 |
+
fake_loss = self.bce_loss(discr_fake_pred, real_mask_gt) * self.weight
|
153 |
+
return fake_loss, dict()
|
154 |
+
|
155 |
+
def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
|
156 |
+
generator: nn.Module, discriminator: nn.Module):
|
157 |
+
real_batch.requires_grad = True
|
158 |
+
|
159 |
+
def discriminator_loss(self,
|
160 |
+
mask: torch.Tensor,
|
161 |
+
discr_real_pred: torch.Tensor,
|
162 |
+
discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
163 |
+
|
164 |
+
real_mask_gt = torch.zeros(discr_real_pred.shape).to(discr_real_pred.device)
|
165 |
+
sum_discr_loss = (self.bce_loss(discr_real_pred, real_mask_gt) + self.bce_loss(discr_fake_pred, mask)) / 2
|
166 |
+
metrics = dict(discr_real_out=discr_real_pred.mean(),
|
167 |
+
discr_fake_out=discr_fake_pred.mean(),
|
168 |
+
discr_real_gp=0)
|
169 |
+
return sum_discr_loss, metrics
|
170 |
+
|
171 |
+
|
172 |
+
def make_discrim_loss(kind, **kwargs):
|
173 |
+
if kind == 'r1':
|
174 |
+
return NonSaturatingWithR1(**kwargs)
|
175 |
+
elif kind == 'bce':
|
176 |
+
return BCELoss(**kwargs)
|
177 |
+
raise ValueError(f'Unknown adversarial loss kind {kind}')
|
annotator/lama/saicinpainting/training/losses/constants.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
weights = {"ade20k":
|
2 |
+
[6.34517766497462,
|
3 |
+
9.328358208955224,
|
4 |
+
11.389521640091116,
|
5 |
+
16.10305958132045,
|
6 |
+
20.833333333333332,
|
7 |
+
22.22222222222222,
|
8 |
+
25.125628140703515,
|
9 |
+
43.29004329004329,
|
10 |
+
50.5050505050505,
|
11 |
+
54.6448087431694,
|
12 |
+
55.24861878453038,
|
13 |
+
60.24096385542168,
|
14 |
+
62.5,
|
15 |
+
66.2251655629139,
|
16 |
+
84.74576271186442,
|
17 |
+
90.90909090909092,
|
18 |
+
91.74311926605505,
|
19 |
+
96.15384615384616,
|
20 |
+
96.15384615384616,
|
21 |
+
97.08737864077669,
|
22 |
+
102.04081632653062,
|
23 |
+
135.13513513513513,
|
24 |
+
149.2537313432836,
|
25 |
+
153.84615384615384,
|
26 |
+
163.93442622950818,
|
27 |
+
166.66666666666666,
|
28 |
+
188.67924528301887,
|
29 |
+
192.30769230769232,
|
30 |
+
217.3913043478261,
|
31 |
+
227.27272727272725,
|
32 |
+
227.27272727272725,
|
33 |
+
227.27272727272725,
|
34 |
+
303.03030303030306,
|
35 |
+
322.5806451612903,
|
36 |
+
333.3333333333333,
|
37 |
+
370.3703703703703,
|
38 |
+
384.61538461538464,
|
39 |
+
416.6666666666667,
|
40 |
+
416.6666666666667,
|
41 |
+
434.7826086956522,
|
42 |
+
434.7826086956522,
|
43 |
+
454.5454545454545,
|
44 |
+
454.5454545454545,
|
45 |
+
500.0,
|
46 |
+
526.3157894736842,
|
47 |
+
526.3157894736842,
|
48 |
+
555.5555555555555,
|
49 |
+
555.5555555555555,
|
50 |
+
555.5555555555555,
|
51 |
+
555.5555555555555,
|
52 |
+
555.5555555555555,
|
53 |
+
555.5555555555555,
|
54 |
+
555.5555555555555,
|
55 |
+
588.2352941176471,
|
56 |
+
588.2352941176471,
|
57 |
+
588.2352941176471,
|
58 |
+
588.2352941176471,
|
59 |
+
588.2352941176471,
|
60 |
+
666.6666666666666,
|
61 |
+
666.6666666666666,
|
62 |
+
666.6666666666666,
|
63 |
+
666.6666666666666,
|
64 |
+
714.2857142857143,
|
65 |
+
714.2857142857143,
|
66 |
+
714.2857142857143,
|
67 |
+
714.2857142857143,
|
68 |
+
714.2857142857143,
|
69 |
+
769.2307692307693,
|
70 |
+
769.2307692307693,
|
71 |
+
769.2307692307693,
|
72 |
+
833.3333333333334,
|
73 |
+
833.3333333333334,
|
74 |
+
833.3333333333334,
|
75 |
+
833.3333333333334,
|
76 |
+
909.090909090909,
|
77 |
+
1000.0,
|
78 |
+
1111.111111111111,
|
79 |
+
1111.111111111111,
|
80 |
+
1111.111111111111,
|
81 |
+
1111.111111111111,
|
82 |
+
1111.111111111111,
|
83 |
+
1250.0,
|
84 |
+
1250.0,
|
85 |
+
1250.0,
|
86 |
+
1250.0,
|
87 |
+
1250.0,
|
88 |
+
1428.5714285714287,
|
89 |
+
1428.5714285714287,
|
90 |
+
1428.5714285714287,
|
91 |
+
1428.5714285714287,
|
92 |
+
1428.5714285714287,
|
93 |
+
1428.5714285714287,
|
94 |
+
1428.5714285714287,
|
95 |
+
1666.6666666666667,
|
96 |
+
1666.6666666666667,
|
97 |
+
1666.6666666666667,
|
98 |
+
1666.6666666666667,
|
99 |
+
1666.6666666666667,
|
100 |
+
1666.6666666666667,
|
101 |
+
1666.6666666666667,
|
102 |
+
1666.6666666666667,
|
103 |
+
1666.6666666666667,
|
104 |
+
1666.6666666666667,
|
105 |
+
1666.6666666666667,
|
106 |
+
2000.0,
|
107 |
+
2000.0,
|
108 |
+
2000.0,
|
109 |
+
2000.0,
|
110 |
+
2000.0,
|
111 |
+
2000.0,
|
112 |
+
2000.0,
|
113 |
+
2000.0,
|
114 |
+
2000.0,
|
115 |
+
2000.0,
|
116 |
+
2000.0,
|
117 |
+
2000.0,
|
118 |
+
2000.0,
|
119 |
+
2000.0,
|
120 |
+
2000.0,
|
121 |
+
2000.0,
|
122 |
+
2000.0,
|
123 |
+
2500.0,
|
124 |
+
2500.0,
|
125 |
+
2500.0,
|
126 |
+
2500.0,
|
127 |
+
2500.0,
|
128 |
+
2500.0,
|
129 |
+
2500.0,
|
130 |
+
2500.0,
|
131 |
+
2500.0,
|
132 |
+
2500.0,
|
133 |
+
2500.0,
|
134 |
+
2500.0,
|
135 |
+
2500.0,
|
136 |
+
3333.3333333333335,
|
137 |
+
3333.3333333333335,
|
138 |
+
3333.3333333333335,
|
139 |
+
3333.3333333333335,
|
140 |
+
3333.3333333333335,
|
141 |
+
3333.3333333333335,
|
142 |
+
3333.3333333333335,
|
143 |
+
3333.3333333333335,
|
144 |
+
3333.3333333333335,
|
145 |
+
3333.3333333333335,
|
146 |
+
3333.3333333333335,
|
147 |
+
3333.3333333333335,
|
148 |
+
3333.3333333333335,
|
149 |
+
5000.0,
|
150 |
+
5000.0,
|
151 |
+
5000.0]
|
152 |
+
}
|
annotator/lama/saicinpainting/training/losses/distance_weighting.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision
|
5 |
+
|
6 |
+
from annotator.lama.saicinpainting.training.losses.perceptual import IMAGENET_STD, IMAGENET_MEAN
|
7 |
+
|
8 |
+
|
9 |
+
def dummy_distance_weighter(real_img, pred_img, mask):
|
10 |
+
return mask
|
11 |
+
|
12 |
+
|
13 |
+
def get_gauss_kernel(kernel_size, width_factor=1):
|
14 |
+
coords = torch.stack(torch.meshgrid(torch.arange(kernel_size),
|
15 |
+
torch.arange(kernel_size)),
|
16 |
+
dim=0).float()
|
17 |
+
diff = torch.exp(-((coords - kernel_size // 2) ** 2).sum(0) / kernel_size / width_factor)
|
18 |
+
diff /= diff.sum()
|
19 |
+
return diff
|
20 |
+
|
21 |
+
|
22 |
+
class BlurMask(nn.Module):
|
23 |
+
def __init__(self, kernel_size=5, width_factor=1):
|
24 |
+
super().__init__()
|
25 |
+
self.filter = nn.Conv2d(1, 1, kernel_size, padding=kernel_size // 2, padding_mode='replicate', bias=False)
|
26 |
+
self.filter.weight.data.copy_(get_gauss_kernel(kernel_size, width_factor=width_factor))
|
27 |
+
|
28 |
+
def forward(self, real_img, pred_img, mask):
|
29 |
+
with torch.no_grad():
|
30 |
+
result = self.filter(mask) * mask
|
31 |
+
return result
|
32 |
+
|
33 |
+
|
34 |
+
class EmulatedEDTMask(nn.Module):
|
35 |
+
def __init__(self, dilate_kernel_size=5, blur_kernel_size=5, width_factor=1):
|
36 |
+
super().__init__()
|
37 |
+
self.dilate_filter = nn.Conv2d(1, 1, dilate_kernel_size, padding=dilate_kernel_size// 2, padding_mode='replicate',
|
38 |
+
bias=False)
|
39 |
+
self.dilate_filter.weight.data.copy_(torch.ones(1, 1, dilate_kernel_size, dilate_kernel_size, dtype=torch.float))
|
40 |
+
self.blur_filter = nn.Conv2d(1, 1, blur_kernel_size, padding=blur_kernel_size // 2, padding_mode='replicate', bias=False)
|
41 |
+
self.blur_filter.weight.data.copy_(get_gauss_kernel(blur_kernel_size, width_factor=width_factor))
|
42 |
+
|
43 |
+
def forward(self, real_img, pred_img, mask):
|
44 |
+
with torch.no_grad():
|
45 |
+
known_mask = 1 - mask
|
46 |
+
dilated_known_mask = (self.dilate_filter(known_mask) > 1).float()
|
47 |
+
result = self.blur_filter(1 - dilated_known_mask) * mask
|
48 |
+
return result
|
49 |
+
|
50 |
+
|
51 |
+
class PropagatePerceptualSim(nn.Module):
|
52 |
+
def __init__(self, level=2, max_iters=10, temperature=500, erode_mask_size=3):
|
53 |
+
super().__init__()
|
54 |
+
vgg = torchvision.models.vgg19(pretrained=True).features
|
55 |
+
vgg_avg_pooling = []
|
56 |
+
|
57 |
+
for weights in vgg.parameters():
|
58 |
+
weights.requires_grad = False
|
59 |
+
|
60 |
+
cur_level_i = 0
|
61 |
+
for module in vgg.modules():
|
62 |
+
if module.__class__.__name__ == 'Sequential':
|
63 |
+
continue
|
64 |
+
elif module.__class__.__name__ == 'MaxPool2d':
|
65 |
+
vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
|
66 |
+
else:
|
67 |
+
vgg_avg_pooling.append(module)
|
68 |
+
if module.__class__.__name__ == 'ReLU':
|
69 |
+
cur_level_i += 1
|
70 |
+
if cur_level_i == level:
|
71 |
+
break
|
72 |
+
|
73 |
+
self.features = nn.Sequential(*vgg_avg_pooling)
|
74 |
+
|
75 |
+
self.max_iters = max_iters
|
76 |
+
self.temperature = temperature
|
77 |
+
self.do_erode = erode_mask_size > 0
|
78 |
+
if self.do_erode:
|
79 |
+
self.erode_mask = nn.Conv2d(1, 1, erode_mask_size, padding=erode_mask_size // 2, bias=False)
|
80 |
+
self.erode_mask.weight.data.fill_(1)
|
81 |
+
|
82 |
+
def forward(self, real_img, pred_img, mask):
|
83 |
+
with torch.no_grad():
|
84 |
+
real_img = (real_img - IMAGENET_MEAN.to(real_img)) / IMAGENET_STD.to(real_img)
|
85 |
+
real_feats = self.features(real_img)
|
86 |
+
|
87 |
+
vertical_sim = torch.exp(-(real_feats[:, :, 1:] - real_feats[:, :, :-1]).pow(2).sum(1, keepdim=True)
|
88 |
+
/ self.temperature)
|
89 |
+
horizontal_sim = torch.exp(-(real_feats[:, :, :, 1:] - real_feats[:, :, :, :-1]).pow(2).sum(1, keepdim=True)
|
90 |
+
/ self.temperature)
|
91 |
+
|
92 |
+
mask_scaled = F.interpolate(mask, size=real_feats.shape[-2:], mode='bilinear', align_corners=False)
|
93 |
+
if self.do_erode:
|
94 |
+
mask_scaled = (self.erode_mask(mask_scaled) > 1).float()
|
95 |
+
|
96 |
+
cur_knowness = 1 - mask_scaled
|
97 |
+
|
98 |
+
for iter_i in range(self.max_iters):
|
99 |
+
new_top_knowness = F.pad(cur_knowness[:, :, :-1] * vertical_sim, (0, 0, 1, 0), mode='replicate')
|
100 |
+
new_bottom_knowness = F.pad(cur_knowness[:, :, 1:] * vertical_sim, (0, 0, 0, 1), mode='replicate')
|
101 |
+
|
102 |
+
new_left_knowness = F.pad(cur_knowness[:, :, :, :-1] * horizontal_sim, (1, 0, 0, 0), mode='replicate')
|
103 |
+
new_right_knowness = F.pad(cur_knowness[:, :, :, 1:] * horizontal_sim, (0, 1, 0, 0), mode='replicate')
|
104 |
+
|
105 |
+
new_knowness = torch.stack([new_top_knowness, new_bottom_knowness,
|
106 |
+
new_left_knowness, new_right_knowness],
|
107 |
+
dim=0).max(0).values
|
108 |
+
|
109 |
+
cur_knowness = torch.max(cur_knowness, new_knowness)
|
110 |
+
|
111 |
+
cur_knowness = F.interpolate(cur_knowness, size=mask.shape[-2:], mode='bilinear')
|
112 |
+
result = torch.min(mask, 1 - cur_knowness)
|
113 |
+
|
114 |
+
return result
|
115 |
+
|
116 |
+
|
117 |
+
def make_mask_distance_weighter(kind='none', **kwargs):
|
118 |
+
if kind == 'none':
|
119 |
+
return dummy_distance_weighter
|
120 |
+
if kind == 'blur':
|
121 |
+
return BlurMask(**kwargs)
|
122 |
+
if kind == 'edt':
|
123 |
+
return EmulatedEDTMask(**kwargs)
|
124 |
+
if kind == 'pps':
|
125 |
+
return PropagatePerceptualSim(**kwargs)
|
126 |
+
raise ValueError(f'Unknown mask distance weighter kind {kind}')
|
annotator/lama/saicinpainting/training/losses/feature_matching.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
def masked_l2_loss(pred, target, mask, weight_known, weight_missing):
|
8 |
+
per_pixel_l2 = F.mse_loss(pred, target, reduction='none')
|
9 |
+
pixel_weights = mask * weight_missing + (1 - mask) * weight_known
|
10 |
+
return (pixel_weights * per_pixel_l2).mean()
|
11 |
+
|
12 |
+
|
13 |
+
def masked_l1_loss(pred, target, mask, weight_known, weight_missing):
|
14 |
+
per_pixel_l1 = F.l1_loss(pred, target, reduction='none')
|
15 |
+
pixel_weights = mask * weight_missing + (1 - mask) * weight_known
|
16 |
+
return (pixel_weights * per_pixel_l1).mean()
|
17 |
+
|
18 |
+
|
19 |
+
def feature_matching_loss(fake_features: List[torch.Tensor], target_features: List[torch.Tensor], mask=None):
|
20 |
+
if mask is None:
|
21 |
+
res = torch.stack([F.mse_loss(fake_feat, target_feat)
|
22 |
+
for fake_feat, target_feat in zip(fake_features, target_features)]).mean()
|
23 |
+
else:
|
24 |
+
res = 0
|
25 |
+
norm = 0
|
26 |
+
for fake_feat, target_feat in zip(fake_features, target_features):
|
27 |
+
cur_mask = F.interpolate(mask, size=fake_feat.shape[-2:], mode='bilinear', align_corners=False)
|
28 |
+
error_weights = 1 - cur_mask
|
29 |
+
cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean()
|
30 |
+
res = res + cur_val
|
31 |
+
norm += 1
|
32 |
+
res = res / norm
|
33 |
+
return res
|
annotator/lama/saicinpainting/training/losses/perceptual.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision
|
5 |
+
|
6 |
+
# from models.ade20k import ModelBuilder
|
7 |
+
from annotator.lama.saicinpainting.utils import check_and_warn_input_range
|
8 |
+
|
9 |
+
|
10 |
+
IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None]
|
11 |
+
IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None]
|
12 |
+
|
13 |
+
|
14 |
+
class PerceptualLoss(nn.Module):
|
15 |
+
def __init__(self, normalize_inputs=True):
|
16 |
+
super(PerceptualLoss, self).__init__()
|
17 |
+
|
18 |
+
self.normalize_inputs = normalize_inputs
|
19 |
+
self.mean_ = IMAGENET_MEAN
|
20 |
+
self.std_ = IMAGENET_STD
|
21 |
+
|
22 |
+
vgg = torchvision.models.vgg19(pretrained=True).features
|
23 |
+
vgg_avg_pooling = []
|
24 |
+
|
25 |
+
for weights in vgg.parameters():
|
26 |
+
weights.requires_grad = False
|
27 |
+
|
28 |
+
for module in vgg.modules():
|
29 |
+
if module.__class__.__name__ == 'Sequential':
|
30 |
+
continue
|
31 |
+
elif module.__class__.__name__ == 'MaxPool2d':
|
32 |
+
vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
|
33 |
+
else:
|
34 |
+
vgg_avg_pooling.append(module)
|
35 |
+
|
36 |
+
self.vgg = nn.Sequential(*vgg_avg_pooling)
|
37 |
+
|
38 |
+
def do_normalize_inputs(self, x):
|
39 |
+
return (x - self.mean_.to(x.device)) / self.std_.to(x.device)
|
40 |
+
|
41 |
+
def partial_losses(self, input, target, mask=None):
|
42 |
+
check_and_warn_input_range(target, 0, 1, 'PerceptualLoss target in partial_losses')
|
43 |
+
|
44 |
+
# we expect input and target to be in [0, 1] range
|
45 |
+
losses = []
|
46 |
+
|
47 |
+
if self.normalize_inputs:
|
48 |
+
features_input = self.do_normalize_inputs(input)
|
49 |
+
features_target = self.do_normalize_inputs(target)
|
50 |
+
else:
|
51 |
+
features_input = input
|
52 |
+
features_target = target
|
53 |
+
|
54 |
+
for layer in self.vgg[:30]:
|
55 |
+
|
56 |
+
features_input = layer(features_input)
|
57 |
+
features_target = layer(features_target)
|
58 |
+
|
59 |
+
if layer.__class__.__name__ == 'ReLU':
|
60 |
+
loss = F.mse_loss(features_input, features_target, reduction='none')
|
61 |
+
|
62 |
+
if mask is not None:
|
63 |
+
cur_mask = F.interpolate(mask, size=features_input.shape[-2:],
|
64 |
+
mode='bilinear', align_corners=False)
|
65 |
+
loss = loss * (1 - cur_mask)
|
66 |
+
|
67 |
+
loss = loss.mean(dim=tuple(range(1, len(loss.shape))))
|
68 |
+
losses.append(loss)
|
69 |
+
|
70 |
+
return losses
|
71 |
+
|
72 |
+
def forward(self, input, target, mask=None):
|
73 |
+
losses = self.partial_losses(input, target, mask=mask)
|
74 |
+
return torch.stack(losses).sum(dim=0)
|
75 |
+
|
76 |
+
def get_global_features(self, input):
|
77 |
+
check_and_warn_input_range(input, 0, 1, 'PerceptualLoss input in get_global_features')
|
78 |
+
|
79 |
+
if self.normalize_inputs:
|
80 |
+
features_input = self.do_normalize_inputs(input)
|
81 |
+
else:
|
82 |
+
features_input = input
|
83 |
+
|
84 |
+
features_input = self.vgg(features_input)
|
85 |
+
return features_input
|
86 |
+
|
87 |
+
|
88 |
+
class ResNetPL(nn.Module):
|
89 |
+
def __init__(self, weight=1,
|
90 |
+
weights_path=None, arch_encoder='resnet50dilated', segmentation=True):
|
91 |
+
super().__init__()
|
92 |
+
self.impl = ModelBuilder.get_encoder(weights_path=weights_path,
|
93 |
+
arch_encoder=arch_encoder,
|
94 |
+
arch_decoder='ppm_deepsup',
|
95 |
+
fc_dim=2048,
|
96 |
+
segmentation=segmentation)
|
97 |
+
self.impl.eval()
|
98 |
+
for w in self.impl.parameters():
|
99 |
+
w.requires_grad_(False)
|
100 |
+
|
101 |
+
self.weight = weight
|
102 |
+
|
103 |
+
def forward(self, pred, target):
|
104 |
+
pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred)
|
105 |
+
target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target)
|
106 |
+
|
107 |
+
pred_feats = self.impl(pred, return_feature_maps=True)
|
108 |
+
target_feats = self.impl(target, return_feature_maps=True)
|
109 |
+
|
110 |
+
result = torch.stack([F.mse_loss(cur_pred, cur_target)
|
111 |
+
for cur_pred, cur_target
|
112 |
+
in zip(pred_feats, target_feats)]).sum() * self.weight
|
113 |
+
return result
|
annotator/lama/saicinpainting/training/losses/segmentation.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .constants import weights as constant_weights
|
6 |
+
|
7 |
+
|
8 |
+
class CrossEntropy2d(nn.Module):
|
9 |
+
def __init__(self, reduction="mean", ignore_label=255, weights=None, *args, **kwargs):
|
10 |
+
"""
|
11 |
+
weight (Tensor, optional): a manual rescaling weight given to each class.
|
12 |
+
If given, has to be a Tensor of size "nclasses"
|
13 |
+
"""
|
14 |
+
super(CrossEntropy2d, self).__init__()
|
15 |
+
self.reduction = reduction
|
16 |
+
self.ignore_label = ignore_label
|
17 |
+
self.weights = weights
|
18 |
+
if self.weights is not None:
|
19 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
20 |
+
self.weights = torch.FloatTensor(constant_weights[weights]).to(device)
|
21 |
+
|
22 |
+
def forward(self, predict, target):
|
23 |
+
"""
|
24 |
+
Args:
|
25 |
+
predict:(n, c, h, w)
|
26 |
+
target:(n, 1, h, w)
|
27 |
+
"""
|
28 |
+
target = target.long()
|
29 |
+
assert not target.requires_grad
|
30 |
+
assert predict.dim() == 4, "{0}".format(predict.size())
|
31 |
+
assert target.dim() == 4, "{0}".format(target.size())
|
32 |
+
assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
|
33 |
+
assert target.size(1) == 1, "{0}".format(target.size(1))
|
34 |
+
assert predict.size(2) == target.size(2), "{0} vs {1} ".format(predict.size(2), target.size(2))
|
35 |
+
assert predict.size(3) == target.size(3), "{0} vs {1} ".format(predict.size(3), target.size(3))
|
36 |
+
target = target.squeeze(1)
|
37 |
+
n, c, h, w = predict.size()
|
38 |
+
target_mask = (target >= 0) * (target != self.ignore_label)
|
39 |
+
target = target[target_mask]
|
40 |
+
predict = predict.transpose(1, 2).transpose(2, 3).contiguous()
|
41 |
+
predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)
|
42 |
+
loss = F.cross_entropy(predict, target, weight=self.weights, reduction=self.reduction)
|
43 |
+
return loss
|
annotator/lama/saicinpainting/training/losses/style_loss.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision.models as models
|
4 |
+
|
5 |
+
|
6 |
+
class PerceptualLoss(nn.Module):
|
7 |
+
r"""
|
8 |
+
Perceptual loss, VGG-based
|
9 |
+
https://arxiv.org/abs/1603.08155
|
10 |
+
https://github.com/dxyang/StyleTransfer/blob/master/utils.py
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
|
14 |
+
super(PerceptualLoss, self).__init__()
|
15 |
+
self.add_module('vgg', VGG19())
|
16 |
+
self.criterion = torch.nn.L1Loss()
|
17 |
+
self.weights = weights
|
18 |
+
|
19 |
+
def __call__(self, x, y):
|
20 |
+
# Compute features
|
21 |
+
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
|
22 |
+
|
23 |
+
content_loss = 0.0
|
24 |
+
content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
|
25 |
+
content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
|
26 |
+
content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
|
27 |
+
content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
|
28 |
+
content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])
|
29 |
+
|
30 |
+
|
31 |
+
return content_loss
|
32 |
+
|
33 |
+
|
34 |
+
class VGG19(torch.nn.Module):
|
35 |
+
def __init__(self):
|
36 |
+
super(VGG19, self).__init__()
|
37 |
+
features = models.vgg19(pretrained=True).features
|
38 |
+
self.relu1_1 = torch.nn.Sequential()
|
39 |
+
self.relu1_2 = torch.nn.Sequential()
|
40 |
+
|
41 |
+
self.relu2_1 = torch.nn.Sequential()
|
42 |
+
self.relu2_2 = torch.nn.Sequential()
|
43 |
+
|
44 |
+
self.relu3_1 = torch.nn.Sequential()
|
45 |
+
self.relu3_2 = torch.nn.Sequential()
|
46 |
+
self.relu3_3 = torch.nn.Sequential()
|
47 |
+
self.relu3_4 = torch.nn.Sequential()
|
48 |
+
|
49 |
+
self.relu4_1 = torch.nn.Sequential()
|
50 |
+
self.relu4_2 = torch.nn.Sequential()
|
51 |
+
self.relu4_3 = torch.nn.Sequential()
|
52 |
+
self.relu4_4 = torch.nn.Sequential()
|
53 |
+
|
54 |
+
self.relu5_1 = torch.nn.Sequential()
|
55 |
+
self.relu5_2 = torch.nn.Sequential()
|
56 |
+
self.relu5_3 = torch.nn.Sequential()
|
57 |
+
self.relu5_4 = torch.nn.Sequential()
|
58 |
+
|
59 |
+
for x in range(2):
|
60 |
+
self.relu1_1.add_module(str(x), features[x])
|
61 |
+
|
62 |
+
for x in range(2, 4):
|
63 |
+
self.relu1_2.add_module(str(x), features[x])
|
64 |
+
|
65 |
+
for x in range(4, 7):
|
66 |
+
self.relu2_1.add_module(str(x), features[x])
|
67 |
+
|
68 |
+
for x in range(7, 9):
|
69 |
+
self.relu2_2.add_module(str(x), features[x])
|
70 |
+
|
71 |
+
for x in range(9, 12):
|
72 |
+
self.relu3_1.add_module(str(x), features[x])
|
73 |
+
|
74 |
+
for x in range(12, 14):
|
75 |
+
self.relu3_2.add_module(str(x), features[x])
|
76 |
+
|
77 |
+
for x in range(14, 16):
|
78 |
+
self.relu3_2.add_module(str(x), features[x])
|
79 |
+
|
80 |
+
for x in range(16, 18):
|
81 |
+
self.relu3_4.add_module(str(x), features[x])
|
82 |
+
|
83 |
+
for x in range(18, 21):
|
84 |
+
self.relu4_1.add_module(str(x), features[x])
|
85 |
+
|
86 |
+
for x in range(21, 23):
|
87 |
+
self.relu4_2.add_module(str(x), features[x])
|
88 |
+
|
89 |
+
for x in range(23, 25):
|
90 |
+
self.relu4_3.add_module(str(x), features[x])
|
91 |
+
|
92 |
+
for x in range(25, 27):
|
93 |
+
self.relu4_4.add_module(str(x), features[x])
|
94 |
+
|
95 |
+
for x in range(27, 30):
|
96 |
+
self.relu5_1.add_module(str(x), features[x])
|
97 |
+
|
98 |
+
for x in range(30, 32):
|
99 |
+
self.relu5_2.add_module(str(x), features[x])
|
100 |
+
|
101 |
+
for x in range(32, 34):
|
102 |
+
self.relu5_3.add_module(str(x), features[x])
|
103 |
+
|
104 |
+
for x in range(34, 36):
|
105 |
+
self.relu5_4.add_module(str(x), features[x])
|
106 |
+
|
107 |
+
# don't need the gradients, just want the features
|
108 |
+
for param in self.parameters():
|
109 |
+
param.requires_grad = False
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
relu1_1 = self.relu1_1(x)
|
113 |
+
relu1_2 = self.relu1_2(relu1_1)
|
114 |
+
|
115 |
+
relu2_1 = self.relu2_1(relu1_2)
|
116 |
+
relu2_2 = self.relu2_2(relu2_1)
|
117 |
+
|
118 |
+
relu3_1 = self.relu3_1(relu2_2)
|
119 |
+
relu3_2 = self.relu3_2(relu3_1)
|
120 |
+
relu3_3 = self.relu3_3(relu3_2)
|
121 |
+
relu3_4 = self.relu3_4(relu3_3)
|
122 |
+
|
123 |
+
relu4_1 = self.relu4_1(relu3_4)
|
124 |
+
relu4_2 = self.relu4_2(relu4_1)
|
125 |
+
relu4_3 = self.relu4_3(relu4_2)
|
126 |
+
relu4_4 = self.relu4_4(relu4_3)
|
127 |
+
|
128 |
+
relu5_1 = self.relu5_1(relu4_4)
|
129 |
+
relu5_2 = self.relu5_2(relu5_1)
|
130 |
+
relu5_3 = self.relu5_3(relu5_2)
|
131 |
+
relu5_4 = self.relu5_4(relu5_3)
|
132 |
+
|
133 |
+
out = {
|
134 |
+
'relu1_1': relu1_1,
|
135 |
+
'relu1_2': relu1_2,
|
136 |
+
|
137 |
+
'relu2_1': relu2_1,
|
138 |
+
'relu2_2': relu2_2,
|
139 |
+
|
140 |
+
'relu3_1': relu3_1,
|
141 |
+
'relu3_2': relu3_2,
|
142 |
+
'relu3_3': relu3_3,
|
143 |
+
'relu3_4': relu3_4,
|
144 |
+
|
145 |
+
'relu4_1': relu4_1,
|
146 |
+
'relu4_2': relu4_2,
|
147 |
+
'relu4_3': relu4_3,
|
148 |
+
'relu4_4': relu4_4,
|
149 |
+
|
150 |
+
'relu5_1': relu5_1,
|
151 |
+
'relu5_2': relu5_2,
|
152 |
+
'relu5_3': relu5_3,
|
153 |
+
'relu5_4': relu5_4,
|
154 |
+
}
|
155 |
+
return out
|
annotator/lama/saicinpainting/training/modules/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from annotator.lama.saicinpainting.training.modules.ffc import FFCResNetGenerator
|
4 |
+
from annotator.lama.saicinpainting.training.modules.pix2pixhd import GlobalGenerator, MultiDilatedGlobalGenerator, \
|
5 |
+
NLayerDiscriminator, MultidilatedNLayerDiscriminator
|
6 |
+
|
7 |
+
def make_generator(config, kind, **kwargs):
|
8 |
+
logging.info(f'Make generator {kind}')
|
9 |
+
|
10 |
+
if kind == 'pix2pixhd_multidilated':
|
11 |
+
return MultiDilatedGlobalGenerator(**kwargs)
|
12 |
+
|
13 |
+
if kind == 'pix2pixhd_global':
|
14 |
+
return GlobalGenerator(**kwargs)
|
15 |
+
|
16 |
+
if kind == 'ffc_resnet':
|
17 |
+
return FFCResNetGenerator(**kwargs)
|
18 |
+
|
19 |
+
raise ValueError(f'Unknown generator kind {kind}')
|
20 |
+
|
21 |
+
|
22 |
+
def make_discriminator(kind, **kwargs):
|
23 |
+
logging.info(f'Make discriminator {kind}')
|
24 |
+
|
25 |
+
if kind == 'pix2pixhd_nlayer_multidilated':
|
26 |
+
return MultidilatedNLayerDiscriminator(**kwargs)
|
27 |
+
|
28 |
+
if kind == 'pix2pixhd_nlayer':
|
29 |
+
return NLayerDiscriminator(**kwargs)
|
30 |
+
|
31 |
+
raise ValueError(f'Unknown discriminator kind {kind}')
|
annotator/lama/saicinpainting/training/modules/base.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
from typing import Tuple, List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from annotator.lama.saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
|
8 |
+
from annotator.lama.saicinpainting.training.modules.multidilated_conv import MultidilatedConv
|
9 |
+
|
10 |
+
|
11 |
+
class BaseDiscriminator(nn.Module):
|
12 |
+
@abc.abstractmethod
|
13 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
14 |
+
"""
|
15 |
+
Predict scores and get intermediate activations. Useful for feature matching loss
|
16 |
+
:return tuple (scores, list of intermediate activations)
|
17 |
+
"""
|
18 |
+
raise NotImplemented()
|
19 |
+
|
20 |
+
|
21 |
+
def get_conv_block_ctor(kind='default'):
|
22 |
+
if not isinstance(kind, str):
|
23 |
+
return kind
|
24 |
+
if kind == 'default':
|
25 |
+
return nn.Conv2d
|
26 |
+
if kind == 'depthwise':
|
27 |
+
return DepthWiseSeperableConv
|
28 |
+
if kind == 'multidilated':
|
29 |
+
return MultidilatedConv
|
30 |
+
raise ValueError(f'Unknown convolutional block kind {kind}')
|
31 |
+
|
32 |
+
|
33 |
+
def get_norm_layer(kind='bn'):
|
34 |
+
if not isinstance(kind, str):
|
35 |
+
return kind
|
36 |
+
if kind == 'bn':
|
37 |
+
return nn.BatchNorm2d
|
38 |
+
if kind == 'in':
|
39 |
+
return nn.InstanceNorm2d
|
40 |
+
raise ValueError(f'Unknown norm block kind {kind}')
|
41 |
+
|
42 |
+
|
43 |
+
def get_activation(kind='tanh'):
|
44 |
+
if kind == 'tanh':
|
45 |
+
return nn.Tanh()
|
46 |
+
if kind == 'sigmoid':
|
47 |
+
return nn.Sigmoid()
|
48 |
+
if kind is False:
|
49 |
+
return nn.Identity()
|
50 |
+
raise ValueError(f'Unknown activation kind {kind}')
|
51 |
+
|
52 |
+
|
53 |
+
class SimpleMultiStepGenerator(nn.Module):
|
54 |
+
def __init__(self, steps: List[nn.Module]):
|
55 |
+
super().__init__()
|
56 |
+
self.steps = nn.ModuleList(steps)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
cur_in = x
|
60 |
+
outs = []
|
61 |
+
for step in self.steps:
|
62 |
+
cur_out = step(cur_in)
|
63 |
+
outs.append(cur_out)
|
64 |
+
cur_in = torch.cat((cur_in, cur_out), dim=1)
|
65 |
+
return torch.cat(outs[::-1], dim=1)
|
66 |
+
|
67 |
+
def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features):
|
68 |
+
if kind == 'convtranspose':
|
69 |
+
return [nn.ConvTranspose2d(min(max_features, ngf * mult),
|
70 |
+
min(max_features, int(ngf * mult / 2)),
|
71 |
+
kernel_size=3, stride=2, padding=1, output_padding=1),
|
72 |
+
norm_layer(min(max_features, int(ngf * mult / 2))), activation]
|
73 |
+
elif kind == 'bilinear':
|
74 |
+
return [nn.Upsample(scale_factor=2, mode='bilinear'),
|
75 |
+
DepthWiseSeperableConv(min(max_features, ngf * mult),
|
76 |
+
min(max_features, int(ngf * mult / 2)),
|
77 |
+
kernel_size=3, stride=1, padding=1),
|
78 |
+
norm_layer(min(max_features, int(ngf * mult / 2))), activation]
|
79 |
+
else:
|
80 |
+
raise Exception(f"Invalid deconv kind: {kind}")
|
annotator/lama/saicinpainting/training/modules/depthwise_sep_conv.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class DepthWiseSeperableConv(nn.Module):
|
5 |
+
def __init__(self, in_dim, out_dim, *args, **kwargs):
|
6 |
+
super().__init__()
|
7 |
+
if 'groups' in kwargs:
|
8 |
+
# ignoring groups for Depthwise Sep Conv
|
9 |
+
del kwargs['groups']
|
10 |
+
|
11 |
+
self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs)
|
12 |
+
self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
out = self.depthwise(x)
|
16 |
+
out = self.pointwise(out)
|
17 |
+
return out
|
annotator/lama/saicinpainting/training/modules/fake_fakes.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from kornia import SamplePadding
|
3 |
+
from kornia.augmentation import RandomAffine, CenterCrop
|
4 |
+
|
5 |
+
|
6 |
+
class FakeFakesGenerator:
|
7 |
+
def __init__(self, aug_proba=0.5, img_aug_degree=30, img_aug_translate=0.2):
|
8 |
+
self.grad_aug = RandomAffine(degrees=360,
|
9 |
+
translate=0.2,
|
10 |
+
padding_mode=SamplePadding.REFLECTION,
|
11 |
+
keepdim=False,
|
12 |
+
p=1)
|
13 |
+
self.img_aug = RandomAffine(degrees=img_aug_degree,
|
14 |
+
translate=img_aug_translate,
|
15 |
+
padding_mode=SamplePadding.REFLECTION,
|
16 |
+
keepdim=True,
|
17 |
+
p=1)
|
18 |
+
self.aug_proba = aug_proba
|
19 |
+
|
20 |
+
def __call__(self, input_images, masks):
|
21 |
+
blend_masks = self._fill_masks_with_gradient(masks)
|
22 |
+
blend_target = self._make_blend_target(input_images)
|
23 |
+
result = input_images * (1 - blend_masks) + blend_target * blend_masks
|
24 |
+
return result, blend_masks
|
25 |
+
|
26 |
+
def _make_blend_target(self, input_images):
|
27 |
+
batch_size = input_images.shape[0]
|
28 |
+
permuted = input_images[torch.randperm(batch_size)]
|
29 |
+
augmented = self.img_aug(input_images)
|
30 |
+
is_aug = (torch.rand(batch_size, device=input_images.device)[:, None, None, None] < self.aug_proba).float()
|
31 |
+
result = augmented * is_aug + permuted * (1 - is_aug)
|
32 |
+
return result
|
33 |
+
|
34 |
+
def _fill_masks_with_gradient(self, masks):
|
35 |
+
batch_size, _, height, width = masks.shape
|
36 |
+
grad = torch.linspace(0, 1, steps=width * 2, device=masks.device, dtype=masks.dtype) \
|
37 |
+
.view(1, 1, 1, -1).expand(batch_size, 1, height * 2, width * 2)
|
38 |
+
grad = self.grad_aug(grad)
|
39 |
+
grad = CenterCrop((height, width))(grad)
|
40 |
+
grad *= masks
|
41 |
+
|
42 |
+
grad_for_min = grad + (1 - masks) * 10
|
43 |
+
grad -= grad_for_min.view(batch_size, -1).min(-1).values[:, None, None, None]
|
44 |
+
grad /= grad.view(batch_size, -1).max(-1).values[:, None, None, None] + 1e-6
|
45 |
+
grad.clamp_(min=0, max=1)
|
46 |
+
|
47 |
+
return grad
|
annotator/lama/saicinpainting/training/modules/ffc.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Fast Fourier Convolution NeurIPS 2020
|
2 |
+
# original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
|
3 |
+
# paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from annotator.lama.saicinpainting.training.modules.base import get_activation, BaseDiscriminator
|
11 |
+
from annotator.lama.saicinpainting.training.modules.spatial_transform import LearnableSpatialTransformWrapper
|
12 |
+
from annotator.lama.saicinpainting.training.modules.squeeze_excitation import SELayer
|
13 |
+
from annotator.lama.saicinpainting.utils import get_shape
|
14 |
+
|
15 |
+
|
16 |
+
class FFCSE_block(nn.Module):
|
17 |
+
|
18 |
+
def __init__(self, channels, ratio_g):
|
19 |
+
super(FFCSE_block, self).__init__()
|
20 |
+
in_cg = int(channels * ratio_g)
|
21 |
+
in_cl = channels - in_cg
|
22 |
+
r = 16
|
23 |
+
|
24 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
25 |
+
self.conv1 = nn.Conv2d(channels, channels // r,
|
26 |
+
kernel_size=1, bias=True)
|
27 |
+
self.relu1 = nn.ReLU(inplace=True)
|
28 |
+
self.conv_a2l = None if in_cl == 0 else nn.Conv2d(
|
29 |
+
channels // r, in_cl, kernel_size=1, bias=True)
|
30 |
+
self.conv_a2g = None if in_cg == 0 else nn.Conv2d(
|
31 |
+
channels // r, in_cg, kernel_size=1, bias=True)
|
32 |
+
self.sigmoid = nn.Sigmoid()
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
x = x if type(x) is tuple else (x, 0)
|
36 |
+
id_l, id_g = x
|
37 |
+
|
38 |
+
x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1)
|
39 |
+
x = self.avgpool(x)
|
40 |
+
x = self.relu1(self.conv1(x))
|
41 |
+
|
42 |
+
x_l = 0 if self.conv_a2l is None else id_l * \
|
43 |
+
self.sigmoid(self.conv_a2l(x))
|
44 |
+
x_g = 0 if self.conv_a2g is None else id_g * \
|
45 |
+
self.sigmoid(self.conv_a2g(x))
|
46 |
+
return x_l, x_g
|
47 |
+
|
48 |
+
|
49 |
+
class FourierUnit(nn.Module):
|
50 |
+
|
51 |
+
def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
|
52 |
+
spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
|
53 |
+
# bn_layer not used
|
54 |
+
super(FourierUnit, self).__init__()
|
55 |
+
self.groups = groups
|
56 |
+
|
57 |
+
self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
|
58 |
+
out_channels=out_channels * 2,
|
59 |
+
kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
|
60 |
+
self.bn = torch.nn.BatchNorm2d(out_channels * 2)
|
61 |
+
self.relu = torch.nn.ReLU(inplace=True)
|
62 |
+
|
63 |
+
# squeeze and excitation block
|
64 |
+
self.use_se = use_se
|
65 |
+
if use_se:
|
66 |
+
if se_kwargs is None:
|
67 |
+
se_kwargs = {}
|
68 |
+
self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
|
69 |
+
|
70 |
+
self.spatial_scale_factor = spatial_scale_factor
|
71 |
+
self.spatial_scale_mode = spatial_scale_mode
|
72 |
+
self.spectral_pos_encoding = spectral_pos_encoding
|
73 |
+
self.ffc3d = ffc3d
|
74 |
+
self.fft_norm = fft_norm
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
batch = x.shape[0]
|
78 |
+
|
79 |
+
if self.spatial_scale_factor is not None:
|
80 |
+
orig_size = x.shape[-2:]
|
81 |
+
x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)
|
82 |
+
|
83 |
+
r_size = x.size()
|
84 |
+
# (batch, c, h, w/2+1, 2)
|
85 |
+
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
|
86 |
+
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
|
87 |
+
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
|
88 |
+
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
|
89 |
+
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
|
90 |
+
|
91 |
+
if self.spectral_pos_encoding:
|
92 |
+
height, width = ffted.shape[-2:]
|
93 |
+
coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
|
94 |
+
coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
|
95 |
+
ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
|
96 |
+
|
97 |
+
if self.use_se:
|
98 |
+
ffted = self.se(ffted)
|
99 |
+
|
100 |
+
ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
|
101 |
+
ffted = self.relu(self.bn(ffted))
|
102 |
+
|
103 |
+
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
|
104 |
+
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
|
105 |
+
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
|
106 |
+
|
107 |
+
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
|
108 |
+
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
|
109 |
+
|
110 |
+
if self.spatial_scale_factor is not None:
|
111 |
+
output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
|
112 |
+
|
113 |
+
return output
|
114 |
+
|
115 |
+
|
116 |
+
class SeparableFourierUnit(nn.Module):
|
117 |
+
|
118 |
+
def __init__(self, in_channels, out_channels, groups=1, kernel_size=3):
|
119 |
+
# bn_layer not used
|
120 |
+
super(SeparableFourierUnit, self).__init__()
|
121 |
+
self.groups = groups
|
122 |
+
row_out_channels = out_channels // 2
|
123 |
+
col_out_channels = out_channels - row_out_channels
|
124 |
+
self.row_conv = torch.nn.Conv2d(in_channels=in_channels * 2,
|
125 |
+
out_channels=row_out_channels * 2,
|
126 |
+
kernel_size=(kernel_size, 1), # kernel size is always like this, but the data will be transposed
|
127 |
+
stride=1, padding=(kernel_size // 2, 0),
|
128 |
+
padding_mode='reflect',
|
129 |
+
groups=self.groups, bias=False)
|
130 |
+
self.col_conv = torch.nn.Conv2d(in_channels=in_channels * 2,
|
131 |
+
out_channels=col_out_channels * 2,
|
132 |
+
kernel_size=(kernel_size, 1), # kernel size is always like this, but the data will be transposed
|
133 |
+
stride=1, padding=(kernel_size // 2, 0),
|
134 |
+
padding_mode='reflect',
|
135 |
+
groups=self.groups, bias=False)
|
136 |
+
self.row_bn = torch.nn.BatchNorm2d(row_out_channels * 2)
|
137 |
+
self.col_bn = torch.nn.BatchNorm2d(col_out_channels * 2)
|
138 |
+
self.relu = torch.nn.ReLU(inplace=True)
|
139 |
+
|
140 |
+
def process_branch(self, x, conv, bn):
|
141 |
+
batch = x.shape[0]
|
142 |
+
|
143 |
+
r_size = x.size()
|
144 |
+
# (batch, c, h, w/2+1, 2)
|
145 |
+
ffted = torch.fft.rfft(x, norm="ortho")
|
146 |
+
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
|
147 |
+
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
|
148 |
+
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
|
149 |
+
|
150 |
+
ffted = self.relu(bn(conv(ffted)))
|
151 |
+
|
152 |
+
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
|
153 |
+
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
|
154 |
+
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
|
155 |
+
|
156 |
+
output = torch.fft.irfft(ffted, s=x.shape[-1:], norm="ortho")
|
157 |
+
return output
|
158 |
+
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
rowwise = self.process_branch(x, self.row_conv, self.row_bn)
|
162 |
+
colwise = self.process_branch(x.permute(0, 1, 3, 2), self.col_conv, self.col_bn).permute(0, 1, 3, 2)
|
163 |
+
out = torch.cat((rowwise, colwise), dim=1)
|
164 |
+
return out
|
165 |
+
|
166 |
+
|
167 |
+
class SpectralTransform(nn.Module):
|
168 |
+
|
169 |
+
def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, separable_fu=False, **fu_kwargs):
|
170 |
+
# bn_layer not used
|
171 |
+
super(SpectralTransform, self).__init__()
|
172 |
+
self.enable_lfu = enable_lfu
|
173 |
+
if stride == 2:
|
174 |
+
self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
|
175 |
+
else:
|
176 |
+
self.downsample = nn.Identity()
|
177 |
+
|
178 |
+
self.stride = stride
|
179 |
+
self.conv1 = nn.Sequential(
|
180 |
+
nn.Conv2d(in_channels, out_channels //
|
181 |
+
2, kernel_size=1, groups=groups, bias=False),
|
182 |
+
nn.BatchNorm2d(out_channels // 2),
|
183 |
+
nn.ReLU(inplace=True)
|
184 |
+
)
|
185 |
+
fu_class = SeparableFourierUnit if separable_fu else FourierUnit
|
186 |
+
self.fu = fu_class(
|
187 |
+
out_channels // 2, out_channels // 2, groups, **fu_kwargs)
|
188 |
+
if self.enable_lfu:
|
189 |
+
self.lfu = fu_class(
|
190 |
+
out_channels // 2, out_channels // 2, groups)
|
191 |
+
self.conv2 = torch.nn.Conv2d(
|
192 |
+
out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
|
193 |
+
|
194 |
+
def forward(self, x):
|
195 |
+
|
196 |
+
x = self.downsample(x)
|
197 |
+
x = self.conv1(x)
|
198 |
+
output = self.fu(x)
|
199 |
+
|
200 |
+
if self.enable_lfu:
|
201 |
+
n, c, h, w = x.shape
|
202 |
+
split_no = 2
|
203 |
+
split_s = h // split_no
|
204 |
+
xs = torch.cat(torch.split(
|
205 |
+
x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
|
206 |
+
xs = torch.cat(torch.split(xs, split_s, dim=-1),
|
207 |
+
dim=1).contiguous()
|
208 |
+
xs = self.lfu(xs)
|
209 |
+
xs = xs.repeat(1, 1, split_no, split_no).contiguous()
|
210 |
+
else:
|
211 |
+
xs = 0
|
212 |
+
|
213 |
+
output = self.conv2(x + output + xs)
|
214 |
+
|
215 |
+
return output
|
216 |
+
|
217 |
+
|
218 |
+
class FFC(nn.Module):
|
219 |
+
|
220 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
221 |
+
ratio_gin, ratio_gout, stride=1, padding=0,
|
222 |
+
dilation=1, groups=1, bias=False, enable_lfu=True,
|
223 |
+
padding_type='reflect', gated=False, **spectral_kwargs):
|
224 |
+
super(FFC, self).__init__()
|
225 |
+
|
226 |
+
assert stride == 1 or stride == 2, "Stride should be 1 or 2."
|
227 |
+
self.stride = stride
|
228 |
+
|
229 |
+
in_cg = int(in_channels * ratio_gin)
|
230 |
+
in_cl = in_channels - in_cg
|
231 |
+
out_cg = int(out_channels * ratio_gout)
|
232 |
+
out_cl = out_channels - out_cg
|
233 |
+
#groups_g = 1 if groups == 1 else int(groups * ratio_gout)
|
234 |
+
#groups_l = 1 if groups == 1 else groups - groups_g
|
235 |
+
|
236 |
+
self.ratio_gin = ratio_gin
|
237 |
+
self.ratio_gout = ratio_gout
|
238 |
+
self.global_in_num = in_cg
|
239 |
+
|
240 |
+
module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
|
241 |
+
self.convl2l = module(in_cl, out_cl, kernel_size,
|
242 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
243 |
+
module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
|
244 |
+
self.convl2g = module(in_cl, out_cg, kernel_size,
|
245 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
246 |
+
module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
|
247 |
+
self.convg2l = module(in_cg, out_cl, kernel_size,
|
248 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
249 |
+
module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
|
250 |
+
self.convg2g = module(
|
251 |
+
in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)
|
252 |
+
|
253 |
+
self.gated = gated
|
254 |
+
module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
|
255 |
+
self.gate = module(in_channels, 2, 1)
|
256 |
+
|
257 |
+
def forward(self, x):
|
258 |
+
x_l, x_g = x if type(x) is tuple else (x, 0)
|
259 |
+
out_xl, out_xg = 0, 0
|
260 |
+
|
261 |
+
if self.gated:
|
262 |
+
total_input_parts = [x_l]
|
263 |
+
if torch.is_tensor(x_g):
|
264 |
+
total_input_parts.append(x_g)
|
265 |
+
total_input = torch.cat(total_input_parts, dim=1)
|
266 |
+
|
267 |
+
gates = torch.sigmoid(self.gate(total_input))
|
268 |
+
g2l_gate, l2g_gate = gates.chunk(2, dim=1)
|
269 |
+
else:
|
270 |
+
g2l_gate, l2g_gate = 1, 1
|
271 |
+
|
272 |
+
if self.ratio_gout != 1:
|
273 |
+
out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
|
274 |
+
if self.ratio_gout != 0:
|
275 |
+
out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
|
276 |
+
|
277 |
+
return out_xl, out_xg
|
278 |
+
|
279 |
+
|
280 |
+
class FFC_BN_ACT(nn.Module):
|
281 |
+
|
282 |
+
def __init__(self, in_channels, out_channels,
|
283 |
+
kernel_size, ratio_gin, ratio_gout,
|
284 |
+
stride=1, padding=0, dilation=1, groups=1, bias=False,
|
285 |
+
norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity,
|
286 |
+
padding_type='reflect',
|
287 |
+
enable_lfu=True, **kwargs):
|
288 |
+
super(FFC_BN_ACT, self).__init__()
|
289 |
+
self.ffc = FFC(in_channels, out_channels, kernel_size,
|
290 |
+
ratio_gin, ratio_gout, stride, padding, dilation,
|
291 |
+
groups, bias, enable_lfu, padding_type=padding_type, **kwargs)
|
292 |
+
lnorm = nn.Identity if ratio_gout == 1 else norm_layer
|
293 |
+
gnorm = nn.Identity if ratio_gout == 0 else norm_layer
|
294 |
+
global_channels = int(out_channels * ratio_gout)
|
295 |
+
self.bn_l = lnorm(out_channels - global_channels)
|
296 |
+
self.bn_g = gnorm(global_channels)
|
297 |
+
|
298 |
+
lact = nn.Identity if ratio_gout == 1 else activation_layer
|
299 |
+
gact = nn.Identity if ratio_gout == 0 else activation_layer
|
300 |
+
self.act_l = lact(inplace=True)
|
301 |
+
self.act_g = gact(inplace=True)
|
302 |
+
|
303 |
+
def forward(self, x):
|
304 |
+
x_l, x_g = self.ffc(x)
|
305 |
+
x_l = self.act_l(self.bn_l(x_l))
|
306 |
+
x_g = self.act_g(self.bn_g(x_g))
|
307 |
+
return x_l, x_g
|
308 |
+
|
309 |
+
|
310 |
+
class FFCResnetBlock(nn.Module):
|
311 |
+
def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1,
|
312 |
+
spatial_transform_kwargs=None, inline=False, **conv_kwargs):
|
313 |
+
super().__init__()
|
314 |
+
self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
|
315 |
+
norm_layer=norm_layer,
|
316 |
+
activation_layer=activation_layer,
|
317 |
+
padding_type=padding_type,
|
318 |
+
**conv_kwargs)
|
319 |
+
self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
|
320 |
+
norm_layer=norm_layer,
|
321 |
+
activation_layer=activation_layer,
|
322 |
+
padding_type=padding_type,
|
323 |
+
**conv_kwargs)
|
324 |
+
if spatial_transform_kwargs is not None:
|
325 |
+
self.conv1 = LearnableSpatialTransformWrapper(self.conv1, **spatial_transform_kwargs)
|
326 |
+
self.conv2 = LearnableSpatialTransformWrapper(self.conv2, **spatial_transform_kwargs)
|
327 |
+
self.inline = inline
|
328 |
+
|
329 |
+
def forward(self, x):
|
330 |
+
if self.inline:
|
331 |
+
x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
|
332 |
+
else:
|
333 |
+
x_l, x_g = x if type(x) is tuple else (x, 0)
|
334 |
+
|
335 |
+
id_l, id_g = x_l, x_g
|
336 |
+
|
337 |
+
x_l, x_g = self.conv1((x_l, x_g))
|
338 |
+
x_l, x_g = self.conv2((x_l, x_g))
|
339 |
+
|
340 |
+
x_l, x_g = id_l + x_l, id_g + x_g
|
341 |
+
out = x_l, x_g
|
342 |
+
if self.inline:
|
343 |
+
out = torch.cat(out, dim=1)
|
344 |
+
return out
|
345 |
+
|
346 |
+
|
347 |
+
class ConcatTupleLayer(nn.Module):
|
348 |
+
def forward(self, x):
|
349 |
+
assert isinstance(x, tuple)
|
350 |
+
x_l, x_g = x
|
351 |
+
assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
|
352 |
+
if not torch.is_tensor(x_g):
|
353 |
+
return x_l
|
354 |
+
return torch.cat(x, dim=1)
|
355 |
+
|
356 |
+
|
357 |
+
class FFCResNetGenerator(nn.Module):
|
358 |
+
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
|
359 |
+
padding_type='reflect', activation_layer=nn.ReLU,
|
360 |
+
up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True),
|
361 |
+
init_conv_kwargs={}, downsample_conv_kwargs={}, resnet_conv_kwargs={},
|
362 |
+
spatial_transform_layers=None, spatial_transform_kwargs={},
|
363 |
+
add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}):
|
364 |
+
assert (n_blocks >= 0)
|
365 |
+
super().__init__()
|
366 |
+
|
367 |
+
model = [nn.ReflectionPad2d(3),
|
368 |
+
FFC_BN_ACT(input_nc, ngf, kernel_size=7, padding=0, norm_layer=norm_layer,
|
369 |
+
activation_layer=activation_layer, **init_conv_kwargs)]
|
370 |
+
|
371 |
+
### downsample
|
372 |
+
for i in range(n_downsampling):
|
373 |
+
mult = 2 ** i
|
374 |
+
if i == n_downsampling - 1:
|
375 |
+
cur_conv_kwargs = dict(downsample_conv_kwargs)
|
376 |
+
cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get('ratio_gin', 0)
|
377 |
+
else:
|
378 |
+
cur_conv_kwargs = downsample_conv_kwargs
|
379 |
+
model += [FFC_BN_ACT(min(max_features, ngf * mult),
|
380 |
+
min(max_features, ngf * mult * 2),
|
381 |
+
kernel_size=3, stride=2, padding=1,
|
382 |
+
norm_layer=norm_layer,
|
383 |
+
activation_layer=activation_layer,
|
384 |
+
**cur_conv_kwargs)]
|
385 |
+
|
386 |
+
mult = 2 ** n_downsampling
|
387 |
+
feats_num_bottleneck = min(max_features, ngf * mult)
|
388 |
+
|
389 |
+
### resnet blocks
|
390 |
+
for i in range(n_blocks):
|
391 |
+
cur_resblock = FFCResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation_layer=activation_layer,
|
392 |
+
norm_layer=norm_layer, **resnet_conv_kwargs)
|
393 |
+
if spatial_transform_layers is not None and i in spatial_transform_layers:
|
394 |
+
cur_resblock = LearnableSpatialTransformWrapper(cur_resblock, **spatial_transform_kwargs)
|
395 |
+
model += [cur_resblock]
|
396 |
+
|
397 |
+
model += [ConcatTupleLayer()]
|
398 |
+
|
399 |
+
### upsample
|
400 |
+
for i in range(n_downsampling):
|
401 |
+
mult = 2 ** (n_downsampling - i)
|
402 |
+
model += [nn.ConvTranspose2d(min(max_features, ngf * mult),
|
403 |
+
min(max_features, int(ngf * mult / 2)),
|
404 |
+
kernel_size=3, stride=2, padding=1, output_padding=1),
|
405 |
+
up_norm_layer(min(max_features, int(ngf * mult / 2))),
|
406 |
+
up_activation]
|
407 |
+
|
408 |
+
if out_ffc:
|
409 |
+
model += [FFCResnetBlock(ngf, padding_type=padding_type, activation_layer=activation_layer,
|
410 |
+
norm_layer=norm_layer, inline=True, **out_ffc_kwargs)]
|
411 |
+
|
412 |
+
model += [nn.ReflectionPad2d(3),
|
413 |
+
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
414 |
+
if add_out_act:
|
415 |
+
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
|
416 |
+
self.model = nn.Sequential(*model)
|
417 |
+
|
418 |
+
def forward(self, input):
|
419 |
+
return self.model(input)
|
420 |
+
|
421 |
+
|
422 |
+
class FFCNLayerDiscriminator(BaseDiscriminator):
|
423 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, max_features=512,
|
424 |
+
init_conv_kwargs={}, conv_kwargs={}):
|
425 |
+
super().__init__()
|
426 |
+
self.n_layers = n_layers
|
427 |
+
|
428 |
+
def _act_ctor(inplace=True):
|
429 |
+
return nn.LeakyReLU(negative_slope=0.2, inplace=inplace)
|
430 |
+
|
431 |
+
kw = 3
|
432 |
+
padw = int(np.ceil((kw-1.0)/2))
|
433 |
+
sequence = [[FFC_BN_ACT(input_nc, ndf, kernel_size=kw, padding=padw, norm_layer=norm_layer,
|
434 |
+
activation_layer=_act_ctor, **init_conv_kwargs)]]
|
435 |
+
|
436 |
+
nf = ndf
|
437 |
+
for n in range(1, n_layers):
|
438 |
+
nf_prev = nf
|
439 |
+
nf = min(nf * 2, max_features)
|
440 |
+
|
441 |
+
cur_model = [
|
442 |
+
FFC_BN_ACT(nf_prev, nf,
|
443 |
+
kernel_size=kw, stride=2, padding=padw,
|
444 |
+
norm_layer=norm_layer,
|
445 |
+
activation_layer=_act_ctor,
|
446 |
+
**conv_kwargs)
|
447 |
+
]
|
448 |
+
sequence.append(cur_model)
|
449 |
+
|
450 |
+
nf_prev = nf
|
451 |
+
nf = min(nf * 2, 512)
|
452 |
+
|
453 |
+
cur_model = [
|
454 |
+
FFC_BN_ACT(nf_prev, nf,
|
455 |
+
kernel_size=kw, stride=1, padding=padw,
|
456 |
+
norm_layer=norm_layer,
|
457 |
+
activation_layer=lambda *args, **kwargs: nn.LeakyReLU(*args, negative_slope=0.2, **kwargs),
|
458 |
+
**conv_kwargs),
|
459 |
+
ConcatTupleLayer()
|
460 |
+
]
|
461 |
+
sequence.append(cur_model)
|
462 |
+
|
463 |
+
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
|
464 |
+
|
465 |
+
for n in range(len(sequence)):
|
466 |
+
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
|
467 |
+
|
468 |
+
def get_all_activations(self, x):
|
469 |
+
res = [x]
|
470 |
+
for n in range(self.n_layers + 2):
|
471 |
+
model = getattr(self, 'model' + str(n))
|
472 |
+
res.append(model(res[-1]))
|
473 |
+
return res[1:]
|
474 |
+
|
475 |
+
def forward(self, x):
|
476 |
+
act = self.get_all_activations(x)
|
477 |
+
feats = []
|
478 |
+
for out in act[:-1]:
|
479 |
+
if isinstance(out, tuple):
|
480 |
+
if torch.is_tensor(out[1]):
|
481 |
+
out = torch.cat(out, dim=1)
|
482 |
+
else:
|
483 |
+
out = out[0]
|
484 |
+
feats.append(out)
|
485 |
+
return act[-1], feats
|
annotator/lama/saicinpainting/training/modules/multidilated_conv.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import random
|
4 |
+
from annotator.lama.saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
|
5 |
+
|
6 |
+
class MultidilatedConv(nn.Module):
|
7 |
+
def __init__(self, in_dim, out_dim, kernel_size, dilation_num=3, comb_mode='sum', equal_dim=True,
|
8 |
+
shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs):
|
9 |
+
super().__init__()
|
10 |
+
convs = []
|
11 |
+
self.equal_dim = equal_dim
|
12 |
+
assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode
|
13 |
+
if comb_mode in ('cat_out', 'cat_both'):
|
14 |
+
self.cat_out = True
|
15 |
+
if equal_dim:
|
16 |
+
assert out_dim % dilation_num == 0
|
17 |
+
out_dims = [out_dim // dilation_num] * dilation_num
|
18 |
+
self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], [])
|
19 |
+
else:
|
20 |
+
out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
|
21 |
+
out_dims.append(out_dim - sum(out_dims))
|
22 |
+
index = []
|
23 |
+
starts = [0] + out_dims[:-1]
|
24 |
+
lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)]
|
25 |
+
for i in range(out_dims[-1]):
|
26 |
+
for j in range(dilation_num):
|
27 |
+
index += list(range(starts[j], starts[j] + lengths[j]))
|
28 |
+
starts[j] += lengths[j]
|
29 |
+
self.index = index
|
30 |
+
assert(len(index) == out_dim)
|
31 |
+
self.out_dims = out_dims
|
32 |
+
else:
|
33 |
+
self.cat_out = False
|
34 |
+
self.out_dims = [out_dim] * dilation_num
|
35 |
+
|
36 |
+
if comb_mode in ('cat_in', 'cat_both'):
|
37 |
+
if equal_dim:
|
38 |
+
assert in_dim % dilation_num == 0
|
39 |
+
in_dims = [in_dim // dilation_num] * dilation_num
|
40 |
+
else:
|
41 |
+
in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
|
42 |
+
in_dims.append(in_dim - sum(in_dims))
|
43 |
+
self.in_dims = in_dims
|
44 |
+
self.cat_in = True
|
45 |
+
else:
|
46 |
+
self.cat_in = False
|
47 |
+
self.in_dims = [in_dim] * dilation_num
|
48 |
+
|
49 |
+
conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d
|
50 |
+
dilation = min_dilation
|
51 |
+
for i in range(dilation_num):
|
52 |
+
if isinstance(padding, int):
|
53 |
+
cur_padding = padding * dilation
|
54 |
+
else:
|
55 |
+
cur_padding = padding[i]
|
56 |
+
convs.append(conv_type(
|
57 |
+
self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs
|
58 |
+
))
|
59 |
+
if i > 0 and shared_weights:
|
60 |
+
convs[-1].weight = convs[0].weight
|
61 |
+
convs[-1].bias = convs[0].bias
|
62 |
+
dilation *= 2
|
63 |
+
self.convs = nn.ModuleList(convs)
|
64 |
+
|
65 |
+
self.shuffle_in_channels = shuffle_in_channels
|
66 |
+
if self.shuffle_in_channels:
|
67 |
+
# shuffle list as shuffling of tensors is nondeterministic
|
68 |
+
in_channels_permute = list(range(in_dim))
|
69 |
+
random.shuffle(in_channels_permute)
|
70 |
+
# save as buffer so it is saved and loaded with checkpoint
|
71 |
+
self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute))
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
if self.shuffle_in_channels:
|
75 |
+
x = x[:, self.in_channels_permute]
|
76 |
+
|
77 |
+
outs = []
|
78 |
+
if self.cat_in:
|
79 |
+
if self.equal_dim:
|
80 |
+
x = x.chunk(len(self.convs), dim=1)
|
81 |
+
else:
|
82 |
+
new_x = []
|
83 |
+
start = 0
|
84 |
+
for dim in self.in_dims:
|
85 |
+
new_x.append(x[:, start:start+dim])
|
86 |
+
start += dim
|
87 |
+
x = new_x
|
88 |
+
for i, conv in enumerate(self.convs):
|
89 |
+
if self.cat_in:
|
90 |
+
input = x[i]
|
91 |
+
else:
|
92 |
+
input = x
|
93 |
+
outs.append(conv(input))
|
94 |
+
if self.cat_out:
|
95 |
+
out = torch.cat(outs, dim=1)[:, self.index]
|
96 |
+
else:
|
97 |
+
out = sum(outs)
|
98 |
+
return out
|
annotator/lama/saicinpainting/training/modules/multiscale.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Union, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from annotator.lama.saicinpainting.training.modules.base import get_conv_block_ctor, get_activation
|
8 |
+
from annotator.lama.saicinpainting.training.modules.pix2pixhd import ResnetBlock
|
9 |
+
|
10 |
+
|
11 |
+
class ResNetHead(nn.Module):
|
12 |
+
def __init__(self, input_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
|
13 |
+
padding_type='reflect', conv_kind='default', activation=nn.ReLU(True)):
|
14 |
+
assert (n_blocks >= 0)
|
15 |
+
super(ResNetHead, self).__init__()
|
16 |
+
|
17 |
+
conv_layer = get_conv_block_ctor(conv_kind)
|
18 |
+
|
19 |
+
model = [nn.ReflectionPad2d(3),
|
20 |
+
conv_layer(input_nc, ngf, kernel_size=7, padding=0),
|
21 |
+
norm_layer(ngf),
|
22 |
+
activation]
|
23 |
+
|
24 |
+
### downsample
|
25 |
+
for i in range(n_downsampling):
|
26 |
+
mult = 2 ** i
|
27 |
+
model += [conv_layer(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
|
28 |
+
norm_layer(ngf * mult * 2),
|
29 |
+
activation]
|
30 |
+
|
31 |
+
mult = 2 ** n_downsampling
|
32 |
+
|
33 |
+
### resnet blocks
|
34 |
+
for i in range(n_blocks):
|
35 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
|
36 |
+
conv_kind=conv_kind)]
|
37 |
+
|
38 |
+
self.model = nn.Sequential(*model)
|
39 |
+
|
40 |
+
def forward(self, input):
|
41 |
+
return self.model(input)
|
42 |
+
|
43 |
+
|
44 |
+
class ResNetTail(nn.Module):
|
45 |
+
def __init__(self, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
|
46 |
+
padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
|
47 |
+
up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0,
|
48 |
+
add_in_proj=None):
|
49 |
+
assert (n_blocks >= 0)
|
50 |
+
super(ResNetTail, self).__init__()
|
51 |
+
|
52 |
+
mult = 2 ** n_downsampling
|
53 |
+
|
54 |
+
model = []
|
55 |
+
|
56 |
+
if add_in_proj is not None:
|
57 |
+
model.append(nn.Conv2d(add_in_proj, ngf * mult, kernel_size=1))
|
58 |
+
|
59 |
+
### resnet blocks
|
60 |
+
for i in range(n_blocks):
|
61 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
|
62 |
+
conv_kind=conv_kind)]
|
63 |
+
|
64 |
+
### upsample
|
65 |
+
for i in range(n_downsampling):
|
66 |
+
mult = 2 ** (n_downsampling - i)
|
67 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
|
68 |
+
output_padding=1),
|
69 |
+
up_norm_layer(int(ngf * mult / 2)),
|
70 |
+
up_activation]
|
71 |
+
self.model = nn.Sequential(*model)
|
72 |
+
|
73 |
+
out_layers = []
|
74 |
+
for _ in range(out_extra_layers_n):
|
75 |
+
out_layers += [nn.Conv2d(ngf, ngf, kernel_size=1, padding=0),
|
76 |
+
up_norm_layer(ngf),
|
77 |
+
up_activation]
|
78 |
+
out_layers += [nn.ReflectionPad2d(3),
|
79 |
+
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
80 |
+
|
81 |
+
if add_out_act:
|
82 |
+
out_layers.append(get_activation('tanh' if add_out_act is True else add_out_act))
|
83 |
+
|
84 |
+
self.out_proj = nn.Sequential(*out_layers)
|
85 |
+
|
86 |
+
def forward(self, input, return_last_act=False):
|
87 |
+
features = self.model(input)
|
88 |
+
out = self.out_proj(features)
|
89 |
+
if return_last_act:
|
90 |
+
return out, features
|
91 |
+
else:
|
92 |
+
return out
|
93 |
+
|
94 |
+
|
95 |
+
class MultiscaleResNet(nn.Module):
|
96 |
+
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=2, n_blocks_head=2, n_blocks_tail=6, n_scales=3,
|
97 |
+
norm_layer=nn.BatchNorm2d, padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
|
98 |
+
up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0,
|
99 |
+
out_cumulative=False, return_only_hr=False):
|
100 |
+
super().__init__()
|
101 |
+
|
102 |
+
self.heads = nn.ModuleList([ResNetHead(input_nc, ngf=ngf, n_downsampling=n_downsampling,
|
103 |
+
n_blocks=n_blocks_head, norm_layer=norm_layer, padding_type=padding_type,
|
104 |
+
conv_kind=conv_kind, activation=activation)
|
105 |
+
for i in range(n_scales)])
|
106 |
+
tail_in_feats = ngf * (2 ** n_downsampling) + ngf
|
107 |
+
self.tails = nn.ModuleList([ResNetTail(output_nc,
|
108 |
+
ngf=ngf, n_downsampling=n_downsampling,
|
109 |
+
n_blocks=n_blocks_tail, norm_layer=norm_layer, padding_type=padding_type,
|
110 |
+
conv_kind=conv_kind, activation=activation, up_norm_layer=up_norm_layer,
|
111 |
+
up_activation=up_activation, add_out_act=add_out_act,
|
112 |
+
out_extra_layers_n=out_extra_layers_n,
|
113 |
+
add_in_proj=None if (i == n_scales - 1) else tail_in_feats)
|
114 |
+
for i in range(n_scales)])
|
115 |
+
|
116 |
+
self.out_cumulative = out_cumulative
|
117 |
+
self.return_only_hr = return_only_hr
|
118 |
+
|
119 |
+
@property
|
120 |
+
def num_scales(self):
|
121 |
+
return len(self.heads)
|
122 |
+
|
123 |
+
def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \
|
124 |
+
-> Union[torch.Tensor, List[torch.Tensor]]:
|
125 |
+
"""
|
126 |
+
:param ms_inputs: List of inputs of different resolutions from HR to LR
|
127 |
+
:param smallest_scales_num: int or None, number of smallest scales to take at input
|
128 |
+
:return: Depending on return_only_hr:
|
129 |
+
True: Only the most HR output
|
130 |
+
False: List of outputs of different resolutions from HR to LR
|
131 |
+
"""
|
132 |
+
if smallest_scales_num is None:
|
133 |
+
assert len(self.heads) == len(ms_inputs), (len(self.heads), len(ms_inputs), smallest_scales_num)
|
134 |
+
smallest_scales_num = len(self.heads)
|
135 |
+
else:
|
136 |
+
assert smallest_scales_num == len(ms_inputs) <= len(self.heads), (len(self.heads), len(ms_inputs), smallest_scales_num)
|
137 |
+
|
138 |
+
cur_heads = self.heads[-smallest_scales_num:]
|
139 |
+
ms_features = [cur_head(cur_inp) for cur_head, cur_inp in zip(cur_heads, ms_inputs)]
|
140 |
+
|
141 |
+
all_outputs = []
|
142 |
+
prev_tail_features = None
|
143 |
+
for i in range(len(ms_features)):
|
144 |
+
scale_i = -i - 1
|
145 |
+
|
146 |
+
cur_tail_input = ms_features[-i - 1]
|
147 |
+
if prev_tail_features is not None:
|
148 |
+
if prev_tail_features.shape != cur_tail_input.shape:
|
149 |
+
prev_tail_features = F.interpolate(prev_tail_features, size=cur_tail_input.shape[2:],
|
150 |
+
mode='bilinear', align_corners=False)
|
151 |
+
cur_tail_input = torch.cat((cur_tail_input, prev_tail_features), dim=1)
|
152 |
+
|
153 |
+
cur_out, cur_tail_feats = self.tails[scale_i](cur_tail_input, return_last_act=True)
|
154 |
+
|
155 |
+
prev_tail_features = cur_tail_feats
|
156 |
+
all_outputs.append(cur_out)
|
157 |
+
|
158 |
+
if self.out_cumulative:
|
159 |
+
all_outputs_cum = [all_outputs[0]]
|
160 |
+
for i in range(1, len(ms_features)):
|
161 |
+
cur_out = all_outputs[i]
|
162 |
+
cur_out_cum = cur_out + F.interpolate(all_outputs_cum[-1], size=cur_out.shape[2:],
|
163 |
+
mode='bilinear', align_corners=False)
|
164 |
+
all_outputs_cum.append(cur_out_cum)
|
165 |
+
all_outputs = all_outputs_cum
|
166 |
+
|
167 |
+
if self.return_only_hr:
|
168 |
+
return all_outputs[-1]
|
169 |
+
else:
|
170 |
+
return all_outputs[::-1]
|
171 |
+
|
172 |
+
|
173 |
+
class MultiscaleDiscriminatorSimple(nn.Module):
|
174 |
+
def __init__(self, ms_impl):
|
175 |
+
super().__init__()
|
176 |
+
self.ms_impl = nn.ModuleList(ms_impl)
|
177 |
+
|
178 |
+
@property
|
179 |
+
def num_scales(self):
|
180 |
+
return len(self.ms_impl)
|
181 |
+
|
182 |
+
def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \
|
183 |
+
-> List[Tuple[torch.Tensor, List[torch.Tensor]]]:
|
184 |
+
"""
|
185 |
+
:param ms_inputs: List of inputs of different resolutions from HR to LR
|
186 |
+
:param smallest_scales_num: int or None, number of smallest scales to take at input
|
187 |
+
:return: List of pairs (prediction, features) for different resolutions from HR to LR
|
188 |
+
"""
|
189 |
+
if smallest_scales_num is None:
|
190 |
+
assert len(self.ms_impl) == len(ms_inputs), (len(self.ms_impl), len(ms_inputs), smallest_scales_num)
|
191 |
+
smallest_scales_num = len(self.heads)
|
192 |
+
else:
|
193 |
+
assert smallest_scales_num == len(ms_inputs) <= len(self.ms_impl), \
|
194 |
+
(len(self.ms_impl), len(ms_inputs), smallest_scales_num)
|
195 |
+
|
196 |
+
return [cur_discr(cur_input) for cur_discr, cur_input in zip(self.ms_impl[-smallest_scales_num:], ms_inputs)]
|
197 |
+
|
198 |
+
|
199 |
+
class SingleToMultiScaleInputMixin:
|
200 |
+
def forward(self, x: torch.Tensor) -> List:
|
201 |
+
orig_height, orig_width = x.shape[2:]
|
202 |
+
factors = [2 ** i for i in range(self.num_scales)]
|
203 |
+
ms_inputs = [F.interpolate(x, size=(orig_height // f, orig_width // f), mode='bilinear', align_corners=False)
|
204 |
+
for f in factors]
|
205 |
+
return super().forward(ms_inputs)
|
206 |
+
|
207 |
+
|
208 |
+
class GeneratorMultiToSingleOutputMixin:
|
209 |
+
def forward(self, x):
|
210 |
+
return super().forward(x)[0]
|
211 |
+
|
212 |
+
|
213 |
+
class DiscriminatorMultiToSingleOutputMixin:
|
214 |
+
def forward(self, x):
|
215 |
+
out_feat_tuples = super().forward(x)
|
216 |
+
return out_feat_tuples[0][0], [f for _, flist in out_feat_tuples for f in flist]
|
217 |
+
|
218 |
+
|
219 |
+
class DiscriminatorMultiToSingleOutputStackedMixin:
|
220 |
+
def __init__(self, *args, return_feats_only_levels=None, **kwargs):
|
221 |
+
super().__init__(*args, **kwargs)
|
222 |
+
self.return_feats_only_levels = return_feats_only_levels
|
223 |
+
|
224 |
+
def forward(self, x):
|
225 |
+
out_feat_tuples = super().forward(x)
|
226 |
+
outs = [out for out, _ in out_feat_tuples]
|
227 |
+
scaled_outs = [outs[0]] + [F.interpolate(cur_out, size=outs[0].shape[-2:],
|
228 |
+
mode='bilinear', align_corners=False)
|
229 |
+
for cur_out in outs[1:]]
|
230 |
+
out = torch.cat(scaled_outs, dim=1)
|
231 |
+
if self.return_feats_only_levels is not None:
|
232 |
+
feat_lists = [out_feat_tuples[i][1] for i in self.return_feats_only_levels]
|
233 |
+
else:
|
234 |
+
feat_lists = [flist for _, flist in out_feat_tuples]
|
235 |
+
feats = [f for flist in feat_lists for f in flist]
|
236 |
+
return out, feats
|
237 |
+
|
238 |
+
|
239 |
+
class MultiscaleDiscrSingleInput(SingleToMultiScaleInputMixin, DiscriminatorMultiToSingleOutputStackedMixin, MultiscaleDiscriminatorSimple):
|
240 |
+
pass
|
241 |
+
|
242 |
+
|
243 |
+
class MultiscaleResNetSingle(GeneratorMultiToSingleOutputMixin, SingleToMultiScaleInputMixin, MultiscaleResNet):
|
244 |
+
pass
|
annotator/lama/saicinpainting/training/modules/pix2pixhd.py
ADDED
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# original: https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py
|
2 |
+
import collections
|
3 |
+
from functools import partial
|
4 |
+
import functools
|
5 |
+
import logging
|
6 |
+
from collections import defaultdict
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from annotator.lama.saicinpainting.training.modules.base import BaseDiscriminator, deconv_factory, get_conv_block_ctor, get_norm_layer, get_activation
|
12 |
+
from annotator.lama.saicinpainting.training.modules.ffc import FFCResnetBlock
|
13 |
+
from annotator.lama.saicinpainting.training.modules.multidilated_conv import MultidilatedConv
|
14 |
+
|
15 |
+
class DotDict(defaultdict):
|
16 |
+
# https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary
|
17 |
+
"""dot.notation access to dictionary attributes"""
|
18 |
+
__getattr__ = defaultdict.get
|
19 |
+
__setattr__ = defaultdict.__setitem__
|
20 |
+
__delattr__ = defaultdict.__delitem__
|
21 |
+
|
22 |
+
class Identity(nn.Module):
|
23 |
+
def __init__(self):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
return x
|
28 |
+
|
29 |
+
|
30 |
+
class ResnetBlock(nn.Module):
|
31 |
+
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',
|
32 |
+
dilation=1, in_dim=None, groups=1, second_dilation=None):
|
33 |
+
super(ResnetBlock, self).__init__()
|
34 |
+
self.in_dim = in_dim
|
35 |
+
self.dim = dim
|
36 |
+
if second_dilation is None:
|
37 |
+
second_dilation = dilation
|
38 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,
|
39 |
+
conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,
|
40 |
+
second_dilation=second_dilation)
|
41 |
+
|
42 |
+
if self.in_dim is not None:
|
43 |
+
self.input_conv = nn.Conv2d(in_dim, dim, 1)
|
44 |
+
|
45 |
+
self.out_channnels = dim
|
46 |
+
|
47 |
+
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',
|
48 |
+
dilation=1, in_dim=None, groups=1, second_dilation=1):
|
49 |
+
conv_layer = get_conv_block_ctor(conv_kind)
|
50 |
+
|
51 |
+
conv_block = []
|
52 |
+
p = 0
|
53 |
+
if padding_type == 'reflect':
|
54 |
+
conv_block += [nn.ReflectionPad2d(dilation)]
|
55 |
+
elif padding_type == 'replicate':
|
56 |
+
conv_block += [nn.ReplicationPad2d(dilation)]
|
57 |
+
elif padding_type == 'zero':
|
58 |
+
p = dilation
|
59 |
+
else:
|
60 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
61 |
+
|
62 |
+
if in_dim is None:
|
63 |
+
in_dim = dim
|
64 |
+
|
65 |
+
conv_block += [conv_layer(in_dim, dim, kernel_size=3, padding=p, dilation=dilation),
|
66 |
+
norm_layer(dim),
|
67 |
+
activation]
|
68 |
+
if use_dropout:
|
69 |
+
conv_block += [nn.Dropout(0.5)]
|
70 |
+
|
71 |
+
p = 0
|
72 |
+
if padding_type == 'reflect':
|
73 |
+
conv_block += [nn.ReflectionPad2d(second_dilation)]
|
74 |
+
elif padding_type == 'replicate':
|
75 |
+
conv_block += [nn.ReplicationPad2d(second_dilation)]
|
76 |
+
elif padding_type == 'zero':
|
77 |
+
p = second_dilation
|
78 |
+
else:
|
79 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
80 |
+
conv_block += [conv_layer(dim, dim, kernel_size=3, padding=p, dilation=second_dilation, groups=groups),
|
81 |
+
norm_layer(dim)]
|
82 |
+
|
83 |
+
return nn.Sequential(*conv_block)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
x_before = x
|
87 |
+
if self.in_dim is not None:
|
88 |
+
x = self.input_conv(x)
|
89 |
+
out = x + self.conv_block(x_before)
|
90 |
+
return out
|
91 |
+
|
92 |
+
class ResnetBlock5x5(nn.Module):
|
93 |
+
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',
|
94 |
+
dilation=1, in_dim=None, groups=1, second_dilation=None):
|
95 |
+
super(ResnetBlock5x5, self).__init__()
|
96 |
+
self.in_dim = in_dim
|
97 |
+
self.dim = dim
|
98 |
+
if second_dilation is None:
|
99 |
+
second_dilation = dilation
|
100 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,
|
101 |
+
conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,
|
102 |
+
second_dilation=second_dilation)
|
103 |
+
|
104 |
+
if self.in_dim is not None:
|
105 |
+
self.input_conv = nn.Conv2d(in_dim, dim, 1)
|
106 |
+
|
107 |
+
self.out_channnels = dim
|
108 |
+
|
109 |
+
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',
|
110 |
+
dilation=1, in_dim=None, groups=1, second_dilation=1):
|
111 |
+
conv_layer = get_conv_block_ctor(conv_kind)
|
112 |
+
|
113 |
+
conv_block = []
|
114 |
+
p = 0
|
115 |
+
if padding_type == 'reflect':
|
116 |
+
conv_block += [nn.ReflectionPad2d(dilation * 2)]
|
117 |
+
elif padding_type == 'replicate':
|
118 |
+
conv_block += [nn.ReplicationPad2d(dilation * 2)]
|
119 |
+
elif padding_type == 'zero':
|
120 |
+
p = dilation * 2
|
121 |
+
else:
|
122 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
123 |
+
|
124 |
+
if in_dim is None:
|
125 |
+
in_dim = dim
|
126 |
+
|
127 |
+
conv_block += [conv_layer(in_dim, dim, kernel_size=5, padding=p, dilation=dilation),
|
128 |
+
norm_layer(dim),
|
129 |
+
activation]
|
130 |
+
if use_dropout:
|
131 |
+
conv_block += [nn.Dropout(0.5)]
|
132 |
+
|
133 |
+
p = 0
|
134 |
+
if padding_type == 'reflect':
|
135 |
+
conv_block += [nn.ReflectionPad2d(second_dilation * 2)]
|
136 |
+
elif padding_type == 'replicate':
|
137 |
+
conv_block += [nn.ReplicationPad2d(second_dilation * 2)]
|
138 |
+
elif padding_type == 'zero':
|
139 |
+
p = second_dilation * 2
|
140 |
+
else:
|
141 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
142 |
+
conv_block += [conv_layer(dim, dim, kernel_size=5, padding=p, dilation=second_dilation, groups=groups),
|
143 |
+
norm_layer(dim)]
|
144 |
+
|
145 |
+
return nn.Sequential(*conv_block)
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
x_before = x
|
149 |
+
if self.in_dim is not None:
|
150 |
+
x = self.input_conv(x)
|
151 |
+
out = x + self.conv_block(x_before)
|
152 |
+
return out
|
153 |
+
|
154 |
+
|
155 |
+
class MultidilatedResnetBlock(nn.Module):
|
156 |
+
def __init__(self, dim, padding_type, conv_layer, norm_layer, activation=nn.ReLU(True), use_dropout=False):
|
157 |
+
super().__init__()
|
158 |
+
self.conv_block = self.build_conv_block(dim, padding_type, conv_layer, norm_layer, activation, use_dropout)
|
159 |
+
|
160 |
+
def build_conv_block(self, dim, padding_type, conv_layer, norm_layer, activation, use_dropout, dilation=1):
|
161 |
+
conv_block = []
|
162 |
+
conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),
|
163 |
+
norm_layer(dim),
|
164 |
+
activation]
|
165 |
+
if use_dropout:
|
166 |
+
conv_block += [nn.Dropout(0.5)]
|
167 |
+
|
168 |
+
conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),
|
169 |
+
norm_layer(dim)]
|
170 |
+
|
171 |
+
return nn.Sequential(*conv_block)
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
out = x + self.conv_block(x)
|
175 |
+
return out
|
176 |
+
|
177 |
+
|
178 |
+
class MultiDilatedGlobalGenerator(nn.Module):
|
179 |
+
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
|
180 |
+
n_blocks=3, norm_layer=nn.BatchNorm2d,
|
181 |
+
padding_type='reflect', conv_kind='default',
|
182 |
+
deconv_kind='convtranspose', activation=nn.ReLU(True),
|
183 |
+
up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),
|
184 |
+
add_out_act=True, max_features=1024, multidilation_kwargs={},
|
185 |
+
ffc_positions=None, ffc_kwargs={}):
|
186 |
+
assert (n_blocks >= 0)
|
187 |
+
super().__init__()
|
188 |
+
|
189 |
+
conv_layer = get_conv_block_ctor(conv_kind)
|
190 |
+
resnet_conv_layer = functools.partial(get_conv_block_ctor('multidilated'), **multidilation_kwargs)
|
191 |
+
norm_layer = get_norm_layer(norm_layer)
|
192 |
+
if affine is not None:
|
193 |
+
norm_layer = partial(norm_layer, affine=affine)
|
194 |
+
up_norm_layer = get_norm_layer(up_norm_layer)
|
195 |
+
if affine is not None:
|
196 |
+
up_norm_layer = partial(up_norm_layer, affine=affine)
|
197 |
+
|
198 |
+
model = [nn.ReflectionPad2d(3),
|
199 |
+
conv_layer(input_nc, ngf, kernel_size=7, padding=0),
|
200 |
+
norm_layer(ngf),
|
201 |
+
activation]
|
202 |
+
|
203 |
+
identity = Identity()
|
204 |
+
### downsample
|
205 |
+
for i in range(n_downsampling):
|
206 |
+
mult = 2 ** i
|
207 |
+
|
208 |
+
model += [conv_layer(min(max_features, ngf * mult),
|
209 |
+
min(max_features, ngf * mult * 2),
|
210 |
+
kernel_size=3, stride=2, padding=1),
|
211 |
+
norm_layer(min(max_features, ngf * mult * 2)),
|
212 |
+
activation]
|
213 |
+
|
214 |
+
mult = 2 ** n_downsampling
|
215 |
+
feats_num_bottleneck = min(max_features, ngf * mult)
|
216 |
+
|
217 |
+
### resnet blocks
|
218 |
+
for i in range(n_blocks):
|
219 |
+
if ffc_positions is not None and i in ffc_positions:
|
220 |
+
model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,
|
221 |
+
inline=True, **ffc_kwargs)]
|
222 |
+
model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,
|
223 |
+
conv_layer=resnet_conv_layer, activation=activation,
|
224 |
+
norm_layer=norm_layer)]
|
225 |
+
|
226 |
+
### upsample
|
227 |
+
for i in range(n_downsampling):
|
228 |
+
mult = 2 ** (n_downsampling - i)
|
229 |
+
model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)
|
230 |
+
model += [nn.ReflectionPad2d(3),
|
231 |
+
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
232 |
+
if add_out_act:
|
233 |
+
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
|
234 |
+
self.model = nn.Sequential(*model)
|
235 |
+
|
236 |
+
def forward(self, input):
|
237 |
+
return self.model(input)
|
238 |
+
|
239 |
+
class ConfigGlobalGenerator(nn.Module):
|
240 |
+
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
|
241 |
+
n_blocks=3, norm_layer=nn.BatchNorm2d,
|
242 |
+
padding_type='reflect', conv_kind='default',
|
243 |
+
deconv_kind='convtranspose', activation=nn.ReLU(True),
|
244 |
+
up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),
|
245 |
+
add_out_act=True, max_features=1024,
|
246 |
+
manual_block_spec=[],
|
247 |
+
resnet_block_kind='multidilatedresnetblock',
|
248 |
+
resnet_conv_kind='multidilated',
|
249 |
+
resnet_dilation=1,
|
250 |
+
multidilation_kwargs={}):
|
251 |
+
assert (n_blocks >= 0)
|
252 |
+
super().__init__()
|
253 |
+
|
254 |
+
conv_layer = get_conv_block_ctor(conv_kind)
|
255 |
+
resnet_conv_layer = functools.partial(get_conv_block_ctor(resnet_conv_kind), **multidilation_kwargs)
|
256 |
+
norm_layer = get_norm_layer(norm_layer)
|
257 |
+
if affine is not None:
|
258 |
+
norm_layer = partial(norm_layer, affine=affine)
|
259 |
+
up_norm_layer = get_norm_layer(up_norm_layer)
|
260 |
+
if affine is not None:
|
261 |
+
up_norm_layer = partial(up_norm_layer, affine=affine)
|
262 |
+
|
263 |
+
model = [nn.ReflectionPad2d(3),
|
264 |
+
conv_layer(input_nc, ngf, kernel_size=7, padding=0),
|
265 |
+
norm_layer(ngf),
|
266 |
+
activation]
|
267 |
+
|
268 |
+
identity = Identity()
|
269 |
+
|
270 |
+
### downsample
|
271 |
+
for i in range(n_downsampling):
|
272 |
+
mult = 2 ** i
|
273 |
+
model += [conv_layer(min(max_features, ngf * mult),
|
274 |
+
min(max_features, ngf * mult * 2),
|
275 |
+
kernel_size=3, stride=2, padding=1),
|
276 |
+
norm_layer(min(max_features, ngf * mult * 2)),
|
277 |
+
activation]
|
278 |
+
|
279 |
+
mult = 2 ** n_downsampling
|
280 |
+
feats_num_bottleneck = min(max_features, ngf * mult)
|
281 |
+
|
282 |
+
if len(manual_block_spec) == 0:
|
283 |
+
manual_block_spec = [
|
284 |
+
DotDict(lambda : None, {
|
285 |
+
'n_blocks': n_blocks,
|
286 |
+
'use_default': True})
|
287 |
+
]
|
288 |
+
|
289 |
+
### resnet blocks
|
290 |
+
for block_spec in manual_block_spec:
|
291 |
+
def make_and_add_blocks(model, block_spec):
|
292 |
+
block_spec = DotDict(lambda : None, block_spec)
|
293 |
+
if not block_spec.use_default:
|
294 |
+
resnet_conv_layer = functools.partial(get_conv_block_ctor(block_spec.resnet_conv_kind), **block_spec.multidilation_kwargs)
|
295 |
+
resnet_conv_kind = block_spec.resnet_conv_kind
|
296 |
+
resnet_block_kind = block_spec.resnet_block_kind
|
297 |
+
if block_spec.resnet_dilation is not None:
|
298 |
+
resnet_dilation = block_spec.resnet_dilation
|
299 |
+
for i in range(block_spec.n_blocks):
|
300 |
+
if resnet_block_kind == "multidilatedresnetblock":
|
301 |
+
model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,
|
302 |
+
conv_layer=resnet_conv_layer, activation=activation,
|
303 |
+
norm_layer=norm_layer)]
|
304 |
+
if resnet_block_kind == "resnetblock":
|
305 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
|
306 |
+
conv_kind=resnet_conv_kind)]
|
307 |
+
if resnet_block_kind == "resnetblock5x5":
|
308 |
+
model += [ResnetBlock5x5(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
|
309 |
+
conv_kind=resnet_conv_kind)]
|
310 |
+
if resnet_block_kind == "resnetblockdwdil":
|
311 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
|
312 |
+
conv_kind=resnet_conv_kind, dilation=resnet_dilation, second_dilation=resnet_dilation)]
|
313 |
+
make_and_add_blocks(model, block_spec)
|
314 |
+
|
315 |
+
### upsample
|
316 |
+
for i in range(n_downsampling):
|
317 |
+
mult = 2 ** (n_downsampling - i)
|
318 |
+
model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)
|
319 |
+
model += [nn.ReflectionPad2d(3),
|
320 |
+
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
321 |
+
if add_out_act:
|
322 |
+
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
|
323 |
+
self.model = nn.Sequential(*model)
|
324 |
+
|
325 |
+
def forward(self, input):
|
326 |
+
return self.model(input)
|
327 |
+
|
328 |
+
|
329 |
+
def make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs):
|
330 |
+
blocks = []
|
331 |
+
for i in range(dilated_blocks_n):
|
332 |
+
if dilation_block_kind == 'simple':
|
333 |
+
blocks.append(ResnetBlock(**dilated_block_kwargs, dilation=2 ** (i + 1)))
|
334 |
+
elif dilation_block_kind == 'multi':
|
335 |
+
blocks.append(MultidilatedResnetBlock(**dilated_block_kwargs))
|
336 |
+
else:
|
337 |
+
raise ValueError(f'dilation_block_kind could not be "{dilation_block_kind}"')
|
338 |
+
return blocks
|
339 |
+
|
340 |
+
|
341 |
+
class GlobalGenerator(nn.Module):
|
342 |
+
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
|
343 |
+
padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
|
344 |
+
up_norm_layer=nn.BatchNorm2d, affine=None,
|
345 |
+
up_activation=nn.ReLU(True), dilated_blocks_n=0, dilated_blocks_n_start=0,
|
346 |
+
dilated_blocks_n_middle=0,
|
347 |
+
add_out_act=True,
|
348 |
+
max_features=1024, is_resblock_depthwise=False,
|
349 |
+
ffc_positions=None, ffc_kwargs={}, dilation=1, second_dilation=None,
|
350 |
+
dilation_block_kind='simple', multidilation_kwargs={}):
|
351 |
+
assert (n_blocks >= 0)
|
352 |
+
super().__init__()
|
353 |
+
|
354 |
+
conv_layer = get_conv_block_ctor(conv_kind)
|
355 |
+
norm_layer = get_norm_layer(norm_layer)
|
356 |
+
if affine is not None:
|
357 |
+
norm_layer = partial(norm_layer, affine=affine)
|
358 |
+
up_norm_layer = get_norm_layer(up_norm_layer)
|
359 |
+
if affine is not None:
|
360 |
+
up_norm_layer = partial(up_norm_layer, affine=affine)
|
361 |
+
|
362 |
+
if ffc_positions is not None:
|
363 |
+
ffc_positions = collections.Counter(ffc_positions)
|
364 |
+
|
365 |
+
model = [nn.ReflectionPad2d(3),
|
366 |
+
conv_layer(input_nc, ngf, kernel_size=7, padding=0),
|
367 |
+
norm_layer(ngf),
|
368 |
+
activation]
|
369 |
+
|
370 |
+
identity = Identity()
|
371 |
+
### downsample
|
372 |
+
for i in range(n_downsampling):
|
373 |
+
mult = 2 ** i
|
374 |
+
|
375 |
+
model += [conv_layer(min(max_features, ngf * mult),
|
376 |
+
min(max_features, ngf * mult * 2),
|
377 |
+
kernel_size=3, stride=2, padding=1),
|
378 |
+
norm_layer(min(max_features, ngf * mult * 2)),
|
379 |
+
activation]
|
380 |
+
|
381 |
+
mult = 2 ** n_downsampling
|
382 |
+
feats_num_bottleneck = min(max_features, ngf * mult)
|
383 |
+
|
384 |
+
dilated_block_kwargs = dict(dim=feats_num_bottleneck, padding_type=padding_type,
|
385 |
+
activation=activation, norm_layer=norm_layer)
|
386 |
+
if dilation_block_kind == 'simple':
|
387 |
+
dilated_block_kwargs['conv_kind'] = conv_kind
|
388 |
+
elif dilation_block_kind == 'multi':
|
389 |
+
dilated_block_kwargs['conv_layer'] = functools.partial(
|
390 |
+
get_conv_block_ctor('multidilated'), **multidilation_kwargs)
|
391 |
+
|
392 |
+
# dilated blocks at the start of the bottleneck sausage
|
393 |
+
if dilated_blocks_n_start is not None and dilated_blocks_n_start > 0:
|
394 |
+
model += make_dil_blocks(dilated_blocks_n_start, dilation_block_kind, dilated_block_kwargs)
|
395 |
+
|
396 |
+
# resnet blocks
|
397 |
+
for i in range(n_blocks):
|
398 |
+
# dilated blocks at the middle of the bottleneck sausage
|
399 |
+
if i == n_blocks // 2 and dilated_blocks_n_middle is not None and dilated_blocks_n_middle > 0:
|
400 |
+
model += make_dil_blocks(dilated_blocks_n_middle, dilation_block_kind, dilated_block_kwargs)
|
401 |
+
|
402 |
+
if ffc_positions is not None and i in ffc_positions:
|
403 |
+
for _ in range(ffc_positions[i]): # same position can occur more than once
|
404 |
+
model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,
|
405 |
+
inline=True, **ffc_kwargs)]
|
406 |
+
|
407 |
+
if is_resblock_depthwise:
|
408 |
+
resblock_groups = feats_num_bottleneck
|
409 |
+
else:
|
410 |
+
resblock_groups = 1
|
411 |
+
|
412 |
+
model += [ResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation=activation,
|
413 |
+
norm_layer=norm_layer, conv_kind=conv_kind, groups=resblock_groups,
|
414 |
+
dilation=dilation, second_dilation=second_dilation)]
|
415 |
+
|
416 |
+
|
417 |
+
# dilated blocks at the end of the bottleneck sausage
|
418 |
+
if dilated_blocks_n is not None and dilated_blocks_n > 0:
|
419 |
+
model += make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs)
|
420 |
+
|
421 |
+
# upsample
|
422 |
+
for i in range(n_downsampling):
|
423 |
+
mult = 2 ** (n_downsampling - i)
|
424 |
+
model += [nn.ConvTranspose2d(min(max_features, ngf * mult),
|
425 |
+
min(max_features, int(ngf * mult / 2)),
|
426 |
+
kernel_size=3, stride=2, padding=1, output_padding=1),
|
427 |
+
up_norm_layer(min(max_features, int(ngf * mult / 2))),
|
428 |
+
up_activation]
|
429 |
+
model += [nn.ReflectionPad2d(3),
|
430 |
+
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
431 |
+
if add_out_act:
|
432 |
+
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
|
433 |
+
self.model = nn.Sequential(*model)
|
434 |
+
|
435 |
+
def forward(self, input):
|
436 |
+
return self.model(input)
|
437 |
+
|
438 |
+
|
439 |
+
class GlobalGeneratorGated(GlobalGenerator):
|
440 |
+
def __init__(self, *args, **kwargs):
|
441 |
+
real_kwargs=dict(
|
442 |
+
conv_kind='gated_bn_relu',
|
443 |
+
activation=nn.Identity(),
|
444 |
+
norm_layer=nn.Identity
|
445 |
+
)
|
446 |
+
real_kwargs.update(kwargs)
|
447 |
+
super().__init__(*args, **real_kwargs)
|
448 |
+
|
449 |
+
|
450 |
+
class GlobalGeneratorFromSuperChannels(nn.Module):
|
451 |
+
def __init__(self, input_nc, output_nc, n_downsampling, n_blocks, super_channels, norm_layer="bn", padding_type='reflect', add_out_act=True):
|
452 |
+
super().__init__()
|
453 |
+
self.n_downsampling = n_downsampling
|
454 |
+
norm_layer = get_norm_layer(norm_layer)
|
455 |
+
if type(norm_layer) == functools.partial:
|
456 |
+
use_bias = (norm_layer.func == nn.InstanceNorm2d)
|
457 |
+
else:
|
458 |
+
use_bias = (norm_layer == nn.InstanceNorm2d)
|
459 |
+
|
460 |
+
channels = self.convert_super_channels(super_channels)
|
461 |
+
self.channels = channels
|
462 |
+
|
463 |
+
model = [nn.ReflectionPad2d(3),
|
464 |
+
nn.Conv2d(input_nc, channels[0], kernel_size=7, padding=0, bias=use_bias),
|
465 |
+
norm_layer(channels[0]),
|
466 |
+
nn.ReLU(True)]
|
467 |
+
|
468 |
+
for i in range(n_downsampling): # add downsampling layers
|
469 |
+
mult = 2 ** i
|
470 |
+
model += [nn.Conv2d(channels[0+i], channels[1+i], kernel_size=3, stride=2, padding=1, bias=use_bias),
|
471 |
+
norm_layer(channels[1+i]),
|
472 |
+
nn.ReLU(True)]
|
473 |
+
|
474 |
+
mult = 2 ** n_downsampling
|
475 |
+
|
476 |
+
n_blocks1 = n_blocks // 3
|
477 |
+
n_blocks2 = n_blocks1
|
478 |
+
n_blocks3 = n_blocks - n_blocks1 - n_blocks2
|
479 |
+
|
480 |
+
for i in range(n_blocks1):
|
481 |
+
c = n_downsampling
|
482 |
+
dim = channels[c]
|
483 |
+
model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer)]
|
484 |
+
|
485 |
+
for i in range(n_blocks2):
|
486 |
+
c = n_downsampling+1
|
487 |
+
dim = channels[c]
|
488 |
+
kwargs = {}
|
489 |
+
if i == 0:
|
490 |
+
kwargs = {"in_dim": channels[c-1]}
|
491 |
+
model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]
|
492 |
+
|
493 |
+
for i in range(n_blocks3):
|
494 |
+
c = n_downsampling+2
|
495 |
+
dim = channels[c]
|
496 |
+
kwargs = {}
|
497 |
+
if i == 0:
|
498 |
+
kwargs = {"in_dim": channels[c-1]}
|
499 |
+
model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]
|
500 |
+
|
501 |
+
for i in range(n_downsampling): # add upsampling layers
|
502 |
+
mult = 2 ** (n_downsampling - i)
|
503 |
+
model += [nn.ConvTranspose2d(channels[n_downsampling+3+i],
|
504 |
+
channels[n_downsampling+3+i+1],
|
505 |
+
kernel_size=3, stride=2,
|
506 |
+
padding=1, output_padding=1,
|
507 |
+
bias=use_bias),
|
508 |
+
norm_layer(channels[n_downsampling+3+i+1]),
|
509 |
+
nn.ReLU(True)]
|
510 |
+
model += [nn.ReflectionPad2d(3)]
|
511 |
+
model += [nn.Conv2d(channels[2*n_downsampling+3], output_nc, kernel_size=7, padding=0)]
|
512 |
+
|
513 |
+
if add_out_act:
|
514 |
+
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
|
515 |
+
self.model = nn.Sequential(*model)
|
516 |
+
|
517 |
+
def convert_super_channels(self, super_channels):
|
518 |
+
n_downsampling = self.n_downsampling
|
519 |
+
result = []
|
520 |
+
cnt = 0
|
521 |
+
|
522 |
+
if n_downsampling == 2:
|
523 |
+
N1 = 10
|
524 |
+
elif n_downsampling == 3:
|
525 |
+
N1 = 13
|
526 |
+
else:
|
527 |
+
raise NotImplementedError
|
528 |
+
|
529 |
+
for i in range(0, N1):
|
530 |
+
if i in [1,4,7,10]:
|
531 |
+
channel = super_channels[cnt] * (2 ** cnt)
|
532 |
+
config = {'channel': channel}
|
533 |
+
result.append(channel)
|
534 |
+
logging.info(f"Downsample channels {result[-1]}")
|
535 |
+
cnt += 1
|
536 |
+
|
537 |
+
for i in range(3):
|
538 |
+
for counter, j in enumerate(range(N1 + i * 3, N1 + 3 + i * 3)):
|
539 |
+
if len(super_channels) == 6:
|
540 |
+
channel = super_channels[3] * 4
|
541 |
+
else:
|
542 |
+
channel = super_channels[i + 3] * 4
|
543 |
+
config = {'channel': channel}
|
544 |
+
if counter == 0:
|
545 |
+
result.append(channel)
|
546 |
+
logging.info(f"Bottleneck channels {result[-1]}")
|
547 |
+
cnt = 2
|
548 |
+
|
549 |
+
for i in range(N1+9, N1+21):
|
550 |
+
if i in [22, 25,28]:
|
551 |
+
cnt -= 1
|
552 |
+
if len(super_channels) == 6:
|
553 |
+
channel = super_channels[5 - cnt] * (2 ** cnt)
|
554 |
+
else:
|
555 |
+
channel = super_channels[7 - cnt] * (2 ** cnt)
|
556 |
+
result.append(int(channel))
|
557 |
+
logging.info(f"Upsample channels {result[-1]}")
|
558 |
+
return result
|
559 |
+
|
560 |
+
def forward(self, input):
|
561 |
+
return self.model(input)
|
562 |
+
|
563 |
+
|
564 |
+
# Defines the PatchGAN discriminator with the specified arguments.
|
565 |
+
class NLayerDiscriminator(BaseDiscriminator):
|
566 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,):
|
567 |
+
super().__init__()
|
568 |
+
self.n_layers = n_layers
|
569 |
+
|
570 |
+
kw = 4
|
571 |
+
padw = int(np.ceil((kw-1.0)/2))
|
572 |
+
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
573 |
+
nn.LeakyReLU(0.2, True)]]
|
574 |
+
|
575 |
+
nf = ndf
|
576 |
+
for n in range(1, n_layers):
|
577 |
+
nf_prev = nf
|
578 |
+
nf = min(nf * 2, 512)
|
579 |
+
|
580 |
+
cur_model = []
|
581 |
+
cur_model += [
|
582 |
+
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
|
583 |
+
norm_layer(nf),
|
584 |
+
nn.LeakyReLU(0.2, True)
|
585 |
+
]
|
586 |
+
sequence.append(cur_model)
|
587 |
+
|
588 |
+
nf_prev = nf
|
589 |
+
nf = min(nf * 2, 512)
|
590 |
+
|
591 |
+
cur_model = []
|
592 |
+
cur_model += [
|
593 |
+
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
|
594 |
+
norm_layer(nf),
|
595 |
+
nn.LeakyReLU(0.2, True)
|
596 |
+
]
|
597 |
+
sequence.append(cur_model)
|
598 |
+
|
599 |
+
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
|
600 |
+
|
601 |
+
for n in range(len(sequence)):
|
602 |
+
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
|
603 |
+
|
604 |
+
def get_all_activations(self, x):
|
605 |
+
res = [x]
|
606 |
+
for n in range(self.n_layers + 2):
|
607 |
+
model = getattr(self, 'model' + str(n))
|
608 |
+
res.append(model(res[-1]))
|
609 |
+
return res[1:]
|
610 |
+
|
611 |
+
def forward(self, x):
|
612 |
+
act = self.get_all_activations(x)
|
613 |
+
return act[-1], act[:-1]
|
614 |
+
|
615 |
+
|
616 |
+
class MultidilatedNLayerDiscriminator(BaseDiscriminator):
|
617 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, multidilation_kwargs={}):
|
618 |
+
super().__init__()
|
619 |
+
self.n_layers = n_layers
|
620 |
+
|
621 |
+
kw = 4
|
622 |
+
padw = int(np.ceil((kw-1.0)/2))
|
623 |
+
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
624 |
+
nn.LeakyReLU(0.2, True)]]
|
625 |
+
|
626 |
+
nf = ndf
|
627 |
+
for n in range(1, n_layers):
|
628 |
+
nf_prev = nf
|
629 |
+
nf = min(nf * 2, 512)
|
630 |
+
|
631 |
+
cur_model = []
|
632 |
+
cur_model += [
|
633 |
+
MultidilatedConv(nf_prev, nf, kernel_size=kw, stride=2, padding=[2, 3], **multidilation_kwargs),
|
634 |
+
norm_layer(nf),
|
635 |
+
nn.LeakyReLU(0.2, True)
|
636 |
+
]
|
637 |
+
sequence.append(cur_model)
|
638 |
+
|
639 |
+
nf_prev = nf
|
640 |
+
nf = min(nf * 2, 512)
|
641 |
+
|
642 |
+
cur_model = []
|
643 |
+
cur_model += [
|
644 |
+
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
|
645 |
+
norm_layer(nf),
|
646 |
+
nn.LeakyReLU(0.2, True)
|
647 |
+
]
|
648 |
+
sequence.append(cur_model)
|
649 |
+
|
650 |
+
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
|
651 |
+
|
652 |
+
for n in range(len(sequence)):
|
653 |
+
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
|
654 |
+
|
655 |
+
def get_all_activations(self, x):
|
656 |
+
res = [x]
|
657 |
+
for n in range(self.n_layers + 2):
|
658 |
+
model = getattr(self, 'model' + str(n))
|
659 |
+
res.append(model(res[-1]))
|
660 |
+
return res[1:]
|
661 |
+
|
662 |
+
def forward(self, x):
|
663 |
+
act = self.get_all_activations(x)
|
664 |
+
return act[-1], act[:-1]
|
665 |
+
|
666 |
+
|
667 |
+
class NLayerDiscriminatorAsGen(NLayerDiscriminator):
|
668 |
+
def forward(self, x):
|
669 |
+
return super().forward(x)[0]
|
annotator/lama/saicinpainting/training/modules/spatial_transform.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from kornia.geometry.transform import rotate
|
5 |
+
|
6 |
+
|
7 |
+
class LearnableSpatialTransformWrapper(nn.Module):
|
8 |
+
def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True):
|
9 |
+
super().__init__()
|
10 |
+
self.impl = impl
|
11 |
+
self.angle = torch.rand(1) * angle_init_range
|
12 |
+
if train_angle:
|
13 |
+
self.angle = nn.Parameter(self.angle, requires_grad=True)
|
14 |
+
self.pad_coef = pad_coef
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
if torch.is_tensor(x):
|
18 |
+
return self.inverse_transform(self.impl(self.transform(x)), x)
|
19 |
+
elif isinstance(x, tuple):
|
20 |
+
x_trans = tuple(self.transform(elem) for elem in x)
|
21 |
+
y_trans = self.impl(x_trans)
|
22 |
+
return tuple(self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x))
|
23 |
+
else:
|
24 |
+
raise ValueError(f'Unexpected input type {type(x)}')
|
25 |
+
|
26 |
+
def transform(self, x):
|
27 |
+
height, width = x.shape[2:]
|
28 |
+
pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
|
29 |
+
x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect')
|
30 |
+
x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded))
|
31 |
+
return x_padded_rotated
|
32 |
+
|
33 |
+
def inverse_transform(self, y_padded_rotated, orig_x):
|
34 |
+
height, width = orig_x.shape[2:]
|
35 |
+
pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
|
36 |
+
|
37 |
+
y_padded = rotate(y_padded_rotated, angle=-self.angle.to(y_padded_rotated))
|
38 |
+
y_height, y_width = y_padded.shape[2:]
|
39 |
+
y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w]
|
40 |
+
return y
|
41 |
+
|
42 |
+
|
43 |
+
if __name__ == '__main__':
|
44 |
+
layer = LearnableSpatialTransformWrapper(nn.Identity())
|
45 |
+
x = torch.arange(2* 3 * 15 * 15).view(2, 3, 15, 15).float()
|
46 |
+
y = layer(x)
|
47 |
+
assert x.shape == y.shape
|
48 |
+
assert torch.allclose(x[:, :, 1:, 1:][:, :, :-1, :-1], y[:, :, 1:, 1:][:, :, :-1, :-1])
|
49 |
+
print('all ok')
|
annotator/lama/saicinpainting/training/modules/squeeze_excitation.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
class SELayer(nn.Module):
|
5 |
+
def __init__(self, channel, reduction=16):
|
6 |
+
super(SELayer, self).__init__()
|
7 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
8 |
+
self.fc = nn.Sequential(
|
9 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
10 |
+
nn.ReLU(inplace=True),
|
11 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
12 |
+
nn.Sigmoid()
|
13 |
+
)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
b, c, _, _ = x.size()
|
17 |
+
y = self.avg_pool(x).view(b, c)
|
18 |
+
y = self.fc(y).view(b, c, 1, 1)
|
19 |
+
res = x * y.expand_as(x)
|
20 |
+
return res
|
annotator/lama/saicinpainting/training/trainers/__init__.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import torch
|
3 |
+
from annotator.lama.saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule
|
4 |
+
|
5 |
+
|
6 |
+
def get_training_model_class(kind):
|
7 |
+
if kind == 'default':
|
8 |
+
return DefaultInpaintingTrainingModule
|
9 |
+
|
10 |
+
raise ValueError(f'Unknown trainer module {kind}')
|
11 |
+
|
12 |
+
|
13 |
+
def make_training_model(config):
|
14 |
+
kind = config.training_model.kind
|
15 |
+
kwargs = dict(config.training_model)
|
16 |
+
kwargs.pop('kind')
|
17 |
+
kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp'
|
18 |
+
|
19 |
+
logging.info(f'Make training model {kind}')
|
20 |
+
|
21 |
+
cls = get_training_model_class(kind)
|
22 |
+
return cls(config, **kwargs)
|
23 |
+
|
24 |
+
|
25 |
+
def load_checkpoint(train_config, path, map_location='cuda', strict=True):
|
26 |
+
model = make_training_model(train_config).generator
|
27 |
+
state = torch.load(path, map_location=map_location)
|
28 |
+
model.load_state_dict(state, strict=strict)
|
29 |
+
return model
|
annotator/lama/saicinpainting/training/trainers/base.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import logging
|
3 |
+
from typing import Dict, Tuple
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
import pytorch_lightning as ptl
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
# from torch.utils.data import DistributedSampler
|
11 |
+
|
12 |
+
# from annotator.lama.saicinpainting.evaluation import make_evaluator
|
13 |
+
# from annotator.lama.saicinpainting.training.data.datasets import make_default_train_dataloader, make_default_val_dataloader
|
14 |
+
# from annotator.lama.saicinpainting.training.losses.adversarial import make_discrim_loss
|
15 |
+
# from annotator.lama.saicinpainting.training.losses.perceptual import PerceptualLoss, ResNetPL
|
16 |
+
from annotator.lama.saicinpainting.training.modules import make_generator #, make_discriminator
|
17 |
+
# from annotator.lama.saicinpainting.training.visualizers import make_visualizer
|
18 |
+
from annotator.lama.saicinpainting.utils import add_prefix_to_keys, average_dicts, set_requires_grad, flatten_dict, \
|
19 |
+
get_has_ddp_rank
|
20 |
+
|
21 |
+
LOGGER = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
def make_optimizer(parameters, kind='adamw', **kwargs):
|
25 |
+
if kind == 'adam':
|
26 |
+
optimizer_class = torch.optim.Adam
|
27 |
+
elif kind == 'adamw':
|
28 |
+
optimizer_class = torch.optim.AdamW
|
29 |
+
else:
|
30 |
+
raise ValueError(f'Unknown optimizer kind {kind}')
|
31 |
+
return optimizer_class(parameters, **kwargs)
|
32 |
+
|
33 |
+
|
34 |
+
def update_running_average(result: nn.Module, new_iterate_model: nn.Module, decay=0.999):
|
35 |
+
with torch.no_grad():
|
36 |
+
res_params = dict(result.named_parameters())
|
37 |
+
new_params = dict(new_iterate_model.named_parameters())
|
38 |
+
|
39 |
+
for k in res_params.keys():
|
40 |
+
res_params[k].data.mul_(decay).add_(new_params[k].data, alpha=1 - decay)
|
41 |
+
|
42 |
+
|
43 |
+
def make_multiscale_noise(base_tensor, scales=6, scale_mode='bilinear'):
|
44 |
+
batch_size, _, height, width = base_tensor.shape
|
45 |
+
cur_height, cur_width = height, width
|
46 |
+
result = []
|
47 |
+
align_corners = False if scale_mode in ('bilinear', 'bicubic') else None
|
48 |
+
for _ in range(scales):
|
49 |
+
cur_sample = torch.randn(batch_size, 1, cur_height, cur_width, device=base_tensor.device)
|
50 |
+
cur_sample_scaled = F.interpolate(cur_sample, size=(height, width), mode=scale_mode, align_corners=align_corners)
|
51 |
+
result.append(cur_sample_scaled)
|
52 |
+
cur_height //= 2
|
53 |
+
cur_width //= 2
|
54 |
+
return torch.cat(result, dim=1)
|
55 |
+
|
56 |
+
|
57 |
+
class BaseInpaintingTrainingModule(ptl.LightningModule):
|
58 |
+
def __init__(self, config, use_ddp, *args, predict_only=False, visualize_each_iters=100,
|
59 |
+
average_generator=False, generator_avg_beta=0.999, average_generator_start_step=30000,
|
60 |
+
average_generator_period=10, store_discr_outputs_for_vis=False,
|
61 |
+
**kwargs):
|
62 |
+
super().__init__(*args, **kwargs)
|
63 |
+
LOGGER.info('BaseInpaintingTrainingModule init called')
|
64 |
+
|
65 |
+
self.config = config
|
66 |
+
|
67 |
+
self.generator = make_generator(config, **self.config.generator)
|
68 |
+
self.use_ddp = use_ddp
|
69 |
+
|
70 |
+
if not get_has_ddp_rank():
|
71 |
+
LOGGER.info(f'Generator\n{self.generator}')
|
72 |
+
|
73 |
+
# if not predict_only:
|
74 |
+
# self.save_hyperparameters(self.config)
|
75 |
+
# self.discriminator = make_discriminator(**self.config.discriminator)
|
76 |
+
# self.adversarial_loss = make_discrim_loss(**self.config.losses.adversarial)
|
77 |
+
# self.visualizer = make_visualizer(**self.config.visualizer)
|
78 |
+
# self.val_evaluator = make_evaluator(**self.config.evaluator)
|
79 |
+
# self.test_evaluator = make_evaluator(**self.config.evaluator)
|
80 |
+
#
|
81 |
+
# if not get_has_ddp_rank():
|
82 |
+
# LOGGER.info(f'Discriminator\n{self.discriminator}')
|
83 |
+
#
|
84 |
+
# extra_val = self.config.data.get('extra_val', ())
|
85 |
+
# if extra_val:
|
86 |
+
# self.extra_val_titles = list(extra_val)
|
87 |
+
# self.extra_evaluators = nn.ModuleDict({k: make_evaluator(**self.config.evaluator)
|
88 |
+
# for k in extra_val})
|
89 |
+
# else:
|
90 |
+
# self.extra_evaluators = {}
|
91 |
+
#
|
92 |
+
# self.average_generator = average_generator
|
93 |
+
# self.generator_avg_beta = generator_avg_beta
|
94 |
+
# self.average_generator_start_step = average_generator_start_step
|
95 |
+
# self.average_generator_period = average_generator_period
|
96 |
+
# self.generator_average = None
|
97 |
+
# self.last_generator_averaging_step = -1
|
98 |
+
# self.store_discr_outputs_for_vis = store_discr_outputs_for_vis
|
99 |
+
#
|
100 |
+
# if self.config.losses.get("l1", {"weight_known": 0})['weight_known'] > 0:
|
101 |
+
# self.loss_l1 = nn.L1Loss(reduction='none')
|
102 |
+
#
|
103 |
+
# if self.config.losses.get("mse", {"weight": 0})['weight'] > 0:
|
104 |
+
# self.loss_mse = nn.MSELoss(reduction='none')
|
105 |
+
#
|
106 |
+
# if self.config.losses.perceptual.weight > 0:
|
107 |
+
# self.loss_pl = PerceptualLoss()
|
108 |
+
#
|
109 |
+
# # if self.config.losses.get("resnet_pl", {"weight": 0})['weight'] > 0:
|
110 |
+
# # self.loss_resnet_pl = ResNetPL(**self.config.losses.resnet_pl)
|
111 |
+
# # else:
|
112 |
+
# # self.loss_resnet_pl = None
|
113 |
+
#
|
114 |
+
# self.loss_resnet_pl = None
|
115 |
+
|
116 |
+
self.visualize_each_iters = visualize_each_iters
|
117 |
+
LOGGER.info('BaseInpaintingTrainingModule init done')
|
118 |
+
|
119 |
+
def configure_optimizers(self):
|
120 |
+
discriminator_params = list(self.discriminator.parameters())
|
121 |
+
return [
|
122 |
+
dict(optimizer=make_optimizer(self.generator.parameters(), **self.config.optimizers.generator)),
|
123 |
+
dict(optimizer=make_optimizer(discriminator_params, **self.config.optimizers.discriminator)),
|
124 |
+
]
|
125 |
+
|
126 |
+
def train_dataloader(self):
|
127 |
+
kwargs = dict(self.config.data.train)
|
128 |
+
if self.use_ddp:
|
129 |
+
kwargs['ddp_kwargs'] = dict(num_replicas=self.trainer.num_nodes * self.trainer.num_processes,
|
130 |
+
rank=self.trainer.global_rank,
|
131 |
+
shuffle=True)
|
132 |
+
dataloader = make_default_train_dataloader(**self.config.data.train)
|
133 |
+
return dataloader
|
134 |
+
|
135 |
+
def val_dataloader(self):
|
136 |
+
res = [make_default_val_dataloader(**self.config.data.val)]
|
137 |
+
|
138 |
+
if self.config.data.visual_test is not None:
|
139 |
+
res = res + [make_default_val_dataloader(**self.config.data.visual_test)]
|
140 |
+
else:
|
141 |
+
res = res + res
|
142 |
+
|
143 |
+
extra_val = self.config.data.get('extra_val', ())
|
144 |
+
if extra_val:
|
145 |
+
res += [make_default_val_dataloader(**extra_val[k]) for k in self.extra_val_titles]
|
146 |
+
|
147 |
+
return res
|
148 |
+
|
149 |
+
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
150 |
+
self._is_training_step = True
|
151 |
+
return self._do_step(batch, batch_idx, mode='train', optimizer_idx=optimizer_idx)
|
152 |
+
|
153 |
+
def validation_step(self, batch, batch_idx, dataloader_idx):
|
154 |
+
extra_val_key = None
|
155 |
+
if dataloader_idx == 0:
|
156 |
+
mode = 'val'
|
157 |
+
elif dataloader_idx == 1:
|
158 |
+
mode = 'test'
|
159 |
+
else:
|
160 |
+
mode = 'extra_val'
|
161 |
+
extra_val_key = self.extra_val_titles[dataloader_idx - 2]
|
162 |
+
self._is_training_step = False
|
163 |
+
return self._do_step(batch, batch_idx, mode=mode, extra_val_key=extra_val_key)
|
164 |
+
|
165 |
+
def training_step_end(self, batch_parts_outputs):
|
166 |
+
if self.training and self.average_generator \
|
167 |
+
and self.global_step >= self.average_generator_start_step \
|
168 |
+
and self.global_step >= self.last_generator_averaging_step + self.average_generator_period:
|
169 |
+
if self.generator_average is None:
|
170 |
+
self.generator_average = copy.deepcopy(self.generator)
|
171 |
+
else:
|
172 |
+
update_running_average(self.generator_average, self.generator, decay=self.generator_avg_beta)
|
173 |
+
self.last_generator_averaging_step = self.global_step
|
174 |
+
|
175 |
+
full_loss = (batch_parts_outputs['loss'].mean()
|
176 |
+
if torch.is_tensor(batch_parts_outputs['loss']) # loss is not tensor when no discriminator used
|
177 |
+
else torch.tensor(batch_parts_outputs['loss']).float().requires_grad_(True))
|
178 |
+
log_info = {k: v.mean() for k, v in batch_parts_outputs['log_info'].items()}
|
179 |
+
self.log_dict(log_info, on_step=True, on_epoch=False)
|
180 |
+
return full_loss
|
181 |
+
|
182 |
+
def validation_epoch_end(self, outputs):
|
183 |
+
outputs = [step_out for out_group in outputs for step_out in out_group]
|
184 |
+
averaged_logs = average_dicts(step_out['log_info'] for step_out in outputs)
|
185 |
+
self.log_dict({k: v.mean() for k, v in averaged_logs.items()})
|
186 |
+
|
187 |
+
pd.set_option('display.max_columns', 500)
|
188 |
+
pd.set_option('display.width', 1000)
|
189 |
+
|
190 |
+
# standard validation
|
191 |
+
val_evaluator_states = [s['val_evaluator_state'] for s in outputs if 'val_evaluator_state' in s]
|
192 |
+
val_evaluator_res = self.val_evaluator.evaluation_end(states=val_evaluator_states)
|
193 |
+
val_evaluator_res_df = pd.DataFrame(val_evaluator_res).stack(1).unstack(0)
|
194 |
+
val_evaluator_res_df.dropna(axis=1, how='all', inplace=True)
|
195 |
+
LOGGER.info(f'Validation metrics after epoch #{self.current_epoch}, '
|
196 |
+
f'total {self.global_step} iterations:\n{val_evaluator_res_df}')
|
197 |
+
|
198 |
+
for k, v in flatten_dict(val_evaluator_res).items():
|
199 |
+
self.log(f'val_{k}', v)
|
200 |
+
|
201 |
+
# standard visual test
|
202 |
+
test_evaluator_states = [s['test_evaluator_state'] for s in outputs
|
203 |
+
if 'test_evaluator_state' in s]
|
204 |
+
test_evaluator_res = self.test_evaluator.evaluation_end(states=test_evaluator_states)
|
205 |
+
test_evaluator_res_df = pd.DataFrame(test_evaluator_res).stack(1).unstack(0)
|
206 |
+
test_evaluator_res_df.dropna(axis=1, how='all', inplace=True)
|
207 |
+
LOGGER.info(f'Test metrics after epoch #{self.current_epoch}, '
|
208 |
+
f'total {self.global_step} iterations:\n{test_evaluator_res_df}')
|
209 |
+
|
210 |
+
for k, v in flatten_dict(test_evaluator_res).items():
|
211 |
+
self.log(f'test_{k}', v)
|
212 |
+
|
213 |
+
# extra validations
|
214 |
+
if self.extra_evaluators:
|
215 |
+
for cur_eval_title, cur_evaluator in self.extra_evaluators.items():
|
216 |
+
cur_state_key = f'extra_val_{cur_eval_title}_evaluator_state'
|
217 |
+
cur_states = [s[cur_state_key] for s in outputs if cur_state_key in s]
|
218 |
+
cur_evaluator_res = cur_evaluator.evaluation_end(states=cur_states)
|
219 |
+
cur_evaluator_res_df = pd.DataFrame(cur_evaluator_res).stack(1).unstack(0)
|
220 |
+
cur_evaluator_res_df.dropna(axis=1, how='all', inplace=True)
|
221 |
+
LOGGER.info(f'Extra val {cur_eval_title} metrics after epoch #{self.current_epoch}, '
|
222 |
+
f'total {self.global_step} iterations:\n{cur_evaluator_res_df}')
|
223 |
+
for k, v in flatten_dict(cur_evaluator_res).items():
|
224 |
+
self.log(f'extra_val_{cur_eval_title}_{k}', v)
|
225 |
+
|
226 |
+
def _do_step(self, batch, batch_idx, mode='train', optimizer_idx=None, extra_val_key=None):
|
227 |
+
if optimizer_idx == 0: # step for generator
|
228 |
+
set_requires_grad(self.generator, True)
|
229 |
+
set_requires_grad(self.discriminator, False)
|
230 |
+
elif optimizer_idx == 1: # step for discriminator
|
231 |
+
set_requires_grad(self.generator, False)
|
232 |
+
set_requires_grad(self.discriminator, True)
|
233 |
+
|
234 |
+
batch = self(batch)
|
235 |
+
|
236 |
+
total_loss = 0
|
237 |
+
metrics = {}
|
238 |
+
|
239 |
+
if optimizer_idx is None or optimizer_idx == 0: # step for generator
|
240 |
+
total_loss, metrics = self.generator_loss(batch)
|
241 |
+
|
242 |
+
elif optimizer_idx is None or optimizer_idx == 1: # step for discriminator
|
243 |
+
if self.config.losses.adversarial.weight > 0:
|
244 |
+
total_loss, metrics = self.discriminator_loss(batch)
|
245 |
+
|
246 |
+
if self.get_ddp_rank() in (None, 0) and (batch_idx % self.visualize_each_iters == 0 or mode == 'test'):
|
247 |
+
if self.config.losses.adversarial.weight > 0:
|
248 |
+
if self.store_discr_outputs_for_vis:
|
249 |
+
with torch.no_grad():
|
250 |
+
self.store_discr_outputs(batch)
|
251 |
+
vis_suffix = f'_{mode}'
|
252 |
+
if mode == 'extra_val':
|
253 |
+
vis_suffix += f'_{extra_val_key}'
|
254 |
+
self.visualizer(self.current_epoch, batch_idx, batch, suffix=vis_suffix)
|
255 |
+
|
256 |
+
metrics_prefix = f'{mode}_'
|
257 |
+
if mode == 'extra_val':
|
258 |
+
metrics_prefix += f'{extra_val_key}_'
|
259 |
+
result = dict(loss=total_loss, log_info=add_prefix_to_keys(metrics, metrics_prefix))
|
260 |
+
if mode == 'val':
|
261 |
+
result['val_evaluator_state'] = self.val_evaluator.process_batch(batch)
|
262 |
+
elif mode == 'test':
|
263 |
+
result['test_evaluator_state'] = self.test_evaluator.process_batch(batch)
|
264 |
+
elif mode == 'extra_val':
|
265 |
+
result[f'extra_val_{extra_val_key}_evaluator_state'] = self.extra_evaluators[extra_val_key].process_batch(batch)
|
266 |
+
|
267 |
+
return result
|
268 |
+
|
269 |
+
def get_current_generator(self, no_average=False):
|
270 |
+
if not no_average and not self.training and self.average_generator and self.generator_average is not None:
|
271 |
+
return self.generator_average
|
272 |
+
return self.generator
|
273 |
+
|
274 |
+
def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
275 |
+
"""Pass data through generator and obtain at leas 'predicted_image' and 'inpainted' keys"""
|
276 |
+
raise NotImplementedError()
|
277 |
+
|
278 |
+
def generator_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
279 |
+
raise NotImplementedError()
|
280 |
+
|
281 |
+
def discriminator_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
282 |
+
raise NotImplementedError()
|
283 |
+
|
284 |
+
def store_discr_outputs(self, batch):
|
285 |
+
out_size = batch['image'].shape[2:]
|
286 |
+
discr_real_out, _ = self.discriminator(batch['image'])
|
287 |
+
discr_fake_out, _ = self.discriminator(batch['predicted_image'])
|
288 |
+
batch['discr_output_real'] = F.interpolate(discr_real_out, size=out_size, mode='nearest')
|
289 |
+
batch['discr_output_fake'] = F.interpolate(discr_fake_out, size=out_size, mode='nearest')
|
290 |
+
batch['discr_output_diff'] = batch['discr_output_real'] - batch['discr_output_fake']
|
291 |
+
|
292 |
+
def get_ddp_rank(self):
|
293 |
+
return self.trainer.global_rank if (self.trainer.num_nodes * self.trainer.num_processes) > 1 else None
|
annotator/lama/saicinpainting/training/trainers/default.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
|
7 |
+
# from annotator.lama.saicinpainting.training.data.datasets import make_constant_area_crop_params
|
8 |
+
from annotator.lama.saicinpainting.training.losses.distance_weighting import make_mask_distance_weighter
|
9 |
+
from annotator.lama.saicinpainting.training.losses.feature_matching import feature_matching_loss, masked_l1_loss
|
10 |
+
# from annotator.lama.saicinpainting.training.modules.fake_fakes import FakeFakesGenerator
|
11 |
+
from annotator.lama.saicinpainting.training.trainers.base import BaseInpaintingTrainingModule, make_multiscale_noise
|
12 |
+
from annotator.lama.saicinpainting.utils import add_prefix_to_keys, get_ramp
|
13 |
+
|
14 |
+
LOGGER = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
def make_constant_area_crop_batch(batch, **kwargs):
|
18 |
+
crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params(img_height=batch['image'].shape[2],
|
19 |
+
img_width=batch['image'].shape[3],
|
20 |
+
**kwargs)
|
21 |
+
batch['image'] = batch['image'][:, :, crop_y : crop_y + crop_height, crop_x : crop_x + crop_width]
|
22 |
+
batch['mask'] = batch['mask'][:, :, crop_y: crop_y + crop_height, crop_x: crop_x + crop_width]
|
23 |
+
return batch
|
24 |
+
|
25 |
+
|
26 |
+
class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
|
27 |
+
def __init__(self, *args, concat_mask=True, rescale_scheduler_kwargs=None, image_to_discriminator='predicted_image',
|
28 |
+
add_noise_kwargs=None, noise_fill_hole=False, const_area_crop_kwargs=None,
|
29 |
+
distance_weighter_kwargs=None, distance_weighted_mask_for_discr=False,
|
30 |
+
fake_fakes_proba=0, fake_fakes_generator_kwargs=None,
|
31 |
+
**kwargs):
|
32 |
+
super().__init__(*args, **kwargs)
|
33 |
+
self.concat_mask = concat_mask
|
34 |
+
self.rescale_size_getter = get_ramp(**rescale_scheduler_kwargs) if rescale_scheduler_kwargs is not None else None
|
35 |
+
self.image_to_discriminator = image_to_discriminator
|
36 |
+
self.add_noise_kwargs = add_noise_kwargs
|
37 |
+
self.noise_fill_hole = noise_fill_hole
|
38 |
+
self.const_area_crop_kwargs = const_area_crop_kwargs
|
39 |
+
self.refine_mask_for_losses = make_mask_distance_weighter(**distance_weighter_kwargs) \
|
40 |
+
if distance_weighter_kwargs is not None else None
|
41 |
+
self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr
|
42 |
+
|
43 |
+
self.fake_fakes_proba = fake_fakes_proba
|
44 |
+
if self.fake_fakes_proba > 1e-3:
|
45 |
+
self.fake_fakes_gen = FakeFakesGenerator(**(fake_fakes_generator_kwargs or {}))
|
46 |
+
|
47 |
+
def forward(self, batch):
|
48 |
+
if self.training and self.rescale_size_getter is not None:
|
49 |
+
cur_size = self.rescale_size_getter(self.global_step)
|
50 |
+
batch['image'] = F.interpolate(batch['image'], size=cur_size, mode='bilinear', align_corners=False)
|
51 |
+
batch['mask'] = F.interpolate(batch['mask'], size=cur_size, mode='nearest')
|
52 |
+
|
53 |
+
if self.training and self.const_area_crop_kwargs is not None:
|
54 |
+
batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs)
|
55 |
+
|
56 |
+
img = batch['image']
|
57 |
+
mask = batch['mask']
|
58 |
+
|
59 |
+
masked_img = img * (1 - mask)
|
60 |
+
|
61 |
+
if self.add_noise_kwargs is not None:
|
62 |
+
noise = make_multiscale_noise(masked_img, **self.add_noise_kwargs)
|
63 |
+
if self.noise_fill_hole:
|
64 |
+
masked_img = masked_img + mask * noise[:, :masked_img.shape[1]]
|
65 |
+
masked_img = torch.cat([masked_img, noise], dim=1)
|
66 |
+
|
67 |
+
if self.concat_mask:
|
68 |
+
masked_img = torch.cat([masked_img, mask], dim=1)
|
69 |
+
|
70 |
+
batch['predicted_image'] = self.generator(masked_img)
|
71 |
+
batch['inpainted'] = mask * batch['predicted_image'] + (1 - mask) * batch['image']
|
72 |
+
|
73 |
+
if self.fake_fakes_proba > 1e-3:
|
74 |
+
if self.training and torch.rand(1).item() < self.fake_fakes_proba:
|
75 |
+
batch['fake_fakes'], batch['fake_fakes_masks'] = self.fake_fakes_gen(img, mask)
|
76 |
+
batch['use_fake_fakes'] = True
|
77 |
+
else:
|
78 |
+
batch['fake_fakes'] = torch.zeros_like(img)
|
79 |
+
batch['fake_fakes_masks'] = torch.zeros_like(mask)
|
80 |
+
batch['use_fake_fakes'] = False
|
81 |
+
|
82 |
+
batch['mask_for_losses'] = self.refine_mask_for_losses(img, batch['predicted_image'], mask) \
|
83 |
+
if self.refine_mask_for_losses is not None and self.training \
|
84 |
+
else mask
|
85 |
+
|
86 |
+
return batch
|
87 |
+
|
88 |
+
def generator_loss(self, batch):
|
89 |
+
img = batch['image']
|
90 |
+
predicted_img = batch[self.image_to_discriminator]
|
91 |
+
original_mask = batch['mask']
|
92 |
+
supervised_mask = batch['mask_for_losses']
|
93 |
+
|
94 |
+
# L1
|
95 |
+
l1_value = masked_l1_loss(predicted_img, img, supervised_mask,
|
96 |
+
self.config.losses.l1.weight_known,
|
97 |
+
self.config.losses.l1.weight_missing)
|
98 |
+
|
99 |
+
total_loss = l1_value
|
100 |
+
metrics = dict(gen_l1=l1_value)
|
101 |
+
|
102 |
+
# vgg-based perceptual loss
|
103 |
+
if self.config.losses.perceptual.weight > 0:
|
104 |
+
pl_value = self.loss_pl(predicted_img, img, mask=supervised_mask).sum() * self.config.losses.perceptual.weight
|
105 |
+
total_loss = total_loss + pl_value
|
106 |
+
metrics['gen_pl'] = pl_value
|
107 |
+
|
108 |
+
# discriminator
|
109 |
+
# adversarial_loss calls backward by itself
|
110 |
+
mask_for_discr = supervised_mask if self.distance_weighted_mask_for_discr else original_mask
|
111 |
+
self.adversarial_loss.pre_generator_step(real_batch=img, fake_batch=predicted_img,
|
112 |
+
generator=self.generator, discriminator=self.discriminator)
|
113 |
+
discr_real_pred, discr_real_features = self.discriminator(img)
|
114 |
+
discr_fake_pred, discr_fake_features = self.discriminator(predicted_img)
|
115 |
+
adv_gen_loss, adv_metrics = self.adversarial_loss.generator_loss(real_batch=img,
|
116 |
+
fake_batch=predicted_img,
|
117 |
+
discr_real_pred=discr_real_pred,
|
118 |
+
discr_fake_pred=discr_fake_pred,
|
119 |
+
mask=mask_for_discr)
|
120 |
+
total_loss = total_loss + adv_gen_loss
|
121 |
+
metrics['gen_adv'] = adv_gen_loss
|
122 |
+
metrics.update(add_prefix_to_keys(adv_metrics, 'adv_'))
|
123 |
+
|
124 |
+
# feature matching
|
125 |
+
if self.config.losses.feature_matching.weight > 0:
|
126 |
+
need_mask_in_fm = OmegaConf.to_container(self.config.losses.feature_matching).get('pass_mask', False)
|
127 |
+
mask_for_fm = supervised_mask if need_mask_in_fm else None
|
128 |
+
fm_value = feature_matching_loss(discr_fake_features, discr_real_features,
|
129 |
+
mask=mask_for_fm) * self.config.losses.feature_matching.weight
|
130 |
+
total_loss = total_loss + fm_value
|
131 |
+
metrics['gen_fm'] = fm_value
|
132 |
+
|
133 |
+
if self.loss_resnet_pl is not None:
|
134 |
+
resnet_pl_value = self.loss_resnet_pl(predicted_img, img)
|
135 |
+
total_loss = total_loss + resnet_pl_value
|
136 |
+
metrics['gen_resnet_pl'] = resnet_pl_value
|
137 |
+
|
138 |
+
return total_loss, metrics
|
139 |
+
|
140 |
+
def discriminator_loss(self, batch):
|
141 |
+
total_loss = 0
|
142 |
+
metrics = {}
|
143 |
+
|
144 |
+
predicted_img = batch[self.image_to_discriminator].detach()
|
145 |
+
self.adversarial_loss.pre_discriminator_step(real_batch=batch['image'], fake_batch=predicted_img,
|
146 |
+
generator=self.generator, discriminator=self.discriminator)
|
147 |
+
discr_real_pred, discr_real_features = self.discriminator(batch['image'])
|
148 |
+
discr_fake_pred, discr_fake_features = self.discriminator(predicted_img)
|
149 |
+
adv_discr_loss, adv_metrics = self.adversarial_loss.discriminator_loss(real_batch=batch['image'],
|
150 |
+
fake_batch=predicted_img,
|
151 |
+
discr_real_pred=discr_real_pred,
|
152 |
+
discr_fake_pred=discr_fake_pred,
|
153 |
+
mask=batch['mask'])
|
154 |
+
total_loss = total_loss + adv_discr_loss
|
155 |
+
metrics['discr_adv'] = adv_discr_loss
|
156 |
+
metrics.update(add_prefix_to_keys(adv_metrics, 'adv_'))
|
157 |
+
|
158 |
+
|
159 |
+
if batch.get('use_fake_fakes', False):
|
160 |
+
fake_fakes = batch['fake_fakes']
|
161 |
+
self.adversarial_loss.pre_discriminator_step(real_batch=batch['image'], fake_batch=fake_fakes,
|
162 |
+
generator=self.generator, discriminator=self.discriminator)
|
163 |
+
discr_fake_fakes_pred, _ = self.discriminator(fake_fakes)
|
164 |
+
fake_fakes_adv_discr_loss, fake_fakes_adv_metrics = self.adversarial_loss.discriminator_loss(
|
165 |
+
real_batch=batch['image'],
|
166 |
+
fake_batch=fake_fakes,
|
167 |
+
discr_real_pred=discr_real_pred,
|
168 |
+
discr_fake_pred=discr_fake_fakes_pred,
|
169 |
+
mask=batch['mask']
|
170 |
+
)
|
171 |
+
total_loss = total_loss + fake_fakes_adv_discr_loss
|
172 |
+
metrics['discr_adv_fake_fakes'] = fake_fakes_adv_discr_loss
|
173 |
+
metrics.update(add_prefix_to_keys(fake_fakes_adv_metrics, 'adv_'))
|
174 |
+
|
175 |
+
return total_loss, metrics
|
annotator/lama/saicinpainting/training/visualizers/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from annotator.lama.saicinpainting.training.visualizers.directory import DirectoryVisualizer
|
4 |
+
from annotator.lama.saicinpainting.training.visualizers.noop import NoopVisualizer
|
5 |
+
|
6 |
+
|
7 |
+
def make_visualizer(kind, **kwargs):
|
8 |
+
logging.info(f'Make visualizer {kind}')
|
9 |
+
|
10 |
+
if kind == 'directory':
|
11 |
+
return DirectoryVisualizer(**kwargs)
|
12 |
+
if kind == 'noop':
|
13 |
+
return NoopVisualizer()
|
14 |
+
|
15 |
+
raise ValueError(f'Unknown visualizer kind {kind}')
|
annotator/lama/saicinpainting/training/visualizers/base.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
from typing import Dict, List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from skimage import color
|
7 |
+
from skimage.segmentation import mark_boundaries
|
8 |
+
|
9 |
+
from . import colors
|
10 |
+
|
11 |
+
COLORS, _ = colors.generate_colors(151) # 151 - max classes for semantic segmentation
|
12 |
+
|
13 |
+
|
14 |
+
class BaseVisualizer:
|
15 |
+
@abc.abstractmethod
|
16 |
+
def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
|
17 |
+
"""
|
18 |
+
Take a batch, make an image from it and visualize
|
19 |
+
"""
|
20 |
+
raise NotImplementedError()
|
21 |
+
|
22 |
+
|
23 |
+
def visualize_mask_and_images(images_dict: Dict[str, np.ndarray], keys: List[str],
|
24 |
+
last_without_mask=True, rescale_keys=None, mask_only_first=None,
|
25 |
+
black_mask=False) -> np.ndarray:
|
26 |
+
mask = images_dict['mask'] > 0.5
|
27 |
+
result = []
|
28 |
+
for i, k in enumerate(keys):
|
29 |
+
img = images_dict[k]
|
30 |
+
img = np.transpose(img, (1, 2, 0))
|
31 |
+
|
32 |
+
if rescale_keys is not None and k in rescale_keys:
|
33 |
+
img = img - img.min()
|
34 |
+
img /= img.max() + 1e-5
|
35 |
+
if len(img.shape) == 2:
|
36 |
+
img = np.expand_dims(img, 2)
|
37 |
+
|
38 |
+
if img.shape[2] == 1:
|
39 |
+
img = np.repeat(img, 3, axis=2)
|
40 |
+
elif (img.shape[2] > 3):
|
41 |
+
img_classes = img.argmax(2)
|
42 |
+
img = color.label2rgb(img_classes, colors=COLORS)
|
43 |
+
|
44 |
+
if mask_only_first:
|
45 |
+
need_mark_boundaries = i == 0
|
46 |
+
else:
|
47 |
+
need_mark_boundaries = i < len(keys) - 1 or not last_without_mask
|
48 |
+
|
49 |
+
if need_mark_boundaries:
|
50 |
+
if black_mask:
|
51 |
+
img = img * (1 - mask[0][..., None])
|
52 |
+
img = mark_boundaries(img,
|
53 |
+
mask[0],
|
54 |
+
color=(1., 0., 0.),
|
55 |
+
outline_color=(1., 1., 1.),
|
56 |
+
mode='thick')
|
57 |
+
result.append(img)
|
58 |
+
return np.concatenate(result, axis=1)
|
59 |
+
|
60 |
+
|
61 |
+
def visualize_mask_and_images_batch(batch: Dict[str, torch.Tensor], keys: List[str], max_items=10,
|
62 |
+
last_without_mask=True, rescale_keys=None) -> np.ndarray:
|
63 |
+
batch = {k: tens.detach().cpu().numpy() for k, tens in batch.items()
|
64 |
+
if k in keys or k == 'mask'}
|
65 |
+
|
66 |
+
batch_size = next(iter(batch.values())).shape[0]
|
67 |
+
items_to_vis = min(batch_size, max_items)
|
68 |
+
result = []
|
69 |
+
for i in range(items_to_vis):
|
70 |
+
cur_dct = {k: tens[i] for k, tens in batch.items()}
|
71 |
+
result.append(visualize_mask_and_images(cur_dct, keys, last_without_mask=last_without_mask,
|
72 |
+
rescale_keys=rescale_keys))
|
73 |
+
return np.concatenate(result, axis=0)
|
annotator/lama/saicinpainting/training/visualizers/colors.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import colorsys
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib
|
6 |
+
matplotlib.use('agg')
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from matplotlib.colors import LinearSegmentedColormap
|
9 |
+
|
10 |
+
|
11 |
+
def generate_colors(nlabels, type='bright', first_color_black=False, last_color_black=True, verbose=False):
|
12 |
+
# https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib
|
13 |
+
"""
|
14 |
+
Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks
|
15 |
+
:param nlabels: Number of labels (size of colormap)
|
16 |
+
:param type: 'bright' for strong colors, 'soft' for pastel colors
|
17 |
+
:param first_color_black: Option to use first color as black, True or False
|
18 |
+
:param last_color_black: Option to use last color as black, True or False
|
19 |
+
:param verbose: Prints the number of labels and shows the colormap. True or False
|
20 |
+
:return: colormap for matplotlib
|
21 |
+
"""
|
22 |
+
if type not in ('bright', 'soft'):
|
23 |
+
print ('Please choose "bright" or "soft" for type')
|
24 |
+
return
|
25 |
+
|
26 |
+
if verbose:
|
27 |
+
print('Number of labels: ' + str(nlabels))
|
28 |
+
|
29 |
+
# Generate color map for bright colors, based on hsv
|
30 |
+
if type == 'bright':
|
31 |
+
randHSVcolors = [(np.random.uniform(low=0.0, high=1),
|
32 |
+
np.random.uniform(low=0.2, high=1),
|
33 |
+
np.random.uniform(low=0.9, high=1)) for i in range(nlabels)]
|
34 |
+
|
35 |
+
# Convert HSV list to RGB
|
36 |
+
randRGBcolors = []
|
37 |
+
for HSVcolor in randHSVcolors:
|
38 |
+
randRGBcolors.append(colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2]))
|
39 |
+
|
40 |
+
if first_color_black:
|
41 |
+
randRGBcolors[0] = [0, 0, 0]
|
42 |
+
|
43 |
+
if last_color_black:
|
44 |
+
randRGBcolors[-1] = [0, 0, 0]
|
45 |
+
|
46 |
+
random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)
|
47 |
+
|
48 |
+
# Generate soft pastel colors, by limiting the RGB spectrum
|
49 |
+
if type == 'soft':
|
50 |
+
low = 0.6
|
51 |
+
high = 0.95
|
52 |
+
randRGBcolors = [(np.random.uniform(low=low, high=high),
|
53 |
+
np.random.uniform(low=low, high=high),
|
54 |
+
np.random.uniform(low=low, high=high)) for i in range(nlabels)]
|
55 |
+
|
56 |
+
if first_color_black:
|
57 |
+
randRGBcolors[0] = [0, 0, 0]
|
58 |
+
|
59 |
+
if last_color_black:
|
60 |
+
randRGBcolors[-1] = [0, 0, 0]
|
61 |
+
random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)
|
62 |
+
|
63 |
+
# Display colorbar
|
64 |
+
if verbose:
|
65 |
+
from matplotlib import colors, colorbar
|
66 |
+
from matplotlib import pyplot as plt
|
67 |
+
fig, ax = plt.subplots(1, 1, figsize=(15, 0.5))
|
68 |
+
|
69 |
+
bounds = np.linspace(0, nlabels, nlabels + 1)
|
70 |
+
norm = colors.BoundaryNorm(bounds, nlabels)
|
71 |
+
|
72 |
+
cb = colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None,
|
73 |
+
boundaries=bounds, format='%1i', orientation=u'horizontal')
|
74 |
+
|
75 |
+
return randRGBcolors, random_colormap
|
76 |
+
|
annotator/lama/saicinpainting/training/visualizers/directory.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from annotator.lama.saicinpainting.training.visualizers.base import BaseVisualizer, visualize_mask_and_images_batch
|
7 |
+
from annotator.lama.saicinpainting.utils import check_and_warn_input_range
|
8 |
+
|
9 |
+
|
10 |
+
class DirectoryVisualizer(BaseVisualizer):
|
11 |
+
DEFAULT_KEY_ORDER = 'image predicted_image inpainted'.split(' ')
|
12 |
+
|
13 |
+
def __init__(self, outdir, key_order=DEFAULT_KEY_ORDER, max_items_in_batch=10,
|
14 |
+
last_without_mask=True, rescale_keys=None):
|
15 |
+
self.outdir = outdir
|
16 |
+
os.makedirs(self.outdir, exist_ok=True)
|
17 |
+
self.key_order = key_order
|
18 |
+
self.max_items_in_batch = max_items_in_batch
|
19 |
+
self.last_without_mask = last_without_mask
|
20 |
+
self.rescale_keys = rescale_keys
|
21 |
+
|
22 |
+
def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
|
23 |
+
check_and_warn_input_range(batch['image'], 0, 1, 'DirectoryVisualizer target image')
|
24 |
+
vis_img = visualize_mask_and_images_batch(batch, self.key_order, max_items=self.max_items_in_batch,
|
25 |
+
last_without_mask=self.last_without_mask,
|
26 |
+
rescale_keys=self.rescale_keys)
|
27 |
+
|
28 |
+
vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8')
|
29 |
+
|
30 |
+
curoutdir = os.path.join(self.outdir, f'epoch{epoch_i:04d}{suffix}')
|
31 |
+
os.makedirs(curoutdir, exist_ok=True)
|
32 |
+
rank_suffix = f'_r{rank}' if rank is not None else ''
|
33 |
+
out_fname = os.path.join(curoutdir, f'batch{batch_i:07d}{rank_suffix}.jpg')
|
34 |
+
|
35 |
+
vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR)
|
36 |
+
cv2.imwrite(out_fname, vis_img)
|
annotator/lama/saicinpainting/training/visualizers/noop.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from annotator.lama.saicinpainting.training.visualizers.base import BaseVisualizer
|
2 |
+
|
3 |
+
|
4 |
+
class NoopVisualizer(BaseVisualizer):
|
5 |
+
def __init__(self, *args, **kwargs):
|
6 |
+
pass
|
7 |
+
|
8 |
+
def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
|
9 |
+
pass
|
annotator/lama/saicinpainting/utils.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import bisect
|
2 |
+
import functools
|
3 |
+
import logging
|
4 |
+
import numbers
|
5 |
+
import os
|
6 |
+
import signal
|
7 |
+
import sys
|
8 |
+
import traceback
|
9 |
+
import warnings
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from pytorch_lightning import seed_everything
|
13 |
+
|
14 |
+
LOGGER = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
def check_and_warn_input_range(tensor, min_value, max_value, name):
|
18 |
+
actual_min = tensor.min()
|
19 |
+
actual_max = tensor.max()
|
20 |
+
if actual_min < min_value or actual_max > max_value:
|
21 |
+
warnings.warn(f"{name} must be in {min_value}..{max_value} range, but it ranges {actual_min}..{actual_max}")
|
22 |
+
|
23 |
+
|
24 |
+
def sum_dict_with_prefix(target, cur_dict, prefix, default=0):
|
25 |
+
for k, v in cur_dict.items():
|
26 |
+
target_key = prefix + k
|
27 |
+
target[target_key] = target.get(target_key, default) + v
|
28 |
+
|
29 |
+
|
30 |
+
def average_dicts(dict_list):
|
31 |
+
result = {}
|
32 |
+
norm = 1e-3
|
33 |
+
for dct in dict_list:
|
34 |
+
sum_dict_with_prefix(result, dct, '')
|
35 |
+
norm += 1
|
36 |
+
for k in list(result):
|
37 |
+
result[k] /= norm
|
38 |
+
return result
|
39 |
+
|
40 |
+
|
41 |
+
def add_prefix_to_keys(dct, prefix):
|
42 |
+
return {prefix + k: v for k, v in dct.items()}
|
43 |
+
|
44 |
+
|
45 |
+
def set_requires_grad(module, value):
|
46 |
+
for param in module.parameters():
|
47 |
+
param.requires_grad = value
|
48 |
+
|
49 |
+
|
50 |
+
def flatten_dict(dct):
|
51 |
+
result = {}
|
52 |
+
for k, v in dct.items():
|
53 |
+
if isinstance(k, tuple):
|
54 |
+
k = '_'.join(k)
|
55 |
+
if isinstance(v, dict):
|
56 |
+
for sub_k, sub_v in flatten_dict(v).items():
|
57 |
+
result[f'{k}_{sub_k}'] = sub_v
|
58 |
+
else:
|
59 |
+
result[k] = v
|
60 |
+
return result
|
61 |
+
|
62 |
+
|
63 |
+
class LinearRamp:
|
64 |
+
def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
|
65 |
+
self.start_value = start_value
|
66 |
+
self.end_value = end_value
|
67 |
+
self.start_iter = start_iter
|
68 |
+
self.end_iter = end_iter
|
69 |
+
|
70 |
+
def __call__(self, i):
|
71 |
+
if i < self.start_iter:
|
72 |
+
return self.start_value
|
73 |
+
if i >= self.end_iter:
|
74 |
+
return self.end_value
|
75 |
+
part = (i - self.start_iter) / (self.end_iter - self.start_iter)
|
76 |
+
return self.start_value * (1 - part) + self.end_value * part
|
77 |
+
|
78 |
+
|
79 |
+
class LadderRamp:
|
80 |
+
def __init__(self, start_iters, values):
|
81 |
+
self.start_iters = start_iters
|
82 |
+
self.values = values
|
83 |
+
assert len(values) == len(start_iters) + 1, (len(values), len(start_iters))
|
84 |
+
|
85 |
+
def __call__(self, i):
|
86 |
+
segment_i = bisect.bisect_right(self.start_iters, i)
|
87 |
+
return self.values[segment_i]
|
88 |
+
|
89 |
+
|
90 |
+
def get_ramp(kind='ladder', **kwargs):
|
91 |
+
if kind == 'linear':
|
92 |
+
return LinearRamp(**kwargs)
|
93 |
+
if kind == 'ladder':
|
94 |
+
return LadderRamp(**kwargs)
|
95 |
+
raise ValueError(f'Unexpected ramp kind: {kind}')
|
96 |
+
|
97 |
+
|
98 |
+
def print_traceback_handler(sig, frame):
|
99 |
+
LOGGER.warning(f'Received signal {sig}')
|
100 |
+
bt = ''.join(traceback.format_stack())
|
101 |
+
LOGGER.warning(f'Requested stack trace:\n{bt}')
|
102 |
+
|
103 |
+
|
104 |
+
def register_debug_signal_handlers(sig=None, handler=print_traceback_handler):
|
105 |
+
LOGGER.warning(f'Setting signal {sig} handler {handler}')
|
106 |
+
signal.signal(sig, handler)
|
107 |
+
|
108 |
+
|
109 |
+
def handle_deterministic_config(config):
|
110 |
+
seed = dict(config).get('seed', None)
|
111 |
+
if seed is None:
|
112 |
+
return False
|
113 |
+
|
114 |
+
seed_everything(seed)
|
115 |
+
return True
|
116 |
+
|
117 |
+
|
118 |
+
def get_shape(t):
|
119 |
+
if torch.is_tensor(t):
|
120 |
+
return tuple(t.shape)
|
121 |
+
elif isinstance(t, dict):
|
122 |
+
return {n: get_shape(q) for n, q in t.items()}
|
123 |
+
elif isinstance(t, (list, tuple)):
|
124 |
+
return [get_shape(q) for q in t]
|
125 |
+
elif isinstance(t, numbers.Number):
|
126 |
+
return type(t)
|
127 |
+
else:
|
128 |
+
raise ValueError('unexpected type {}'.format(type(t)))
|
129 |
+
|
130 |
+
|
131 |
+
def get_has_ddp_rank():
|
132 |
+
master_port = os.environ.get('MASTER_PORT', None)
|
133 |
+
node_rank = os.environ.get('NODE_RANK', None)
|
134 |
+
local_rank = os.environ.get('LOCAL_RANK', None)
|
135 |
+
world_size = os.environ.get('WORLD_SIZE', None)
|
136 |
+
has_rank = master_port is not None or node_rank is not None or local_rank is not None or world_size is not None
|
137 |
+
return has_rank
|
138 |
+
|
139 |
+
|
140 |
+
def handle_ddp_subprocess():
|
141 |
+
def main_decorator(main_func):
|
142 |
+
@functools.wraps(main_func)
|
143 |
+
def new_main(*args, **kwargs):
|
144 |
+
# Trainer sets MASTER_PORT, NODE_RANK, LOCAL_RANK, WORLD_SIZE
|
145 |
+
parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
|
146 |
+
has_parent = parent_cwd is not None
|
147 |
+
has_rank = get_has_ddp_rank()
|
148 |
+
assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
|
149 |
+
|
150 |
+
if has_parent:
|
151 |
+
# we are in the worker
|
152 |
+
sys.argv.extend([
|
153 |
+
f'hydra.run.dir={parent_cwd}',
|
154 |
+
# 'hydra/hydra_logging=disabled',
|
155 |
+
# 'hydra/job_logging=disabled'
|
156 |
+
])
|
157 |
+
# do nothing if this is a top-level process
|
158 |
+
# TRAINING_PARENT_WORK_DIR is set in handle_ddp_parent_process after hydra initialization
|
159 |
+
|
160 |
+
main_func(*args, **kwargs)
|
161 |
+
return new_main
|
162 |
+
return main_decorator
|
163 |
+
|
164 |
+
|
165 |
+
def handle_ddp_parent_process():
|
166 |
+
parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
|
167 |
+
has_parent = parent_cwd is not None
|
168 |
+
has_rank = get_has_ddp_rank()
|
169 |
+
assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
|
170 |
+
|
171 |
+
if parent_cwd is None:
|
172 |
+
os.environ['TRAINING_PARENT_WORK_DIR'] = os.getcwd()
|
173 |
+
|
174 |
+
return has_parent
|
annotator/leres/__init__.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
from modules import devices, shared
|
6 |
+
from annotator.annotator_path import models_path
|
7 |
+
from torchvision.transforms import transforms
|
8 |
+
|
9 |
+
# AdelaiDepth/LeReS imports
|
10 |
+
from .leres.depthmap import estimateleres, estimateboost
|
11 |
+
from .leres.multi_depth_model_woauxi import RelDepthModel
|
12 |
+
from .leres.net_tools import strip_prefix_if_present
|
13 |
+
|
14 |
+
# pix2pix/merge net imports
|
15 |
+
from .pix2pix.options.test_options import TestOptions
|
16 |
+
from .pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel
|
17 |
+
|
18 |
+
base_model_path = os.path.join(models_path, "leres")
|
19 |
+
old_modeldir = os.path.dirname(os.path.realpath(__file__))
|
20 |
+
|
21 |
+
remote_model_path_leres = "https://huggingface.co/lllyasviel/Annotators/resolve/main/res101.pth"
|
22 |
+
remote_model_path_pix2pix = "https://huggingface.co/lllyasviel/Annotators/resolve/main/latest_net_G.pth"
|
23 |
+
|
24 |
+
model = None
|
25 |
+
pix2pixmodel = None
|
26 |
+
|
27 |
+
def unload_leres_model():
|
28 |
+
global model, pix2pixmodel
|
29 |
+
if model is not None:
|
30 |
+
model = model.cpu()
|
31 |
+
if pix2pixmodel is not None:
|
32 |
+
pix2pixmodel = pix2pixmodel.unload_network('G')
|
33 |
+
|
34 |
+
|
35 |
+
def apply_leres(input_image, thr_a, thr_b, boost=False):
|
36 |
+
global model, pix2pixmodel
|
37 |
+
if model is None:
|
38 |
+
model_path = os.path.join(base_model_path, "res101.pth")
|
39 |
+
old_model_path = os.path.join(old_modeldir, "res101.pth")
|
40 |
+
|
41 |
+
if os.path.exists(old_model_path):
|
42 |
+
model_path = old_model_path
|
43 |
+
elif not os.path.exists(model_path):
|
44 |
+
from basicsr.utils.download_util import load_file_from_url
|
45 |
+
load_file_from_url(remote_model_path_leres, model_dir=base_model_path)
|
46 |
+
|
47 |
+
if torch.cuda.is_available():
|
48 |
+
checkpoint = torch.load(model_path)
|
49 |
+
else:
|
50 |
+
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
|
51 |
+
|
52 |
+
model = RelDepthModel(backbone='resnext101')
|
53 |
+
model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), strict=True)
|
54 |
+
del checkpoint
|
55 |
+
|
56 |
+
if boost and pix2pixmodel is None:
|
57 |
+
pix2pixmodel_path = os.path.join(base_model_path, "latest_net_G.pth")
|
58 |
+
if not os.path.exists(pix2pixmodel_path):
|
59 |
+
from basicsr.utils.download_util import load_file_from_url
|
60 |
+
load_file_from_url(remote_model_path_pix2pix, model_dir=base_model_path)
|
61 |
+
|
62 |
+
opt = TestOptions().parse()
|
63 |
+
if not torch.cuda.is_available():
|
64 |
+
opt.gpu_ids = [] # cpu mode
|
65 |
+
pix2pixmodel = Pix2Pix4DepthModel(opt)
|
66 |
+
pix2pixmodel.save_dir = base_model_path
|
67 |
+
pix2pixmodel.load_networks('latest')
|
68 |
+
pix2pixmodel.eval()
|
69 |
+
|
70 |
+
if devices.get_device_for("controlnet").type != 'mps':
|
71 |
+
model = model.to(devices.get_device_for("controlnet"))
|
72 |
+
|
73 |
+
assert input_image.ndim == 3
|
74 |
+
height, width, dim = input_image.shape
|
75 |
+
|
76 |
+
with torch.no_grad():
|
77 |
+
|
78 |
+
if boost:
|
79 |
+
depth = estimateboost(input_image, model, 0, pix2pixmodel, max(width, height))
|
80 |
+
else:
|
81 |
+
depth = estimateleres(input_image, model, width, height)
|
82 |
+
|
83 |
+
numbytes=2
|
84 |
+
depth_min = depth.min()
|
85 |
+
depth_max = depth.max()
|
86 |
+
max_val = (2**(8*numbytes))-1
|
87 |
+
|
88 |
+
# check output before normalizing and mapping to 16 bit
|
89 |
+
if depth_max - depth_min > np.finfo("float").eps:
|
90 |
+
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
91 |
+
else:
|
92 |
+
out = np.zeros(depth.shape)
|
93 |
+
|
94 |
+
# single channel, 16 bit image
|
95 |
+
depth_image = out.astype("uint16")
|
96 |
+
|
97 |
+
# convert to uint8
|
98 |
+
depth_image = cv2.convertScaleAbs(depth_image, alpha=(255.0/65535.0))
|
99 |
+
|
100 |
+
# remove near
|
101 |
+
if thr_a != 0:
|
102 |
+
thr_a = ((thr_a/100)*255)
|
103 |
+
depth_image = cv2.threshold(depth_image, thr_a, 255, cv2.THRESH_TOZERO)[1]
|
104 |
+
|
105 |
+
# invert image
|
106 |
+
depth_image = cv2.bitwise_not(depth_image)
|
107 |
+
|
108 |
+
# remove bg
|
109 |
+
if thr_b != 0:
|
110 |
+
thr_b = ((thr_b/100)*255)
|
111 |
+
depth_image = cv2.threshold(depth_image, thr_b, 255, cv2.THRESH_TOZERO)[1]
|
112 |
+
|
113 |
+
return depth_image
|
annotator/leres/leres/LICENSE
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
https://github.com/thygate/stable-diffusion-webui-depthmap-script
|
2 |
+
|
3 |
+
MIT License
|
4 |
+
|
5 |
+
Copyright (c) 2023 Bob Thiry
|
6 |
+
|
7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
8 |
+
of this software and associated documentation files (the "Software"), to deal
|
9 |
+
in the Software without restriction, including without limitation the rights
|
10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
11 |
+
copies of the Software, and to permit persons to whom the Software is
|
12 |
+
furnished to do so, subject to the following conditions:
|
13 |
+
|
14 |
+
The above copyright notice and this permission notice shall be included in all
|
15 |
+
copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
23 |
+
SOFTWARE.
|
annotator/leres/leres/Resnet.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn as NN
|
3 |
+
|
4 |
+
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
5 |
+
'resnet152']
|
6 |
+
|
7 |
+
|
8 |
+
model_urls = {
|
9 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
10 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
11 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
12 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
13 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
14 |
+
}
|
15 |
+
|
16 |
+
|
17 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
18 |
+
"""3x3 convolution with padding"""
|
19 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
20 |
+
padding=1, bias=False)
|
21 |
+
|
22 |
+
|
23 |
+
class BasicBlock(nn.Module):
|
24 |
+
expansion = 1
|
25 |
+
|
26 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
27 |
+
super(BasicBlock, self).__init__()
|
28 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
29 |
+
self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
|
30 |
+
self.relu = nn.ReLU(inplace=True)
|
31 |
+
self.conv2 = conv3x3(planes, planes)
|
32 |
+
self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
|
33 |
+
self.downsample = downsample
|
34 |
+
self.stride = stride
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
residual = x
|
38 |
+
|
39 |
+
out = self.conv1(x)
|
40 |
+
out = self.bn1(out)
|
41 |
+
out = self.relu(out)
|
42 |
+
|
43 |
+
out = self.conv2(out)
|
44 |
+
out = self.bn2(out)
|
45 |
+
|
46 |
+
if self.downsample is not None:
|
47 |
+
residual = self.downsample(x)
|
48 |
+
|
49 |
+
out += residual
|
50 |
+
out = self.relu(out)
|
51 |
+
|
52 |
+
return out
|
53 |
+
|
54 |
+
|
55 |
+
class Bottleneck(nn.Module):
|
56 |
+
expansion = 4
|
57 |
+
|
58 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
59 |
+
super(Bottleneck, self).__init__()
|
60 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
61 |
+
self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
|
62 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
63 |
+
padding=1, bias=False)
|
64 |
+
self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
|
65 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
66 |
+
self.bn3 = NN.BatchNorm2d(planes * self.expansion) #NN.BatchNorm2d
|
67 |
+
self.relu = nn.ReLU(inplace=True)
|
68 |
+
self.downsample = downsample
|
69 |
+
self.stride = stride
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
residual = x
|
73 |
+
|
74 |
+
out = self.conv1(x)
|
75 |
+
out = self.bn1(out)
|
76 |
+
out = self.relu(out)
|
77 |
+
|
78 |
+
out = self.conv2(out)
|
79 |
+
out = self.bn2(out)
|
80 |
+
out = self.relu(out)
|
81 |
+
|
82 |
+
out = self.conv3(out)
|
83 |
+
out = self.bn3(out)
|
84 |
+
|
85 |
+
if self.downsample is not None:
|
86 |
+
residual = self.downsample(x)
|
87 |
+
|
88 |
+
out += residual
|
89 |
+
out = self.relu(out)
|
90 |
+
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
class ResNet(nn.Module):
|
95 |
+
|
96 |
+
def __init__(self, block, layers, num_classes=1000):
|
97 |
+
self.inplanes = 64
|
98 |
+
super(ResNet, self).__init__()
|
99 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
100 |
+
bias=False)
|
101 |
+
self.bn1 = NN.BatchNorm2d(64) #NN.BatchNorm2d
|
102 |
+
self.relu = nn.ReLU(inplace=True)
|
103 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
104 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
105 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
106 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
107 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
108 |
+
#self.avgpool = nn.AvgPool2d(7, stride=1)
|
109 |
+
#self.fc = nn.Linear(512 * block.expansion, num_classes)
|
110 |
+
|
111 |
+
for m in self.modules():
|
112 |
+
if isinstance(m, nn.Conv2d):
|
113 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
114 |
+
elif isinstance(m, nn.BatchNorm2d):
|
115 |
+
nn.init.constant_(m.weight, 1)
|
116 |
+
nn.init.constant_(m.bias, 0)
|
117 |
+
|
118 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
119 |
+
downsample = None
|
120 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
121 |
+
downsample = nn.Sequential(
|
122 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
123 |
+
kernel_size=1, stride=stride, bias=False),
|
124 |
+
NN.BatchNorm2d(planes * block.expansion), #NN.BatchNorm2d
|
125 |
+
)
|
126 |
+
|
127 |
+
layers = []
|
128 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
129 |
+
self.inplanes = planes * block.expansion
|
130 |
+
for i in range(1, blocks):
|
131 |
+
layers.append(block(self.inplanes, planes))
|
132 |
+
|
133 |
+
return nn.Sequential(*layers)
|
134 |
+
|
135 |
+
def forward(self, x):
|
136 |
+
features = []
|
137 |
+
|
138 |
+
x = self.conv1(x)
|
139 |
+
x = self.bn1(x)
|
140 |
+
x = self.relu(x)
|
141 |
+
x = self.maxpool(x)
|
142 |
+
|
143 |
+
x = self.layer1(x)
|
144 |
+
features.append(x)
|
145 |
+
x = self.layer2(x)
|
146 |
+
features.append(x)
|
147 |
+
x = self.layer3(x)
|
148 |
+
features.append(x)
|
149 |
+
x = self.layer4(x)
|
150 |
+
features.append(x)
|
151 |
+
|
152 |
+
return features
|
153 |
+
|
154 |
+
|
155 |
+
def resnet18(pretrained=True, **kwargs):
|
156 |
+
"""Constructs a ResNet-18 model.
|
157 |
+
Args:
|
158 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
159 |
+
"""
|
160 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
161 |
+
return model
|
162 |
+
|
163 |
+
|
164 |
+
def resnet34(pretrained=True, **kwargs):
|
165 |
+
"""Constructs a ResNet-34 model.
|
166 |
+
Args:
|
167 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
168 |
+
"""
|
169 |
+
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
170 |
+
return model
|
171 |
+
|
172 |
+
|
173 |
+
def resnet50(pretrained=True, **kwargs):
|
174 |
+
"""Constructs a ResNet-50 model.
|
175 |
+
Args:
|
176 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
177 |
+
"""
|
178 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
179 |
+
|
180 |
+
return model
|
181 |
+
|
182 |
+
|
183 |
+
def resnet101(pretrained=True, **kwargs):
|
184 |
+
"""Constructs a ResNet-101 model.
|
185 |
+
Args:
|
186 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
187 |
+
"""
|
188 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
189 |
+
|
190 |
+
return model
|
191 |
+
|
192 |
+
|
193 |
+
def resnet152(pretrained=True, **kwargs):
|
194 |
+
"""Constructs a ResNet-152 model.
|
195 |
+
Args:
|
196 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
197 |
+
"""
|
198 |
+
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
199 |
+
return model
|
annotator/leres/leres/Resnext_torch.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
try:
|
6 |
+
from urllib import urlretrieve
|
7 |
+
except ImportError:
|
8 |
+
from urllib.request import urlretrieve
|
9 |
+
|
10 |
+
__all__ = ['resnext101_32x8d']
|
11 |
+
|
12 |
+
|
13 |
+
model_urls = {
|
14 |
+
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
15 |
+
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
20 |
+
"""3x3 convolution with padding"""
|
21 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
22 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
23 |
+
|
24 |
+
|
25 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
26 |
+
"""1x1 convolution"""
|
27 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
28 |
+
|
29 |
+
|
30 |
+
class BasicBlock(nn.Module):
|
31 |
+
expansion = 1
|
32 |
+
|
33 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
34 |
+
base_width=64, dilation=1, norm_layer=None):
|
35 |
+
super(BasicBlock, self).__init__()
|
36 |
+
if norm_layer is None:
|
37 |
+
norm_layer = nn.BatchNorm2d
|
38 |
+
if groups != 1 or base_width != 64:
|
39 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
40 |
+
if dilation > 1:
|
41 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
42 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
43 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
44 |
+
self.bn1 = norm_layer(planes)
|
45 |
+
self.relu = nn.ReLU(inplace=True)
|
46 |
+
self.conv2 = conv3x3(planes, planes)
|
47 |
+
self.bn2 = norm_layer(planes)
|
48 |
+
self.downsample = downsample
|
49 |
+
self.stride = stride
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
identity = x
|
53 |
+
|
54 |
+
out = self.conv1(x)
|
55 |
+
out = self.bn1(out)
|
56 |
+
out = self.relu(out)
|
57 |
+
|
58 |
+
out = self.conv2(out)
|
59 |
+
out = self.bn2(out)
|
60 |
+
|
61 |
+
if self.downsample is not None:
|
62 |
+
identity = self.downsample(x)
|
63 |
+
|
64 |
+
out += identity
|
65 |
+
out = self.relu(out)
|
66 |
+
|
67 |
+
return out
|
68 |
+
|
69 |
+
|
70 |
+
class Bottleneck(nn.Module):
|
71 |
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
72 |
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
73 |
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
74 |
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
75 |
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
76 |
+
|
77 |
+
expansion = 4
|
78 |
+
|
79 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
80 |
+
base_width=64, dilation=1, norm_layer=None):
|
81 |
+
super(Bottleneck, self).__init__()
|
82 |
+
if norm_layer is None:
|
83 |
+
norm_layer = nn.BatchNorm2d
|
84 |
+
width = int(planes * (base_width / 64.)) * groups
|
85 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
86 |
+
self.conv1 = conv1x1(inplanes, width)
|
87 |
+
self.bn1 = norm_layer(width)
|
88 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
89 |
+
self.bn2 = norm_layer(width)
|
90 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
91 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
92 |
+
self.relu = nn.ReLU(inplace=True)
|
93 |
+
self.downsample = downsample
|
94 |
+
self.stride = stride
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
identity = x
|
98 |
+
|
99 |
+
out = self.conv1(x)
|
100 |
+
out = self.bn1(out)
|
101 |
+
out = self.relu(out)
|
102 |
+
|
103 |
+
out = self.conv2(out)
|
104 |
+
out = self.bn2(out)
|
105 |
+
out = self.relu(out)
|
106 |
+
|
107 |
+
out = self.conv3(out)
|
108 |
+
out = self.bn3(out)
|
109 |
+
|
110 |
+
if self.downsample is not None:
|
111 |
+
identity = self.downsample(x)
|
112 |
+
|
113 |
+
out += identity
|
114 |
+
out = self.relu(out)
|
115 |
+
|
116 |
+
return out
|
117 |
+
|
118 |
+
|
119 |
+
class ResNet(nn.Module):
|
120 |
+
|
121 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
122 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
123 |
+
norm_layer=None):
|
124 |
+
super(ResNet, self).__init__()
|
125 |
+
if norm_layer is None:
|
126 |
+
norm_layer = nn.BatchNorm2d
|
127 |
+
self._norm_layer = norm_layer
|
128 |
+
|
129 |
+
self.inplanes = 64
|
130 |
+
self.dilation = 1
|
131 |
+
if replace_stride_with_dilation is None:
|
132 |
+
# each element in the tuple indicates if we should replace
|
133 |
+
# the 2x2 stride with a dilated convolution instead
|
134 |
+
replace_stride_with_dilation = [False, False, False]
|
135 |
+
if len(replace_stride_with_dilation) != 3:
|
136 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
137 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
138 |
+
self.groups = groups
|
139 |
+
self.base_width = width_per_group
|
140 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
141 |
+
bias=False)
|
142 |
+
self.bn1 = norm_layer(self.inplanes)
|
143 |
+
self.relu = nn.ReLU(inplace=True)
|
144 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
145 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
146 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
147 |
+
dilate=replace_stride_with_dilation[0])
|
148 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
149 |
+
dilate=replace_stride_with_dilation[1])
|
150 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
151 |
+
dilate=replace_stride_with_dilation[2])
|
152 |
+
#self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
153 |
+
#self.fc = nn.Linear(512 * block.expansion, num_classes)
|
154 |
+
|
155 |
+
for m in self.modules():
|
156 |
+
if isinstance(m, nn.Conv2d):
|
157 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
158 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
159 |
+
nn.init.constant_(m.weight, 1)
|
160 |
+
nn.init.constant_(m.bias, 0)
|
161 |
+
|
162 |
+
# Zero-initialize the last BN in each residual branch,
|
163 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
164 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
165 |
+
if zero_init_residual:
|
166 |
+
for m in self.modules():
|
167 |
+
if isinstance(m, Bottleneck):
|
168 |
+
nn.init.constant_(m.bn3.weight, 0)
|
169 |
+
elif isinstance(m, BasicBlock):
|
170 |
+
nn.init.constant_(m.bn2.weight, 0)
|
171 |
+
|
172 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
173 |
+
norm_layer = self._norm_layer
|
174 |
+
downsample = None
|
175 |
+
previous_dilation = self.dilation
|
176 |
+
if dilate:
|
177 |
+
self.dilation *= stride
|
178 |
+
stride = 1
|
179 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
180 |
+
downsample = nn.Sequential(
|
181 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
182 |
+
norm_layer(planes * block.expansion),
|
183 |
+
)
|
184 |
+
|
185 |
+
layers = []
|
186 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
187 |
+
self.base_width, previous_dilation, norm_layer))
|
188 |
+
self.inplanes = planes * block.expansion
|
189 |
+
for _ in range(1, blocks):
|
190 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
191 |
+
base_width=self.base_width, dilation=self.dilation,
|
192 |
+
norm_layer=norm_layer))
|
193 |
+
|
194 |
+
return nn.Sequential(*layers)
|
195 |
+
|
196 |
+
def _forward_impl(self, x):
|
197 |
+
# See note [TorchScript super()]
|
198 |
+
features = []
|
199 |
+
x = self.conv1(x)
|
200 |
+
x = self.bn1(x)
|
201 |
+
x = self.relu(x)
|
202 |
+
x = self.maxpool(x)
|
203 |
+
|
204 |
+
x = self.layer1(x)
|
205 |
+
features.append(x)
|
206 |
+
|
207 |
+
x = self.layer2(x)
|
208 |
+
features.append(x)
|
209 |
+
|
210 |
+
x = self.layer3(x)
|
211 |
+
features.append(x)
|
212 |
+
|
213 |
+
x = self.layer4(x)
|
214 |
+
features.append(x)
|
215 |
+
|
216 |
+
#x = self.avgpool(x)
|
217 |
+
#x = torch.flatten(x, 1)
|
218 |
+
#x = self.fc(x)
|
219 |
+
|
220 |
+
return features
|
221 |
+
|
222 |
+
def forward(self, x):
|
223 |
+
return self._forward_impl(x)
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
def resnext101_32x8d(pretrained=True, **kwargs):
|
228 |
+
"""Constructs a ResNet-152 model.
|
229 |
+
Args:
|
230 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
231 |
+
"""
|
232 |
+
kwargs['groups'] = 32
|
233 |
+
kwargs['width_per_group'] = 8
|
234 |
+
|
235 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
236 |
+
return model
|
237 |
+
|
annotator/leres/leres/depthmap.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Author: thygate
|
2 |
+
# https://github.com/thygate/stable-diffusion-webui-depthmap-script
|
3 |
+
|
4 |
+
from modules import devices
|
5 |
+
from modules.shared import opts
|
6 |
+
from torchvision.transforms import transforms
|
7 |
+
from operator import getitem
|
8 |
+
|
9 |
+
import torch, gc
|
10 |
+
import cv2
|
11 |
+
import numpy as np
|
12 |
+
import skimage.measure
|
13 |
+
|
14 |
+
whole_size_threshold = 1600 # R_max from the paper
|
15 |
+
pix2pixsize = 1024
|
16 |
+
|
17 |
+
def scale_torch(img):
|
18 |
+
"""
|
19 |
+
Scale the image and output it in torch.tensor.
|
20 |
+
:param img: input rgb is in shape [H, W, C], input depth/disp is in shape [H, W]
|
21 |
+
:param scale: the scale factor. float
|
22 |
+
:return: img. [C, H, W]
|
23 |
+
"""
|
24 |
+
if len(img.shape) == 2:
|
25 |
+
img = img[np.newaxis, :, :]
|
26 |
+
if img.shape[2] == 3:
|
27 |
+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406) , (0.229, 0.224, 0.225) )])
|
28 |
+
img = transform(img.astype(np.float32))
|
29 |
+
else:
|
30 |
+
img = img.astype(np.float32)
|
31 |
+
img = torch.from_numpy(img)
|
32 |
+
return img
|
33 |
+
|
34 |
+
def estimateleres(img, model, w, h):
|
35 |
+
# leres transform input
|
36 |
+
rgb_c = img[:, :, ::-1].copy()
|
37 |
+
A_resize = cv2.resize(rgb_c, (w, h))
|
38 |
+
img_torch = scale_torch(A_resize)[None, :, :, :]
|
39 |
+
|
40 |
+
# compute
|
41 |
+
with torch.no_grad():
|
42 |
+
img_torch = img_torch.to(devices.get_device_for("controlnet"))
|
43 |
+
prediction = model.depth_model(img_torch)
|
44 |
+
|
45 |
+
prediction = prediction.squeeze().cpu().numpy()
|
46 |
+
prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC)
|
47 |
+
|
48 |
+
return prediction
|
49 |
+
|
50 |
+
def generatemask(size):
|
51 |
+
# Generates a Guassian mask
|
52 |
+
mask = np.zeros(size, dtype=np.float32)
|
53 |
+
sigma = int(size[0]/16)
|
54 |
+
k_size = int(2 * np.ceil(2 * int(size[0]/16)) + 1)
|
55 |
+
mask[int(0.15*size[0]):size[0] - int(0.15*size[0]), int(0.15*size[1]): size[1] - int(0.15*size[1])] = 1
|
56 |
+
mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma)
|
57 |
+
mask = (mask - mask.min()) / (mask.max() - mask.min())
|
58 |
+
mask = mask.astype(np.float32)
|
59 |
+
return mask
|
60 |
+
|
61 |
+
def resizewithpool(img, size):
|
62 |
+
i_size = img.shape[0]
|
63 |
+
n = int(np.floor(i_size/size))
|
64 |
+
|
65 |
+
out = skimage.measure.block_reduce(img, (n, n), np.max)
|
66 |
+
return out
|
67 |
+
|
68 |
+
def rgb2gray(rgb):
|
69 |
+
# Converts rgb to gray
|
70 |
+
return np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140])
|
71 |
+
|
72 |
+
def calculateprocessingres(img, basesize, confidence=0.1, scale_threshold=3, whole_size_threshold=3000):
|
73 |
+
# Returns the R_x resolution described in section 5 of the main paper.
|
74 |
+
|
75 |
+
# Parameters:
|
76 |
+
# img :input rgb image
|
77 |
+
# basesize : size the dilation kernel which is equal to receptive field of the network.
|
78 |
+
# confidence: value of x in R_x; allowed percentage of pixels that are not getting any contextual cue.
|
79 |
+
# scale_threshold: maximum allowed upscaling on the input image ; it has been set to 3.
|
80 |
+
# whole_size_threshold: maximum allowed resolution. (R_max from section 6 of the main paper)
|
81 |
+
|
82 |
+
# Returns:
|
83 |
+
# outputsize_scale*speed_scale :The computed R_x resolution
|
84 |
+
# patch_scale: K parameter from section 6 of the paper
|
85 |
+
|
86 |
+
# speed scale parameter is to process every image in a smaller size to accelerate the R_x resolution search
|
87 |
+
speed_scale = 32
|
88 |
+
image_dim = int(min(img.shape[0:2]))
|
89 |
+
|
90 |
+
gray = rgb2gray(img)
|
91 |
+
grad = np.abs(cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)) + np.abs(cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3))
|
92 |
+
grad = cv2.resize(grad, (image_dim, image_dim), cv2.INTER_AREA)
|
93 |
+
|
94 |
+
# thresholding the gradient map to generate the edge-map as a proxy of the contextual cues
|
95 |
+
m = grad.min()
|
96 |
+
M = grad.max()
|
97 |
+
middle = m + (0.4 * (M - m))
|
98 |
+
grad[grad < middle] = 0
|
99 |
+
grad[grad >= middle] = 1
|
100 |
+
|
101 |
+
# dilation kernel with size of the receptive field
|
102 |
+
kernel = np.ones((int(basesize/speed_scale), int(basesize/speed_scale)), float)
|
103 |
+
# dilation kernel with size of the a quarter of receptive field used to compute k
|
104 |
+
# as described in section 6 of main paper
|
105 |
+
kernel2 = np.ones((int(basesize / (4*speed_scale)), int(basesize / (4*speed_scale))), float)
|
106 |
+
|
107 |
+
# Output resolution limit set by the whole_size_threshold and scale_threshold.
|
108 |
+
threshold = min(whole_size_threshold, scale_threshold * max(img.shape[:2]))
|
109 |
+
|
110 |
+
outputsize_scale = basesize / speed_scale
|
111 |
+
for p_size in range(int(basesize/speed_scale), int(threshold/speed_scale), int(basesize / (2*speed_scale))):
|
112 |
+
grad_resized = resizewithpool(grad, p_size)
|
113 |
+
grad_resized = cv2.resize(grad_resized, (p_size, p_size), cv2.INTER_NEAREST)
|
114 |
+
grad_resized[grad_resized >= 0.5] = 1
|
115 |
+
grad_resized[grad_resized < 0.5] = 0
|
116 |
+
|
117 |
+
dilated = cv2.dilate(grad_resized, kernel, iterations=1)
|
118 |
+
meanvalue = (1-dilated).mean()
|
119 |
+
if meanvalue > confidence:
|
120 |
+
break
|
121 |
+
else:
|
122 |
+
outputsize_scale = p_size
|
123 |
+
|
124 |
+
grad_region = cv2.dilate(grad_resized, kernel2, iterations=1)
|
125 |
+
patch_scale = grad_region.mean()
|
126 |
+
|
127 |
+
return int(outputsize_scale*speed_scale), patch_scale
|
128 |
+
|
129 |
+
# Generate a double-input depth estimation
|
130 |
+
def doubleestimate(img, size1, size2, pix2pixsize, model, net_type, pix2pixmodel):
|
131 |
+
# Generate the low resolution estimation
|
132 |
+
estimate1 = singleestimate(img, size1, model, net_type)
|
133 |
+
# Resize to the inference size of merge network.
|
134 |
+
estimate1 = cv2.resize(estimate1, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
|
135 |
+
|
136 |
+
# Generate the high resolution estimation
|
137 |
+
estimate2 = singleestimate(img, size2, model, net_type)
|
138 |
+
# Resize to the inference size of merge network.
|
139 |
+
estimate2 = cv2.resize(estimate2, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
|
140 |
+
|
141 |
+
# Inference on the merge model
|
142 |
+
pix2pixmodel.set_input(estimate1, estimate2)
|
143 |
+
pix2pixmodel.test()
|
144 |
+
visuals = pix2pixmodel.get_current_visuals()
|
145 |
+
prediction_mapped = visuals['fake_B']
|
146 |
+
prediction_mapped = (prediction_mapped+1)/2
|
147 |
+
prediction_mapped = (prediction_mapped - torch.min(prediction_mapped)) / (
|
148 |
+
torch.max(prediction_mapped) - torch.min(prediction_mapped))
|
149 |
+
prediction_mapped = prediction_mapped.squeeze().cpu().numpy()
|
150 |
+
|
151 |
+
return prediction_mapped
|
152 |
+
|
153 |
+
# Generate a single-input depth estimation
|
154 |
+
def singleestimate(img, msize, model, net_type):
|
155 |
+
# if net_type == 0:
|
156 |
+
return estimateleres(img, model, msize, msize)
|
157 |
+
# else:
|
158 |
+
# return estimatemidasBoost(img, model, msize, msize)
|
159 |
+
|
160 |
+
def applyGridpatch(blsize, stride, img, box):
|
161 |
+
# Extract a simple grid patch.
|
162 |
+
counter1 = 0
|
163 |
+
patch_bound_list = {}
|
164 |
+
for k in range(blsize, img.shape[1] - blsize, stride):
|
165 |
+
for j in range(blsize, img.shape[0] - blsize, stride):
|
166 |
+
patch_bound_list[str(counter1)] = {}
|
167 |
+
patchbounds = [j - blsize, k - blsize, j - blsize + 2 * blsize, k - blsize + 2 * blsize]
|
168 |
+
patch_bound = [box[0] + patchbounds[1], box[1] + patchbounds[0], patchbounds[3] - patchbounds[1],
|
169 |
+
patchbounds[2] - patchbounds[0]]
|
170 |
+
patch_bound_list[str(counter1)]['rect'] = patch_bound
|
171 |
+
patch_bound_list[str(counter1)]['size'] = patch_bound[2]
|
172 |
+
counter1 = counter1 + 1
|
173 |
+
return patch_bound_list
|
174 |
+
|
175 |
+
# Generating local patches to perform the local refinement described in section 6 of the main paper.
|
176 |
+
def generatepatchs(img, base_size):
|
177 |
+
|
178 |
+
# Compute the gradients as a proxy of the contextual cues.
|
179 |
+
img_gray = rgb2gray(img)
|
180 |
+
whole_grad = np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 0, 1, ksize=3)) +\
|
181 |
+
np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 1, 0, ksize=3))
|
182 |
+
|
183 |
+
threshold = whole_grad[whole_grad > 0].mean()
|
184 |
+
whole_grad[whole_grad < threshold] = 0
|
185 |
+
|
186 |
+
# We use the integral image to speed-up the evaluation of the amount of gradients for each patch.
|
187 |
+
gf = whole_grad.sum()/len(whole_grad.reshape(-1))
|
188 |
+
grad_integral_image = cv2.integral(whole_grad)
|
189 |
+
|
190 |
+
# Variables are selected such that the initial patch size would be the receptive field size
|
191 |
+
# and the stride is set to 1/3 of the receptive field size.
|
192 |
+
blsize = int(round(base_size/2))
|
193 |
+
stride = int(round(blsize*0.75))
|
194 |
+
|
195 |
+
# Get initial Grid
|
196 |
+
patch_bound_list = applyGridpatch(blsize, stride, img, [0, 0, 0, 0])
|
197 |
+
|
198 |
+
# Refine initial Grid of patches by discarding the flat (in terms of gradients of the rgb image) ones. Refine
|
199 |
+
# each patch size to ensure that there will be enough depth cues for the network to generate a consistent depth map.
|
200 |
+
print("Selecting patches ...")
|
201 |
+
patch_bound_list = adaptiveselection(grad_integral_image, patch_bound_list, gf)
|
202 |
+
|
203 |
+
# Sort the patch list to make sure the merging operation will be done with the correct order: starting from biggest
|
204 |
+
# patch
|
205 |
+
patchset = sorted(patch_bound_list.items(), key=lambda x: getitem(x[1], 'size'), reverse=True)
|
206 |
+
return patchset
|
207 |
+
|
208 |
+
def getGF_fromintegral(integralimage, rect):
|
209 |
+
# Computes the gradient density of a given patch from the gradient integral image.
|
210 |
+
x1 = rect[1]
|
211 |
+
x2 = rect[1]+rect[3]
|
212 |
+
y1 = rect[0]
|
213 |
+
y2 = rect[0]+rect[2]
|
214 |
+
value = integralimage[x2, y2]-integralimage[x1, y2]-integralimage[x2, y1]+integralimage[x1, y1]
|
215 |
+
return value
|
216 |
+
|
217 |
+
# Adaptively select patches
|
218 |
+
def adaptiveselection(integral_grad, patch_bound_list, gf):
|
219 |
+
patchlist = {}
|
220 |
+
count = 0
|
221 |
+
height, width = integral_grad.shape
|
222 |
+
|
223 |
+
search_step = int(32/factor)
|
224 |
+
|
225 |
+
# Go through all patches
|
226 |
+
for c in range(len(patch_bound_list)):
|
227 |
+
# Get patch
|
228 |
+
bbox = patch_bound_list[str(c)]['rect']
|
229 |
+
|
230 |
+
# Compute the amount of gradients present in the patch from the integral image.
|
231 |
+
cgf = getGF_fromintegral(integral_grad, bbox)/(bbox[2]*bbox[3])
|
232 |
+
|
233 |
+
# Check if patching is beneficial by comparing the gradient density of the patch to
|
234 |
+
# the gradient density of the whole image
|
235 |
+
if cgf >= gf:
|
236 |
+
bbox_test = bbox.copy()
|
237 |
+
patchlist[str(count)] = {}
|
238 |
+
|
239 |
+
# Enlarge each patch until the gradient density of the patch is equal
|
240 |
+
# to the whole image gradient density
|
241 |
+
while True:
|
242 |
+
|
243 |
+
bbox_test[0] = bbox_test[0] - int(search_step/2)
|
244 |
+
bbox_test[1] = bbox_test[1] - int(search_step/2)
|
245 |
+
|
246 |
+
bbox_test[2] = bbox_test[2] + search_step
|
247 |
+
bbox_test[3] = bbox_test[3] + search_step
|
248 |
+
|
249 |
+
# Check if we are still within the image
|
250 |
+
if bbox_test[0] < 0 or bbox_test[1] < 0 or bbox_test[1] + bbox_test[3] >= height \
|
251 |
+
or bbox_test[0] + bbox_test[2] >= width:
|
252 |
+
break
|
253 |
+
|
254 |
+
# Compare gradient density
|
255 |
+
cgf = getGF_fromintegral(integral_grad, bbox_test)/(bbox_test[2]*bbox_test[3])
|
256 |
+
if cgf < gf:
|
257 |
+
break
|
258 |
+
bbox = bbox_test.copy()
|
259 |
+
|
260 |
+
# Add patch to selected patches
|
261 |
+
patchlist[str(count)]['rect'] = bbox
|
262 |
+
patchlist[str(count)]['size'] = bbox[2]
|
263 |
+
count = count + 1
|
264 |
+
|
265 |
+
# Return selected patches
|
266 |
+
return patchlist
|
267 |
+
|
268 |
+
def impatch(image, rect):
|
269 |
+
# Extract the given patch pixels from a given image.
|
270 |
+
w1 = rect[0]
|
271 |
+
h1 = rect[1]
|
272 |
+
w2 = w1 + rect[2]
|
273 |
+
h2 = h1 + rect[3]
|
274 |
+
image_patch = image[h1:h2, w1:w2]
|
275 |
+
return image_patch
|
276 |
+
|
277 |
+
class ImageandPatchs:
|
278 |
+
def __init__(self, root_dir, name, patchsinfo, rgb_image, scale=1):
|
279 |
+
self.root_dir = root_dir
|
280 |
+
self.patchsinfo = patchsinfo
|
281 |
+
self.name = name
|
282 |
+
self.patchs = patchsinfo
|
283 |
+
self.scale = scale
|
284 |
+
|
285 |
+
self.rgb_image = cv2.resize(rgb_image, (round(rgb_image.shape[1]*scale), round(rgb_image.shape[0]*scale)),
|
286 |
+
interpolation=cv2.INTER_CUBIC)
|
287 |
+
|
288 |
+
self.do_have_estimate = False
|
289 |
+
self.estimation_updated_image = None
|
290 |
+
self.estimation_base_image = None
|
291 |
+
|
292 |
+
def __len__(self):
|
293 |
+
return len(self.patchs)
|
294 |
+
|
295 |
+
def set_base_estimate(self, est):
|
296 |
+
self.estimation_base_image = est
|
297 |
+
if self.estimation_updated_image is not None:
|
298 |
+
self.do_have_estimate = True
|
299 |
+
|
300 |
+
def set_updated_estimate(self, est):
|
301 |
+
self.estimation_updated_image = est
|
302 |
+
if self.estimation_base_image is not None:
|
303 |
+
self.do_have_estimate = True
|
304 |
+
|
305 |
+
def __getitem__(self, index):
|
306 |
+
patch_id = int(self.patchs[index][0])
|
307 |
+
rect = np.array(self.patchs[index][1]['rect'])
|
308 |
+
msize = self.patchs[index][1]['size']
|
309 |
+
|
310 |
+
## applying scale to rect:
|
311 |
+
rect = np.round(rect * self.scale)
|
312 |
+
rect = rect.astype('int')
|
313 |
+
msize = round(msize * self.scale)
|
314 |
+
|
315 |
+
patch_rgb = impatch(self.rgb_image, rect)
|
316 |
+
if self.do_have_estimate:
|
317 |
+
patch_whole_estimate_base = impatch(self.estimation_base_image, rect)
|
318 |
+
patch_whole_estimate_updated = impatch(self.estimation_updated_image, rect)
|
319 |
+
return {'patch_rgb': patch_rgb, 'patch_whole_estimate_base': patch_whole_estimate_base,
|
320 |
+
'patch_whole_estimate_updated': patch_whole_estimate_updated, 'rect': rect,
|
321 |
+
'size': msize, 'id': patch_id}
|
322 |
+
else:
|
323 |
+
return {'patch_rgb': patch_rgb, 'rect': rect, 'size': msize, 'id': patch_id}
|
324 |
+
|
325 |
+
def print_options(self, opt):
|
326 |
+
"""Print and save options
|
327 |
+
|
328 |
+
It will print both current options and default values(if different).
|
329 |
+
It will save options into a text file / [checkpoints_dir] / opt.txt
|
330 |
+
"""
|
331 |
+
message = ''
|
332 |
+
message += '----------------- Options ---------------\n'
|
333 |
+
for k, v in sorted(vars(opt).items()):
|
334 |
+
comment = ''
|
335 |
+
default = self.parser.get_default(k)
|
336 |
+
if v != default:
|
337 |
+
comment = '\t[default: %s]' % str(default)
|
338 |
+
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
339 |
+
message += '----------------- End -------------------'
|
340 |
+
print(message)
|
341 |
+
|
342 |
+
# save to the disk
|
343 |
+
"""
|
344 |
+
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
345 |
+
util.mkdirs(expr_dir)
|
346 |
+
file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
|
347 |
+
with open(file_name, 'wt') as opt_file:
|
348 |
+
opt_file.write(message)
|
349 |
+
opt_file.write('\n')
|
350 |
+
"""
|
351 |
+
|
352 |
+
def parse(self):
|
353 |
+
"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
|
354 |
+
opt = self.gather_options()
|
355 |
+
opt.isTrain = self.isTrain # train or test
|
356 |
+
|
357 |
+
# process opt.suffix
|
358 |
+
if opt.suffix:
|
359 |
+
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
|
360 |
+
opt.name = opt.name + suffix
|
361 |
+
|
362 |
+
#self.print_options(opt)
|
363 |
+
|
364 |
+
# set gpu ids
|
365 |
+
str_ids = opt.gpu_ids.split(',')
|
366 |
+
opt.gpu_ids = []
|
367 |
+
for str_id in str_ids:
|
368 |
+
id = int(str_id)
|
369 |
+
if id >= 0:
|
370 |
+
opt.gpu_ids.append(id)
|
371 |
+
#if len(opt.gpu_ids) > 0:
|
372 |
+
# torch.cuda.set_device(opt.gpu_ids[0])
|
373 |
+
|
374 |
+
self.opt = opt
|
375 |
+
return self.opt
|
376 |
+
|
377 |
+
|
378 |
+
def estimateboost(img, model, model_type, pix2pixmodel, max_res=512):
|
379 |
+
global whole_size_threshold
|
380 |
+
|
381 |
+
# get settings
|
382 |
+
if hasattr(opts, 'depthmap_script_boost_rmax'):
|
383 |
+
whole_size_threshold = opts.depthmap_script_boost_rmax
|
384 |
+
|
385 |
+
if model_type == 0: #leres
|
386 |
+
net_receptive_field_size = 448
|
387 |
+
patch_netsize = 2 * net_receptive_field_size
|
388 |
+
elif model_type == 1: #dpt_beit_large_512
|
389 |
+
net_receptive_field_size = 512
|
390 |
+
patch_netsize = 2 * net_receptive_field_size
|
391 |
+
else: #other midas
|
392 |
+
net_receptive_field_size = 384
|
393 |
+
patch_netsize = 2 * net_receptive_field_size
|
394 |
+
|
395 |
+
gc.collect()
|
396 |
+
devices.torch_gc()
|
397 |
+
|
398 |
+
# Generate mask used to smoothly blend the local pathc estimations to the base estimate.
|
399 |
+
# It is arbitrarily large to avoid artifacts during rescaling for each crop.
|
400 |
+
mask_org = generatemask((3000, 3000))
|
401 |
+
mask = mask_org.copy()
|
402 |
+
|
403 |
+
# Value x of R_x defined in the section 5 of the main paper.
|
404 |
+
r_threshold_value = 0.2
|
405 |
+
#if R0:
|
406 |
+
# r_threshold_value = 0
|
407 |
+
|
408 |
+
input_resolution = img.shape
|
409 |
+
scale_threshold = 3 # Allows up-scaling with a scale up to 3
|
410 |
+
|
411 |
+
# Find the best input resolution R-x. The resolution search described in section 5-double estimation of the main paper and section B of the
|
412 |
+
# supplementary material.
|
413 |
+
whole_image_optimal_size, patch_scale = calculateprocessingres(img, net_receptive_field_size, r_threshold_value, scale_threshold, whole_size_threshold)
|
414 |
+
|
415 |
+
# print('wholeImage being processed in :', whole_image_optimal_size)
|
416 |
+
|
417 |
+
# Generate the base estimate using the double estimation.
|
418 |
+
whole_estimate = doubleestimate(img, net_receptive_field_size, whole_image_optimal_size, pix2pixsize, model, model_type, pix2pixmodel)
|
419 |
+
|
420 |
+
# Compute the multiplier described in section 6 of the main paper to make sure our initial patch can select
|
421 |
+
# small high-density regions of the image.
|
422 |
+
global factor
|
423 |
+
factor = max(min(1, 4 * patch_scale * whole_image_optimal_size / whole_size_threshold), 0.2)
|
424 |
+
# print('Adjust factor is:', 1/factor)
|
425 |
+
|
426 |
+
# Check if Local boosting is beneficial.
|
427 |
+
if max_res < whole_image_optimal_size:
|
428 |
+
# print("No Local boosting. Specified Max Res is smaller than R20, Returning doubleestimate result")
|
429 |
+
return cv2.resize(whole_estimate, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC)
|
430 |
+
|
431 |
+
# Compute the default target resolution.
|
432 |
+
if img.shape[0] > img.shape[1]:
|
433 |
+
a = 2 * whole_image_optimal_size
|
434 |
+
b = round(2 * whole_image_optimal_size * img.shape[1] / img.shape[0])
|
435 |
+
else:
|
436 |
+
a = round(2 * whole_image_optimal_size * img.shape[0] / img.shape[1])
|
437 |
+
b = 2 * whole_image_optimal_size
|
438 |
+
b = int(round(b / factor))
|
439 |
+
a = int(round(a / factor))
|
440 |
+
|
441 |
+
"""
|
442 |
+
# recompute a, b and saturate to max res.
|
443 |
+
if max(a,b) > max_res:
|
444 |
+
print('Default Res is higher than max-res: Reducing final resolution')
|
445 |
+
if img.shape[0] > img.shape[1]:
|
446 |
+
a = max_res
|
447 |
+
b = round(max_res * img.shape[1] / img.shape[0])
|
448 |
+
else:
|
449 |
+
a = round(max_res * img.shape[0] / img.shape[1])
|
450 |
+
b = max_res
|
451 |
+
b = int(b)
|
452 |
+
a = int(a)
|
453 |
+
"""
|
454 |
+
|
455 |
+
img = cv2.resize(img, (b, a), interpolation=cv2.INTER_CUBIC)
|
456 |
+
|
457 |
+
# Extract selected patches for local refinement
|
458 |
+
base_size = net_receptive_field_size * 2
|
459 |
+
patchset = generatepatchs(img, base_size)
|
460 |
+
|
461 |
+
# print('Target resolution: ', img.shape)
|
462 |
+
|
463 |
+
# Computing a scale in case user prompted to generate the results as the same resolution of the input.
|
464 |
+
# Notice that our method output resolution is independent of the input resolution and this parameter will only
|
465 |
+
# enable a scaling operation during the local patch merge implementation to generate results with the same resolution
|
466 |
+
# as the input.
|
467 |
+
"""
|
468 |
+
if output_resolution == 1:
|
469 |
+
mergein_scale = input_resolution[0] / img.shape[0]
|
470 |
+
print('Dynamicly change merged-in resolution; scale:', mergein_scale)
|
471 |
+
else:
|
472 |
+
mergein_scale = 1
|
473 |
+
"""
|
474 |
+
# always rescale to input res for now
|
475 |
+
mergein_scale = input_resolution[0] / img.shape[0]
|
476 |
+
|
477 |
+
imageandpatchs = ImageandPatchs('', '', patchset, img, mergein_scale)
|
478 |
+
whole_estimate_resized = cv2.resize(whole_estimate, (round(img.shape[1]*mergein_scale),
|
479 |
+
round(img.shape[0]*mergein_scale)), interpolation=cv2.INTER_CUBIC)
|
480 |
+
imageandpatchs.set_base_estimate(whole_estimate_resized.copy())
|
481 |
+
imageandpatchs.set_updated_estimate(whole_estimate_resized.copy())
|
482 |
+
|
483 |
+
print('Resulting depthmap resolution will be :', whole_estimate_resized.shape[:2])
|
484 |
+
print('Patches to process: '+str(len(imageandpatchs)))
|
485 |
+
|
486 |
+
# Enumerate through all patches, generate their estimations and refining the base estimate.
|
487 |
+
for patch_ind in range(len(imageandpatchs)):
|
488 |
+
|
489 |
+
# Get patch information
|
490 |
+
patch = imageandpatchs[patch_ind] # patch object
|
491 |
+
patch_rgb = patch['patch_rgb'] # rgb patch
|
492 |
+
patch_whole_estimate_base = patch['patch_whole_estimate_base'] # corresponding patch from base
|
493 |
+
rect = patch['rect'] # patch size and location
|
494 |
+
patch_id = patch['id'] # patch ID
|
495 |
+
org_size = patch_whole_estimate_base.shape # the original size from the unscaled input
|
496 |
+
print('\t Processing patch', patch_ind, '/', len(imageandpatchs)-1, '|', rect)
|
497 |
+
|
498 |
+
# We apply double estimation for patches. The high resolution value is fixed to twice the receptive
|
499 |
+
# field size of the network for patches to accelerate the process.
|
500 |
+
patch_estimation = doubleestimate(patch_rgb, net_receptive_field_size, patch_netsize, pix2pixsize, model, model_type, pix2pixmodel)
|
501 |
+
patch_estimation = cv2.resize(patch_estimation, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
|
502 |
+
patch_whole_estimate_base = cv2.resize(patch_whole_estimate_base, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
|
503 |
+
|
504 |
+
# Merging the patch estimation into the base estimate using our merge network:
|
505 |
+
# We feed the patch estimation and the same region from the updated base estimate to the merge network
|
506 |
+
# to generate the target estimate for the corresponding region.
|
507 |
+
pix2pixmodel.set_input(patch_whole_estimate_base, patch_estimation)
|
508 |
+
|
509 |
+
# Run merging network
|
510 |
+
pix2pixmodel.test()
|
511 |
+
visuals = pix2pixmodel.get_current_visuals()
|
512 |
+
|
513 |
+
prediction_mapped = visuals['fake_B']
|
514 |
+
prediction_mapped = (prediction_mapped+1)/2
|
515 |
+
prediction_mapped = prediction_mapped.squeeze().cpu().numpy()
|
516 |
+
|
517 |
+
mapped = prediction_mapped
|
518 |
+
|
519 |
+
# We use a simple linear polynomial to make sure the result of the merge network would match the values of
|
520 |
+
# base estimate
|
521 |
+
p_coef = np.polyfit(mapped.reshape(-1), patch_whole_estimate_base.reshape(-1), deg=1)
|
522 |
+
merged = np.polyval(p_coef, mapped.reshape(-1)).reshape(mapped.shape)
|
523 |
+
|
524 |
+
merged = cv2.resize(merged, (org_size[1],org_size[0]), interpolation=cv2.INTER_CUBIC)
|
525 |
+
|
526 |
+
# Get patch size and location
|
527 |
+
w1 = rect[0]
|
528 |
+
h1 = rect[1]
|
529 |
+
w2 = w1 + rect[2]
|
530 |
+
h2 = h1 + rect[3]
|
531 |
+
|
532 |
+
# To speed up the implementation, we only generate the Gaussian mask once with a sufficiently large size
|
533 |
+
# and resize it to our needed size while merging the patches.
|
534 |
+
if mask.shape != org_size:
|
535 |
+
mask = cv2.resize(mask_org, (org_size[1],org_size[0]), interpolation=cv2.INTER_LINEAR)
|
536 |
+
|
537 |
+
tobemergedto = imageandpatchs.estimation_updated_image
|
538 |
+
|
539 |
+
# Update the whole estimation:
|
540 |
+
# We use a simple Gaussian mask to blend the merged patch region with the base estimate to ensure seamless
|
541 |
+
# blending at the boundaries of the patch region.
|
542 |
+
tobemergedto[h1:h2, w1:w2] = np.multiply(tobemergedto[h1:h2, w1:w2], 1 - mask) + np.multiply(merged, mask)
|
543 |
+
imageandpatchs.set_updated_estimate(tobemergedto)
|
544 |
+
|
545 |
+
# output
|
546 |
+
return cv2.resize(imageandpatchs.estimation_updated_image, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC)
|
annotator/leres/leres/multi_depth_model_woauxi.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import network_auxi as network
|
2 |
+
from .net_tools import get_func
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from modules import devices
|
6 |
+
|
7 |
+
class RelDepthModel(nn.Module):
|
8 |
+
def __init__(self, backbone='resnet50'):
|
9 |
+
super(RelDepthModel, self).__init__()
|
10 |
+
if backbone == 'resnet50':
|
11 |
+
encoder = 'resnet50_stride32'
|
12 |
+
elif backbone == 'resnext101':
|
13 |
+
encoder = 'resnext101_stride32x8d'
|
14 |
+
self.depth_model = DepthModel(encoder)
|
15 |
+
|
16 |
+
def inference(self, rgb):
|
17 |
+
with torch.no_grad():
|
18 |
+
input = rgb.to(self.depth_model.device)
|
19 |
+
depth = self.depth_model(input)
|
20 |
+
#pred_depth_out = depth - depth.min() + 0.01
|
21 |
+
return depth #pred_depth_out
|
22 |
+
|
23 |
+
|
24 |
+
class DepthModel(nn.Module):
|
25 |
+
def __init__(self, encoder):
|
26 |
+
super(DepthModel, self).__init__()
|
27 |
+
backbone = network.__name__.split('.')[-1] + '.' + encoder
|
28 |
+
self.encoder_modules = get_func(backbone)()
|
29 |
+
self.decoder_modules = network.Decoder()
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
lateral_out = self.encoder_modules(x)
|
33 |
+
out_logit = self.decoder_modules(lateral_out)
|
34 |
+
return out_logit
|
annotator/leres/leres/net_tools.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
|
7 |
+
def get_func(func_name):
|
8 |
+
"""Helper to return a function object by name. func_name must identify a
|
9 |
+
function in this module or the path to a function relative to the base
|
10 |
+
'modeling' module.
|
11 |
+
"""
|
12 |
+
if func_name == '':
|
13 |
+
return None
|
14 |
+
try:
|
15 |
+
parts = func_name.split('.')
|
16 |
+
# Refers to a function in this module
|
17 |
+
if len(parts) == 1:
|
18 |
+
return globals()[parts[0]]
|
19 |
+
# Otherwise, assume we're referencing a module under modeling
|
20 |
+
module_name = 'annotator.leres.leres.' + '.'.join(parts[:-1])
|
21 |
+
module = importlib.import_module(module_name)
|
22 |
+
return getattr(module, parts[-1])
|
23 |
+
except Exception:
|
24 |
+
print('Failed to f1ind function: %s', func_name)
|
25 |
+
raise
|
26 |
+
|
27 |
+
def load_ckpt(args, depth_model, shift_model, focal_model):
|
28 |
+
"""
|
29 |
+
Load checkpoint.
|
30 |
+
"""
|
31 |
+
if os.path.isfile(args.load_ckpt):
|
32 |
+
print("loading checkpoint %s" % args.load_ckpt)
|
33 |
+
checkpoint = torch.load(args.load_ckpt)
|
34 |
+
if shift_model is not None:
|
35 |
+
shift_model.load_state_dict(strip_prefix_if_present(checkpoint['shift_model'], 'module.'),
|
36 |
+
strict=True)
|
37 |
+
if focal_model is not None:
|
38 |
+
focal_model.load_state_dict(strip_prefix_if_present(checkpoint['focal_model'], 'module.'),
|
39 |
+
strict=True)
|
40 |
+
depth_model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."),
|
41 |
+
strict=True)
|
42 |
+
del checkpoint
|
43 |
+
if torch.cuda.is_available():
|
44 |
+
torch.cuda.empty_cache()
|
45 |
+
|
46 |
+
|
47 |
+
def strip_prefix_if_present(state_dict, prefix):
|
48 |
+
keys = sorted(state_dict.keys())
|
49 |
+
if not all(key.startswith(prefix) for key in keys):
|
50 |
+
return state_dict
|
51 |
+
stripped_state_dict = OrderedDict()
|
52 |
+
for key, value in state_dict.items():
|
53 |
+
stripped_state_dict[key.replace(prefix, "")] = value
|
54 |
+
return stripped_state_dict
|