DamarJati commited on
Commit
cba094e
·
1 Parent(s): b204b04
Files changed (45) hide show
  1. README.md +6 -6
  2. app.py +102 -0
  3. image0.jpeg +0 -0
  4. image1.jpeg +0 -0
  5. lama_cleaner/__init__.py +0 -0
  6. lama_cleaner/__pycache__/__init__.cpython-310.pyc +0 -0
  7. lama_cleaner/__pycache__/helper.cpython-310.pyc +0 -0
  8. lama_cleaner/__pycache__/model_manager.cpython-310.pyc +0 -0
  9. lama_cleaner/__pycache__/schema.cpython-310.pyc +0 -0
  10. lama_cleaner/__pycache__/settings.cpython-310.pyc +0 -0
  11. lama_cleaner/__pycache__/urls.cpython-310.pyc +0 -0
  12. lama_cleaner/__pycache__/wsgi.cpython-310.pyc +0 -0
  13. lama_cleaner/asgi.py +16 -0
  14. lama_cleaner/helper.py +182 -0
  15. lama_cleaner/model/__init__.py +0 -0
  16. lama_cleaner/model/__pycache__/__init__.cpython-310.pyc +0 -0
  17. lama_cleaner/model/__pycache__/base.cpython-310.pyc +0 -0
  18. lama_cleaner/model/__pycache__/ddim_sampler.cpython-310.pyc +0 -0
  19. lama_cleaner/model/__pycache__/fcf.cpython-310.pyc +0 -0
  20. lama_cleaner/model/__pycache__/lama.cpython-310.pyc +0 -0
  21. lama_cleaner/model/__pycache__/ldm.cpython-310.pyc +0 -0
  22. lama_cleaner/model/__pycache__/mat.cpython-310.pyc +0 -0
  23. lama_cleaner/model/__pycache__/opencv2.cpython-310.pyc +0 -0
  24. lama_cleaner/model/__pycache__/plms_sampler.cpython-310.pyc +0 -0
  25. lama_cleaner/model/__pycache__/sd.cpython-310.pyc +0 -0
  26. lama_cleaner/model/__pycache__/utils.cpython-310.pyc +0 -0
  27. lama_cleaner/model/__pycache__/zits.cpython-310.pyc +0 -0
  28. lama_cleaner/model/base.py +183 -0
  29. lama_cleaner/model/ddim_sampler.py +193 -0
  30. lama_cleaner/model/fcf.py +1214 -0
  31. lama_cleaner/model/lama.py +61 -0
  32. lama_cleaner/model/ldm.py +312 -0
  33. lama_cleaner/model/mat.py +1444 -0
  34. lama_cleaner/model/opencv2.py +24 -0
  35. lama_cleaner/model/plms_sampler.py +225 -0
  36. lama_cleaner/model/sd.py +215 -0
  37. lama_cleaner/model/sd_pipeline.py +310 -0
  38. lama_cleaner/model/utils.py +709 -0
  39. lama_cleaner/model/zits.py +427 -0
  40. lama_cleaner/model_manager.py +43 -0
  41. lama_cleaner/schema.py +50 -0
  42. lama_cleaner/settings.py +124 -0
  43. lama_cleaner/urls.py +22 -0
  44. lama_cleaner/wsgi.py +16 -0
  45. requirements.txt +12 -0
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: Remove Watermark
3
- emoji: 📚
4
- colorFrom: gray
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.38.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Remove-WM
3
+ emoji: 🔍🗑️
4
+ colorFrom: pink
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.38.1
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
+ https://github.com/sponsors/Damarcreative
 
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from lama_cleaner.model_manager import ModelManager
4
+ from lama_cleaner.schema import Config, HDStrategy, LDMSampler
5
+ from transformers import AutoProcessor, AutoModelForCausalLM
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image, ImageDraw
9
+ import spaces
10
+ import subprocess
11
+
12
+ # Install necessary packages
13
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
14
+
15
+ # Initialize Florence model
16
+ model_id = 'microsoft/Florence-2-large'
17
+ florence_model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to("cuda").eval()
18
+ florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
19
+
20
+ # Initialize Llama Cleaner model
21
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
+
23
+ @spaces.GPU()
24
+ def process_image(image, mask, strategy, sampler, fx=1, fy=1):
25
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
26
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
27
+
28
+ if fx != 1 or fy != 1:
29
+ image = cv2.resize(image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
30
+ mask = cv2.resize(mask, None, fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST)
31
+
32
+ config = Config(
33
+ ldm_steps=1,
34
+ ldm_sampler=sampler,
35
+ hd_strategy=strategy,
36
+ hd_strategy_crop_margin=32,
37
+ hd_strategy_crop_trigger_size=200,
38
+ hd_strategy_resize_limit=200,
39
+ )
40
+
41
+ model = ModelManager(name="lama", device=device)
42
+ result = model(image, mask, config)
43
+ return result
44
+
45
+ def create_mask(image, prediction):
46
+ mask = Image.new("RGBA", image.size, (0, 0, 0, 255)) # Black background
47
+ draw = ImageDraw.Draw(mask)
48
+ scale = 1
49
+ for polygons in prediction['polygons']:
50
+ for _polygon in polygons:
51
+ _polygon = np.array(_polygon).reshape(-1, 2)
52
+ if len(_polygon) < 3:
53
+ continue
54
+ _polygon = (_polygon * scale).reshape(-1).tolist()
55
+ draw.polygon(_polygon, fill=(255, 255, 255, 255)) # Make selected area white
56
+ return mask
57
+
58
+ @spaces.GPU()
59
+ def process_images_florence_lama(image):
60
+ # Convert image to OpenCV format
61
+ image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
62
+
63
+ # Run Florence to get mask
64
+ text_input = 'watermark' # Teks untuk Florence agar mengenali watermark
65
+ task_prompt = '<REGION_TO_SEGMENTATION>'
66
+ image_pil = Image.fromarray(image_cv) # Convert array to PIL Image
67
+ inputs = florence_processor(text=task_prompt + text_input, images=image_pil, return_tensors="pt").to("cuda")
68
+ generated_ids = florence_model.generate(
69
+ input_ids=inputs["input_ids"],
70
+ pixel_values=inputs["pixel_values"],
71
+ max_new_tokens=1024,
72
+ early_stopping=False,
73
+ do_sample=False,
74
+ num_beams=3,
75
+ )
76
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
77
+ parsed_answer = florence_processor.post_process_generation(
78
+ generated_text,
79
+ task=task_prompt,
80
+ image_size=(image_pil.width, image_pil.height)
81
+ )
82
+
83
+ # Create mask and process image with Llama Cleaner
84
+ mask_image = create_mask(image_pil, parsed_answer['<REGION_TO_SEGMENTATION>'])
85
+ result_image = process_image(image_cv, np.array(mask_image), HDStrategy.RESIZE, LDMSampler.ddim)
86
+
87
+ # Convert result back to PIL Image
88
+ result_image_pil = Image.fromarray(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB))
89
+
90
+ return result_image_pil
91
+
92
+ # Define Gradio interface
93
+ demo = gr.Interface(
94
+ fn=process_images_florence_lama,
95
+ inputs=gr.Image(type="pil", label="Input Image"),
96
+ outputs=gr.Image(type="pil", label="Output Image"),
97
+ title="Watermark Remover.",
98
+ description="Upload images and remove selected watermarks using Florence and Llama Cleaner."
99
+ )
100
+ # Launch Gradio interface with example images
101
+ if __name__ == "__main__":
102
+ demo.launch()
image0.jpeg ADDED
image1.jpeg ADDED
lama_cleaner/__init__.py ADDED
File without changes
lama_cleaner/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (145 Bytes). View file
 
lama_cleaner/__pycache__/helper.cpython-310.pyc ADDED
Binary file (4.86 kB). View file
 
lama_cleaner/__pycache__/model_manager.cpython-310.pyc ADDED
Binary file (1.92 kB). View file
 
lama_cleaner/__pycache__/schema.cpython-310.pyc ADDED
Binary file (1.68 kB). View file
 
lama_cleaner/__pycache__/settings.cpython-310.pyc ADDED
Binary file (2.3 kB). View file
 
lama_cleaner/__pycache__/urls.cpython-310.pyc ADDED
Binary file (989 Bytes). View file
 
lama_cleaner/__pycache__/wsgi.cpython-310.pyc ADDED
Binary file (558 Bytes). View file
 
lama_cleaner/asgi.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ASGI config for lama_cleaner project.
3
+
4
+ It exposes the ASGI callable as a module-level variable named ``application``.
5
+
6
+ For more information on this file, see
7
+ https://docs.djangoproject.com/en/4.1/howto/deployment/asgi/
8
+ """
9
+
10
+ import os
11
+
12
+ from django.core.asgi import get_asgi_application
13
+
14
+ os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'lama_cleaner.settings')
15
+
16
+ application = get_asgi_application()
lama_cleaner/helper.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import List, Optional
4
+
5
+ from urllib.parse import urlparse
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from loguru import logger
10
+ from torch.hub import download_url_to_file, get_dir
11
+
12
+
13
+ def get_cache_path_by_url(url):
14
+ parts = urlparse(url)
15
+ hub_dir = get_dir()
16
+ model_dir = os.path.join(hub_dir, "checkpoints")
17
+ if not os.path.isdir(model_dir):
18
+ os.makedirs(os.path.join(model_dir, "hub", "checkpoints"))
19
+ filename = os.path.basename(parts.path)
20
+ cached_file = os.path.join(model_dir, filename)
21
+ return cached_file
22
+
23
+
24
+ def download_model(url):
25
+ cached_file = get_cache_path_by_url(url)
26
+ if not os.path.exists(cached_file):
27
+ sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
28
+ hash_prefix = None
29
+ download_url_to_file(url, cached_file, hash_prefix, progress=True)
30
+ return cached_file
31
+
32
+
33
+ def ceil_modulo(x, mod):
34
+ if x % mod == 0:
35
+ return x
36
+ return (x // mod + 1) * mod
37
+
38
+
39
+ def load_jit_model(url_or_path, device):
40
+ if os.path.exists(url_or_path):
41
+ model_path = url_or_path
42
+ else:
43
+ model_path = download_model(url_or_path)
44
+ logger.info(f"Load model from: {model_path}")
45
+ try:
46
+ model = torch.jit.load(model_path).to(device)
47
+ except:
48
+ logger.error(
49
+ f"Failed to load {model_path}, delete model and restart lama-cleaner"
50
+ )
51
+ exit(-1)
52
+ model.eval()
53
+ return model
54
+
55
+
56
+ def load_model(model: torch.nn.Module, url_or_path, device):
57
+ if os.path.exists(url_or_path):
58
+ model_path = url_or_path
59
+ else:
60
+ model_path = download_model(url_or_path)
61
+
62
+ try:
63
+ state_dict = torch.load(model_path, map_location='cpu')
64
+ model.load_state_dict(state_dict, strict=True)
65
+ model.to(device)
66
+ logger.info(f"Load model from: {model_path}")
67
+ except:
68
+ logger.error(
69
+ f"Failed to load {model_path}, delete model and restart lama-cleaner"
70
+ )
71
+ exit(-1)
72
+ model.eval()
73
+ return model
74
+
75
+
76
+ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
77
+ data = cv2.imencode(
78
+ f".{ext}",
79
+ image_numpy,
80
+ [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
81
+ )[1]
82
+ image_bytes = data.tobytes()
83
+ return image_bytes
84
+
85
+
86
+ def load_img(img_bytes, gray: bool = False):
87
+ alpha_channel = None
88
+ nparr = np.frombuffer(img_bytes, np.uint8)
89
+ if gray:
90
+ np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
91
+ else:
92
+ np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
93
+ if len(np_img.shape) == 3 and np_img.shape[2] == 4:
94
+ alpha_channel = np_img[:, :, -1]
95
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB)
96
+ else:
97
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
98
+
99
+ return np_img, alpha_channel
100
+
101
+
102
+ def norm_img(np_img):
103
+ if len(np_img.shape) == 2:
104
+ np_img = np_img[:, :, np.newaxis]
105
+ np_img = np.transpose(np_img, (2, 0, 1))
106
+ np_img = np_img.astype("float32") / 255
107
+ return np_img
108
+
109
+
110
+ def resize_max_size(
111
+ np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
112
+ ) -> np.ndarray:
113
+ # Resize image's longer size to size_limit if longer size larger than size_limit
114
+ h, w = np_img.shape[:2]
115
+ if max(h, w) > size_limit:
116
+ ratio = size_limit / max(h, w)
117
+ new_w = int(w * ratio + 0.5)
118
+ new_h = int(h * ratio + 0.5)
119
+ return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
120
+ else:
121
+ return np_img
122
+
123
+
124
+ def pad_img_to_modulo(
125
+ img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None
126
+ ):
127
+ """
128
+
129
+ Args:
130
+ img: [H, W, C]
131
+ mod:
132
+ square: 是否为正方形
133
+ min_size:
134
+
135
+ Returns:
136
+
137
+ """
138
+ if len(img.shape) == 2:
139
+ img = img[:, :, np.newaxis]
140
+ height, width = img.shape[:2]
141
+ out_height = ceil_modulo(height, mod)
142
+ out_width = ceil_modulo(width, mod)
143
+
144
+ if min_size is not None:
145
+ assert min_size % mod == 0
146
+ out_width = max(min_size, out_width)
147
+ out_height = max(min_size, out_height)
148
+
149
+ if square:
150
+ max_size = max(out_height, out_width)
151
+ out_height = max_size
152
+ out_width = max_size
153
+
154
+ return np.pad(
155
+ img,
156
+ ((0, out_height - height), (0, out_width - width), (0, 0)),
157
+ mode="symmetric",
158
+ )
159
+
160
+
161
+ def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
162
+ """
163
+ Args:
164
+ mask: (h, w, 1) 0~255
165
+
166
+ Returns:
167
+
168
+ """
169
+ height, width = mask.shape[:2]
170
+ _, thresh = cv2.threshold(mask, 127, 255, 0)
171
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
172
+
173
+ boxes = []
174
+ for cnt in contours:
175
+ x, y, w, h = cv2.boundingRect(cnt)
176
+ box = np.array([x, y, x + w, y + h]).astype(int)
177
+
178
+ box[::2] = np.clip(box[::2], 0, width)
179
+ box[1::2] = np.clip(box[1::2], 0, height)
180
+ boxes.append(box)
181
+
182
+ return boxes
lama_cleaner/model/__init__.py ADDED
File without changes
lama_cleaner/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (184 Bytes). View file
 
lama_cleaner/model/__pycache__/base.cpython-310.pyc ADDED
Binary file (4.76 kB). View file
 
lama_cleaner/model/__pycache__/ddim_sampler.cpython-310.pyc ADDED
Binary file (4.74 kB). View file
 
lama_cleaner/model/__pycache__/fcf.cpython-310.pyc ADDED
Binary file (33.4 kB). View file
 
lama_cleaner/model/__pycache__/lama.cpython-310.pyc ADDED
Binary file (2.16 kB). View file
 
lama_cleaner/model/__pycache__/ldm.cpython-310.pyc ADDED
Binary file (7.86 kB). View file
 
lama_cleaner/model/__pycache__/mat.cpython-310.pyc ADDED
Binary file (38 kB). View file
 
lama_cleaner/model/__pycache__/opencv2.cpython-310.pyc ADDED
Binary file (1.14 kB). View file
 
lama_cleaner/model/__pycache__/plms_sampler.cpython-310.pyc ADDED
Binary file (7.08 kB). View file
 
lama_cleaner/model/__pycache__/sd.cpython-310.pyc ADDED
Binary file (5.83 kB). View file
 
lama_cleaner/model/__pycache__/utils.cpython-310.pyc ADDED
Binary file (26 kB). View file
 
lama_cleaner/model/__pycache__/zits.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
lama_cleaner/model/base.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Optional
3
+
4
+ import cv2
5
+ import torch
6
+ from loguru import logger
7
+
8
+ from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo
9
+ from lama_cleaner.schema import Config, HDStrategy
10
+
11
+
12
+ class InpaintModel:
13
+ min_size: Optional[int] = None
14
+ pad_mod = 8
15
+ pad_to_square = False
16
+
17
+ def __init__(self, device, **kwargs):
18
+ """
19
+
20
+ Args:
21
+ device:
22
+ """
23
+ self.device = device
24
+ self.init_model(device, **kwargs)
25
+
26
+ @abc.abstractmethod
27
+ def init_model(self, device, **kwargs):
28
+ ...
29
+
30
+ @staticmethod
31
+ @abc.abstractmethod
32
+ def is_downloaded() -> bool:
33
+ ...
34
+
35
+ @abc.abstractmethod
36
+ def forward(self, image, mask, config: Config):
37
+ """Input images and output images have same size
38
+ images: [H, W, C] RGB
39
+ masks: [H, W, 1] 255 为 masks 区域
40
+ return: BGR IMAGE
41
+ """
42
+ ...
43
+
44
+ def _pad_forward(self, image, mask, config: Config):
45
+ origin_height, origin_width = image.shape[:2]
46
+ pad_image = pad_img_to_modulo(
47
+ image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
48
+ )
49
+ pad_mask = pad_img_to_modulo(
50
+ mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
51
+ )
52
+
53
+ logger.info(f"final forward pad size: {pad_image.shape}")
54
+
55
+ result = self.forward(pad_image, pad_mask, config)
56
+ result = result[0:origin_height, 0:origin_width, :]
57
+
58
+ original_pixel_indices = mask < 127
59
+ result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
60
+ return result
61
+
62
+ @torch.no_grad()
63
+ def __call__(self, image, mask, config: Config):
64
+ """
65
+ images: [H, W, C] RGB, not normalized
66
+ masks: [H, W]
67
+ return: BGR IMAGE
68
+ """
69
+ inpaint_result = None
70
+ logger.info(f"hd_strategy: {config.hd_strategy}")
71
+ if config.hd_strategy == HDStrategy.CROP:
72
+ if max(image.shape) > config.hd_strategy_crop_trigger_size:
73
+ logger.info(f"Run crop strategy")
74
+ boxes = boxes_from_mask(mask)
75
+ crop_result = []
76
+ for box in boxes:
77
+ crop_image, crop_box = self._run_box(image, mask, box, config)
78
+ crop_result.append((crop_image, crop_box))
79
+
80
+ inpaint_result = image[:, :, ::-1]
81
+ for crop_image, crop_box in crop_result:
82
+ x1, y1, x2, y2 = crop_box
83
+ inpaint_result[y1:y2, x1:x2, :] = crop_image
84
+
85
+ elif config.hd_strategy == HDStrategy.RESIZE:
86
+ if max(image.shape) > config.hd_strategy_resize_limit:
87
+ origin_size = image.shape[:2]
88
+ downsize_image = resize_max_size(
89
+ image, size_limit=config.hd_strategy_resize_limit
90
+ )
91
+ downsize_mask = resize_max_size(
92
+ mask, size_limit=config.hd_strategy_resize_limit
93
+ )
94
+
95
+ logger.info(
96
+ f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}"
97
+ )
98
+ inpaint_result = self._pad_forward(
99
+ downsize_image, downsize_mask, config
100
+ )
101
+
102
+ # only paste masked area result
103
+ inpaint_result = cv2.resize(
104
+ inpaint_result,
105
+ (origin_size[1], origin_size[0]),
106
+ interpolation=cv2.INTER_CUBIC,
107
+ )
108
+ original_pixel_indices = mask < 127
109
+ inpaint_result[original_pixel_indices] = image[:, :, ::-1][
110
+ original_pixel_indices
111
+ ]
112
+
113
+ if inpaint_result is None:
114
+ inpaint_result = self._pad_forward(image, mask, config)
115
+
116
+ return inpaint_result
117
+
118
+ def _crop_box(self, image, mask, box, config: Config):
119
+ """
120
+
121
+ Args:
122
+ image: [H, W, C] RGB
123
+ mask: [H, W, 1]
124
+ box: [left,top,right,bottom]
125
+
126
+ Returns:
127
+ BGR IMAGE, (l, r, r, b)
128
+ """
129
+ box_h = box[3] - box[1]
130
+ box_w = box[2] - box[0]
131
+ cx = (box[0] + box[2]) // 2
132
+ cy = (box[1] + box[3]) // 2
133
+ img_h, img_w = image.shape[:2]
134
+
135
+ w = box_w + config.hd_strategy_crop_margin * 2
136
+ h = box_h + config.hd_strategy_crop_margin * 2
137
+
138
+ _l = cx - w // 2
139
+ _r = cx + w // 2
140
+ _t = cy - h // 2
141
+ _b = cy + h // 2
142
+
143
+ l = max(_l, 0)
144
+ r = min(_r, img_w)
145
+ t = max(_t, 0)
146
+ b = min(_b, img_h)
147
+
148
+ # try to get more context when crop around image edge
149
+ if _l < 0:
150
+ r += abs(_l)
151
+ if _r > img_w:
152
+ l -= _r - img_w
153
+ if _t < 0:
154
+ b += abs(_t)
155
+ if _b > img_h:
156
+ t -= _b - img_h
157
+
158
+ l = max(l, 0)
159
+ r = min(r, img_w)
160
+ t = max(t, 0)
161
+ b = min(b, img_h)
162
+
163
+ crop_img = image[t:b, l:r, :]
164
+ crop_mask = mask[t:b, l:r]
165
+
166
+ logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
167
+
168
+ return crop_img, crop_mask, [l, t, r, b]
169
+
170
+ def _run_box(self, image, mask, box, config: Config):
171
+ """
172
+
173
+ Args:
174
+ image: [H, W, C] RGB
175
+ mask: [H, W, 1]
176
+ box: [left,top,right,bottom]
177
+
178
+ Returns:
179
+ BGR IMAGE
180
+ """
181
+ crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
182
+
183
+ return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
lama_cleaner/model/ddim_sampler.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+
5
+ from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
6
+
7
+ from loguru import logger
8
+
9
+
10
+ class DDIMSampler(object):
11
+ def __init__(self, model, schedule="linear"):
12
+ super().__init__()
13
+ self.model = model
14
+ self.ddpm_num_timesteps = model.num_timesteps
15
+ self.schedule = schedule
16
+
17
+ def register_buffer(self, name, attr):
18
+ setattr(self, name, attr)
19
+
20
+ def make_schedule(
21
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
22
+ ):
23
+ self.ddim_timesteps = make_ddim_timesteps(
24
+ ddim_discr_method=ddim_discretize,
25
+ num_ddim_timesteps=ddim_num_steps,
26
+ # array([1])
27
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
28
+ verbose=verbose,
29
+ )
30
+ alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
31
+ assert (
32
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
33
+ ), "alphas have to be defined for each timestep"
34
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
35
+
36
+ self.register_buffer("betas", to_torch(self.model.betas))
37
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
38
+ self.register_buffer(
39
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
40
+ )
41
+
42
+ # calculations for diffusion q(x_t | x_{t-1}) and others
43
+ self.register_buffer(
44
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
45
+ )
46
+ self.register_buffer(
47
+ "sqrt_one_minus_alphas_cumprod",
48
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
49
+ )
50
+ self.register_buffer(
51
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
52
+ )
53
+ self.register_buffer(
54
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
55
+ )
56
+ self.register_buffer(
57
+ "sqrt_recipm1_alphas_cumprod",
58
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
59
+ )
60
+
61
+ # ddim sampling parameters
62
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
63
+ alphacums=alphas_cumprod.cpu(),
64
+ ddim_timesteps=self.ddim_timesteps,
65
+ eta=ddim_eta,
66
+ verbose=verbose,
67
+ )
68
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
69
+ self.register_buffer("ddim_alphas", ddim_alphas)
70
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
71
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
72
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
73
+ (1 - self.alphas_cumprod_prev)
74
+ / (1 - self.alphas_cumprod)
75
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
76
+ )
77
+ self.register_buffer(
78
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
79
+ )
80
+
81
+ @torch.no_grad()
82
+ def sample(self, steps, conditioning, batch_size, shape):
83
+ self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
84
+ # sampling
85
+ C, H, W = shape
86
+ size = (batch_size, C, H, W)
87
+
88
+ # samples: 1,3,128,128
89
+ return self.ddim_sampling(
90
+ conditioning,
91
+ size,
92
+ quantize_denoised=False,
93
+ ddim_use_original_steps=False,
94
+ noise_dropout=0,
95
+ temperature=1.0,
96
+ )
97
+
98
+ @torch.no_grad()
99
+ def ddim_sampling(
100
+ self,
101
+ cond,
102
+ shape,
103
+ ddim_use_original_steps=False,
104
+ quantize_denoised=False,
105
+ temperature=1.0,
106
+ noise_dropout=0.0,
107
+ ):
108
+ device = self.model.betas.device
109
+ b = shape[0]
110
+ img = torch.randn(shape, device=device, dtype=cond.dtype)
111
+ timesteps = (
112
+ self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
113
+ )
114
+
115
+ time_range = (
116
+ reversed(range(0, timesteps))
117
+ if ddim_use_original_steps
118
+ else np.flip(timesteps)
119
+ )
120
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
121
+ logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
122
+
123
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
124
+
125
+ for i, step in enumerate(iterator):
126
+ index = total_steps - i - 1
127
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
128
+
129
+ outs = self.p_sample_ddim(
130
+ img,
131
+ cond,
132
+ ts,
133
+ index=index,
134
+ use_original_steps=ddim_use_original_steps,
135
+ quantize_denoised=quantize_denoised,
136
+ temperature=temperature,
137
+ noise_dropout=noise_dropout,
138
+ )
139
+ img, _ = outs
140
+
141
+ return img
142
+
143
+ @torch.no_grad()
144
+ def p_sample_ddim(
145
+ self,
146
+ x,
147
+ c,
148
+ t,
149
+ index,
150
+ repeat_noise=False,
151
+ use_original_steps=False,
152
+ quantize_denoised=False,
153
+ temperature=1.0,
154
+ noise_dropout=0.0,
155
+ ):
156
+ b, *_, device = *x.shape, x.device
157
+ e_t = self.model.apply_model(x, t, c)
158
+
159
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
160
+ alphas_prev = (
161
+ self.model.alphas_cumprod_prev
162
+ if use_original_steps
163
+ else self.ddim_alphas_prev
164
+ )
165
+ sqrt_one_minus_alphas = (
166
+ self.model.sqrt_one_minus_alphas_cumprod
167
+ if use_original_steps
168
+ else self.ddim_sqrt_one_minus_alphas
169
+ )
170
+ sigmas = (
171
+ self.model.ddim_sigmas_for_original_num_steps
172
+ if use_original_steps
173
+ else self.ddim_sigmas
174
+ )
175
+ # select parameters corresponding to the currently considered timestep
176
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
177
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
178
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
179
+ sqrt_one_minus_at = torch.full(
180
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
181
+ )
182
+
183
+ # current prediction for x_0
184
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
185
+ if quantize_denoised: # 没用
186
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
187
+ # direction pointing to x_t
188
+ dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
189
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
190
+ if noise_dropout > 0.0: # 没用
191
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
192
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
193
+ return x_prev, pred_x0
lama_cleaner/model/fcf.py ADDED
@@ -0,0 +1,1214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import cv2
5
+ import torch
6
+ import numpy as np
7
+ import torch.fft as fft
8
+
9
+ from lama_cleaner.schema import Config
10
+
11
+ from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img, boxes_from_mask, resize_max_size
12
+ from lama_cleaner.model.base import InpaintModel
13
+ from torch import conv2d, nn
14
+ import torch.nn.functional as F
15
+
16
+ from lama_cleaner.model.utils import setup_filter, _parse_scaling, _parse_padding, Conv2dLayer, FullyConnectedLayer, \
17
+ MinibatchStdLayer, activation_funcs, conv2d_resample, bias_act, upsample2d, normalize_2nd_moment, downsample2d
18
+
19
+
20
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
21
+ assert isinstance(x, torch.Tensor)
22
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
23
+
24
+
25
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
26
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
27
+ """
28
+ # Validate arguments.
29
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
30
+ if f is None:
31
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
32
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
33
+ assert f.dtype == torch.float32 and not f.requires_grad
34
+ batch_size, num_channels, in_height, in_width = x.shape
35
+ upx, upy = _parse_scaling(up)
36
+ downx, downy = _parse_scaling(down)
37
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
38
+
39
+ # Upsample by inserting zeros.
40
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
41
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
42
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
43
+
44
+ # Pad or crop.
45
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
46
+ x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)]
47
+
48
+ # Setup filter.
49
+ f = f * (gain ** (f.ndim / 2))
50
+ f = f.to(x.dtype)
51
+ if not flip_filter:
52
+ f = f.flip(list(range(f.ndim)))
53
+
54
+ # Convolve with the filter.
55
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
56
+ if f.ndim == 4:
57
+ x = conv2d(input=x, weight=f, groups=num_channels)
58
+ else:
59
+ x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
60
+ x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
61
+
62
+ # Downsample by throwing away pixels.
63
+ x = x[:, :, ::downy, ::downx]
64
+ return x
65
+
66
+
67
+ class EncoderEpilogue(torch.nn.Module):
68
+ def __init__(self,
69
+ in_channels, # Number of input channels.
70
+ cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
71
+ z_dim, # Output Latent (Z) dimensionality.
72
+ resolution, # Resolution of this block.
73
+ img_channels, # Number of input color channels.
74
+ architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'.
75
+ mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
76
+ mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
77
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
78
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
79
+ ):
80
+ assert architecture in ['orig', 'skip', 'resnet']
81
+ super().__init__()
82
+ self.in_channels = in_channels
83
+ self.cmap_dim = cmap_dim
84
+ self.resolution = resolution
85
+ self.img_channels = img_channels
86
+ self.architecture = architecture
87
+
88
+ if architecture == 'skip':
89
+ self.fromrgb = Conv2dLayer(self.img_channels, in_channels, kernel_size=1, activation=activation)
90
+ self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size,
91
+ num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
92
+ self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation,
93
+ conv_clamp=conv_clamp)
94
+ self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), z_dim, activation=activation)
95
+ self.dropout = torch.nn.Dropout(p=0.5)
96
+
97
+ def forward(self, x, cmap, force_fp32=False):
98
+ _ = force_fp32 # unused
99
+ dtype = torch.float32
100
+ memory_format = torch.contiguous_format
101
+
102
+ # FromRGB.
103
+ x = x.to(dtype=dtype, memory_format=memory_format)
104
+
105
+ # Main layers.
106
+ if self.mbstd is not None:
107
+ x = self.mbstd(x)
108
+ const_e = self.conv(x)
109
+ x = self.fc(const_e.flatten(1))
110
+ x = self.dropout(x)
111
+
112
+ # Conditioning.
113
+ if self.cmap_dim > 0:
114
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
115
+
116
+ assert x.dtype == dtype
117
+ return x, const_e
118
+
119
+
120
+ class EncoderBlock(torch.nn.Module):
121
+ def __init__(self,
122
+ in_channels, # Number of input channels, 0 = first block.
123
+ tmp_channels, # Number of intermediate channels.
124
+ out_channels, # Number of output channels.
125
+ resolution, # Resolution of this block.
126
+ img_channels, # Number of input color channels.
127
+ first_layer_idx, # Index of the first layer.
128
+ architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
129
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
130
+ resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
131
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
132
+ use_fp16=False, # Use FP16 for this block?
133
+ fp16_channels_last=False, # Use channels-last memory format with FP16?
134
+ freeze_layers=0, # Freeze-D: Number of layers to freeze.
135
+ ):
136
+ assert in_channels in [0, tmp_channels]
137
+ assert architecture in ['orig', 'skip', 'resnet']
138
+ super().__init__()
139
+ self.in_channels = in_channels
140
+ self.resolution = resolution
141
+ self.img_channels = img_channels + 1
142
+ self.first_layer_idx = first_layer_idx
143
+ self.architecture = architecture
144
+ self.use_fp16 = use_fp16
145
+ self.channels_last = (use_fp16 and fp16_channels_last)
146
+ self.register_buffer('resample_filter', setup_filter(resample_filter))
147
+
148
+ self.num_layers = 0
149
+
150
+ def trainable_gen():
151
+ while True:
152
+ layer_idx = self.first_layer_idx + self.num_layers
153
+ trainable = (layer_idx >= freeze_layers)
154
+ self.num_layers += 1
155
+ yield trainable
156
+
157
+ trainable_iter = trainable_gen()
158
+
159
+ if in_channels == 0:
160
+ self.fromrgb = Conv2dLayer(self.img_channels, tmp_channels, kernel_size=1, activation=activation,
161
+ trainable=next(trainable_iter), conv_clamp=conv_clamp,
162
+ channels_last=self.channels_last)
163
+
164
+ self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
165
+ trainable=next(trainable_iter), conv_clamp=conv_clamp,
166
+ channels_last=self.channels_last)
167
+
168
+ self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
169
+ trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp,
170
+ channels_last=self.channels_last)
171
+
172
+ if architecture == 'resnet':
173
+ self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
174
+ trainable=next(trainable_iter), resample_filter=resample_filter,
175
+ channels_last=self.channels_last)
176
+
177
+ def forward(self, x, img, force_fp32=False):
178
+ # dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
179
+ dtype = torch.float32
180
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
181
+
182
+ # Input.
183
+ if x is not None:
184
+ x = x.to(dtype=dtype, memory_format=memory_format)
185
+
186
+ # FromRGB.
187
+ if self.in_channels == 0:
188
+ img = img.to(dtype=dtype, memory_format=memory_format)
189
+ y = self.fromrgb(img)
190
+ x = x + y if x is not None else y
191
+ img = downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None
192
+
193
+ # Main layers.
194
+ if self.architecture == 'resnet':
195
+ y = self.skip(x, gain=np.sqrt(0.5))
196
+ x = self.conv0(x)
197
+ feat = x.clone()
198
+ x = self.conv1(x, gain=np.sqrt(0.5))
199
+ x = y.add_(x)
200
+ else:
201
+ x = self.conv0(x)
202
+ feat = x.clone()
203
+ x = self.conv1(x)
204
+
205
+ assert x.dtype == dtype
206
+ return x, img, feat
207
+
208
+
209
+ class EncoderNetwork(torch.nn.Module):
210
+ def __init__(self,
211
+ c_dim, # Conditioning label (C) dimensionality.
212
+ z_dim, # Input latent (Z) dimensionality.
213
+ img_resolution, # Input resolution.
214
+ img_channels, # Number of input color channels.
215
+ architecture='orig', # Architecture: 'orig', 'skip', 'resnet'.
216
+ channel_base=16384, # Overall multiplier for the number of channels.
217
+ channel_max=512, # Maximum number of channels in any layer.
218
+ num_fp16_res=0, # Use FP16 for the N highest resolutions.
219
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
220
+ cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
221
+ block_kwargs={}, # Arguments for DiscriminatorBlock.
222
+ mapping_kwargs={}, # Arguments for MappingNetwork.
223
+ epilogue_kwargs={}, # Arguments for EncoderEpilogue.
224
+ ):
225
+ super().__init__()
226
+ self.c_dim = c_dim
227
+ self.z_dim = z_dim
228
+ self.img_resolution = img_resolution
229
+ self.img_resolution_log2 = int(np.log2(img_resolution))
230
+ self.img_channels = img_channels
231
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
232
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
233
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
234
+
235
+ if cmap_dim is None:
236
+ cmap_dim = channels_dict[4]
237
+ if c_dim == 0:
238
+ cmap_dim = 0
239
+
240
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
241
+ cur_layer_idx = 0
242
+ for res in self.block_resolutions:
243
+ in_channels = channels_dict[res] if res < img_resolution else 0
244
+ tmp_channels = channels_dict[res]
245
+ out_channels = channels_dict[res // 2]
246
+ use_fp16 = (res >= fp16_resolution)
247
+ use_fp16 = False
248
+ block = EncoderBlock(in_channels, tmp_channels, out_channels, resolution=res,
249
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
250
+ setattr(self, f'b{res}', block)
251
+ cur_layer_idx += block.num_layers
252
+ if c_dim > 0:
253
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None,
254
+ **mapping_kwargs)
255
+ self.b4 = EncoderEpilogue(channels_dict[4], cmap_dim=cmap_dim, z_dim=z_dim * 2, resolution=4, **epilogue_kwargs,
256
+ **common_kwargs)
257
+
258
+ def forward(self, img, c, **block_kwargs):
259
+ x = None
260
+ feats = {}
261
+ for res in self.block_resolutions:
262
+ block = getattr(self, f'b{res}')
263
+ x, img, feat = block(x, img, **block_kwargs)
264
+ feats[res] = feat
265
+
266
+ cmap = None
267
+ if self.c_dim > 0:
268
+ cmap = self.mapping(None, c)
269
+ x, const_e = self.b4(x, cmap)
270
+ feats[4] = const_e
271
+
272
+ B, _ = x.shape
273
+ z = torch.zeros((B, self.z_dim), requires_grad=False, dtype=x.dtype,
274
+ device=x.device) ## Noise for Co-Modulation
275
+ return x, z, feats
276
+
277
+
278
+ def fma(a, b, c): # => a * b + c
279
+ return _FusedMultiplyAdd.apply(a, b, c)
280
+
281
+
282
+ class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
283
+ @staticmethod
284
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
285
+ out = torch.addcmul(c, a, b)
286
+ ctx.save_for_backward(a, b)
287
+ ctx.c_shape = c.shape
288
+ return out
289
+
290
+ @staticmethod
291
+ def backward(ctx, dout): # pylint: disable=arguments-differ
292
+ a, b = ctx.saved_tensors
293
+ c_shape = ctx.c_shape
294
+ da = None
295
+ db = None
296
+ dc = None
297
+
298
+ if ctx.needs_input_grad[0]:
299
+ da = _unbroadcast(dout * b, a.shape)
300
+
301
+ if ctx.needs_input_grad[1]:
302
+ db = _unbroadcast(dout * a, b.shape)
303
+
304
+ if ctx.needs_input_grad[2]:
305
+ dc = _unbroadcast(dout, c_shape)
306
+
307
+ return da, db, dc
308
+
309
+
310
+ def _unbroadcast(x, shape):
311
+ extra_dims = x.ndim - len(shape)
312
+ assert extra_dims >= 0
313
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
314
+ if len(dim):
315
+ x = x.sum(dim=dim, keepdim=True)
316
+ if extra_dims:
317
+ x = x.reshape(-1, *x.shape[extra_dims + 1:])
318
+ assert x.shape == shape
319
+ return x
320
+
321
+
322
+ def modulated_conv2d(
323
+ x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
324
+ weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
325
+ styles, # Modulation coefficients of shape [batch_size, in_channels].
326
+ noise=None, # Optional noise tensor to add to the output activations.
327
+ up=1, # Integer upsampling factor.
328
+ down=1, # Integer downsampling factor.
329
+ padding=0, # Padding with respect to the upsampled image.
330
+ resample_filter=None,
331
+ # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
332
+ demodulate=True, # Apply weight demodulation?
333
+ flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
334
+ fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation?
335
+ ):
336
+ batch_size = x.shape[0]
337
+ out_channels, in_channels, kh, kw = weight.shape
338
+
339
+ # Pre-normalize inputs to avoid FP16 overflow.
340
+ if x.dtype == torch.float16 and demodulate:
341
+ weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1, 2, 3],
342
+ keepdim=True)) # max_Ikk
343
+ styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
344
+
345
+ # Calculate per-sample weights and demodulation coefficients.
346
+ w = None
347
+ dcoefs = None
348
+ if demodulate or fused_modconv:
349
+ w = weight.unsqueeze(0) # [NOIkk]
350
+ w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
351
+ if demodulate:
352
+ dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO]
353
+ if demodulate and fused_modconv:
354
+ w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
355
+ # Execute by scaling the activations before and after the convolution.
356
+ if not fused_modconv:
357
+ x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
358
+ x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down,
359
+ padding=padding, flip_weight=flip_weight)
360
+ if demodulate and noise is not None:
361
+ x = fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
362
+ elif demodulate:
363
+ x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
364
+ elif noise is not None:
365
+ x = x.add_(noise.to(x.dtype))
366
+ return x
367
+
368
+ # Execute as one fused op using grouped convolution.
369
+ batch_size = int(batch_size)
370
+ x = x.reshape(1, -1, *x.shape[2:])
371
+ w = w.reshape(-1, in_channels, kh, kw)
372
+ x = conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding,
373
+ groups=batch_size, flip_weight=flip_weight)
374
+ x = x.reshape(batch_size, -1, *x.shape[2:])
375
+ if noise is not None:
376
+ x = x.add_(noise)
377
+ return x
378
+
379
+
380
+ class SynthesisLayer(torch.nn.Module):
381
+ def __init__(self,
382
+ in_channels, # Number of input channels.
383
+ out_channels, # Number of output channels.
384
+ w_dim, # Intermediate latent (W) dimensionality.
385
+ resolution, # Resolution of this layer.
386
+ kernel_size=3, # Convolution kernel size.
387
+ up=1, # Integer upsampling factor.
388
+ use_noise=True, # Enable noise input?
389
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
390
+ resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
391
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
392
+ channels_last=False, # Use channels_last format for the weights?
393
+ ):
394
+ super().__init__()
395
+ self.resolution = resolution
396
+ self.up = up
397
+ self.use_noise = use_noise
398
+ self.activation = activation
399
+ self.conv_clamp = conv_clamp
400
+ self.register_buffer('resample_filter', setup_filter(resample_filter))
401
+ self.padding = kernel_size // 2
402
+ self.act_gain = activation_funcs[activation].def_gain
403
+
404
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
405
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
406
+ self.weight = torch.nn.Parameter(
407
+ torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
408
+ if use_noise:
409
+ self.register_buffer('noise_const', torch.randn([resolution, resolution]))
410
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
411
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
412
+
413
+ def forward(self, x, w, noise_mode='none', fused_modconv=True, gain=1):
414
+ assert noise_mode in ['random', 'const', 'none']
415
+ in_resolution = self.resolution // self.up
416
+ styles = self.affine(w)
417
+
418
+ noise = None
419
+ if self.use_noise and noise_mode == 'random':
420
+ noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution],
421
+ device=x.device) * self.noise_strength
422
+ if self.use_noise and noise_mode == 'const':
423
+ noise = self.noise_const * self.noise_strength
424
+
425
+ flip_weight = (self.up == 1) # slightly faster
426
+ x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
427
+ padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight,
428
+ fused_modconv=fused_modconv)
429
+
430
+ act_gain = self.act_gain * gain
431
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
432
+ x = F.leaky_relu(x, negative_slope=0.2, inplace=False)
433
+ if act_gain != 1:
434
+ x = x * act_gain
435
+ if act_clamp is not None:
436
+ x = x.clamp(-act_clamp, act_clamp)
437
+ return x
438
+
439
+
440
+ class ToRGBLayer(torch.nn.Module):
441
+ def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
442
+ super().__init__()
443
+ self.conv_clamp = conv_clamp
444
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
445
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
446
+ self.weight = torch.nn.Parameter(
447
+ torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
448
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
449
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
450
+
451
+ def forward(self, x, w, fused_modconv=True):
452
+ styles = self.affine(w) * self.weight_gain
453
+ x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
454
+ x = bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
455
+ return x
456
+
457
+
458
+ class SynthesisForeword(torch.nn.Module):
459
+ def __init__(self,
460
+ z_dim, # Output Latent (Z) dimensionality.
461
+ resolution, # Resolution of this block.
462
+ in_channels,
463
+ img_channels, # Number of input color channels.
464
+ architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
465
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
466
+
467
+ ):
468
+ super().__init__()
469
+ self.in_channels = in_channels
470
+ self.z_dim = z_dim
471
+ self.resolution = resolution
472
+ self.img_channels = img_channels
473
+ self.architecture = architecture
474
+
475
+ self.fc = FullyConnectedLayer(self.z_dim, (self.z_dim // 2) * 4 * 4, activation=activation)
476
+ self.conv = SynthesisLayer(self.in_channels, self.in_channels, w_dim=(z_dim // 2) * 3, resolution=4)
477
+
478
+ if architecture == 'skip':
479
+ self.torgb = ToRGBLayer(self.in_channels, self.img_channels, kernel_size=1, w_dim=(z_dim // 2) * 3)
480
+
481
+ def forward(self, x, ws, feats, img, force_fp32=False):
482
+ _ = force_fp32 # unused
483
+ dtype = torch.float32
484
+ memory_format = torch.contiguous_format
485
+
486
+ x_global = x.clone()
487
+ # ToRGB.
488
+ x = self.fc(x)
489
+ x = x.view(-1, self.z_dim // 2, 4, 4)
490
+ x = x.to(dtype=dtype, memory_format=memory_format)
491
+
492
+ # Main layers.
493
+ x_skip = feats[4].clone()
494
+ x = x + x_skip
495
+
496
+ mod_vector = []
497
+ mod_vector.append(ws[:, 0])
498
+ mod_vector.append(x_global.clone())
499
+ mod_vector = torch.cat(mod_vector, dim=1)
500
+
501
+ x = self.conv(x, mod_vector)
502
+
503
+ mod_vector = []
504
+ mod_vector.append(ws[:, 2 * 2 - 3])
505
+ mod_vector.append(x_global.clone())
506
+ mod_vector = torch.cat(mod_vector, dim=1)
507
+
508
+ if self.architecture == 'skip':
509
+ img = self.torgb(x, mod_vector)
510
+ img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format)
511
+
512
+ assert x.dtype == dtype
513
+ return x, img
514
+
515
+
516
+ class SELayer(nn.Module):
517
+ def __init__(self, channel, reduction=16):
518
+ super(SELayer, self).__init__()
519
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
520
+ self.fc = nn.Sequential(
521
+ nn.Linear(channel, channel // reduction, bias=False),
522
+ nn.ReLU(inplace=False),
523
+ nn.Linear(channel // reduction, channel, bias=False),
524
+ nn.Sigmoid()
525
+ )
526
+
527
+ def forward(self, x):
528
+ b, c, _, _ = x.size()
529
+ y = self.avg_pool(x).view(b, c)
530
+ y = self.fc(y).view(b, c, 1, 1)
531
+ res = x * y.expand_as(x)
532
+ return res
533
+
534
+
535
+ class FourierUnit(nn.Module):
536
+
537
+ def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
538
+ spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
539
+ # bn_layer not used
540
+ super(FourierUnit, self).__init__()
541
+ self.groups = groups
542
+
543
+ self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
544
+ out_channels=out_channels * 2,
545
+ kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
546
+ self.relu = torch.nn.ReLU(inplace=False)
547
+
548
+ # squeeze and excitation block
549
+ self.use_se = use_se
550
+ if use_se:
551
+ if se_kwargs is None:
552
+ se_kwargs = {}
553
+ self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
554
+
555
+ self.spatial_scale_factor = spatial_scale_factor
556
+ self.spatial_scale_mode = spatial_scale_mode
557
+ self.spectral_pos_encoding = spectral_pos_encoding
558
+ self.ffc3d = ffc3d
559
+ self.fft_norm = fft_norm
560
+
561
+ def forward(self, x):
562
+ batch = x.shape[0]
563
+
564
+ if self.spatial_scale_factor is not None:
565
+ orig_size = x.shape[-2:]
566
+ x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode,
567
+ align_corners=False)
568
+
569
+ r_size = x.size()
570
+ # (batch, c, h, w/2+1, 2)
571
+ fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
572
+ ffted = fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
573
+ ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
574
+ ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
575
+ ffted = ffted.view((batch, -1,) + ffted.size()[3:])
576
+
577
+ if self.spectral_pos_encoding:
578
+ height, width = ffted.shape[-2:]
579
+ coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
580
+ coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
581
+ ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
582
+
583
+ if self.use_se:
584
+ ffted = self.se(ffted)
585
+
586
+ ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
587
+ ffted = self.relu(ffted)
588
+
589
+ ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
590
+ 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
591
+ ffted = torch.complex(ffted[..., 0], ffted[..., 1])
592
+
593
+ ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
594
+ output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
595
+
596
+ if self.spatial_scale_factor is not None:
597
+ output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
598
+
599
+ return output
600
+
601
+
602
+ class SpectralTransform(nn.Module):
603
+
604
+ def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs):
605
+ # bn_layer not used
606
+ super(SpectralTransform, self).__init__()
607
+ self.enable_lfu = enable_lfu
608
+ if stride == 2:
609
+ self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
610
+ else:
611
+ self.downsample = nn.Identity()
612
+
613
+ self.stride = stride
614
+ self.conv1 = nn.Sequential(
615
+ nn.Conv2d(in_channels, out_channels //
616
+ 2, kernel_size=1, groups=groups, bias=False),
617
+ # nn.BatchNorm2d(out_channels // 2),
618
+ nn.ReLU(inplace=True)
619
+ )
620
+ self.fu = FourierUnit(
621
+ out_channels // 2, out_channels // 2, groups, **fu_kwargs)
622
+ if self.enable_lfu:
623
+ self.lfu = FourierUnit(
624
+ out_channels // 2, out_channels // 2, groups)
625
+ self.conv2 = torch.nn.Conv2d(
626
+ out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
627
+
628
+ def forward(self, x):
629
+
630
+ x = self.downsample(x)
631
+ x = self.conv1(x)
632
+ output = self.fu(x)
633
+
634
+ if self.enable_lfu:
635
+ n, c, h, w = x.shape
636
+ split_no = 2
637
+ split_s = h // split_no
638
+ xs = torch.cat(torch.split(
639
+ x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
640
+ xs = torch.cat(torch.split(xs, split_s, dim=-1),
641
+ dim=1).contiguous()
642
+ xs = self.lfu(xs)
643
+ xs = xs.repeat(1, 1, split_no, split_no).contiguous()
644
+ else:
645
+ xs = 0
646
+
647
+ output = self.conv2(x + output + xs)
648
+
649
+ return output
650
+
651
+
652
+ class FFC(nn.Module):
653
+
654
+ def __init__(self, in_channels, out_channels, kernel_size,
655
+ ratio_gin, ratio_gout, stride=1, padding=0,
656
+ dilation=1, groups=1, bias=False, enable_lfu=True,
657
+ padding_type='reflect', gated=False, **spectral_kwargs):
658
+ super(FFC, self).__init__()
659
+
660
+ assert stride == 1 or stride == 2, "Stride should be 1 or 2."
661
+ self.stride = stride
662
+
663
+ in_cg = int(in_channels * ratio_gin)
664
+ in_cl = in_channels - in_cg
665
+ out_cg = int(out_channels * ratio_gout)
666
+ out_cl = out_channels - out_cg
667
+ # groups_g = 1 if groups == 1 else int(groups * ratio_gout)
668
+ # groups_l = 1 if groups == 1 else groups - groups_g
669
+
670
+ self.ratio_gin = ratio_gin
671
+ self.ratio_gout = ratio_gout
672
+ self.global_in_num = in_cg
673
+
674
+ module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
675
+ self.convl2l = module(in_cl, out_cl, kernel_size,
676
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
677
+ module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
678
+ self.convl2g = module(in_cl, out_cg, kernel_size,
679
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
680
+ module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
681
+ self.convg2l = module(in_cg, out_cl, kernel_size,
682
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
683
+ module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
684
+ self.convg2g = module(
685
+ in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)
686
+
687
+ self.gated = gated
688
+ module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
689
+ self.gate = module(in_channels, 2, 1)
690
+
691
+ def forward(self, x, fname=None):
692
+ x_l, x_g = x if type(x) is tuple else (x, 0)
693
+ out_xl, out_xg = 0, 0
694
+
695
+ if self.gated:
696
+ total_input_parts = [x_l]
697
+ if torch.is_tensor(x_g):
698
+ total_input_parts.append(x_g)
699
+ total_input = torch.cat(total_input_parts, dim=1)
700
+
701
+ gates = torch.sigmoid(self.gate(total_input))
702
+ g2l_gate, l2g_gate = gates.chunk(2, dim=1)
703
+ else:
704
+ g2l_gate, l2g_gate = 1, 1
705
+
706
+ spec_x = self.convg2g(x_g)
707
+
708
+ if self.ratio_gout != 1:
709
+ out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
710
+ if self.ratio_gout != 0:
711
+ out_xg = self.convl2g(x_l) * l2g_gate + spec_x
712
+
713
+ return out_xl, out_xg
714
+
715
+
716
+ class FFC_BN_ACT(nn.Module):
717
+
718
+ def __init__(self, in_channels, out_channels,
719
+ kernel_size, ratio_gin, ratio_gout,
720
+ stride=1, padding=0, dilation=1, groups=1, bias=False,
721
+ norm_layer=nn.SyncBatchNorm, activation_layer=nn.Identity,
722
+ padding_type='reflect',
723
+ enable_lfu=True, **kwargs):
724
+ super(FFC_BN_ACT, self).__init__()
725
+ self.ffc = FFC(in_channels, out_channels, kernel_size,
726
+ ratio_gin, ratio_gout, stride, padding, dilation,
727
+ groups, bias, enable_lfu, padding_type=padding_type, **kwargs)
728
+ lnorm = nn.Identity if ratio_gout == 1 else norm_layer
729
+ gnorm = nn.Identity if ratio_gout == 0 else norm_layer
730
+ global_channels = int(out_channels * ratio_gout)
731
+ # self.bn_l = lnorm(out_channels - global_channels)
732
+ # self.bn_g = gnorm(global_channels)
733
+
734
+ lact = nn.Identity if ratio_gout == 1 else activation_layer
735
+ gact = nn.Identity if ratio_gout == 0 else activation_layer
736
+ self.act_l = lact(inplace=True)
737
+ self.act_g = gact(inplace=True)
738
+
739
+ def forward(self, x, fname=None):
740
+ x_l, x_g = self.ffc(x, fname=fname, )
741
+ x_l = self.act_l(x_l)
742
+ x_g = self.act_g(x_g)
743
+ return x_l, x_g
744
+
745
+
746
+ class FFCResnetBlock(nn.Module):
747
+ def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1,
748
+ spatial_transform_kwargs=None, inline=False, ratio_gin=0.75, ratio_gout=0.75):
749
+ super().__init__()
750
+ self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
751
+ norm_layer=norm_layer,
752
+ activation_layer=activation_layer,
753
+ padding_type=padding_type,
754
+ ratio_gin=ratio_gin, ratio_gout=ratio_gout)
755
+ self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
756
+ norm_layer=norm_layer,
757
+ activation_layer=activation_layer,
758
+ padding_type=padding_type,
759
+ ratio_gin=ratio_gin, ratio_gout=ratio_gout)
760
+ self.inline = inline
761
+
762
+ def forward(self, x, fname=None):
763
+ if self.inline:
764
+ x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
765
+ else:
766
+ x_l, x_g = x if type(x) is tuple else (x, 0)
767
+
768
+ id_l, id_g = x_l, x_g
769
+
770
+ x_l, x_g = self.conv1((x_l, x_g), fname=fname)
771
+ x_l, x_g = self.conv2((x_l, x_g), fname=fname)
772
+
773
+ x_l, x_g = id_l + x_l, id_g + x_g
774
+ out = x_l, x_g
775
+ if self.inline:
776
+ out = torch.cat(out, dim=1)
777
+ return out
778
+
779
+
780
+ class ConcatTupleLayer(nn.Module):
781
+ def forward(self, x):
782
+ assert isinstance(x, tuple)
783
+ x_l, x_g = x
784
+ assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
785
+ if not torch.is_tensor(x_g):
786
+ return x_l
787
+ return torch.cat(x, dim=1)
788
+
789
+
790
+ class FFCBlock(torch.nn.Module):
791
+ def __init__(self,
792
+ dim, # Number of output/input channels.
793
+ kernel_size, # Width and height of the convolution kernel.
794
+ padding,
795
+ ratio_gin=0.75,
796
+ ratio_gout=0.75,
797
+ activation='linear', # Activation function: 'relu', 'lrelu', etc.
798
+ ):
799
+ super().__init__()
800
+ if activation == 'linear':
801
+ self.activation = nn.Identity
802
+ else:
803
+ self.activation = nn.ReLU
804
+ self.padding = padding
805
+ self.kernel_size = kernel_size
806
+ self.ffc_block = FFCResnetBlock(dim=dim,
807
+ padding_type='reflect',
808
+ norm_layer=nn.SyncBatchNorm,
809
+ activation_layer=self.activation,
810
+ dilation=1,
811
+ ratio_gin=ratio_gin,
812
+ ratio_gout=ratio_gout)
813
+
814
+ self.concat_layer = ConcatTupleLayer()
815
+
816
+ def forward(self, gen_ft, mask, fname=None):
817
+ x = gen_ft.float()
818
+
819
+ x_l, x_g = x[:, :-self.ffc_block.conv1.ffc.global_in_num], x[:, -self.ffc_block.conv1.ffc.global_in_num:]
820
+ id_l, id_g = x_l, x_g
821
+
822
+ x_l, x_g = self.ffc_block((x_l, x_g), fname=fname)
823
+ x_l, x_g = id_l + x_l, id_g + x_g
824
+ x = self.concat_layer((x_l, x_g))
825
+
826
+ return x + gen_ft.float()
827
+
828
+
829
+ class FFCSkipLayer(torch.nn.Module):
830
+ def __init__(self,
831
+ dim, # Number of input/output channels.
832
+ kernel_size=3, # Convolution kernel size.
833
+ ratio_gin=0.75,
834
+ ratio_gout=0.75,
835
+ ):
836
+ super().__init__()
837
+ self.padding = kernel_size // 2
838
+
839
+ self.ffc_act = FFCBlock(dim=dim, kernel_size=kernel_size, activation=nn.ReLU,
840
+ padding=self.padding, ratio_gin=ratio_gin, ratio_gout=ratio_gout)
841
+
842
+ def forward(self, gen_ft, mask, fname=None):
843
+ x = self.ffc_act(gen_ft, mask, fname=fname)
844
+ return x
845
+
846
+
847
+ class SynthesisBlock(torch.nn.Module):
848
+ def __init__(self,
849
+ in_channels, # Number of input channels, 0 = first block.
850
+ out_channels, # Number of output channels.
851
+ w_dim, # Intermediate latent (W) dimensionality.
852
+ resolution, # Resolution of this block.
853
+ img_channels, # Number of output color channels.
854
+ is_last, # Is this the last block?
855
+ architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
856
+ resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
857
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
858
+ use_fp16=False, # Use FP16 for this block?
859
+ fp16_channels_last=False, # Use channels-last memory format with FP16?
860
+ **layer_kwargs, # Arguments for SynthesisLayer.
861
+ ):
862
+ assert architecture in ['orig', 'skip', 'resnet']
863
+ super().__init__()
864
+ self.in_channels = in_channels
865
+ self.w_dim = w_dim
866
+ self.resolution = resolution
867
+ self.img_channels = img_channels
868
+ self.is_last = is_last
869
+ self.architecture = architecture
870
+ self.use_fp16 = use_fp16
871
+ self.channels_last = (use_fp16 and fp16_channels_last)
872
+ self.register_buffer('resample_filter', setup_filter(resample_filter))
873
+ self.num_conv = 0
874
+ self.num_torgb = 0
875
+ self.res_ffc = {4: 0, 8: 0, 16: 0, 32: 1, 64: 1, 128: 1, 256: 1, 512: 1}
876
+
877
+ if in_channels != 0 and resolution >= 8:
878
+ self.ffc_skip = nn.ModuleList()
879
+ for _ in range(self.res_ffc[resolution]):
880
+ self.ffc_skip.append(FFCSkipLayer(dim=out_channels))
881
+
882
+ if in_channels == 0:
883
+ self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
884
+
885
+ if in_channels != 0:
886
+ self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim * 3, resolution=resolution, up=2,
887
+ resample_filter=resample_filter, conv_clamp=conv_clamp,
888
+ channels_last=self.channels_last, **layer_kwargs)
889
+ self.num_conv += 1
890
+
891
+ self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim * 3, resolution=resolution,
892
+ conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
893
+ self.num_conv += 1
894
+
895
+ if is_last or architecture == 'skip':
896
+ self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim * 3,
897
+ conv_clamp=conv_clamp, channels_last=self.channels_last)
898
+ self.num_torgb += 1
899
+
900
+ if in_channels != 0 and architecture == 'resnet':
901
+ self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
902
+ resample_filter=resample_filter, channels_last=self.channels_last)
903
+
904
+ def forward(self, x, mask, feats, img, ws, fname=None, force_fp32=False, fused_modconv=None, **layer_kwargs):
905
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
906
+ dtype = torch.float32
907
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
908
+ if fused_modconv is None:
909
+ fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)
910
+
911
+ x = x.to(dtype=dtype, memory_format=memory_format)
912
+ x_skip = feats[self.resolution].clone().to(dtype=dtype, memory_format=memory_format)
913
+
914
+ # Main layers.
915
+ if self.in_channels == 0:
916
+ x = self.conv1(x, ws[1], fused_modconv=fused_modconv, **layer_kwargs)
917
+ elif self.architecture == 'resnet':
918
+ y = self.skip(x, gain=np.sqrt(0.5))
919
+ x = self.conv0(x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs)
920
+ if len(self.ffc_skip) > 0:
921
+ mask = F.interpolate(mask, size=x_skip.shape[2:], )
922
+ z = x + x_skip
923
+ for fres in self.ffc_skip:
924
+ z = fres(z, mask)
925
+ x = x + z
926
+ else:
927
+ x = x + x_skip
928
+ x = self.conv1(x, ws[1].clone(), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
929
+ x = y.add_(x)
930
+ else:
931
+ x = self.conv0(x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs)
932
+ if len(self.ffc_skip) > 0:
933
+ mask = F.interpolate(mask, size=x_skip.shape[2:], )
934
+ z = x + x_skip
935
+ for fres in self.ffc_skip:
936
+ z = fres(z, mask)
937
+ x = x + z
938
+ else:
939
+ x = x + x_skip
940
+ x = self.conv1(x, ws[1].clone(), fused_modconv=fused_modconv, **layer_kwargs)
941
+ # ToRGB.
942
+ if img is not None:
943
+ img = upsample2d(img, self.resample_filter)
944
+ if self.is_last or self.architecture == 'skip':
945
+ y = self.torgb(x, ws[2].clone(), fused_modconv=fused_modconv)
946
+ y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
947
+ img = img.add_(y) if img is not None else y
948
+
949
+ x = x.to(dtype=dtype)
950
+ assert x.dtype == dtype
951
+ assert img is None or img.dtype == torch.float32
952
+ return x, img
953
+
954
+
955
+ class SynthesisNetwork(torch.nn.Module):
956
+ def __init__(self,
957
+ w_dim, # Intermediate latent (W) dimensionality.
958
+ z_dim, # Output Latent (Z) dimensionality.
959
+ img_resolution, # Output image resolution.
960
+ img_channels, # Number of color channels.
961
+ channel_base=16384, # Overall multiplier for the number of channels.
962
+ channel_max=512, # Maximum number of channels in any layer.
963
+ num_fp16_res=0, # Use FP16 for the N highest resolutions.
964
+ **block_kwargs, # Arguments for SynthesisBlock.
965
+ ):
966
+ assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
967
+ super().__init__()
968
+ self.w_dim = w_dim
969
+ self.img_resolution = img_resolution
970
+ self.img_resolution_log2 = int(np.log2(img_resolution))
971
+ self.img_channels = img_channels
972
+ self.block_resolutions = [2 ** i for i in range(3, self.img_resolution_log2 + 1)]
973
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
974
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
975
+
976
+ self.foreword = SynthesisForeword(img_channels=img_channels, in_channels=min(channel_base // 4, channel_max),
977
+ z_dim=z_dim * 2, resolution=4)
978
+
979
+ self.num_ws = self.img_resolution_log2 * 2 - 2
980
+ for res in self.block_resolutions:
981
+ if res // 2 in channels_dict.keys():
982
+ in_channels = channels_dict[res // 2] if res > 4 else 0
983
+ else:
984
+ in_channels = min(channel_base // (res // 2), channel_max)
985
+ out_channels = channels_dict[res]
986
+ use_fp16 = (res >= fp16_resolution)
987
+ use_fp16 = False
988
+ is_last = (res == self.img_resolution)
989
+ block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
990
+ img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
991
+ setattr(self, f'b{res}', block)
992
+
993
+ def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs):
994
+
995
+ img = None
996
+
997
+ x, img = self.foreword(x_global, ws, feats, img)
998
+
999
+ for res in self.block_resolutions:
1000
+ block = getattr(self, f'b{res}')
1001
+ mod_vector0 = []
1002
+ mod_vector0.append(ws[:, int(np.log2(res)) * 2 - 5])
1003
+ mod_vector0.append(x_global.clone())
1004
+ mod_vector0 = torch.cat(mod_vector0, dim=1)
1005
+
1006
+ mod_vector1 = []
1007
+ mod_vector1.append(ws[:, int(np.log2(res)) * 2 - 4])
1008
+ mod_vector1.append(x_global.clone())
1009
+ mod_vector1 = torch.cat(mod_vector1, dim=1)
1010
+
1011
+ mod_vector_rgb = []
1012
+ mod_vector_rgb.append(ws[:, int(np.log2(res)) * 2 - 3])
1013
+ mod_vector_rgb.append(x_global.clone())
1014
+ mod_vector_rgb = torch.cat(mod_vector_rgb, dim=1)
1015
+ x, img = block(x, mask, feats, img, (mod_vector0, mod_vector1, mod_vector_rgb), fname=fname, **block_kwargs)
1016
+ return img
1017
+
1018
+
1019
+ class MappingNetwork(torch.nn.Module):
1020
+ def __init__(self,
1021
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
1022
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
1023
+ w_dim, # Intermediate latent (W) dimensionality.
1024
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
1025
+ num_layers=8, # Number of mapping layers.
1026
+ embed_features=None, # Label embedding dimensionality, None = same as w_dim.
1027
+ layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
1028
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
1029
+ lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
1030
+ w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
1031
+ ):
1032
+ super().__init__()
1033
+ self.z_dim = z_dim
1034
+ self.c_dim = c_dim
1035
+ self.w_dim = w_dim
1036
+ self.num_ws = num_ws
1037
+ self.num_layers = num_layers
1038
+ self.w_avg_beta = w_avg_beta
1039
+
1040
+ if embed_features is None:
1041
+ embed_features = w_dim
1042
+ if c_dim == 0:
1043
+ embed_features = 0
1044
+ if layer_features is None:
1045
+ layer_features = w_dim
1046
+ features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
1047
+
1048
+ if c_dim > 0:
1049
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
1050
+ for idx in range(num_layers):
1051
+ in_features = features_list[idx]
1052
+ out_features = features_list[idx + 1]
1053
+ layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
1054
+ setattr(self, f'fc{idx}', layer)
1055
+
1056
+ if num_ws is not None and w_avg_beta is not None:
1057
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
1058
+
1059
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
1060
+ # Embed, normalize, and concat inputs.
1061
+ x = None
1062
+ with torch.autograd.profiler.record_function('input'):
1063
+ if self.z_dim > 0:
1064
+ x = normalize_2nd_moment(z.to(torch.float32))
1065
+ if self.c_dim > 0:
1066
+ y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
1067
+ x = torch.cat([x, y], dim=1) if x is not None else y
1068
+
1069
+ # Main layers.
1070
+ for idx in range(self.num_layers):
1071
+ layer = getattr(self, f'fc{idx}')
1072
+ x = layer(x)
1073
+
1074
+ # Update moving average of W.
1075
+ if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
1076
+ with torch.autograd.profiler.record_function('update_w_avg'):
1077
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
1078
+
1079
+ # Broadcast.
1080
+ if self.num_ws is not None:
1081
+ with torch.autograd.profiler.record_function('broadcast'):
1082
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
1083
+
1084
+ # Apply truncation.
1085
+ if truncation_psi != 1:
1086
+ with torch.autograd.profiler.record_function('truncate'):
1087
+ assert self.w_avg_beta is not None
1088
+ if self.num_ws is None or truncation_cutoff is None:
1089
+ x = self.w_avg.lerp(x, truncation_psi)
1090
+ else:
1091
+ x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
1092
+ return x
1093
+
1094
+
1095
+ class Generator(torch.nn.Module):
1096
+ def __init__(self,
1097
+ z_dim, # Input latent (Z) dimensionality.
1098
+ c_dim, # Conditioning label (C) dimensionality.
1099
+ w_dim, # Intermediate latent (W) dimensionality.
1100
+ img_resolution, # Output resolution.
1101
+ img_channels, # Number of output color channels.
1102
+ encoder_kwargs={}, # Arguments for EncoderNetwork.
1103
+ mapping_kwargs={}, # Arguments for MappingNetwork.
1104
+ synthesis_kwargs={}, # Arguments for SynthesisNetwork.
1105
+ ):
1106
+ super().__init__()
1107
+ self.z_dim = z_dim
1108
+ self.c_dim = c_dim
1109
+ self.w_dim = w_dim
1110
+ self.img_resolution = img_resolution
1111
+ self.img_channels = img_channels
1112
+ self.encoder = EncoderNetwork(c_dim=c_dim, z_dim=z_dim, img_resolution=img_resolution,
1113
+ img_channels=img_channels, **encoder_kwargs)
1114
+ self.synthesis = SynthesisNetwork(z_dim=z_dim, w_dim=w_dim, img_resolution=img_resolution,
1115
+ img_channels=img_channels, **synthesis_kwargs)
1116
+ self.num_ws = self.synthesis.num_ws
1117
+ self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
1118
+
1119
+ def forward(self, img, c, fname=None, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs):
1120
+ mask = img[:, -1].unsqueeze(1)
1121
+ x_global, z, feats = self.encoder(img, c)
1122
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
1123
+ img = self.synthesis(x_global, mask, feats, ws, fname=fname, **synthesis_kwargs)
1124
+ return img
1125
+
1126
+
1127
+ FCF_MODEL_URL = os.environ.get(
1128
+ "FCF_MODEL_URL",
1129
+ "https://github.com/Sanster/models/releases/download/add_fcf/places_512_G.pth",
1130
+ )
1131
+
1132
+
1133
+ class FcF(InpaintModel):
1134
+ min_size = 512
1135
+ pad_mod = 512
1136
+ pad_to_square = True
1137
+
1138
+ def init_model(self, device, **kwargs):
1139
+ seed = 0
1140
+ random.seed(seed)
1141
+ np.random.seed(seed)
1142
+ torch.manual_seed(seed)
1143
+ torch.cuda.manual_seed_all(seed)
1144
+ torch.backends.cudnn.deterministic = True
1145
+ torch.backends.cudnn.benchmark = False
1146
+
1147
+ kwargs = {'channel_base': 1 * 32768, 'channel_max': 512, 'num_fp16_res': 4, 'conv_clamp': 256}
1148
+ G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3,
1149
+ synthesis_kwargs=kwargs, encoder_kwargs=kwargs, mapping_kwargs={'num_layers': 2})
1150
+ self.model = load_model(G, FCF_MODEL_URL, device)
1151
+ self.label = torch.zeros([1, self.model.c_dim], device=device)
1152
+
1153
+ @staticmethod
1154
+ def is_downloaded() -> bool:
1155
+ return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL))
1156
+
1157
+ @torch.no_grad()
1158
+ def __call__(self, image, mask, config: Config):
1159
+ """
1160
+ images: [H, W, C] RGB, not normalized
1161
+ masks: [H, W]
1162
+ return: BGR IMAGE
1163
+ """
1164
+ if image.shape[0] == 512 and image.shape[1] == 512:
1165
+ return self._pad_forward(image, mask, config)
1166
+
1167
+ boxes = boxes_from_mask(mask)
1168
+ crop_result = []
1169
+ config.hd_strategy_crop_margin = 128
1170
+ for box in boxes:
1171
+ crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config)
1172
+ origin_size = crop_image.shape[:2]
1173
+ resize_image = resize_max_size(crop_image, size_limit=512)
1174
+ resize_mask = resize_max_size(crop_mask, size_limit=512)
1175
+ inpaint_result = self._pad_forward(resize_image, resize_mask, config)
1176
+
1177
+ # only paste masked area result
1178
+ inpaint_result = cv2.resize(inpaint_result, (origin_size[1], origin_size[0]), interpolation=cv2.INTER_CUBIC)
1179
+
1180
+ original_pixel_indices = crop_mask < 127
1181
+ inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][original_pixel_indices]
1182
+
1183
+ crop_result.append((inpaint_result, crop_box))
1184
+
1185
+ inpaint_result = image[:, :, ::-1]
1186
+ for crop_image, crop_box in crop_result:
1187
+ x1, y1, x2, y2 = crop_box
1188
+ inpaint_result[y1:y2, x1:x2, :] = crop_image
1189
+
1190
+ return inpaint_result
1191
+
1192
+ def forward(self, image, mask, config: Config):
1193
+ """Input images and output images have same size
1194
+ images: [H, W, C] RGB
1195
+ masks: [H, W] mask area == 255
1196
+ return: BGR IMAGE
1197
+ """
1198
+
1199
+ image = norm_img(image) # [0, 1]
1200
+ image = image * 2 - 1 # [0, 1] -> [-1, 1]
1201
+ mask = (mask > 120) * 255
1202
+ mask = norm_img(mask)
1203
+
1204
+ image = torch.from_numpy(image).unsqueeze(0).to(self.device)
1205
+ mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
1206
+
1207
+ erased_img = image * (1 - mask)
1208
+ input_image = torch.cat([0.5 - mask, erased_img], dim=1)
1209
+
1210
+ output = self.model(input_image, self.label, truncation_psi=0.1, noise_mode='none')
1211
+ output = (output.permute(0, 2, 3, 1) * 127.5 + 127.5).round().clamp(0, 255).to(torch.uint8)
1212
+ output = output[0].cpu().numpy()
1213
+ cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
1214
+ return cur_res
lama_cleaner/model/lama.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from loguru import logger
7
+
8
+ from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img, get_cache_path_by_url
9
+ from lama_cleaner.model.base import InpaintModel
10
+ from lama_cleaner.schema import Config
11
+
12
+ LAMA_MODEL_URL = os.environ.get(
13
+ "LAMA_MODEL_URL",
14
+ "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
15
+ )
16
+
17
+ #"https://drive.google.com/file/d/1bMD06F9hkkS1oi8cEmb4cSjXz54Pxs6A/view?usp=sharing" #big-lama.pt file
18
+
19
+
20
+ class LaMa(InpaintModel):
21
+ pad_mod = 8
22
+
23
+ def init_model(self, device, **kwargs):
24
+ if os.environ.get("LAMA_MODEL"):
25
+ model_path = os.environ.get("LAMA_MODEL")
26
+ if not os.path.exists(model_path):
27
+ raise FileNotFoundError(
28
+ f"lama torchscript model not found: {model_path}"
29
+ )
30
+ else:
31
+ model_path = download_model(LAMA_MODEL_URL)
32
+ logger.info(f"Load LaMa model from: {model_path}")
33
+ model = torch.jit.load(model_path, map_location="cpu")
34
+ model = model.to(device)
35
+ model.eval()
36
+ self.model = model
37
+ self.model_path = model_path
38
+
39
+ @staticmethod
40
+ def is_downloaded() -> bool:
41
+ return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
42
+
43
+ def forward(self, image, mask, config: Config):
44
+ """Input image and output image have same size
45
+ image: [H, W, C] RGB
46
+ mask: [H, W]
47
+ return: BGR IMAGE
48
+ """
49
+ image = norm_img(image)
50
+ mask = norm_img(mask)
51
+
52
+ mask = (mask > 0) * 1
53
+ image = torch.from_numpy(image).unsqueeze(0).to(self.device)
54
+ mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
55
+
56
+ inpainted_image = self.model(image, mask)
57
+
58
+ cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
59
+ cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
60
+ cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
61
+ return cur_res
lama_cleaner/model/ldm.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ from loguru import logger
6
+
7
+ from lama_cleaner.model.base import InpaintModel
8
+ from lama_cleaner.model.ddim_sampler import DDIMSampler
9
+ from lama_cleaner.model.plms_sampler import PLMSSampler
10
+ from lama_cleaner.schema import Config, LDMSampler
11
+
12
+ torch.manual_seed(42)
13
+ import torch.nn as nn
14
+ from lama_cleaner.helper import (
15
+ download_model,
16
+ norm_img,
17
+ get_cache_path_by_url,
18
+ load_jit_model,
19
+ )
20
+ from lama_cleaner.model.utils import (
21
+ make_beta_schedule,
22
+ timestep_embedding,
23
+ )
24
+
25
+ LDM_ENCODE_MODEL_URL = os.environ.get(
26
+ "LDM_ENCODE_MODEL_URL",
27
+ "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt",
28
+ )
29
+
30
+ LDM_DECODE_MODEL_URL = os.environ.get(
31
+ "LDM_DECODE_MODEL_URL",
32
+ "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt",
33
+ )
34
+
35
+ LDM_DIFFUSION_MODEL_URL = os.environ.get(
36
+ "LDM_DIFFUSION_MODEL_URL",
37
+ "https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt",
38
+ )
39
+
40
+
41
+ class DDPM(nn.Module):
42
+ # classic DDPM with Gaussian diffusion, in image space
43
+ def __init__(
44
+ self,
45
+ device,
46
+ timesteps=1000,
47
+ beta_schedule="linear",
48
+ linear_start=0.0015,
49
+ linear_end=0.0205,
50
+ cosine_s=0.008,
51
+ original_elbo_weight=0.0,
52
+ v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
53
+ l_simple_weight=1.0,
54
+ parameterization="eps", # all assuming fixed variance schedules
55
+ use_positional_encodings=False,
56
+ ):
57
+ super().__init__()
58
+ self.device = device
59
+ self.parameterization = parameterization
60
+ self.use_positional_encodings = use_positional_encodings
61
+
62
+ self.v_posterior = v_posterior
63
+ self.original_elbo_weight = original_elbo_weight
64
+ self.l_simple_weight = l_simple_weight
65
+
66
+ self.register_schedule(
67
+ beta_schedule=beta_schedule,
68
+ timesteps=timesteps,
69
+ linear_start=linear_start,
70
+ linear_end=linear_end,
71
+ cosine_s=cosine_s,
72
+ )
73
+
74
+ def register_schedule(
75
+ self,
76
+ given_betas=None,
77
+ beta_schedule="linear",
78
+ timesteps=1000,
79
+ linear_start=1e-4,
80
+ linear_end=2e-2,
81
+ cosine_s=8e-3,
82
+ ):
83
+ betas = make_beta_schedule(
84
+ self.device,
85
+ beta_schedule,
86
+ timesteps,
87
+ linear_start=linear_start,
88
+ linear_end=linear_end,
89
+ cosine_s=cosine_s,
90
+ )
91
+ alphas = 1.0 - betas
92
+ alphas_cumprod = np.cumprod(alphas, axis=0)
93
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
94
+
95
+ (timesteps,) = betas.shape
96
+ self.num_timesteps = int(timesteps)
97
+ self.linear_start = linear_start
98
+ self.linear_end = linear_end
99
+ assert (
100
+ alphas_cumprod.shape[0] == self.num_timesteps
101
+ ), "alphas have to be defined for each timestep"
102
+
103
+ to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
104
+
105
+ self.register_buffer("betas", to_torch(betas))
106
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
107
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
108
+
109
+ # calculations for diffusion q(x_t | x_{t-1}) and others
110
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
111
+ self.register_buffer(
112
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
113
+ )
114
+ self.register_buffer(
115
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
116
+ )
117
+ self.register_buffer(
118
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
119
+ )
120
+ self.register_buffer(
121
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
122
+ )
123
+
124
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
125
+ posterior_variance = (1 - self.v_posterior) * betas * (
126
+ 1.0 - alphas_cumprod_prev
127
+ ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
128
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
129
+ self.register_buffer("posterior_variance", to_torch(posterior_variance))
130
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
131
+ self.register_buffer(
132
+ "posterior_log_variance_clipped",
133
+ to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
134
+ )
135
+ self.register_buffer(
136
+ "posterior_mean_coef1",
137
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
138
+ )
139
+ self.register_buffer(
140
+ "posterior_mean_coef2",
141
+ to_torch(
142
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
143
+ ),
144
+ )
145
+
146
+ if self.parameterization == "eps":
147
+ lvlb_weights = self.betas**2 / (
148
+ 2
149
+ * self.posterior_variance
150
+ * to_torch(alphas)
151
+ * (1 - self.alphas_cumprod)
152
+ )
153
+ elif self.parameterization == "x0":
154
+ lvlb_weights = (
155
+ 0.5
156
+ * np.sqrt(torch.Tensor(alphas_cumprod))
157
+ / (2.0 * 1 - torch.Tensor(alphas_cumprod))
158
+ )
159
+ else:
160
+ raise NotImplementedError("mu not supported")
161
+ # TODO how to choose this term
162
+ lvlb_weights[0] = lvlb_weights[1]
163
+ self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
164
+ assert not torch.isnan(self.lvlb_weights).all()
165
+
166
+
167
+ class LatentDiffusion(DDPM):
168
+ def __init__(
169
+ self,
170
+ diffusion_model,
171
+ device,
172
+ cond_stage_key="image",
173
+ cond_stage_trainable=False,
174
+ concat_mode=True,
175
+ scale_factor=1.0,
176
+ scale_by_std=False,
177
+ *args,
178
+ **kwargs,
179
+ ):
180
+ self.num_timesteps_cond = 1
181
+ self.scale_by_std = scale_by_std
182
+ super().__init__(device, *args, **kwargs)
183
+ self.diffusion_model = diffusion_model
184
+ self.concat_mode = concat_mode
185
+ self.cond_stage_trainable = cond_stage_trainable
186
+ self.cond_stage_key = cond_stage_key
187
+ self.num_downs = 2
188
+ self.scale_factor = scale_factor
189
+
190
+ def make_cond_schedule(
191
+ self,
192
+ ):
193
+ self.cond_ids = torch.full(
194
+ size=(self.num_timesteps,),
195
+ fill_value=self.num_timesteps - 1,
196
+ dtype=torch.long,
197
+ )
198
+ ids = torch.round(
199
+ torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
200
+ ).long()
201
+ self.cond_ids[: self.num_timesteps_cond] = ids
202
+
203
+ def register_schedule(
204
+ self,
205
+ given_betas=None,
206
+ beta_schedule="linear",
207
+ timesteps=1000,
208
+ linear_start=1e-4,
209
+ linear_end=2e-2,
210
+ cosine_s=8e-3,
211
+ ):
212
+ super().register_schedule(
213
+ given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
214
+ )
215
+
216
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
217
+ if self.shorten_cond_schedule:
218
+ self.make_cond_schedule()
219
+
220
+ def apply_model(self, x_noisy, t, cond):
221
+ # x_recon = self.model(x_noisy, t, cond['c_concat'][0]) # cond['c_concat'][0].shape 1,4,128,128
222
+ t_emb = timestep_embedding(x_noisy.device, t, 256, repeat_only=False)
223
+ x_recon = self.diffusion_model(x_noisy, t_emb, cond)
224
+ return x_recon
225
+
226
+
227
+ class LDM(InpaintModel):
228
+ pad_mod = 32
229
+
230
+ def __init__(self, device, fp16: bool = True, **kwargs):
231
+ self.fp16 = fp16
232
+ super().__init__(device)
233
+ self.device = device
234
+
235
+ def init_model(self, device, **kwargs):
236
+ self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
237
+ self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
238
+ self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)
239
+ if self.fp16 and "cuda" in str(device):
240
+ self.diffusion_model = self.diffusion_model.half()
241
+ self.cond_stage_model_decode = self.cond_stage_model_decode.half()
242
+ self.cond_stage_model_encode = self.cond_stage_model_encode.half()
243
+
244
+ self.model = LatentDiffusion(self.diffusion_model, device)
245
+
246
+ @staticmethod
247
+ def is_downloaded() -> bool:
248
+ model_paths = [
249
+ get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
250
+ get_cache_path_by_url(LDM_DECODE_MODEL_URL),
251
+ get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
252
+ ]
253
+ return all([os.path.exists(it) for it in model_paths])
254
+
255
+ @torch.cuda.amp.autocast()
256
+ def forward(self, image, mask, config: Config):
257
+ """
258
+ image: [H, W, C] RGB
259
+ mask: [H, W, 1]
260
+ return: BGR IMAGE
261
+ """
262
+ # image [1,3,512,512] float32
263
+ # mask: [1,1,512,512] float32
264
+ # masked_image: [1,3,512,512] float32
265
+ if config.ldm_sampler == LDMSampler.ddim:
266
+ sampler = DDIMSampler(self.model)
267
+ elif config.ldm_sampler == LDMSampler.plms:
268
+ sampler = PLMSSampler(self.model)
269
+ else:
270
+ raise ValueError()
271
+
272
+ steps = config.ldm_steps
273
+ image = norm_img(image)
274
+ mask = norm_img(mask)
275
+
276
+ mask[mask < 0.5] = 0
277
+ mask[mask >= 0.5] = 1
278
+
279
+ image = torch.from_numpy(image).unsqueeze(0).to(self.device)
280
+ mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
281
+ masked_image = (1 - mask) * image
282
+
283
+ mask = self._norm(mask)
284
+ masked_image = self._norm(masked_image)
285
+
286
+ c = self.cond_stage_model_encode(masked_image)
287
+ torch.cuda.empty_cache()
288
+
289
+ cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128
290
+ c = torch.cat((c, cc), dim=1) # 1,4,128,128
291
+
292
+ shape = (c.shape[1] - 1,) + c.shape[2:]
293
+ samples_ddim = sampler.sample(
294
+ steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
295
+ )
296
+ torch.cuda.empty_cache()
297
+ x_samples_ddim = self.cond_stage_model_decode(
298
+ samples_ddim
299
+ ) # samples_ddim: 1, 3, 128, 128 float32
300
+ torch.cuda.empty_cache()
301
+
302
+ # image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
303
+ # mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
304
+ inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
305
+
306
+ # inpainted = (1 - mask) * image + mask * predicted_image
307
+ inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
308
+ inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
309
+ return inpainted_image
310
+
311
+ def _norm(self, tensor):
312
+ return tensor * 2.0 - 1.0
lama_cleaner/model/mat.py ADDED
@@ -0,0 +1,1444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.utils.checkpoint as checkpoint
10
+
11
+ from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img
12
+ from lama_cleaner.model.base import InpaintModel
13
+ from lama_cleaner.model.utils import setup_filter, Conv2dLayer, FullyConnectedLayer, conv2d_resample, bias_act, \
14
+ upsample2d, activation_funcs, MinibatchStdLayer, to_2tuple, normalize_2nd_moment
15
+ from lama_cleaner.schema import Config
16
+
17
+
18
+ class ModulatedConv2d(nn.Module):
19
+ def __init__(self,
20
+ in_channels, # Number of input channels.
21
+ out_channels, # Number of output channels.
22
+ kernel_size, # Width and height of the convolution kernel.
23
+ style_dim, # dimension of the style code
24
+ demodulate=True, # perfrom demodulation
25
+ up=1, # Integer upsampling factor.
26
+ down=1, # Integer downsampling factor.
27
+ resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
28
+ conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
29
+ ):
30
+ super().__init__()
31
+ self.demodulate = demodulate
32
+
33
+ self.weight = torch.nn.Parameter(torch.randn([1, out_channels, in_channels, kernel_size, kernel_size]))
34
+ self.out_channels = out_channels
35
+ self.kernel_size = kernel_size
36
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
37
+ self.padding = self.kernel_size // 2
38
+ self.up = up
39
+ self.down = down
40
+ self.register_buffer('resample_filter', setup_filter(resample_filter))
41
+ self.conv_clamp = conv_clamp
42
+
43
+ self.affine = FullyConnectedLayer(style_dim, in_channels, bias_init=1)
44
+
45
+ def forward(self, x, style):
46
+ batch, in_channels, height, width = x.shape
47
+ style = self.affine(style).view(batch, 1, in_channels, 1, 1)
48
+ weight = self.weight * self.weight_gain * style
49
+
50
+ if self.demodulate:
51
+ decoefs = (weight.pow(2).sum(dim=[2, 3, 4]) + 1e-8).rsqrt()
52
+ weight = weight * decoefs.view(batch, self.out_channels, 1, 1, 1)
53
+
54
+ weight = weight.view(batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size)
55
+ x = x.view(1, batch * in_channels, height, width)
56
+ x = conv2d_resample(x=x, w=weight, f=self.resample_filter, up=self.up, down=self.down,
57
+ padding=self.padding, groups=batch)
58
+ out = x.view(batch, self.out_channels, *x.shape[2:])
59
+
60
+ return out
61
+
62
+
63
+ class StyleConv(torch.nn.Module):
64
+ def __init__(self,
65
+ in_channels, # Number of input channels.
66
+ out_channels, # Number of output channels.
67
+ style_dim, # Intermediate latent (W) dimensionality.
68
+ resolution, # Resolution of this layer.
69
+ kernel_size=3, # Convolution kernel size.
70
+ up=1, # Integer upsampling factor.
71
+ use_noise=False, # Enable noise input?
72
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
73
+ resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
74
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
75
+ demodulate=True, # perform demodulation
76
+ ):
77
+ super().__init__()
78
+
79
+ self.conv = ModulatedConv2d(in_channels=in_channels,
80
+ out_channels=out_channels,
81
+ kernel_size=kernel_size,
82
+ style_dim=style_dim,
83
+ demodulate=demodulate,
84
+ up=up,
85
+ resample_filter=resample_filter,
86
+ conv_clamp=conv_clamp)
87
+
88
+ self.use_noise = use_noise
89
+ self.resolution = resolution
90
+ if use_noise:
91
+ self.register_buffer('noise_const', torch.randn([resolution, resolution]))
92
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
93
+
94
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
95
+ self.activation = activation
96
+ self.act_gain = activation_funcs[activation].def_gain
97
+ self.conv_clamp = conv_clamp
98
+
99
+ def forward(self, x, style, noise_mode='random', gain=1):
100
+ x = self.conv(x, style)
101
+
102
+ assert noise_mode in ['random', 'const', 'none']
103
+
104
+ if self.use_noise:
105
+ if noise_mode == 'random':
106
+ xh, xw = x.size()[-2:]
107
+ noise = torch.randn([x.shape[0], 1, xh, xw], device=x.device) \
108
+ * self.noise_strength
109
+ if noise_mode == 'const':
110
+ noise = self.noise_const * self.noise_strength
111
+ x = x + noise
112
+
113
+ act_gain = self.act_gain * gain
114
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
115
+ out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp)
116
+
117
+ return out
118
+
119
+
120
+ class ToRGB(torch.nn.Module):
121
+ def __init__(self,
122
+ in_channels,
123
+ out_channels,
124
+ style_dim,
125
+ kernel_size=1,
126
+ resample_filter=[1, 3, 3, 1],
127
+ conv_clamp=None,
128
+ demodulate=False):
129
+ super().__init__()
130
+
131
+ self.conv = ModulatedConv2d(in_channels=in_channels,
132
+ out_channels=out_channels,
133
+ kernel_size=kernel_size,
134
+ style_dim=style_dim,
135
+ demodulate=demodulate,
136
+ resample_filter=resample_filter,
137
+ conv_clamp=conv_clamp)
138
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
139
+ self.register_buffer('resample_filter', setup_filter(resample_filter))
140
+ self.conv_clamp = conv_clamp
141
+
142
+ def forward(self, x, style, skip=None):
143
+ x = self.conv(x, style)
144
+ out = bias_act(x, self.bias, clamp=self.conv_clamp)
145
+
146
+ if skip is not None:
147
+ if skip.shape != out.shape:
148
+ skip = upsample2d(skip, self.resample_filter)
149
+ out = out + skip
150
+
151
+ return out
152
+
153
+
154
+ def get_style_code(a, b):
155
+ return torch.cat([a, b], dim=1)
156
+
157
+
158
+ class DecBlockFirst(nn.Module):
159
+ def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
160
+ super().__init__()
161
+ self.fc = FullyConnectedLayer(in_features=in_channels * 2,
162
+ out_features=in_channels * 4 ** 2,
163
+ activation=activation)
164
+ self.conv = StyleConv(in_channels=in_channels,
165
+ out_channels=out_channels,
166
+ style_dim=style_dim,
167
+ resolution=4,
168
+ kernel_size=3,
169
+ use_noise=use_noise,
170
+ activation=activation,
171
+ demodulate=demodulate,
172
+ )
173
+ self.toRGB = ToRGB(in_channels=out_channels,
174
+ out_channels=img_channels,
175
+ style_dim=style_dim,
176
+ kernel_size=1,
177
+ demodulate=False,
178
+ )
179
+
180
+ def forward(self, x, ws, gs, E_features, noise_mode='random'):
181
+ x = self.fc(x).view(x.shape[0], -1, 4, 4)
182
+ x = x + E_features[2]
183
+ style = get_style_code(ws[:, 0], gs)
184
+ x = self.conv(x, style, noise_mode=noise_mode)
185
+ style = get_style_code(ws[:, 1], gs)
186
+ img = self.toRGB(x, style, skip=None)
187
+
188
+ return x, img
189
+
190
+
191
+ class DecBlockFirstV2(nn.Module):
192
+ def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
193
+ super().__init__()
194
+ self.conv0 = Conv2dLayer(in_channels=in_channels,
195
+ out_channels=in_channels,
196
+ kernel_size=3,
197
+ activation=activation,
198
+ )
199
+ self.conv1 = StyleConv(in_channels=in_channels,
200
+ out_channels=out_channels,
201
+ style_dim=style_dim,
202
+ resolution=4,
203
+ kernel_size=3,
204
+ use_noise=use_noise,
205
+ activation=activation,
206
+ demodulate=demodulate,
207
+ )
208
+ self.toRGB = ToRGB(in_channels=out_channels,
209
+ out_channels=img_channels,
210
+ style_dim=style_dim,
211
+ kernel_size=1,
212
+ demodulate=False,
213
+ )
214
+
215
+ def forward(self, x, ws, gs, E_features, noise_mode='random'):
216
+ # x = self.fc(x).view(x.shape[0], -1, 4, 4)
217
+ x = self.conv0(x)
218
+ x = x + E_features[2]
219
+ style = get_style_code(ws[:, 0], gs)
220
+ x = self.conv1(x, style, noise_mode=noise_mode)
221
+ style = get_style_code(ws[:, 1], gs)
222
+ img = self.toRGB(x, style, skip=None)
223
+
224
+ return x, img
225
+
226
+
227
+ class DecBlock(nn.Module):
228
+ def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate,
229
+ img_channels): # res = 2, ..., resolution_log2
230
+ super().__init__()
231
+ self.res = res
232
+
233
+ self.conv0 = StyleConv(in_channels=in_channels,
234
+ out_channels=out_channels,
235
+ style_dim=style_dim,
236
+ resolution=2 ** res,
237
+ kernel_size=3,
238
+ up=2,
239
+ use_noise=use_noise,
240
+ activation=activation,
241
+ demodulate=demodulate,
242
+ )
243
+ self.conv1 = StyleConv(in_channels=out_channels,
244
+ out_channels=out_channels,
245
+ style_dim=style_dim,
246
+ resolution=2 ** res,
247
+ kernel_size=3,
248
+ use_noise=use_noise,
249
+ activation=activation,
250
+ demodulate=demodulate,
251
+ )
252
+ self.toRGB = ToRGB(in_channels=out_channels,
253
+ out_channels=img_channels,
254
+ style_dim=style_dim,
255
+ kernel_size=1,
256
+ demodulate=False,
257
+ )
258
+
259
+ def forward(self, x, img, ws, gs, E_features, noise_mode='random'):
260
+ style = get_style_code(ws[:, self.res * 2 - 5], gs)
261
+ x = self.conv0(x, style, noise_mode=noise_mode)
262
+ x = x + E_features[self.res]
263
+ style = get_style_code(ws[:, self.res * 2 - 4], gs)
264
+ x = self.conv1(x, style, noise_mode=noise_mode)
265
+ style = get_style_code(ws[:, self.res * 2 - 3], gs)
266
+ img = self.toRGB(x, style, skip=img)
267
+
268
+ return x, img
269
+
270
+
271
+ class MappingNet(torch.nn.Module):
272
+ def __init__(self,
273
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
274
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
275
+ w_dim, # Intermediate latent (W) dimensionality.
276
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
277
+ num_layers=8, # Number of mapping layers.
278
+ embed_features=None, # Label embedding dimensionality, None = same as w_dim.
279
+ layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
280
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
281
+ lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
282
+ w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
283
+ ):
284
+ super().__init__()
285
+ self.z_dim = z_dim
286
+ self.c_dim = c_dim
287
+ self.w_dim = w_dim
288
+ self.num_ws = num_ws
289
+ self.num_layers = num_layers
290
+ self.w_avg_beta = w_avg_beta
291
+
292
+ if embed_features is None:
293
+ embed_features = w_dim
294
+ if c_dim == 0:
295
+ embed_features = 0
296
+ if layer_features is None:
297
+ layer_features = w_dim
298
+ features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
299
+
300
+ if c_dim > 0:
301
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
302
+ for idx in range(num_layers):
303
+ in_features = features_list[idx]
304
+ out_features = features_list[idx + 1]
305
+ layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
306
+ setattr(self, f'fc{idx}', layer)
307
+
308
+ if num_ws is not None and w_avg_beta is not None:
309
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
310
+
311
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
312
+ # Embed, normalize, and concat inputs.
313
+ x = None
314
+ with torch.autograd.profiler.record_function('input'):
315
+ if self.z_dim > 0:
316
+ x = normalize_2nd_moment(z.to(torch.float32))
317
+ if self.c_dim > 0:
318
+ y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
319
+ x = torch.cat([x, y], dim=1) if x is not None else y
320
+
321
+ # Main layers.
322
+ for idx in range(self.num_layers):
323
+ layer = getattr(self, f'fc{idx}')
324
+ x = layer(x)
325
+
326
+ # Update moving average of W.
327
+ if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
328
+ with torch.autograd.profiler.record_function('update_w_avg'):
329
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
330
+
331
+ # Broadcast.
332
+ if self.num_ws is not None:
333
+ with torch.autograd.profiler.record_function('broadcast'):
334
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
335
+
336
+ # Apply truncation.
337
+ if truncation_psi != 1:
338
+ with torch.autograd.profiler.record_function('truncate'):
339
+ assert self.w_avg_beta is not None
340
+ if self.num_ws is None or truncation_cutoff is None:
341
+ x = self.w_avg.lerp(x, truncation_psi)
342
+ else:
343
+ x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
344
+
345
+ return x
346
+
347
+
348
+ class DisFromRGB(nn.Module):
349
+ def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2
350
+ super().__init__()
351
+ self.conv = Conv2dLayer(in_channels=in_channels,
352
+ out_channels=out_channels,
353
+ kernel_size=1,
354
+ activation=activation,
355
+ )
356
+
357
+ def forward(self, x):
358
+ return self.conv(x)
359
+
360
+
361
+ class DisBlock(nn.Module):
362
+ def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2
363
+ super().__init__()
364
+ self.conv0 = Conv2dLayer(in_channels=in_channels,
365
+ out_channels=in_channels,
366
+ kernel_size=3,
367
+ activation=activation,
368
+ )
369
+ self.conv1 = Conv2dLayer(in_channels=in_channels,
370
+ out_channels=out_channels,
371
+ kernel_size=3,
372
+ down=2,
373
+ activation=activation,
374
+ )
375
+ self.skip = Conv2dLayer(in_channels=in_channels,
376
+ out_channels=out_channels,
377
+ kernel_size=1,
378
+ down=2,
379
+ bias=False,
380
+ )
381
+
382
+ def forward(self, x):
383
+ skip = self.skip(x, gain=np.sqrt(0.5))
384
+ x = self.conv0(x)
385
+ x = self.conv1(x, gain=np.sqrt(0.5))
386
+ out = skip + x
387
+
388
+ return out
389
+
390
+
391
+ class Discriminator(torch.nn.Module):
392
+ def __init__(self,
393
+ c_dim, # Conditioning label (C) dimensionality.
394
+ img_resolution, # Input resolution.
395
+ img_channels, # Number of input color channels.
396
+ channel_base=32768, # Overall multiplier for the number of channels.
397
+ channel_max=512, # Maximum number of channels in any layer.
398
+ channel_decay=1,
399
+ cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
400
+ activation='lrelu',
401
+ mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
402
+ mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
403
+ ):
404
+ super().__init__()
405
+ self.c_dim = c_dim
406
+ self.img_resolution = img_resolution
407
+ self.img_channels = img_channels
408
+
409
+ resolution_log2 = int(np.log2(img_resolution))
410
+ assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
411
+ self.resolution_log2 = resolution_log2
412
+
413
+ def nf(stage):
414
+ return np.clip(int(channel_base / 2 ** (stage * channel_decay)), 1, channel_max)
415
+
416
+ if cmap_dim == None:
417
+ cmap_dim = nf(2)
418
+ if c_dim == 0:
419
+ cmap_dim = 0
420
+ self.cmap_dim = cmap_dim
421
+
422
+ if c_dim > 0:
423
+ self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None)
424
+
425
+ Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)]
426
+ for res in range(resolution_log2, 2, -1):
427
+ Dis.append(DisBlock(nf(res), nf(res - 1), activation))
428
+
429
+ if mbstd_num_channels > 0:
430
+ Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
431
+ Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation))
432
+ self.Dis = nn.Sequential(*Dis)
433
+
434
+ self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation)
435
+ self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
436
+
437
+ def forward(self, images_in, masks_in, c):
438
+ x = torch.cat([masks_in - 0.5, images_in], dim=1)
439
+ x = self.Dis(x)
440
+ x = self.fc1(self.fc0(x.flatten(start_dim=1)))
441
+
442
+ if self.c_dim > 0:
443
+ cmap = self.mapping(None, c)
444
+
445
+ if self.cmap_dim > 0:
446
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
447
+
448
+ return x
449
+
450
+
451
+ def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512):
452
+ NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512}
453
+ return NF[2 ** stage]
454
+
455
+
456
+ class Mlp(nn.Module):
457
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
458
+ super().__init__()
459
+ out_features = out_features or in_features
460
+ hidden_features = hidden_features or in_features
461
+ self.fc1 = FullyConnectedLayer(in_features=in_features, out_features=hidden_features, activation='lrelu')
462
+ self.fc2 = FullyConnectedLayer(in_features=hidden_features, out_features=out_features)
463
+
464
+ def forward(self, x):
465
+ x = self.fc1(x)
466
+ x = self.fc2(x)
467
+ return x
468
+
469
+
470
+ def window_partition(x, window_size):
471
+ """
472
+ Args:
473
+ x: (B, H, W, C)
474
+ window_size (int): window size
475
+ Returns:
476
+ windows: (num_windows*B, window_size, window_size, C)
477
+ """
478
+ B, H, W, C = x.shape
479
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
480
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
481
+ return windows
482
+
483
+
484
+ def window_reverse(windows, window_size: int, H: int, W: int):
485
+ """
486
+ Args:
487
+ windows: (num_windows*B, window_size, window_size, C)
488
+ window_size (int): Window size
489
+ H (int): Height of image
490
+ W (int): Width of image
491
+ Returns:
492
+ x: (B, H, W, C)
493
+ """
494
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
495
+ # B = windows.shape[0] / (H * W / window_size / window_size)
496
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
497
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
498
+ return x
499
+
500
+
501
+ class Conv2dLayerPartial(nn.Module):
502
+ def __init__(self,
503
+ in_channels, # Number of input channels.
504
+ out_channels, # Number of output channels.
505
+ kernel_size, # Width and height of the convolution kernel.
506
+ bias=True, # Apply additive bias before the activation function?
507
+ activation='linear', # Activation function: 'relu', 'lrelu', etc.
508
+ up=1, # Integer upsampling factor.
509
+ down=1, # Integer downsampling factor.
510
+ resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
511
+ conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
512
+ trainable=True, # Update the weights of this layer during training?
513
+ ):
514
+ super().__init__()
515
+ self.conv = Conv2dLayer(in_channels, out_channels, kernel_size, bias, activation, up, down, resample_filter,
516
+ conv_clamp, trainable)
517
+
518
+ self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size)
519
+ self.slide_winsize = kernel_size ** 2
520
+ self.stride = down
521
+ self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0
522
+
523
+ def forward(self, x, mask=None):
524
+ if mask is not None:
525
+ with torch.no_grad():
526
+ if self.weight_maskUpdater.type() != x.type():
527
+ self.weight_maskUpdater = self.weight_maskUpdater.to(x)
528
+ update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride,
529
+ padding=self.padding)
530
+ mask_ratio = self.slide_winsize / (update_mask + 1e-8)
531
+ update_mask = torch.clamp(update_mask, 0, 1) # 0 or 1
532
+ mask_ratio = torch.mul(mask_ratio, update_mask)
533
+ x = self.conv(x)
534
+ x = torch.mul(x, mask_ratio)
535
+ return x, update_mask
536
+ else:
537
+ x = self.conv(x)
538
+ return x, None
539
+
540
+
541
+ class WindowAttention(nn.Module):
542
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
543
+ It supports both of shifted and non-shifted window.
544
+ Args:
545
+ dim (int): Number of input channels.
546
+ window_size (tuple[int]): The height and width of the window.
547
+ num_heads (int): Number of attention heads.
548
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
549
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
550
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
551
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
552
+ """
553
+
554
+ def __init__(self, dim, window_size, num_heads, down_ratio=1, qkv_bias=True, qk_scale=None, attn_drop=0.,
555
+ proj_drop=0.):
556
+
557
+ super().__init__()
558
+ self.dim = dim
559
+ self.window_size = window_size # Wh, Ww
560
+ self.num_heads = num_heads
561
+ head_dim = dim // num_heads
562
+ self.scale = qk_scale or head_dim ** -0.5
563
+
564
+ self.q = FullyConnectedLayer(in_features=dim, out_features=dim)
565
+ self.k = FullyConnectedLayer(in_features=dim, out_features=dim)
566
+ self.v = FullyConnectedLayer(in_features=dim, out_features=dim)
567
+ self.proj = FullyConnectedLayer(in_features=dim, out_features=dim)
568
+
569
+ self.softmax = nn.Softmax(dim=-1)
570
+
571
+ def forward(self, x, mask_windows=None, mask=None):
572
+ """
573
+ Args:
574
+ x: input features with shape of (num_windows*B, N, C)
575
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
576
+ """
577
+ B_, N, C = x.shape
578
+ norm_x = F.normalize(x, p=2.0, dim=-1)
579
+ q = self.q(norm_x).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
580
+ k = self.k(norm_x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 3, 1)
581
+ v = self.v(x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
582
+
583
+ attn = (q @ k) * self.scale
584
+
585
+ if mask is not None:
586
+ nW = mask.shape[0]
587
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
588
+ attn = attn.view(-1, self.num_heads, N, N)
589
+
590
+ if mask_windows is not None:
591
+ attn_mask_windows = mask_windows.squeeze(-1).unsqueeze(1).unsqueeze(1)
592
+ attn = attn + attn_mask_windows.masked_fill(attn_mask_windows == 0, float(-100.0)).masked_fill(
593
+ attn_mask_windows == 1, float(0.0))
594
+ with torch.no_grad():
595
+ mask_windows = torch.clamp(torch.sum(mask_windows, dim=1, keepdim=True), 0, 1).repeat(1, N, 1)
596
+
597
+ attn = self.softmax(attn)
598
+
599
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
600
+ x = self.proj(x)
601
+ return x, mask_windows
602
+
603
+
604
+ class SwinTransformerBlock(nn.Module):
605
+ r""" Swin Transformer Block.
606
+ Args:
607
+ dim (int): Number of input channels.
608
+ input_resolution (tuple[int]): Input resulotion.
609
+ num_heads (int): Number of attention heads.
610
+ window_size (int): Window size.
611
+ shift_size (int): Shift size for SW-MSA.
612
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
613
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
614
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
615
+ drop (float, optional): Dropout rate. Default: 0.0
616
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
617
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
618
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
619
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
620
+ """
621
+
622
+ def __init__(self, dim, input_resolution, num_heads, down_ratio=1, window_size=7, shift_size=0,
623
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
624
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
625
+ super().__init__()
626
+ self.dim = dim
627
+ self.input_resolution = input_resolution
628
+ self.num_heads = num_heads
629
+ self.window_size = window_size
630
+ self.shift_size = shift_size
631
+ self.mlp_ratio = mlp_ratio
632
+ if min(self.input_resolution) <= self.window_size:
633
+ # if window size is larger than input resolution, we don't partition windows
634
+ self.shift_size = 0
635
+ self.window_size = min(self.input_resolution)
636
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
637
+
638
+ if self.shift_size > 0:
639
+ down_ratio = 1
640
+ self.attn = WindowAttention(dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
641
+ down_ratio=down_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
642
+ proj_drop=drop)
643
+
644
+ self.fuse = FullyConnectedLayer(in_features=dim * 2, out_features=dim, activation='lrelu')
645
+
646
+ mlp_hidden_dim = int(dim * mlp_ratio)
647
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
648
+
649
+ if self.shift_size > 0:
650
+ attn_mask = self.calculate_mask(self.input_resolution)
651
+ else:
652
+ attn_mask = None
653
+
654
+ self.register_buffer("attn_mask", attn_mask)
655
+
656
+ def calculate_mask(self, x_size):
657
+ # calculate attention mask for SW-MSA
658
+ H, W = x_size
659
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
660
+ h_slices = (slice(0, -self.window_size),
661
+ slice(-self.window_size, -self.shift_size),
662
+ slice(-self.shift_size, None))
663
+ w_slices = (slice(0, -self.window_size),
664
+ slice(-self.window_size, -self.shift_size),
665
+ slice(-self.shift_size, None))
666
+ cnt = 0
667
+ for h in h_slices:
668
+ for w in w_slices:
669
+ img_mask[:, h, w, :] = cnt
670
+ cnt += 1
671
+
672
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
673
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
674
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
675
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
676
+
677
+ return attn_mask
678
+
679
+ def forward(self, x, x_size, mask=None):
680
+ # H, W = self.input_resolution
681
+ H, W = x_size
682
+ B, L, C = x.shape
683
+ # assert L == H * W, "input feature has wrong size"
684
+
685
+ shortcut = x
686
+ x = x.view(B, H, W, C)
687
+ if mask is not None:
688
+ mask = mask.view(B, H, W, 1)
689
+
690
+ # cyclic shift
691
+ if self.shift_size > 0:
692
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
693
+ if mask is not None:
694
+ shifted_mask = torch.roll(mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
695
+ else:
696
+ shifted_x = x
697
+ if mask is not None:
698
+ shifted_mask = mask
699
+
700
+ # partition windows
701
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
702
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
703
+ if mask is not None:
704
+ mask_windows = window_partition(shifted_mask, self.window_size)
705
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size, 1)
706
+ else:
707
+ mask_windows = None
708
+
709
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
710
+ if self.input_resolution == x_size:
711
+ attn_windows, mask_windows = self.attn(x_windows, mask_windows,
712
+ mask=self.attn_mask) # nW*B, window_size*window_size, C
713
+ else:
714
+ attn_windows, mask_windows = self.attn(x_windows, mask_windows, mask=self.calculate_mask(x_size).to(
715
+ x.device)) # nW*B, window_size*window_size, C
716
+
717
+ # merge windows
718
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
719
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
720
+ if mask is not None:
721
+ mask_windows = mask_windows.view(-1, self.window_size, self.window_size, 1)
722
+ shifted_mask = window_reverse(mask_windows, self.window_size, H, W)
723
+
724
+ # reverse cyclic shift
725
+ if self.shift_size > 0:
726
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
727
+ if mask is not None:
728
+ mask = torch.roll(shifted_mask, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
729
+ else:
730
+ x = shifted_x
731
+ if mask is not None:
732
+ mask = shifted_mask
733
+ x = x.view(B, H * W, C)
734
+ if mask is not None:
735
+ mask = mask.view(B, H * W, 1)
736
+
737
+ # FFN
738
+ x = self.fuse(torch.cat([shortcut, x], dim=-1))
739
+ x = self.mlp(x)
740
+
741
+ return x, mask
742
+
743
+
744
+ class PatchMerging(nn.Module):
745
+ def __init__(self, in_channels, out_channels, down=2):
746
+ super().__init__()
747
+ self.conv = Conv2dLayerPartial(in_channels=in_channels,
748
+ out_channels=out_channels,
749
+ kernel_size=3,
750
+ activation='lrelu',
751
+ down=down,
752
+ )
753
+ self.down = down
754
+
755
+ def forward(self, x, x_size, mask=None):
756
+ x = token2feature(x, x_size)
757
+ if mask is not None:
758
+ mask = token2feature(mask, x_size)
759
+ x, mask = self.conv(x, mask)
760
+ if self.down != 1:
761
+ ratio = 1 / self.down
762
+ x_size = (int(x_size[0] * ratio), int(x_size[1] * ratio))
763
+ x = feature2token(x)
764
+ if mask is not None:
765
+ mask = feature2token(mask)
766
+ return x, x_size, mask
767
+
768
+
769
+ class PatchUpsampling(nn.Module):
770
+ def __init__(self, in_channels, out_channels, up=2):
771
+ super().__init__()
772
+ self.conv = Conv2dLayerPartial(in_channels=in_channels,
773
+ out_channels=out_channels,
774
+ kernel_size=3,
775
+ activation='lrelu',
776
+ up=up,
777
+ )
778
+ self.up = up
779
+
780
+ def forward(self, x, x_size, mask=None):
781
+ x = token2feature(x, x_size)
782
+ if mask is not None:
783
+ mask = token2feature(mask, x_size)
784
+ x, mask = self.conv(x, mask)
785
+ if self.up != 1:
786
+ x_size = (int(x_size[0] * self.up), int(x_size[1] * self.up))
787
+ x = feature2token(x)
788
+ if mask is not None:
789
+ mask = feature2token(mask)
790
+ return x, x_size, mask
791
+
792
+
793
+ class BasicLayer(nn.Module):
794
+ """ A basic Swin Transformer layer for one stage.
795
+ Args:
796
+ dim (int): Number of input channels.
797
+ input_resolution (tuple[int]): Input resolution.
798
+ depth (int): Number of blocks.
799
+ num_heads (int): Number of attention heads.
800
+ window_size (int): Local window size.
801
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
802
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
803
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
804
+ drop (float, optional): Dropout rate. Default: 0.0
805
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
806
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
807
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
808
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
809
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
810
+ """
811
+
812
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size, down_ratio=1,
813
+ mlp_ratio=2., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
814
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
815
+
816
+ super().__init__()
817
+ self.dim = dim
818
+ self.input_resolution = input_resolution
819
+ self.depth = depth
820
+ self.use_checkpoint = use_checkpoint
821
+
822
+ # patch merging layer
823
+ if downsample is not None:
824
+ # self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
825
+ self.downsample = downsample
826
+ else:
827
+ self.downsample = None
828
+
829
+ # build blocks
830
+ self.blocks = nn.ModuleList([
831
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
832
+ num_heads=num_heads, down_ratio=down_ratio, window_size=window_size,
833
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
834
+ mlp_ratio=mlp_ratio,
835
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
836
+ drop=drop, attn_drop=attn_drop,
837
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
838
+ norm_layer=norm_layer)
839
+ for i in range(depth)])
840
+
841
+ self.conv = Conv2dLayerPartial(in_channels=dim, out_channels=dim, kernel_size=3, activation='lrelu')
842
+
843
+ def forward(self, x, x_size, mask=None):
844
+ if self.downsample is not None:
845
+ x, x_size, mask = self.downsample(x, x_size, mask)
846
+ identity = x
847
+ for blk in self.blocks:
848
+ if self.use_checkpoint:
849
+ x, mask = checkpoint.checkpoint(blk, x, x_size, mask)
850
+ else:
851
+ x, mask = blk(x, x_size, mask)
852
+ if mask is not None:
853
+ mask = token2feature(mask, x_size)
854
+ x, mask = self.conv(token2feature(x, x_size), mask)
855
+ x = feature2token(x) + identity
856
+ if mask is not None:
857
+ mask = feature2token(mask)
858
+ return x, x_size, mask
859
+
860
+
861
+ class ToToken(nn.Module):
862
+ def __init__(self, in_channels=3, dim=128, kernel_size=5, stride=1):
863
+ super().__init__()
864
+
865
+ self.proj = Conv2dLayerPartial(in_channels=in_channels, out_channels=dim, kernel_size=kernel_size,
866
+ activation='lrelu')
867
+
868
+ def forward(self, x, mask):
869
+ x, mask = self.proj(x, mask)
870
+
871
+ return x, mask
872
+
873
+
874
+ class EncFromRGB(nn.Module):
875
+ def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2
876
+ super().__init__()
877
+ self.conv0 = Conv2dLayer(in_channels=in_channels,
878
+ out_channels=out_channels,
879
+ kernel_size=1,
880
+ activation=activation,
881
+ )
882
+ self.conv1 = Conv2dLayer(in_channels=out_channels,
883
+ out_channels=out_channels,
884
+ kernel_size=3,
885
+ activation=activation,
886
+ )
887
+
888
+ def forward(self, x):
889
+ x = self.conv0(x)
890
+ x = self.conv1(x)
891
+
892
+ return x
893
+
894
+
895
+ class ConvBlockDown(nn.Module):
896
+ def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log
897
+ super().__init__()
898
+
899
+ self.conv0 = Conv2dLayer(in_channels=in_channels,
900
+ out_channels=out_channels,
901
+ kernel_size=3,
902
+ activation=activation,
903
+ down=2,
904
+ )
905
+ self.conv1 = Conv2dLayer(in_channels=out_channels,
906
+ out_channels=out_channels,
907
+ kernel_size=3,
908
+ activation=activation,
909
+ )
910
+
911
+ def forward(self, x):
912
+ x = self.conv0(x)
913
+ x = self.conv1(x)
914
+
915
+ return x
916
+
917
+
918
+ def token2feature(x, x_size):
919
+ B, N, C = x.shape
920
+ h, w = x_size
921
+ x = x.permute(0, 2, 1).reshape(B, C, h, w)
922
+ return x
923
+
924
+
925
+ def feature2token(x):
926
+ B, C, H, W = x.shape
927
+ x = x.view(B, C, -1).transpose(1, 2)
928
+ return x
929
+
930
+
931
+ class Encoder(nn.Module):
932
+ def __init__(self, res_log2, img_channels, activation, patch_size=5, channels=16, drop_path_rate=0.1):
933
+ super().__init__()
934
+
935
+ self.resolution = []
936
+
937
+ for idx, i in enumerate(range(res_log2, 3, -1)): # from input size to 16x16
938
+ res = 2 ** i
939
+ self.resolution.append(res)
940
+ if i == res_log2:
941
+ block = EncFromRGB(img_channels * 2 + 1, nf(i), activation)
942
+ else:
943
+ block = ConvBlockDown(nf(i + 1), nf(i), activation)
944
+ setattr(self, 'EncConv_Block_%dx%d' % (res, res), block)
945
+
946
+ def forward(self, x):
947
+ out = {}
948
+ for res in self.resolution:
949
+ res_log2 = int(np.log2(res))
950
+ x = getattr(self, 'EncConv_Block_%dx%d' % (res, res))(x)
951
+ out[res_log2] = x
952
+
953
+ return out
954
+
955
+
956
+ class ToStyle(nn.Module):
957
+ def __init__(self, in_channels, out_channels, activation, drop_rate):
958
+ super().__init__()
959
+ self.conv = nn.Sequential(
960
+ Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation,
961
+ down=2),
962
+ Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation,
963
+ down=2),
964
+ Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation,
965
+ down=2),
966
+ )
967
+
968
+ self.pool = nn.AdaptiveAvgPool2d(1)
969
+ self.fc = FullyConnectedLayer(in_features=in_channels,
970
+ out_features=out_channels,
971
+ activation=activation)
972
+ # self.dropout = nn.Dropout(drop_rate)
973
+
974
+ def forward(self, x):
975
+ x = self.conv(x)
976
+ x = self.pool(x)
977
+ x = self.fc(x.flatten(start_dim=1))
978
+ # x = self.dropout(x)
979
+
980
+ return x
981
+
982
+
983
+ class DecBlockFirstV2(nn.Module):
984
+ def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
985
+ super().__init__()
986
+ self.res = res
987
+
988
+ self.conv0 = Conv2dLayer(in_channels=in_channels,
989
+ out_channels=in_channels,
990
+ kernel_size=3,
991
+ activation=activation,
992
+ )
993
+ self.conv1 = StyleConv(in_channels=in_channels,
994
+ out_channels=out_channels,
995
+ style_dim=style_dim,
996
+ resolution=2 ** res,
997
+ kernel_size=3,
998
+ use_noise=use_noise,
999
+ activation=activation,
1000
+ demodulate=demodulate,
1001
+ )
1002
+ self.toRGB = ToRGB(in_channels=out_channels,
1003
+ out_channels=img_channels,
1004
+ style_dim=style_dim,
1005
+ kernel_size=1,
1006
+ demodulate=False,
1007
+ )
1008
+
1009
+ def forward(self, x, ws, gs, E_features, noise_mode='random'):
1010
+ # x = self.fc(x).view(x.shape[0], -1, 4, 4)
1011
+ x = self.conv0(x)
1012
+ x = x + E_features[self.res]
1013
+ style = get_style_code(ws[:, 0], gs)
1014
+ x = self.conv1(x, style, noise_mode=noise_mode)
1015
+ style = get_style_code(ws[:, 1], gs)
1016
+ img = self.toRGB(x, style, skip=None)
1017
+
1018
+ return x, img
1019
+
1020
+
1021
+ class DecBlock(nn.Module):
1022
+ def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate,
1023
+ img_channels): # res = 4, ..., resolution_log2
1024
+ super().__init__()
1025
+ self.res = res
1026
+
1027
+ self.conv0 = StyleConv(in_channels=in_channels,
1028
+ out_channels=out_channels,
1029
+ style_dim=style_dim,
1030
+ resolution=2 ** res,
1031
+ kernel_size=3,
1032
+ up=2,
1033
+ use_noise=use_noise,
1034
+ activation=activation,
1035
+ demodulate=demodulate,
1036
+ )
1037
+ self.conv1 = StyleConv(in_channels=out_channels,
1038
+ out_channels=out_channels,
1039
+ style_dim=style_dim,
1040
+ resolution=2 ** res,
1041
+ kernel_size=3,
1042
+ use_noise=use_noise,
1043
+ activation=activation,
1044
+ demodulate=demodulate,
1045
+ )
1046
+ self.toRGB = ToRGB(in_channels=out_channels,
1047
+ out_channels=img_channels,
1048
+ style_dim=style_dim,
1049
+ kernel_size=1,
1050
+ demodulate=False,
1051
+ )
1052
+
1053
+ def forward(self, x, img, ws, gs, E_features, noise_mode='random'):
1054
+ style = get_style_code(ws[:, self.res * 2 - 9], gs)
1055
+ x = self.conv0(x, style, noise_mode=noise_mode)
1056
+ x = x + E_features[self.res]
1057
+ style = get_style_code(ws[:, self.res * 2 - 8], gs)
1058
+ x = self.conv1(x, style, noise_mode=noise_mode)
1059
+ style = get_style_code(ws[:, self.res * 2 - 7], gs)
1060
+ img = self.toRGB(x, style, skip=img)
1061
+
1062
+ return x, img
1063
+
1064
+
1065
+ class Decoder(nn.Module):
1066
+ def __init__(self, res_log2, activation, style_dim, use_noise, demodulate, img_channels):
1067
+ super().__init__()
1068
+ self.Dec_16x16 = DecBlockFirstV2(4, nf(4), nf(4), activation, style_dim, use_noise, demodulate, img_channels)
1069
+ for res in range(5, res_log2 + 1):
1070
+ setattr(self, 'Dec_%dx%d' % (2 ** res, 2 ** res),
1071
+ DecBlock(res, nf(res - 1), nf(res), activation, style_dim, use_noise, demodulate, img_channels))
1072
+ self.res_log2 = res_log2
1073
+
1074
+ def forward(self, x, ws, gs, E_features, noise_mode='random'):
1075
+ x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode)
1076
+ for res in range(5, self.res_log2 + 1):
1077
+ block = getattr(self, 'Dec_%dx%d' % (2 ** res, 2 ** res))
1078
+ x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode)
1079
+
1080
+ return img
1081
+
1082
+
1083
+ class DecStyleBlock(nn.Module):
1084
+ def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
1085
+ super().__init__()
1086
+ self.res = res
1087
+
1088
+ self.conv0 = StyleConv(in_channels=in_channels,
1089
+ out_channels=out_channels,
1090
+ style_dim=style_dim,
1091
+ resolution=2 ** res,
1092
+ kernel_size=3,
1093
+ up=2,
1094
+ use_noise=use_noise,
1095
+ activation=activation,
1096
+ demodulate=demodulate,
1097
+ )
1098
+ self.conv1 = StyleConv(in_channels=out_channels,
1099
+ out_channels=out_channels,
1100
+ style_dim=style_dim,
1101
+ resolution=2 ** res,
1102
+ kernel_size=3,
1103
+ use_noise=use_noise,
1104
+ activation=activation,
1105
+ demodulate=demodulate,
1106
+ )
1107
+ self.toRGB = ToRGB(in_channels=out_channels,
1108
+ out_channels=img_channels,
1109
+ style_dim=style_dim,
1110
+ kernel_size=1,
1111
+ demodulate=False,
1112
+ )
1113
+
1114
+ def forward(self, x, img, style, skip, noise_mode='random'):
1115
+ x = self.conv0(x, style, noise_mode=noise_mode)
1116
+ x = x + skip
1117
+ x = self.conv1(x, style, noise_mode=noise_mode)
1118
+ img = self.toRGB(x, style, skip=img)
1119
+
1120
+ return x, img
1121
+
1122
+
1123
+ class FirstStage(nn.Module):
1124
+ def __init__(self, img_channels, img_resolution=256, dim=180, w_dim=512, use_noise=False, demodulate=True,
1125
+ activation='lrelu'):
1126
+ super().__init__()
1127
+ res = 64
1128
+
1129
+ self.conv_first = Conv2dLayerPartial(in_channels=img_channels + 1, out_channels=dim, kernel_size=3,
1130
+ activation=activation)
1131
+ self.enc_conv = nn.ModuleList()
1132
+ down_time = int(np.log2(img_resolution // res))
1133
+ # 根据图片尺寸构建 swim transformer 的层数
1134
+ for i in range(down_time): # from input size to 64
1135
+ self.enc_conv.append(
1136
+ Conv2dLayerPartial(in_channels=dim, out_channels=dim, kernel_size=3, down=2, activation=activation)
1137
+ )
1138
+
1139
+ # from 64 -> 16 -> 64
1140
+ depths = [2, 3, 4, 3, 2]
1141
+ ratios = [1, 1 / 2, 1 / 2, 2, 2]
1142
+ num_heads = 6
1143
+ window_sizes = [8, 16, 16, 16, 8]
1144
+ drop_path_rate = 0.1
1145
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
1146
+
1147
+ self.tran = nn.ModuleList()
1148
+ for i, depth in enumerate(depths):
1149
+ res = int(res * ratios[i])
1150
+ if ratios[i] < 1:
1151
+ merge = PatchMerging(dim, dim, down=int(1 / ratios[i]))
1152
+ elif ratios[i] > 1:
1153
+ merge = PatchUpsampling(dim, dim, up=ratios[i])
1154
+ else:
1155
+ merge = None
1156
+ self.tran.append(
1157
+ BasicLayer(dim=dim, input_resolution=[res, res], depth=depth, num_heads=num_heads,
1158
+ window_size=window_sizes[i], drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
1159
+ downsample=merge)
1160
+ )
1161
+
1162
+ # global style
1163
+ down_conv = []
1164
+ for i in range(int(np.log2(16))):
1165
+ down_conv.append(
1166
+ Conv2dLayer(in_channels=dim, out_channels=dim, kernel_size=3, down=2, activation=activation))
1167
+ down_conv.append(nn.AdaptiveAvgPool2d((1, 1)))
1168
+ self.down_conv = nn.Sequential(*down_conv)
1169
+ self.to_style = FullyConnectedLayer(in_features=dim, out_features=dim * 2, activation=activation)
1170
+ self.ws_style = FullyConnectedLayer(in_features=w_dim, out_features=dim, activation=activation)
1171
+ self.to_square = FullyConnectedLayer(in_features=dim, out_features=16 * 16, activation=activation)
1172
+
1173
+ style_dim = dim * 3
1174
+ self.dec_conv = nn.ModuleList()
1175
+ for i in range(down_time): # from 64 to input size
1176
+ res = res * 2
1177
+ self.dec_conv.append(
1178
+ DecStyleBlock(res, dim, dim, activation, style_dim, use_noise, demodulate, img_channels))
1179
+
1180
+ def forward(self, images_in, masks_in, ws, noise_mode='random'):
1181
+ x = torch.cat([masks_in - 0.5, images_in * masks_in], dim=1)
1182
+
1183
+ skips = []
1184
+ x, mask = self.conv_first(x, masks_in) # input size
1185
+ skips.append(x)
1186
+ for i, block in enumerate(self.enc_conv): # input size to 64
1187
+ x, mask = block(x, mask)
1188
+ if i != len(self.enc_conv) - 1:
1189
+ skips.append(x)
1190
+
1191
+ x_size = x.size()[-2:]
1192
+ x = feature2token(x)
1193
+ mask = feature2token(mask)
1194
+ mid = len(self.tran) // 2
1195
+ for i, block in enumerate(self.tran): # 64 to 16
1196
+ if i < mid:
1197
+ x, x_size, mask = block(x, x_size, mask)
1198
+ skips.append(x)
1199
+ elif i > mid:
1200
+ x, x_size, mask = block(x, x_size, None)
1201
+ x = x + skips[mid - i]
1202
+ else:
1203
+ x, x_size, mask = block(x, x_size, None)
1204
+
1205
+ mul_map = torch.ones_like(x) * 0.5
1206
+ mul_map = F.dropout(mul_map, training=True)
1207
+ ws = self.ws_style(ws[:, -1])
1208
+ add_n = self.to_square(ws).unsqueeze(1)
1209
+ add_n = F.interpolate(add_n, size=x.size(1), mode='linear', align_corners=False).squeeze(1).unsqueeze(
1210
+ -1)
1211
+ x = x * mul_map + add_n * (1 - mul_map)
1212
+ gs = self.to_style(self.down_conv(token2feature(x, x_size)).flatten(start_dim=1))
1213
+ style = torch.cat([gs, ws], dim=1)
1214
+
1215
+ x = token2feature(x, x_size).contiguous()
1216
+ img = None
1217
+ for i, block in enumerate(self.dec_conv):
1218
+ x, img = block(x, img, style, skips[len(self.dec_conv) - i - 1], noise_mode=noise_mode)
1219
+
1220
+ # ensemble
1221
+ img = img * (1 - masks_in) + images_in * masks_in
1222
+
1223
+ return img
1224
+
1225
+
1226
+ class SynthesisNet(nn.Module):
1227
+ def __init__(self,
1228
+ w_dim, # Intermediate latent (W) dimensionality.
1229
+ img_resolution, # Output image resolution.
1230
+ img_channels=3, # Number of color channels.
1231
+ channel_base=32768, # Overall multiplier for the number of channels.
1232
+ channel_decay=1.0,
1233
+ channel_max=512, # Maximum number of channels in any layer.
1234
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
1235
+ drop_rate=0.5,
1236
+ use_noise=False,
1237
+ demodulate=True,
1238
+ ):
1239
+ super().__init__()
1240
+ resolution_log2 = int(np.log2(img_resolution))
1241
+ assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
1242
+
1243
+ self.num_layers = resolution_log2 * 2 - 3 * 2
1244
+ self.img_resolution = img_resolution
1245
+ self.resolution_log2 = resolution_log2
1246
+
1247
+ # first stage
1248
+ self.first_stage = FirstStage(img_channels, img_resolution=img_resolution, w_dim=w_dim, use_noise=False,
1249
+ demodulate=demodulate)
1250
+
1251
+ # second stage
1252
+ self.enc = Encoder(resolution_log2, img_channels, activation, patch_size=5, channels=16)
1253
+ self.to_square = FullyConnectedLayer(in_features=w_dim, out_features=16 * 16, activation=activation)
1254
+ self.to_style = ToStyle(in_channels=nf(4), out_channels=nf(2) * 2, activation=activation, drop_rate=drop_rate)
1255
+ style_dim = w_dim + nf(2) * 2
1256
+ self.dec = Decoder(resolution_log2, activation, style_dim, use_noise, demodulate, img_channels)
1257
+
1258
+ def forward(self, images_in, masks_in, ws, noise_mode='random', return_stg1=False):
1259
+ out_stg1 = self.first_stage(images_in, masks_in, ws, noise_mode=noise_mode)
1260
+
1261
+ # encoder
1262
+ x = images_in * masks_in + out_stg1 * (1 - masks_in)
1263
+ x = torch.cat([masks_in - 0.5, x, images_in * masks_in], dim=1)
1264
+ E_features = self.enc(x)
1265
+
1266
+ fea_16 = E_features[4]
1267
+ mul_map = torch.ones_like(fea_16) * 0.5
1268
+ mul_map = F.dropout(mul_map, training=True)
1269
+ add_n = self.to_square(ws[:, 0]).view(-1, 16, 16).unsqueeze(1)
1270
+ add_n = F.interpolate(add_n, size=fea_16.size()[-2:], mode='bilinear', align_corners=False)
1271
+ fea_16 = fea_16 * mul_map + add_n * (1 - mul_map)
1272
+ E_features[4] = fea_16
1273
+
1274
+ # style
1275
+ gs = self.to_style(fea_16)
1276
+
1277
+ # decoder
1278
+ img = self.dec(fea_16, ws, gs, E_features, noise_mode=noise_mode)
1279
+
1280
+ # ensemble
1281
+ img = img * (1 - masks_in) + images_in * masks_in
1282
+
1283
+ if not return_stg1:
1284
+ return img
1285
+ else:
1286
+ return img, out_stg1
1287
+
1288
+
1289
+ class Generator(nn.Module):
1290
+ def __init__(self,
1291
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
1292
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
1293
+ w_dim, # Intermediate latent (W) dimensionality.
1294
+ img_resolution, # resolution of generated image
1295
+ img_channels, # Number of input color channels.
1296
+ synthesis_kwargs={}, # Arguments for SynthesisNetwork.
1297
+ mapping_kwargs={}, # Arguments for MappingNetwork.
1298
+ ):
1299
+ super().__init__()
1300
+ self.z_dim = z_dim
1301
+ self.c_dim = c_dim
1302
+ self.w_dim = w_dim
1303
+ self.img_resolution = img_resolution
1304
+ self.img_channels = img_channels
1305
+
1306
+ self.synthesis = SynthesisNet(w_dim=w_dim,
1307
+ img_resolution=img_resolution,
1308
+ img_channels=img_channels,
1309
+ **synthesis_kwargs)
1310
+ self.mapping = MappingNet(z_dim=z_dim,
1311
+ c_dim=c_dim,
1312
+ w_dim=w_dim,
1313
+ num_ws=self.synthesis.num_layers,
1314
+ **mapping_kwargs)
1315
+
1316
+ def forward(self, images_in, masks_in, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False,
1317
+ noise_mode='none', return_stg1=False):
1318
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff,
1319
+ skip_w_avg_update=skip_w_avg_update)
1320
+ img = self.synthesis(images_in, masks_in, ws, noise_mode=noise_mode)
1321
+ return img
1322
+
1323
+
1324
+ class Discriminator(torch.nn.Module):
1325
+ def __init__(self,
1326
+ c_dim, # Conditioning label (C) dimensionality.
1327
+ img_resolution, # Input resolution.
1328
+ img_channels, # Number of input color channels.
1329
+ channel_base=32768, # Overall multiplier for the number of channels.
1330
+ channel_max=512, # Maximum number of channels in any layer.
1331
+ channel_decay=1,
1332
+ cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
1333
+ activation='lrelu',
1334
+ mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
1335
+ mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
1336
+ ):
1337
+ super().__init__()
1338
+ self.c_dim = c_dim
1339
+ self.img_resolution = img_resolution
1340
+ self.img_channels = img_channels
1341
+
1342
+ resolution_log2 = int(np.log2(img_resolution))
1343
+ assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
1344
+ self.resolution_log2 = resolution_log2
1345
+
1346
+ if cmap_dim == None:
1347
+ cmap_dim = nf(2)
1348
+ if c_dim == 0:
1349
+ cmap_dim = 0
1350
+ self.cmap_dim = cmap_dim
1351
+
1352
+ if c_dim > 0:
1353
+ self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None)
1354
+
1355
+ Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)]
1356
+ for res in range(resolution_log2, 2, -1):
1357
+ Dis.append(DisBlock(nf(res), nf(res - 1), activation))
1358
+
1359
+ if mbstd_num_channels > 0:
1360
+ Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
1361
+ Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation))
1362
+ self.Dis = nn.Sequential(*Dis)
1363
+
1364
+ self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation)
1365
+ self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
1366
+
1367
+ # for 64x64
1368
+ Dis_stg1 = [DisFromRGB(img_channels + 1, nf(resolution_log2) // 2, activation)]
1369
+ for res in range(resolution_log2, 2, -1):
1370
+ Dis_stg1.append(DisBlock(nf(res) // 2, nf(res - 1) // 2, activation))
1371
+
1372
+ if mbstd_num_channels > 0:
1373
+ Dis_stg1.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
1374
+ Dis_stg1.append(Conv2dLayer(nf(2) // 2 + mbstd_num_channels, nf(2) // 2, kernel_size=3, activation=activation))
1375
+ self.Dis_stg1 = nn.Sequential(*Dis_stg1)
1376
+
1377
+ self.fc0_stg1 = FullyConnectedLayer(nf(2) // 2 * 4 ** 2, nf(2) // 2, activation=activation)
1378
+ self.fc1_stg1 = FullyConnectedLayer(nf(2) // 2, 1 if cmap_dim == 0 else cmap_dim)
1379
+
1380
+ def forward(self, images_in, masks_in, images_stg1, c):
1381
+ x = self.Dis(torch.cat([masks_in - 0.5, images_in], dim=1))
1382
+ x = self.fc1(self.fc0(x.flatten(start_dim=1)))
1383
+
1384
+ x_stg1 = self.Dis_stg1(torch.cat([masks_in - 0.5, images_stg1], dim=1))
1385
+ x_stg1 = self.fc1_stg1(self.fc0_stg1(x_stg1.flatten(start_dim=1)))
1386
+
1387
+ if self.c_dim > 0:
1388
+ cmap = self.mapping(None, c)
1389
+
1390
+ if self.cmap_dim > 0:
1391
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
1392
+ x_stg1 = (x_stg1 * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
1393
+
1394
+ return x, x_stg1
1395
+
1396
+
1397
+ MAT_MODEL_URL = os.environ.get(
1398
+ "MAT_MODEL_URL",
1399
+ "https://github.com/Sanster/models/releases/download/add_mat/Places_512_FullData_G.pth",
1400
+ )
1401
+
1402
+
1403
+ class MAT(InpaintModel):
1404
+ min_size = 512
1405
+ pad_mod = 512
1406
+ pad_to_square = True
1407
+
1408
+ def init_model(self, device, **kwargs):
1409
+ seed = 240 # pick up a random number
1410
+ random.seed(seed)
1411
+ np.random.seed(seed)
1412
+ torch.manual_seed(seed)
1413
+
1414
+ G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3)
1415
+ self.model = load_model(G, MAT_MODEL_URL, device)
1416
+ self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(device) # [1., 512]
1417
+ self.label = torch.zeros([1, self.model.c_dim], device=device)
1418
+
1419
+ @staticmethod
1420
+ def is_downloaded() -> bool:
1421
+ return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL))
1422
+
1423
+ def forward(self, image, mask, config: Config):
1424
+ """Input images and output images have same size
1425
+ images: [H, W, C] RGB
1426
+ masks: [H, W] mask area == 255
1427
+ return: BGR IMAGE
1428
+ """
1429
+
1430
+ image = norm_img(image) # [0, 1]
1431
+ image = image * 2 - 1 # [0, 1] -> [-1, 1]
1432
+
1433
+ mask = (mask > 127) * 255
1434
+ mask = 255 - mask
1435
+ mask = norm_img(mask)
1436
+
1437
+ image = torch.from_numpy(image).unsqueeze(0).to(self.device)
1438
+ mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
1439
+
1440
+ output = self.model(image, mask, self.z, self.label, truncation_psi=1, noise_mode='none')
1441
+ output = (output.permute(0, 2, 3, 1) * 127.5 + 127.5).round().clamp(0, 255).to(torch.uint8)
1442
+ output = output[0].cpu().numpy()
1443
+ cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
1444
+ return cur_res
lama_cleaner/model/opencv2.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from lama_cleaner.model.base import InpaintModel
3
+ from lama_cleaner.schema import Config
4
+
5
+ flag_map = {
6
+ "INPAINT_NS": cv2.INPAINT_NS,
7
+ "INPAINT_TELEA": cv2.INPAINT_TELEA
8
+ }
9
+
10
+ class OpenCV2(InpaintModel):
11
+ pad_mod = 1
12
+
13
+ @staticmethod
14
+ def is_downloaded() -> bool:
15
+ return True
16
+
17
+ def forward(self, image, mask, config: Config):
18
+ """Input image and output image have same size
19
+ image: [H, W, C] RGB
20
+ mask: [H, W, 1]
21
+ return: BGR IMAGE
22
+ """
23
+ cur_res = cv2.inpaint(image[:,:,::-1], mask, inpaintRadius=config.cv2_radius, flags=flag_map[config.cv2_flag])
24
+ return cur_res
lama_cleaner/model/plms_sampler.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py
2
+ import torch
3
+ import numpy as np
4
+ from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
5
+ from tqdm import tqdm
6
+
7
+
8
+ class PLMSSampler(object):
9
+ def __init__(self, model, schedule="linear", **kwargs):
10
+ super().__init__()
11
+ self.model = model
12
+ self.ddpm_num_timesteps = model.num_timesteps
13
+ self.schedule = schedule
14
+
15
+ def register_buffer(self, name, attr):
16
+ setattr(self, name, attr)
17
+
18
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
19
+ if ddim_eta != 0:
20
+ raise ValueError('ddim_eta must be 0 for PLMS')
21
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
22
+ num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
23
+ alphas_cumprod = self.model.alphas_cumprod
24
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
25
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
26
+
27
+ self.register_buffer('betas', to_torch(self.model.betas))
28
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
29
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
30
+
31
+ # calculations for diffusion q(x_t | x_{t-1}) and others
32
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
33
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
34
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
35
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
36
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
37
+
38
+ # ddim sampling parameters
39
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
40
+ ddim_timesteps=self.ddim_timesteps,
41
+ eta=ddim_eta, verbose=verbose)
42
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
43
+ self.register_buffer('ddim_alphas', ddim_alphas)
44
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
45
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
46
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
47
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
48
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
49
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
50
+
51
+ @torch.no_grad()
52
+ def sample(self,
53
+ steps,
54
+ batch_size,
55
+ shape,
56
+ conditioning=None,
57
+ callback=None,
58
+ normals_sequence=None,
59
+ img_callback=None,
60
+ quantize_x0=False,
61
+ eta=0.,
62
+ mask=None,
63
+ x0=None,
64
+ temperature=1.,
65
+ noise_dropout=0.,
66
+ score_corrector=None,
67
+ corrector_kwargs=None,
68
+ verbose=False,
69
+ x_T=None,
70
+ log_every_t=100,
71
+ unconditional_guidance_scale=1.,
72
+ unconditional_conditioning=None,
73
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
74
+ **kwargs
75
+ ):
76
+ if conditioning is not None:
77
+ if isinstance(conditioning, dict):
78
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
79
+ if cbs != batch_size:
80
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
81
+ else:
82
+ if conditioning.shape[0] != batch_size:
83
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
84
+
85
+ self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
86
+ # sampling
87
+ C, H, W = shape
88
+ size = (batch_size, C, H, W)
89
+ print(f'Data shape for PLMS sampling is {size}')
90
+
91
+ samples = self.plms_sampling(conditioning, size,
92
+ callback=callback,
93
+ img_callback=img_callback,
94
+ quantize_denoised=quantize_x0,
95
+ mask=mask, x0=x0,
96
+ ddim_use_original_steps=False,
97
+ noise_dropout=noise_dropout,
98
+ temperature=temperature,
99
+ score_corrector=score_corrector,
100
+ corrector_kwargs=corrector_kwargs,
101
+ x_T=x_T,
102
+ log_every_t=log_every_t,
103
+ unconditional_guidance_scale=unconditional_guidance_scale,
104
+ unconditional_conditioning=unconditional_conditioning,
105
+ )
106
+ return samples
107
+
108
+ @torch.no_grad()
109
+ def plms_sampling(self, cond, shape,
110
+ x_T=None, ddim_use_original_steps=False,
111
+ callback=None, timesteps=None, quantize_denoised=False,
112
+ mask=None, x0=None, img_callback=None, log_every_t=100,
113
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
114
+ unconditional_guidance_scale=1., unconditional_conditioning=None, ):
115
+ device = self.model.betas.device
116
+ b = shape[0]
117
+ if x_T is None:
118
+ img = torch.randn(shape, device=device)
119
+ else:
120
+ img = x_T
121
+
122
+ if timesteps is None:
123
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
124
+ elif timesteps is not None and not ddim_use_original_steps:
125
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
126
+ timesteps = self.ddim_timesteps[:subset_end]
127
+
128
+ time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
129
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
130
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
131
+
132
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
133
+ old_eps = []
134
+
135
+ for i, step in enumerate(iterator):
136
+ index = total_steps - i - 1
137
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
138
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
139
+
140
+ if mask is not None:
141
+ assert x0 is not None
142
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
143
+ img = img_orig * mask + (1. - mask) * img
144
+
145
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
146
+ quantize_denoised=quantize_denoised, temperature=temperature,
147
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
148
+ corrector_kwargs=corrector_kwargs,
149
+ unconditional_guidance_scale=unconditional_guidance_scale,
150
+ unconditional_conditioning=unconditional_conditioning,
151
+ old_eps=old_eps, t_next=ts_next)
152
+ img, pred_x0, e_t = outs
153
+ old_eps.append(e_t)
154
+ if len(old_eps) >= 4:
155
+ old_eps.pop(0)
156
+ if callback: callback(i)
157
+ if img_callback: img_callback(pred_x0, i)
158
+
159
+ return img
160
+
161
+ @torch.no_grad()
162
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
163
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
164
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
165
+ b, *_, device = *x.shape, x.device
166
+
167
+ def get_model_output(x, t):
168
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
169
+ e_t = self.model.apply_model(x, t, c)
170
+ else:
171
+ x_in = torch.cat([x] * 2)
172
+ t_in = torch.cat([t] * 2)
173
+ c_in = torch.cat([unconditional_conditioning, c])
174
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
175
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
176
+
177
+ if score_corrector is not None:
178
+ assert self.model.parameterization == "eps"
179
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
180
+
181
+ return e_t
182
+
183
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
184
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
185
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
186
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
187
+
188
+ def get_x_prev_and_pred_x0(e_t, index):
189
+ # select parameters corresponding to the currently considered timestep
190
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
191
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
192
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
193
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
194
+
195
+ # current prediction for x_0
196
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
197
+ if quantize_denoised:
198
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
199
+ # direction pointing to x_t
200
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
201
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
202
+ if noise_dropout > 0.:
203
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
204
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
205
+ return x_prev, pred_x0
206
+
207
+ e_t = get_model_output(x, t)
208
+ if len(old_eps) == 0:
209
+ # Pseudo Improved Euler (2nd order)
210
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
211
+ e_t_next = get_model_output(x_prev, t_next)
212
+ e_t_prime = (e_t + e_t_next) / 2
213
+ elif len(old_eps) == 1:
214
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
215
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
216
+ elif len(old_eps) == 2:
217
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
218
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
219
+ elif len(old_eps) >= 3:
220
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
221
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
222
+
223
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
224
+
225
+ return x_prev, pred_x0, e_t
lama_cleaner/model/sd.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import PIL.Image
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import PNDMScheduler, DDIMScheduler
8
+ from loguru import logger
9
+ from transformers import FeatureExtractionMixin, ImageFeatureExtractionMixin
10
+
11
+ from lama_cleaner.helper import norm_img
12
+
13
+ from lama_cleaner.model.base import InpaintModel
14
+ from lama_cleaner.schema import Config, SDSampler
15
+
16
+
17
+ #
18
+ #
19
+ # def preprocess_image(image):
20
+ # w, h = image.size
21
+ # w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
22
+ # image = image.resize((w, h), resample=PIL.Image.LANCZOS)
23
+ # image = np.array(image).astype(np.float32) / 255.0
24
+ # image = image[None].transpose(0, 3, 1, 2)
25
+ # image = torch.from_numpy(image)
26
+ # # [-1, 1]
27
+ # return 2.0 * image - 1.0
28
+ #
29
+ #
30
+ # def preprocess_mask(mask):
31
+ # mask = mask.convert("L")
32
+ # w, h = mask.size
33
+ # w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
34
+ # mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
35
+ # mask = np.array(mask).astype(np.float32) / 255.0
36
+ # mask = np.tile(mask, (4, 1, 1))
37
+ # mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
38
+ # mask = 1 - mask # repaint white, keep black
39
+ # mask = torch.from_numpy(mask)
40
+ # return mask
41
+
42
+ class DummyFeatureExtractorOutput:
43
+ def __init__(self, pixel_values):
44
+ self.pixel_values = pixel_values
45
+
46
+ def to(self, device):
47
+ return self
48
+
49
+
50
+ class DummyFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
51
+ def __init__(self, **kwargs):
52
+ super().__init__(**kwargs)
53
+
54
+ def __call__(self, *args, **kwargs):
55
+ return DummyFeatureExtractorOutput(torch.empty(0, 3))
56
+
57
+
58
+ class DummySafetyChecker:
59
+ def __init__(self, *args, **kwargs):
60
+ pass
61
+
62
+ def __call__(self, clip_input, images):
63
+ return images, False
64
+
65
+
66
+ class SD(InpaintModel):
67
+ pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505
68
+ min_size = 512
69
+
70
+ def init_model(self, device: torch.device, **kwargs):
71
+ from .sd_pipeline import StableDiffusionInpaintPipeline
72
+
73
+ model_kwargs = {"local_files_only": kwargs['sd_run_local']}
74
+ if kwargs['sd_disable_nsfw']:
75
+ logger.info("Disable Stable Diffusion Model NSFW checker")
76
+ model_kwargs.update(dict(
77
+ feature_extractor=DummyFeatureExtractor(),
78
+ safety_checker=DummySafetyChecker(),
79
+ ))
80
+
81
+ self.model = StableDiffusionInpaintPipeline.from_pretrained(
82
+ self.model_id_or_path,
83
+ revision="fp16" if torch.cuda.is_available() else "main",
84
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
85
+ use_auth_token=kwargs["hf_access_token"],
86
+ **model_kwargs
87
+ )
88
+ # https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
89
+ self.model.enable_attention_slicing()
90
+ self.model = self.model.to(device)
91
+
92
+ if kwargs['sd_cpu_textencoder']:
93
+ logger.info("Run Stable Diffusion TextEncoder on CPU")
94
+ self.model.text_encoder = self.model.text_encoder.to(torch.device('cpu'), non_blocking=True)
95
+ self.model.text_encoder = self.model.text_encoder.to(torch.float32, non_blocking=True )
96
+
97
+ self.callbacks = kwargs.pop("callbacks", None)
98
+
99
+ @torch.cuda.amp.autocast()
100
+ def forward(self, image, mask, config: Config):
101
+ """Input image and output image have same size
102
+ image: [H, W, C] RGB
103
+ mask: [H, W, 1] 255 means area to repaint
104
+ return: BGR IMAGE
105
+ """
106
+
107
+ # image = norm_img(image) # [0, 1]
108
+ # image = image * 2 - 1 # [0, 1] -> [-1, 1]
109
+
110
+ # resize to latent feature map size
111
+ # h, w = mask.shape[:2]
112
+ # mask = cv2.resize(mask, (h // 8, w // 8), interpolation=cv2.INTER_AREA)
113
+ # mask = norm_img(mask)
114
+ #
115
+ # image = torch.from_numpy(image).unsqueeze(0).to(self.device)
116
+ # mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
117
+
118
+ if config.sd_sampler == SDSampler.ddim:
119
+ scheduler = DDIMScheduler(
120
+ beta_start=0.00085,
121
+ beta_end=0.012,
122
+ beta_schedule="scaled_linear",
123
+ clip_sample=False,
124
+ set_alpha_to_one=False,
125
+ )
126
+ elif config.sd_sampler == SDSampler.pndm:
127
+ PNDM_kwargs = {
128
+ "tensor_format": "pt",
129
+ "beta_schedule": "scaled_linear",
130
+ "beta_start": 0.00085,
131
+ "beta_end": 0.012,
132
+ "num_train_timesteps": 1000,
133
+ "skip_prk_steps": True,
134
+ }
135
+ scheduler = PNDMScheduler(**PNDM_kwargs)
136
+ else:
137
+ raise ValueError(config.sd_sampler)
138
+
139
+ self.model.scheduler = scheduler
140
+
141
+ seed = config.sd_seed
142
+ random.seed(seed)
143
+ np.random.seed(seed)
144
+ torch.manual_seed(seed)
145
+ torch.cuda.manual_seed_all(seed)
146
+
147
+ if config.sd_mask_blur != 0:
148
+ k = 2 * config.sd_mask_blur + 1
149
+ mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
150
+
151
+ output = self.model(
152
+ prompt=config.prompt,
153
+ init_image=PIL.Image.fromarray(image),
154
+ mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
155
+ strength=config.sd_strength,
156
+ num_inference_steps=config.sd_steps,
157
+ guidance_scale=config.sd_guidance_scale,
158
+ output_type="np.array",
159
+ callbacks=self.callbacks,
160
+ ).images[0]
161
+
162
+ output = (output * 255).round().astype("uint8")
163
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
164
+ return output
165
+
166
+ @torch.no_grad()
167
+ def __call__(self, image, mask, config: Config):
168
+ """
169
+ images: [H, W, C] RGB, not normalized
170
+ masks: [H, W]
171
+ return: BGR IMAGE
172
+ """
173
+ img_h, img_w = image.shape[:2]
174
+
175
+ # boxes = boxes_from_mask(mask)
176
+ if config.use_croper:
177
+ logger.info("use croper")
178
+ l, t, w, h = (
179
+ config.croper_x,
180
+ config.croper_y,
181
+ config.croper_width,
182
+ config.croper_height,
183
+ )
184
+ r = l + w
185
+ b = t + h
186
+
187
+ l = max(l, 0)
188
+ r = min(r, img_w)
189
+ t = max(t, 0)
190
+ b = min(b, img_h)
191
+
192
+ crop_img = image[t:b, l:r, :]
193
+ crop_mask = mask[t:b, l:r]
194
+
195
+ crop_image = self._pad_forward(crop_img, crop_mask, config)
196
+
197
+ inpaint_result = image[:, :, ::-1]
198
+ inpaint_result[t:b, l:r, :] = crop_image
199
+ else:
200
+ inpaint_result = self._pad_forward(image, mask, config)
201
+
202
+ return inpaint_result
203
+
204
+ @staticmethod
205
+ def is_downloaded() -> bool:
206
+ # model will be downloaded when app start, and can't switch in frontend settings
207
+ return True
208
+
209
+
210
+ class SD14(SD):
211
+ model_id_or_path = "CompVis/stable-diffusion-v1-4"
212
+
213
+
214
+ class SD15(SD):
215
+ model_id_or_path = "CompVis/stable-diffusion-v1-5"
lama_cleaner/model/sd_pipeline.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import List, Optional, Union, Callable
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ import PIL
8
+ from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, PNDMScheduler
9
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
10
+ from diffusers.utils import logging
11
+ from tqdm.auto import tqdm
12
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+
17
+ def preprocess_image(image):
18
+ w, h = image.size
19
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
20
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
21
+ image = np.array(image).astype(np.float32) / 255.0
22
+ image = image[None].transpose(0, 3, 1, 2)
23
+ image = torch.from_numpy(image)
24
+ return 2.0 * image - 1.0
25
+
26
+
27
+ def preprocess_mask(mask):
28
+ mask = mask.convert("L")
29
+ w, h = mask.size
30
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
31
+ mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
32
+ mask = np.array(mask).astype(np.float32) / 255.0
33
+ mask = np.tile(mask, (4, 1, 1))
34
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
35
+ mask = 1 - mask # repaint white, keep black
36
+ mask = torch.from_numpy(mask)
37
+ return mask
38
+
39
+
40
+ class StableDiffusionInpaintPipeline(DiffusionPipeline):
41
+ r"""
42
+ Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
43
+
44
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
45
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
46
+
47
+ Args:
48
+ vae ([`AutoencoderKL`]):
49
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
50
+ text_encoder ([`CLIPTextModel`]):
51
+ Frozen text-encoder. Stable Diffusion uses the text portion of
52
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
53
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
54
+ tokenizer (`CLIPTokenizer`):
55
+ Tokenizer of class
56
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
57
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
58
+ scheduler ([`SchedulerMixin`]):
59
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
60
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
61
+ safety_checker ([`StableDiffusionSafetyChecker`]):
62
+ Classification module that estimates whether generated images could be considered offsensive or harmful.
63
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
64
+ feature_extractor ([`CLIPFeatureExtractor`]):
65
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ vae: AutoencoderKL,
71
+ text_encoder: CLIPTextModel,
72
+ tokenizer: CLIPTokenizer,
73
+ unet: UNet2DConditionModel,
74
+ scheduler: Union[DDIMScheduler, PNDMScheduler],
75
+ safety_checker: StableDiffusionSafetyChecker,
76
+ feature_extractor: CLIPFeatureExtractor,
77
+ ):
78
+ super().__init__()
79
+ scheduler = scheduler.set_format("pt")
80
+ logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
81
+ self.register_modules(
82
+ vae=vae,
83
+ text_encoder=text_encoder,
84
+ tokenizer=tokenizer,
85
+ unet=unet,
86
+ scheduler=scheduler,
87
+ safety_checker=safety_checker,
88
+ feature_extractor=feature_extractor,
89
+ )
90
+
91
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
92
+ r"""
93
+ Enable sliced attention computation.
94
+
95
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
96
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
97
+
98
+ Args:
99
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
100
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
101
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
102
+ `attention_head_dim` must be a multiple of `slice_size`.
103
+ """
104
+ if slice_size == "auto":
105
+ # half the attention head size is usually a good trade-off between
106
+ # speed and memory
107
+ slice_size = self.unet.config.attention_head_dim // 2
108
+ self.unet.set_attention_slice(slice_size)
109
+
110
+ def disable_attention_slicing(self):
111
+ r"""
112
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
113
+ back to computing attention in one step.
114
+ """
115
+ # set slice_size = `None` to disable `set_attention_slice`
116
+ self.enable_attention_slice(None)
117
+
118
+ @torch.no_grad()
119
+ def __call__(
120
+ self,
121
+ prompt: Union[str, List[str]],
122
+ init_image: Union[torch.FloatTensor, PIL.Image.Image],
123
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
124
+ strength: float = 0.8,
125
+ num_inference_steps: Optional[int] = 50,
126
+ guidance_scale: Optional[float] = 7.5,
127
+ eta: Optional[float] = 0.0,
128
+ generator: Optional[torch.Generator] = None,
129
+ output_type: Optional[str] = "pil",
130
+ return_dict: bool = True,
131
+ callbacks: List[Callable[[int], None]] = None
132
+ ):
133
+ r"""
134
+ Function invoked when calling the pipeline for generation.
135
+
136
+ Args:
137
+ prompt (`str` or `List[str]`):
138
+ The prompt or prompts to guide the image generation.
139
+ init_image (`torch.FloatTensor` or `PIL.Image.Image`):
140
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
141
+ process. This is the image whose masked region will be inpainted.
142
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
143
+ `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
144
+ replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be
145
+ converted to a single channel (luminance) before use.
146
+ strength (`float`, *optional*, defaults to 0.8):
147
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
148
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
149
+ in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
150
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
151
+ num_inference_steps (`int`, *optional*, defaults to 50):
152
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
153
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
154
+ guidance_scale (`float`, *optional*, defaults to 7.5):
155
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
156
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
157
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
158
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
159
+ usually at the expense of lower image quality.
160
+ eta (`float`, *optional*, defaults to 0.0):
161
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
162
+ [`schedulers.DDIMScheduler`], will be ignored for others.
163
+ generator (`torch.Generator`, *optional*):
164
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
165
+ deterministic.
166
+ output_type (`str`, *optional*, defaults to `"pil"`):
167
+ The output format of the generate image. Choose between
168
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
169
+ return_dict (`bool`, *optional*, defaults to `True`):
170
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
171
+ plain tuple.
172
+
173
+ Returns:
174
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
175
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
176
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
177
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
178
+ (nsfw) content, according to the `safety_checker`.
179
+ """
180
+ if isinstance(prompt, str):
181
+ batch_size = 1
182
+ elif isinstance(prompt, list):
183
+ batch_size = len(prompt)
184
+ else:
185
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
186
+
187
+ if strength < 0 or strength > 1:
188
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
189
+
190
+ # set timesteps
191
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
192
+ extra_set_kwargs = {}
193
+ offset = 0
194
+ if accepts_offset:
195
+ offset = 1
196
+ extra_set_kwargs["offset"] = 1
197
+
198
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
199
+
200
+ # preprocess image
201
+ init_image = preprocess_image(init_image).to(self.device)
202
+
203
+ # encode the init image into latents and scale the latents
204
+ init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
205
+ init_latents = init_latent_dist.sample(generator=generator)
206
+
207
+ init_latents = 0.18215 * init_latents
208
+
209
+ # Expand init_latents for batch_size
210
+ init_latents = torch.cat([init_latents] * batch_size)
211
+ init_latents_orig = init_latents
212
+
213
+ # preprocess mask
214
+ mask = preprocess_mask(mask_image).to(self.device)
215
+ mask = torch.cat([mask] * batch_size)
216
+
217
+ # check sizes
218
+ if not mask.shape == init_latents.shape:
219
+ raise ValueError("The mask and init_image should be the same size!")
220
+
221
+ # get the original timestep using init_timestep
222
+ init_timestep = int(num_inference_steps * strength) + offset
223
+ init_timestep = min(init_timestep, num_inference_steps)
224
+ timesteps = self.scheduler.timesteps[-init_timestep]
225
+ timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
226
+
227
+ # add noise to latents using the timesteps
228
+ noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
229
+ init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
230
+
231
+ # get prompt text embeddings
232
+ text_input = self.tokenizer(
233
+ prompt,
234
+ padding="max_length",
235
+ max_length=self.tokenizer.model_max_length,
236
+ truncation=True,
237
+ return_tensors="pt",
238
+ )
239
+ text_encoder_device = self.text_encoder.device
240
+
241
+ text_embeddings = self.text_encoder(text_input.input_ids.to(text_encoder_device, non_blocking=True))[0].to(self.device, non_blocking=True)
242
+
243
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
244
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
245
+ # corresponds to doing no classifier free guidance.
246
+ do_classifier_free_guidance = guidance_scale > 1.0
247
+ # get unconditional embeddings for classifier free guidance
248
+ if do_classifier_free_guidance:
249
+ max_length = text_input.input_ids.shape[-1]
250
+ uncond_input = self.tokenizer(
251
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
252
+ )
253
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(text_encoder_device, non_blocking=True))[0].to(self.device, non_blocking=True)
254
+
255
+ # For classifier free guidance, we need to do two forward passes.
256
+ # Here we concatenate the unconditional and text embeddings into a single batch
257
+ # to avoid doing two forward passes
258
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
259
+
260
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
261
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
262
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
263
+ # and should be between [0, 1]
264
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
265
+ extra_step_kwargs = {}
266
+ if accepts_eta:
267
+ extra_step_kwargs["eta"] = eta
268
+
269
+ latents = init_latents
270
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
271
+ for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
272
+ # expand the latents if we are doing classifier free guidance
273
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
274
+ # predict the noise residual
275
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
276
+
277
+ # perform guidance
278
+ if do_classifier_free_guidance:
279
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
280
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
281
+
282
+ # compute the previous noisy sample x_t -> x_t-1
283
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
284
+
285
+ # masking
286
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
287
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
288
+
289
+ if callbacks is not None:
290
+ for callback in callbacks:
291
+ callback(i)
292
+
293
+ # scale and decode the image latents with vae
294
+ latents = 1 / 0.18215 * latents
295
+ image = self.vae.decode(latents).sample
296
+
297
+ image = (image / 2 + 0.5).clamp(0, 1)
298
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
299
+
300
+ # run safety checker
301
+ safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
302
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
303
+
304
+ if output_type == "pil":
305
+ image = self.numpy_to_pil(image)
306
+
307
+ if not return_dict:
308
+ return (image, has_nsfw_concept)
309
+
310
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
lama_cleaner/model/utils.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any
3
+
4
+ import torch
5
+ import numpy as np
6
+ import collections
7
+ from itertools import repeat
8
+
9
+ from torch import conv2d, conv_transpose2d
10
+
11
+
12
+ def make_beta_schedule(device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
13
+ if schedule == "linear":
14
+ betas = (
15
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
16
+ )
17
+
18
+ elif schedule == "cosine":
19
+ timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s).to(device)
20
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
21
+ alphas = torch.cos(alphas).pow(2).to(device)
22
+ alphas = alphas / alphas[0]
23
+ betas = 1 - alphas[1:] / alphas[:-1]
24
+ betas = np.clip(betas, a_min=0, a_max=0.999)
25
+
26
+ elif schedule == "sqrt_linear":
27
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
28
+ elif schedule == "sqrt":
29
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
30
+ else:
31
+ raise ValueError(f"schedule '{schedule}' unknown.")
32
+ return betas.numpy()
33
+
34
+
35
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
36
+ # select alphas for computing the variance schedule
37
+ alphas = alphacums[ddim_timesteps]
38
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
39
+
40
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
41
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
42
+ if verbose:
43
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
44
+ print(f'For the chosen value of eta, which is {eta}, '
45
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
46
+ return sigmas, alphas, alphas_prev
47
+
48
+
49
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
50
+ if ddim_discr_method == 'uniform':
51
+ c = num_ddpm_timesteps // num_ddim_timesteps
52
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
53
+ elif ddim_discr_method == 'quad':
54
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
55
+ else:
56
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
57
+
58
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
59
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
60
+ steps_out = ddim_timesteps + 1
61
+ if verbose:
62
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
63
+ return steps_out
64
+
65
+
66
+ def noise_like(shape, device, repeat=False):
67
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
68
+ noise = lambda: torch.randn(shape, device=device)
69
+ return repeat_noise() if repeat else noise()
70
+
71
+
72
+ def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=False):
73
+ """
74
+ Create sinusoidal timestep embeddings.
75
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
76
+ These may be fractional.
77
+ :param dim: the dimension of the output.
78
+ :param max_period: controls the minimum frequency of the embeddings.
79
+ :return: an [N x dim] Tensor of positional embeddings.
80
+ """
81
+ half = dim // 2
82
+ freqs = torch.exp(
83
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
84
+ ).to(device=device)
85
+
86
+ args = timesteps[:, None].float() * freqs[None]
87
+
88
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
89
+ if dim % 2:
90
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
91
+ return embedding
92
+
93
+
94
+ ###### MAT and FcF #######
95
+
96
+
97
+ def normalize_2nd_moment(x, dim=1, eps=1e-8):
98
+ return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
99
+
100
+
101
+ class EasyDict(dict):
102
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
103
+
104
+ def __getattr__(self, name: str) -> Any:
105
+ try:
106
+ return self[name]
107
+ except KeyError:
108
+ raise AttributeError(name)
109
+
110
+ def __setattr__(self, name: str, value: Any) -> None:
111
+ self[name] = value
112
+
113
+ def __delattr__(self, name: str) -> None:
114
+ del self[name]
115
+
116
+
117
+ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
118
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
119
+ """
120
+ assert isinstance(x, torch.Tensor)
121
+ assert clamp is None or clamp >= 0
122
+ spec = activation_funcs[act]
123
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
124
+ gain = float(gain if gain is not None else spec.def_gain)
125
+ clamp = float(clamp if clamp is not None else -1)
126
+
127
+ # Add bias.
128
+ if b is not None:
129
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
130
+ assert 0 <= dim < x.ndim
131
+ assert b.shape[0] == x.shape[dim]
132
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
133
+
134
+ # Evaluate activation function.
135
+ alpha = float(alpha)
136
+ x = spec.func(x, alpha=alpha)
137
+
138
+ # Scale by gain.
139
+ gain = float(gain)
140
+ if gain != 1:
141
+ x = x * gain
142
+
143
+ # Clamp.
144
+ if clamp >= 0:
145
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
146
+ return x
147
+
148
+
149
+ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='ref'):
150
+ r"""Fused bias and activation function.
151
+
152
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
153
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
154
+ the fused op is considerably more efficient than performing the same calculation
155
+ using standard PyTorch ops. It supports first and second order gradients,
156
+ but not third order gradients.
157
+
158
+ Args:
159
+ x: Input activation tensor. Can be of any shape.
160
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
161
+ as `x`. The shape must be known, and it must match the dimension of `x`
162
+ corresponding to `dim`.
163
+ dim: The dimension in `x` corresponding to the elements of `b`.
164
+ The value of `dim` is ignored if `b` is not specified.
165
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
166
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
167
+ See `activation_funcs` for a full list. `None` is not allowed.
168
+ alpha: Shape parameter for the activation function, or `None` to use the default.
169
+ gain: Scaling factor for the output tensor, or `None` to use default.
170
+ See `activation_funcs` for the default scaling of each activation function.
171
+ If unsure, consider specifying 1.
172
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
173
+ the clamping (default).
174
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
175
+
176
+ Returns:
177
+ Tensor of the same shape and datatype as `x`.
178
+ """
179
+ assert isinstance(x, torch.Tensor)
180
+ assert impl in ['ref', 'cuda']
181
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
182
+
183
+
184
+ def _get_filter_size(f):
185
+ if f is None:
186
+ return 1, 1
187
+
188
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
189
+ fw = f.shape[-1]
190
+ fh = f.shape[0]
191
+
192
+ fw = int(fw)
193
+ fh = int(fh)
194
+ assert fw >= 1 and fh >= 1
195
+ return fw, fh
196
+
197
+
198
+ def _get_weight_shape(w):
199
+ shape = [int(sz) for sz in w.shape]
200
+ return shape
201
+
202
+
203
+ def _parse_scaling(scaling):
204
+ if isinstance(scaling, int):
205
+ scaling = [scaling, scaling]
206
+ assert isinstance(scaling, (list, tuple))
207
+ assert all(isinstance(x, int) for x in scaling)
208
+ sx, sy = scaling
209
+ assert sx >= 1 and sy >= 1
210
+ return sx, sy
211
+
212
+
213
+ def _parse_padding(padding):
214
+ if isinstance(padding, int):
215
+ padding = [padding, padding]
216
+ assert isinstance(padding, (list, tuple))
217
+ assert all(isinstance(x, int) for x in padding)
218
+ if len(padding) == 2:
219
+ padx, pady = padding
220
+ padding = [padx, padx, pady, pady]
221
+ padx0, padx1, pady0, pady1 = padding
222
+ return padx0, padx1, pady0, pady1
223
+
224
+
225
+ def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
226
+ r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
227
+
228
+ Args:
229
+ f: Torch tensor, numpy array, or python list of the shape
230
+ `[filter_height, filter_width]` (non-separable),
231
+ `[filter_taps]` (separable),
232
+ `[]` (impulse), or
233
+ `None` (identity).
234
+ device: Result device (default: cpu).
235
+ normalize: Normalize the filter so that it retains the magnitude
236
+ for constant input signal (DC)? (default: True).
237
+ flip_filter: Flip the filter? (default: False).
238
+ gain: Overall scaling factor for signal magnitude (default: 1).
239
+ separable: Return a separable filter? (default: select automatically).
240
+
241
+ Returns:
242
+ Float32 tensor of the shape
243
+ `[filter_height, filter_width]` (non-separable) or
244
+ `[filter_taps]` (separable).
245
+ """
246
+ # Validate.
247
+ if f is None:
248
+ f = 1
249
+ f = torch.as_tensor(f, dtype=torch.float32)
250
+ assert f.ndim in [0, 1, 2]
251
+ assert f.numel() > 0
252
+ if f.ndim == 0:
253
+ f = f[np.newaxis]
254
+
255
+ # Separable?
256
+ if separable is None:
257
+ separable = (f.ndim == 1 and f.numel() >= 8)
258
+ if f.ndim == 1 and not separable:
259
+ f = f.ger(f)
260
+ assert f.ndim == (1 if separable else 2)
261
+
262
+ # Apply normalize, flip, gain, and device.
263
+ if normalize:
264
+ f /= f.sum()
265
+ if flip_filter:
266
+ f = f.flip(list(range(f.ndim)))
267
+ f = f * (gain ** (f.ndim / 2))
268
+ f = f.to(device=device)
269
+ return f
270
+
271
+
272
+ def _ntuple(n):
273
+ def parse(x):
274
+ if isinstance(x, collections.abc.Iterable):
275
+ return x
276
+ return tuple(repeat(x, n))
277
+
278
+ return parse
279
+
280
+
281
+ to_2tuple = _ntuple(2)
282
+
283
+ activation_funcs = {
284
+ 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
285
+ 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2,
286
+ ref='y', has_2nd_grad=False),
287
+ 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2,
288
+ def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
289
+ 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y',
290
+ has_2nd_grad=True),
291
+ 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y',
292
+ has_2nd_grad=True),
293
+ 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y',
294
+ has_2nd_grad=True),
295
+ 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y',
296
+ has_2nd_grad=True),
297
+ 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8,
298
+ ref='y', has_2nd_grad=True),
299
+ 'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x',
300
+ has_2nd_grad=True),
301
+ }
302
+
303
+
304
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
305
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
306
+
307
+ Performs the following sequence of operations for each channel:
308
+
309
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
310
+
311
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
312
+ Negative padding corresponds to cropping the image.
313
+
314
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
315
+ so that the footprint of all output pixels lies within the input image.
316
+
317
+ 4. Downsample the image by keeping every Nth pixel (`down`).
318
+
319
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
320
+ The fused op is considerably more efficient than performing the same calculation
321
+ using standard PyTorch ops. It supports gradients of arbitrary order.
322
+
323
+ Args:
324
+ x: Float32/float64/float16 input tensor of the shape
325
+ `[batch_size, num_channels, in_height, in_width]`.
326
+ f: Float32 FIR filter of the shape
327
+ `[filter_height, filter_width]` (non-separable),
328
+ `[filter_taps]` (separable), or
329
+ `None` (identity).
330
+ up: Integer upsampling factor. Can be a single int or a list/tuple
331
+ `[x, y]` (default: 1).
332
+ down: Integer downsampling factor. Can be a single int or a list/tuple
333
+ `[x, y]` (default: 1).
334
+ padding: Padding with respect to the upsampled image. Can be a single number
335
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
336
+ (default: 0).
337
+ flip_filter: False = convolution, True = correlation (default: False).
338
+ gain: Overall scaling factor for signal magnitude (default: 1).
339
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
340
+
341
+ Returns:
342
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
343
+ """
344
+ # assert isinstance(x, torch.Tensor)
345
+ # assert impl in ['ref', 'cuda']
346
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
347
+
348
+
349
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
350
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
351
+ """
352
+ # Validate arguments.
353
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
354
+ if f is None:
355
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
356
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
357
+ assert f.dtype == torch.float32 and not f.requires_grad
358
+ batch_size, num_channels, in_height, in_width = x.shape
359
+ # upx, upy = _parse_scaling(up)
360
+ # downx, downy = _parse_scaling(down)
361
+
362
+ upx, upy = up, up
363
+ downx, downy = down, down
364
+
365
+ # padx0, padx1, pady0, pady1 = _parse_padding(padding)
366
+ padx0, padx1, pady0, pady1 = padding[0], padding[1], padding[2], padding[3]
367
+
368
+ # Upsample by inserting zeros.
369
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
370
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
371
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
372
+
373
+ # Pad or crop.
374
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
375
+ x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)]
376
+
377
+ # Setup filter.
378
+ f = f * (gain ** (f.ndim / 2))
379
+ f = f.to(x.dtype)
380
+ if not flip_filter:
381
+ f = f.flip(list(range(f.ndim)))
382
+
383
+ # Convolve with the filter.
384
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
385
+ if f.ndim == 4:
386
+ x = conv2d(input=x, weight=f, groups=num_channels)
387
+ else:
388
+ x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
389
+ x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
390
+
391
+ # Downsample by throwing away pixels.
392
+ x = x[:, :, ::downy, ::downx]
393
+ return x
394
+
395
+
396
+ def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
397
+ r"""Downsample a batch of 2D images using the given 2D FIR filter.
398
+
399
+ By default, the result is padded so that its shape is a fraction of the input.
400
+ User-specified padding is applied on top of that, with negative values
401
+ indicating cropping. Pixels outside the image are assumed to be zero.
402
+
403
+ Args:
404
+ x: Float32/float64/float16 input tensor of the shape
405
+ `[batch_size, num_channels, in_height, in_width]`.
406
+ f: Float32 FIR filter of the shape
407
+ `[filter_height, filter_width]` (non-separable),
408
+ `[filter_taps]` (separable), or
409
+ `None` (identity).
410
+ down: Integer downsampling factor. Can be a single int or a list/tuple
411
+ `[x, y]` (default: 1).
412
+ padding: Padding with respect to the input. Can be a single number or a
413
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
414
+ (default: 0).
415
+ flip_filter: False = convolution, True = correlation (default: False).
416
+ gain: Overall scaling factor for signal magnitude (default: 1).
417
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
418
+
419
+ Returns:
420
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
421
+ """
422
+ downx, downy = _parse_scaling(down)
423
+ # padx0, padx1, pady0, pady1 = _parse_padding(padding)
424
+ padx0, padx1, pady0, pady1 = padding, padding, padding, padding
425
+
426
+ fw, fh = _get_filter_size(f)
427
+ p = [
428
+ padx0 + (fw - downx + 1) // 2,
429
+ padx1 + (fw - downx) // 2,
430
+ pady0 + (fh - downy + 1) // 2,
431
+ pady1 + (fh - downy) // 2,
432
+ ]
433
+ return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
434
+
435
+
436
+ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
437
+ r"""Upsample a batch of 2D images using the given 2D FIR filter.
438
+
439
+ By default, the result is padded so that its shape is a multiple of the input.
440
+ User-specified padding is applied on top of that, with negative values
441
+ indicating cropping. Pixels outside the image are assumed to be zero.
442
+
443
+ Args:
444
+ x: Float32/float64/float16 input tensor of the shape
445
+ `[batch_size, num_channels, in_height, in_width]`.
446
+ f: Float32 FIR filter of the shape
447
+ `[filter_height, filter_width]` (non-separable),
448
+ `[filter_taps]` (separable), or
449
+ `None` (identity).
450
+ up: Integer upsampling factor. Can be a single int or a list/tuple
451
+ `[x, y]` (default: 1).
452
+ padding: Padding with respect to the output. Can be a single number or a
453
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
454
+ (default: 0).
455
+ flip_filter: False = convolution, True = correlation (default: False).
456
+ gain: Overall scaling factor for signal magnitude (default: 1).
457
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
458
+
459
+ Returns:
460
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
461
+ """
462
+ upx, upy = _parse_scaling(up)
463
+ # upx, upy = up, up
464
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
465
+ # padx0, padx1, pady0, pady1 = padding, padding, padding, padding
466
+ fw, fh = _get_filter_size(f)
467
+ p = [
468
+ padx0 + (fw + upx - 1) // 2,
469
+ padx1 + (fw - upx) // 2,
470
+ pady0 + (fh + upy - 1) // 2,
471
+ pady1 + (fh - upy) // 2,
472
+ ]
473
+ return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain * upx * upy, impl=impl)
474
+
475
+
476
+ class MinibatchStdLayer(torch.nn.Module):
477
+ def __init__(self, group_size, num_channels=1):
478
+ super().__init__()
479
+ self.group_size = group_size
480
+ self.num_channels = num_channels
481
+
482
+ def forward(self, x):
483
+ N, C, H, W = x.shape
484
+ G = torch.min(torch.as_tensor(self.group_size),
485
+ torch.as_tensor(N)) if self.group_size is not None else N
486
+ F = self.num_channels
487
+ c = C // F
488
+
489
+ y = x.reshape(G, -1, F, c, H,
490
+ W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
491
+ y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
492
+ y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
493
+ y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
494
+ y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels.
495
+ y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
496
+ y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
497
+ x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
498
+ return x
499
+
500
+
501
+ class FullyConnectedLayer(torch.nn.Module):
502
+ def __init__(self,
503
+ in_features, # Number of input features.
504
+ out_features, # Number of output features.
505
+ bias=True, # Apply additive bias before the activation function?
506
+ activation='linear', # Activation function: 'relu', 'lrelu', etc.
507
+ lr_multiplier=1, # Learning rate multiplier.
508
+ bias_init=0, # Initial value for the additive bias.
509
+ ):
510
+ super().__init__()
511
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
512
+ self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
513
+ self.activation = activation
514
+
515
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
516
+ self.bias_gain = lr_multiplier
517
+
518
+ def forward(self, x):
519
+ w = self.weight * self.weight_gain
520
+ b = self.bias
521
+ if b is not None and self.bias_gain != 1:
522
+ b = b * self.bias_gain
523
+
524
+ if self.activation == 'linear' and b is not None:
525
+ # out = torch.addmm(b.unsqueeze(0), x, w.t())
526
+ x = x.matmul(w.t())
527
+ out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)])
528
+ else:
529
+ x = x.matmul(w.t())
530
+ out = bias_act(x, b, act=self.activation, dim=x.ndim - 1)
531
+ return out
532
+
533
+
534
+ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
535
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
536
+ """
537
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
538
+
539
+ # Flip weight if requested.
540
+ if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
541
+ w = w.flip([2, 3])
542
+
543
+ # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
544
+ # 1x1 kernel + memory_format=channels_last + less than 64 channels.
545
+ if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
546
+ if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
547
+ if out_channels <= 4 and groups == 1:
548
+ in_shape = x.shape
549
+ x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
550
+ x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
551
+ else:
552
+ x = x.to(memory_format=torch.contiguous_format)
553
+ w = w.to(memory_format=torch.contiguous_format)
554
+ x = conv2d(x, w, groups=groups)
555
+ return x.to(memory_format=torch.channels_last)
556
+
557
+ # Otherwise => execute using conv2d_gradfix.
558
+ op = conv_transpose2d if transpose else conv2d
559
+ return op(x, w, stride=stride, padding=padding, groups=groups)
560
+
561
+
562
+ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
563
+ r"""2D convolution with optional up/downsampling.
564
+
565
+ Padding is performed only once at the beginning, not between the operations.
566
+
567
+ Args:
568
+ x: Input tensor of shape
569
+ `[batch_size, in_channels, in_height, in_width]`.
570
+ w: Weight tensor of shape
571
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
572
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
573
+ calling setup_filter(). None = identity (default).
574
+ up: Integer upsampling factor (default: 1).
575
+ down: Integer downsampling factor (default: 1).
576
+ padding: Padding with respect to the upsampled image. Can be a single number
577
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
578
+ (default: 0).
579
+ groups: Split input channels into N groups (default: 1).
580
+ flip_weight: False = convolution, True = correlation (default: True).
581
+ flip_filter: False = convolution, True = correlation (default: False).
582
+
583
+ Returns:
584
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
585
+ """
586
+ # Validate arguments.
587
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
588
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
589
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
590
+ assert isinstance(up, int) and (up >= 1)
591
+ assert isinstance(down, int) and (down >= 1)
592
+ # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}"
593
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
594
+ fw, fh = _get_filter_size(f)
595
+ # px0, px1, py0, py1 = _parse_padding(padding)
596
+ px0, px1, py0, py1 = padding, padding, padding, padding
597
+
598
+ # Adjust padding to account for up/downsampling.
599
+ if up > 1:
600
+ px0 += (fw + up - 1) // 2
601
+ px1 += (fw - up) // 2
602
+ py0 += (fh + up - 1) // 2
603
+ py1 += (fh - up) // 2
604
+ if down > 1:
605
+ px0 += (fw - down + 1) // 2
606
+ px1 += (fw - down) // 2
607
+ py0 += (fh - down + 1) // 2
608
+ py1 += (fh - down) // 2
609
+
610
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
611
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
612
+ x = upfirdn2d(x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
613
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
614
+ return x
615
+
616
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
617
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
618
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
619
+ x = upfirdn2d(x=x, f=f, up=up, padding=[px0, px1, py0, py1], gain=up ** 2, flip_filter=flip_filter)
620
+ return x
621
+
622
+ # Fast path: downsampling only => use strided convolution.
623
+ if down > 1 and up == 1:
624
+ x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
625
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
626
+ return x
627
+
628
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
629
+ if up > 1:
630
+ if groups == 1:
631
+ w = w.transpose(0, 1)
632
+ else:
633
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
634
+ w = w.transpose(1, 2)
635
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
636
+ px0 -= kw - 1
637
+ px1 -= kw - up
638
+ py0 -= kh - 1
639
+ py1 -= kh - up
640
+ pxt = max(min(-px0, -px1), 0)
641
+ pyt = max(min(-py0, -py1), 0)
642
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt, pxt], groups=groups, transpose=True,
643
+ flip_weight=(not flip_weight))
644
+ x = upfirdn2d(x=x, f=f, padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], gain=up ** 2,
645
+ flip_filter=flip_filter)
646
+ if down > 1:
647
+ x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
648
+ return x
649
+
650
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
651
+ if up == 1 and down == 1:
652
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
653
+ return _conv2d_wrapper(x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight)
654
+
655
+ # Fallback: Generic reference implementation.
656
+ x = upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0, px1, py0, py1], gain=up ** 2,
657
+ flip_filter=flip_filter)
658
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
659
+ if down > 1:
660
+ x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
661
+ return x
662
+
663
+
664
+ class Conv2dLayer(torch.nn.Module):
665
+ def __init__(self,
666
+ in_channels, # Number of input channels.
667
+ out_channels, # Number of output channels.
668
+ kernel_size, # Width and height of the convolution kernel.
669
+ bias=True, # Apply additive bias before the activation function?
670
+ activation='linear', # Activation function: 'relu', 'lrelu', etc.
671
+ up=1, # Integer upsampling factor.
672
+ down=1, # Integer downsampling factor.
673
+ resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
674
+ conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
675
+ channels_last=False, # Expect the input to have memory_format=channels_last?
676
+ trainable=True, # Update the weights of this layer during training?
677
+ ):
678
+ super().__init__()
679
+ self.activation = activation
680
+ self.up = up
681
+ self.down = down
682
+ self.register_buffer('resample_filter', setup_filter(resample_filter))
683
+ self.conv_clamp = conv_clamp
684
+ self.padding = kernel_size // 2
685
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
686
+ self.act_gain = activation_funcs[activation].def_gain
687
+
688
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
689
+ weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
690
+ bias = torch.zeros([out_channels]) if bias else None
691
+ if trainable:
692
+ self.weight = torch.nn.Parameter(weight)
693
+ self.bias = torch.nn.Parameter(bias) if bias is not None else None
694
+ else:
695
+ self.register_buffer('weight', weight)
696
+ if bias is not None:
697
+ self.register_buffer('bias', bias)
698
+ else:
699
+ self.bias = None
700
+
701
+ def forward(self, x, gain=1):
702
+ w = self.weight * self.weight_gain
703
+ x = conv2d_resample(x=x, w=w, f=self.resample_filter, up=self.up, down=self.down,
704
+ padding=self.padding)
705
+
706
+ act_gain = self.act_gain * gain
707
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
708
+ out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp)
709
+ return out
lama_cleaner/model/zits.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import cv2
5
+ import skimage
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from lama_cleaner.helper import get_cache_path_by_url, load_jit_model
10
+ from lama_cleaner.schema import Config
11
+ import numpy as np
12
+
13
+ from lama_cleaner.model.base import InpaintModel
14
+
15
+ ZITS_INPAINT_MODEL_URL = os.environ.get(
16
+ "ZITS_INPAINT_MODEL_URL",
17
+ "https://github.com/Sanster/models/releases/download/add_zits/zits-inpaint-0717.pt",
18
+ )
19
+
20
+ ZITS_EDGE_LINE_MODEL_URL = os.environ.get(
21
+ "ZITS_EDGE_LINE_MODEL_URL",
22
+ "https://github.com/Sanster/models/releases/download/add_zits/zits-edge-line-0717.pt",
23
+ )
24
+
25
+ ZITS_STRUCTURE_UPSAMPLE_MODEL_URL = os.environ.get(
26
+ "ZITS_STRUCTURE_UPSAMPLE_MODEL_URL",
27
+ "https://github.com/Sanster/models/releases/download/add_zits/zits-structure-upsample-0717.pt",
28
+ )
29
+
30
+ ZITS_WIRE_FRAME_MODEL_URL = os.environ.get(
31
+ "ZITS_WIRE_FRAME_MODEL_URL",
32
+ "https://github.com/Sanster/models/releases/download/add_zits/zits-wireframe-0717.pt",
33
+ )
34
+
35
+
36
+ def resize(img, height, width, center_crop=False):
37
+ imgh, imgw = img.shape[0:2]
38
+
39
+ if center_crop and imgh != imgw:
40
+ # center crop
41
+ side = np.minimum(imgh, imgw)
42
+ j = (imgh - side) // 2
43
+ i = (imgw - side) // 2
44
+ img = img[j : j + side, i : i + side, ...]
45
+
46
+ if imgh > height and imgw > width:
47
+ inter = cv2.INTER_AREA
48
+ else:
49
+ inter = cv2.INTER_LINEAR
50
+ img = cv2.resize(img, (height, width), interpolation=inter)
51
+
52
+ return img
53
+
54
+
55
+ def to_tensor(img, scale=True, norm=False):
56
+ if img.ndim == 2:
57
+ img = img[:, :, np.newaxis]
58
+ c = img.shape[-1]
59
+
60
+ if scale:
61
+ img_t = torch.from_numpy(img).permute(2, 0, 1).float().div(255)
62
+ else:
63
+ img_t = torch.from_numpy(img).permute(2, 0, 1).float()
64
+
65
+ if norm:
66
+ mean = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1)
67
+ std = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1)
68
+ img_t = (img_t - mean) / std
69
+ return img_t
70
+
71
+
72
+ def load_masked_position_encoding(mask):
73
+ ones_filter = np.ones((3, 3), dtype=np.float32)
74
+ d_filter1 = np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32)
75
+ d_filter2 = np.array([[0, 0, 0], [1, 1, 0], [1, 1, 0]], dtype=np.float32)
76
+ d_filter3 = np.array([[0, 1, 1], [0, 1, 1], [0, 0, 0]], dtype=np.float32)
77
+ d_filter4 = np.array([[0, 0, 0], [0, 1, 1], [0, 1, 1]], dtype=np.float32)
78
+ str_size = 256
79
+ pos_num = 128
80
+
81
+ ori_mask = mask.copy()
82
+ ori_h, ori_w = ori_mask.shape[0:2]
83
+ ori_mask = ori_mask / 255
84
+ mask = cv2.resize(mask, (str_size, str_size), interpolation=cv2.INTER_AREA)
85
+ mask[mask > 0] = 255
86
+ h, w = mask.shape[0:2]
87
+ mask3 = mask.copy()
88
+ mask3 = 1.0 - (mask3 / 255.0)
89
+ pos = np.zeros((h, w), dtype=np.int32)
90
+ direct = np.zeros((h, w, 4), dtype=np.int32)
91
+ i = 0
92
+ while np.sum(1 - mask3) > 0:
93
+ i += 1
94
+ mask3_ = cv2.filter2D(mask3, -1, ones_filter)
95
+ mask3_[mask3_ > 0] = 1
96
+ sub_mask = mask3_ - mask3
97
+ pos[sub_mask == 1] = i
98
+
99
+ m = cv2.filter2D(mask3, -1, d_filter1)
100
+ m[m > 0] = 1
101
+ m = m - mask3
102
+ direct[m == 1, 0] = 1
103
+
104
+ m = cv2.filter2D(mask3, -1, d_filter2)
105
+ m[m > 0] = 1
106
+ m = m - mask3
107
+ direct[m == 1, 1] = 1
108
+
109
+ m = cv2.filter2D(mask3, -1, d_filter3)
110
+ m[m > 0] = 1
111
+ m = m - mask3
112
+ direct[m == 1, 2] = 1
113
+
114
+ m = cv2.filter2D(mask3, -1, d_filter4)
115
+ m[m > 0] = 1
116
+ m = m - mask3
117
+ direct[m == 1, 3] = 1
118
+
119
+ mask3 = mask3_
120
+
121
+ abs_pos = pos.copy()
122
+ rel_pos = pos / (str_size / 2) # to 0~1 maybe larger than 1
123
+ rel_pos = (rel_pos * pos_num).astype(np.int32)
124
+ rel_pos = np.clip(rel_pos, 0, pos_num - 1)
125
+
126
+ if ori_w != w or ori_h != h:
127
+ rel_pos = cv2.resize(rel_pos, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST)
128
+ rel_pos[ori_mask == 0] = 0
129
+ direct = cv2.resize(direct, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST)
130
+ direct[ori_mask == 0, :] = 0
131
+
132
+ return rel_pos, abs_pos, direct
133
+
134
+
135
+ def load_image(img, mask, device, sigma256=3.0):
136
+ """
137
+ Args:
138
+ img: [H, W, C] RGB
139
+ mask: [H, W] 255 为 masks 区域
140
+ sigma256:
141
+
142
+ Returns:
143
+
144
+ """
145
+ h, w, _ = img.shape
146
+ imgh, imgw = img.shape[0:2]
147
+ img_256 = resize(img, 256, 256)
148
+
149
+ mask = (mask > 127).astype(np.uint8) * 255
150
+ mask_256 = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_AREA)
151
+ mask_256[mask_256 > 0] = 255
152
+
153
+ mask_512 = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_AREA)
154
+ mask_512[mask_512 > 0] = 255
155
+
156
+ # original skimage implemention
157
+ # https://scikit-image.org/docs/stable/api/skimage.feature.html#skimage.feature.canny
158
+ # low_threshold: Lower bound for hysteresis thresholding (linking edges). If None, low_threshold is set to 10% of dtype’s max.
159
+ # high_threshold: Upper bound for hysteresis thresholding (linking edges). If None, high_threshold is set to 20% of dtype’s max.
160
+ gray_256 = skimage.color.rgb2gray(img_256)
161
+ edge_256 = skimage.feature.canny(gray_256, sigma=sigma256, mask=None).astype(float)
162
+ # cv2.imwrite("skimage_gray.jpg", (_gray_256*255).astype(np.uint8))
163
+ # cv2.imwrite("skimage_edge.jpg", (_edge_256*255).astype(np.uint8))
164
+
165
+ # gray_256 = cv2.cvtColor(img_256, cv2.COLOR_RGB2GRAY)
166
+ # gray_256_blured = cv2.GaussianBlur(gray_256, ksize=(3,3), sigmaX=sigma256, sigmaY=sigma256)
167
+ # edge_256 = cv2.Canny(gray_256_blured, threshold1=int(255*0.1), threshold2=int(255*0.2))
168
+ # cv2.imwrite("edge.jpg", edge_256)
169
+
170
+ # line
171
+ img_512 = resize(img, 512, 512)
172
+
173
+ rel_pos, abs_pos, direct = load_masked_position_encoding(mask)
174
+
175
+ batch = dict()
176
+ batch["images"] = to_tensor(img.copy()).unsqueeze(0).to(device)
177
+ batch["img_256"] = to_tensor(img_256, norm=True).unsqueeze(0).to(device)
178
+ batch["masks"] = to_tensor(mask).unsqueeze(0).to(device)
179
+ batch["mask_256"] = to_tensor(mask_256).unsqueeze(0).to(device)
180
+ batch["mask_512"] = to_tensor(mask_512).unsqueeze(0).to(device)
181
+ batch["edge_256"] = to_tensor(edge_256, scale=False).unsqueeze(0).to(device)
182
+ batch["img_512"] = to_tensor(img_512).unsqueeze(0).to(device)
183
+ batch["rel_pos"] = torch.LongTensor(rel_pos).unsqueeze(0).to(device)
184
+ batch["abs_pos"] = torch.LongTensor(abs_pos).unsqueeze(0).to(device)
185
+ batch["direct"] = torch.LongTensor(direct).unsqueeze(0).to(device)
186
+ batch["h"] = imgh
187
+ batch["w"] = imgw
188
+
189
+ return batch
190
+
191
+
192
+ def to_device(data, device):
193
+ if isinstance(data, torch.Tensor):
194
+ return data.to(device)
195
+ if isinstance(data, dict):
196
+ for key in data:
197
+ if isinstance(data[key], torch.Tensor):
198
+ data[key] = data[key].to(device)
199
+ return data
200
+ if isinstance(data, list):
201
+ return [to_device(d, device) for d in data]
202
+
203
+
204
+ class ZITS(InpaintModel):
205
+ min_size = 256
206
+ pad_mod = 32
207
+ pad_to_square = True
208
+
209
+ def __init__(self, device, **kwargs):
210
+ """
211
+
212
+ Args:
213
+ device:
214
+ """
215
+ super().__init__(device)
216
+ self.device = device
217
+ self.sample_edge_line_iterations = 1
218
+
219
+ def init_model(self, device, **kwargs):
220
+ self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device)
221
+ self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device)
222
+ self.structure_upsample = load_jit_model(
223
+ ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device
224
+ )
225
+ self.inpaint = load_jit_model(ZITS_INPAINT_MODEL_URL, device)
226
+
227
+ @staticmethod
228
+ def is_downloaded() -> bool:
229
+ model_paths = [
230
+ get_cache_path_by_url(ZITS_WIRE_FRAME_MODEL_URL),
231
+ get_cache_path_by_url(ZITS_EDGE_LINE_MODEL_URL),
232
+ get_cache_path_by_url(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL),
233
+ get_cache_path_by_url(ZITS_INPAINT_MODEL_URL),
234
+ ]
235
+ return all([os.path.exists(it) for it in model_paths])
236
+
237
+ def wireframe_edge_and_line(self, items, enable: bool):
238
+ # 最终向 items 中添加 edge 和 line key
239
+ if not enable:
240
+ items["edge"] = torch.zeros_like(items["masks"])
241
+ items["line"] = torch.zeros_like(items["masks"])
242
+ return
243
+
244
+ start = time.time()
245
+ try:
246
+ line_256 = self.wireframe_forward(
247
+ items["img_512"],
248
+ h=256,
249
+ w=256,
250
+ masks=items["mask_512"],
251
+ mask_th=0.85,
252
+ )
253
+ except:
254
+ line_256 = torch.zeros_like(items["mask_256"])
255
+
256
+ print(f"wireframe_forward time: {(time.time() - start) * 1000:.2f}ms")
257
+
258
+ # np_line = (line[0][0].numpy() * 255).astype(np.uint8)
259
+ # cv2.imwrite("line.jpg", np_line)
260
+
261
+ start = time.time()
262
+ edge_pred, line_pred = self.sample_edge_line_logits(
263
+ context=[items["img_256"], items["edge_256"], line_256],
264
+ mask=items["mask_256"].clone(),
265
+ iterations=self.sample_edge_line_iterations,
266
+ add_v=0.05,
267
+ mul_v=4,
268
+ )
269
+ print(f"sample_edge_line_logits time: {(time.time() - start) * 1000:.2f}ms")
270
+
271
+ # np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8)
272
+ # cv2.imwrite("edge_pred.jpg", np_edge_pred)
273
+ # np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8)
274
+ # cv2.imwrite("line_pred.jpg", np_line_pred)
275
+ # exit()
276
+
277
+ input_size = min(items["h"], items["w"])
278
+ if input_size != 256 and input_size > 256:
279
+ while edge_pred.shape[2] < input_size:
280
+ edge_pred = self.structure_upsample(edge_pred)
281
+ edge_pred = torch.sigmoid((edge_pred + 2) * 2)
282
+
283
+ line_pred = self.structure_upsample(line_pred)
284
+ line_pred = torch.sigmoid((line_pred + 2) * 2)
285
+
286
+ edge_pred = F.interpolate(
287
+ edge_pred,
288
+ size=(input_size, input_size),
289
+ mode="bilinear",
290
+ align_corners=False,
291
+ )
292
+ line_pred = F.interpolate(
293
+ line_pred,
294
+ size=(input_size, input_size),
295
+ mode="bilinear",
296
+ align_corners=False,
297
+ )
298
+
299
+ # np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8)
300
+ # cv2.imwrite("edge_pred_upsample.jpg", np_edge_pred)
301
+ # np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8)
302
+ # cv2.imwrite("line_pred_upsample.jpg", np_line_pred)
303
+ # exit()
304
+
305
+ items["edge"] = edge_pred.detach()
306
+ items["line"] = line_pred.detach()
307
+
308
+ @torch.no_grad()
309
+ def forward(self, image, mask, config: Config):
310
+ """Input images and output images have same size
311
+ images: [H, W, C] RGB
312
+ masks: [H, W]
313
+ return: BGR IMAGE
314
+ """
315
+ mask = mask[:, :, 0]
316
+ items = load_image(image, mask, device=self.device)
317
+
318
+ self.wireframe_edge_and_line(items, config.zits_wireframe)
319
+
320
+ inpainted_image = self.inpaint(
321
+ items["images"],
322
+ items["masks"],
323
+ items["edge"],
324
+ items["line"],
325
+ items["rel_pos"],
326
+ items["direct"],
327
+ )
328
+
329
+ inpainted_image = inpainted_image * 255.0
330
+ inpainted_image = (
331
+ inpainted_image.cpu().permute(0, 2, 3, 1)[0].numpy().astype(np.uint8)
332
+ )
333
+ inpainted_image = inpainted_image[:, :, ::-1]
334
+
335
+ # cv2.imwrite("inpainted.jpg", inpainted_image)
336
+ # exit()
337
+
338
+ return inpainted_image
339
+
340
+ def wireframe_forward(self, images, h, w, masks, mask_th=0.925):
341
+ lcnn_mean = torch.tensor([109.730, 103.832, 98.681]).reshape(1, 3, 1, 1)
342
+ lcnn_std = torch.tensor([22.275, 22.124, 23.229]).reshape(1, 3, 1, 1)
343
+ images = images * 255.0
344
+ # the masks value of lcnn is 127.5
345
+ masked_images = images * (1 - masks) + torch.ones_like(images) * masks * 127.5
346
+ masked_images = (masked_images - lcnn_mean) / lcnn_std
347
+
348
+ def to_int(x):
349
+ return tuple(map(int, x))
350
+
351
+ lines_tensor = []
352
+ lmap = np.zeros((h, w))
353
+
354
+ output_masked = self.wireframe(masked_images)
355
+
356
+ output_masked = to_device(output_masked, "cpu")
357
+ if output_masked["num_proposals"] == 0:
358
+ lines_masked = []
359
+ scores_masked = []
360
+ else:
361
+ lines_masked = output_masked["lines_pred"].numpy()
362
+ lines_masked = [
363
+ [line[1] * h, line[0] * w, line[3] * h, line[2] * w]
364
+ for line in lines_masked
365
+ ]
366
+ scores_masked = output_masked["lines_score"].numpy()
367
+
368
+ for line, score in zip(lines_masked, scores_masked):
369
+ if score > mask_th:
370
+ rr, cc, value = skimage.draw.line_aa(
371
+ *to_int(line[0:2]), *to_int(line[2:4])
372
+ )
373
+ lmap[rr, cc] = np.maximum(lmap[rr, cc], value)
374
+
375
+ lmap = np.clip(lmap * 255, 0, 255).astype(np.uint8)
376
+ lines_tensor.append(to_tensor(lmap).unsqueeze(0))
377
+
378
+ lines_tensor = torch.cat(lines_tensor, dim=0)
379
+ return lines_tensor.detach().to(self.device)
380
+
381
+ def sample_edge_line_logits(
382
+ self, context, mask=None, iterations=1, add_v=0, mul_v=4
383
+ ):
384
+ [img, edge, line] = context
385
+
386
+ img = img * (1 - mask)
387
+ edge = edge * (1 - mask)
388
+ line = line * (1 - mask)
389
+
390
+ for i in range(iterations):
391
+ edge_logits, line_logits = self.edge_line(img, edge, line, masks=mask)
392
+
393
+ edge_pred = torch.sigmoid(edge_logits)
394
+ line_pred = torch.sigmoid((line_logits + add_v) * mul_v)
395
+ edge = edge + edge_pred * mask
396
+ edge[edge >= 0.25] = 1
397
+ edge[edge < 0.25] = 0
398
+ line = line + line_pred * mask
399
+
400
+ b, _, h, w = edge_pred.shape
401
+ edge_pred = edge_pred.reshape(b, -1, 1)
402
+ line_pred = line_pred.reshape(b, -1, 1)
403
+ mask = mask.reshape(b, -1)
404
+
405
+ edge_probs = torch.cat([1 - edge_pred, edge_pred], dim=-1)
406
+ line_probs = torch.cat([1 - line_pred, line_pred], dim=-1)
407
+ edge_probs[:, :, 1] += 0.5
408
+ line_probs[:, :, 1] += 0.5
409
+ edge_max_probs = edge_probs.max(dim=-1)[0] + (1 - mask) * (-100)
410
+ line_max_probs = line_probs.max(dim=-1)[0] + (1 - mask) * (-100)
411
+
412
+ indices = torch.sort(
413
+ edge_max_probs + line_max_probs, dim=-1, descending=True
414
+ )[1]
415
+
416
+ for ii in range(b):
417
+ keep = int((i + 1) / iterations * torch.sum(mask[ii, ...]))
418
+
419
+ assert torch.sum(mask[ii][indices[ii, :keep]]) == keep, "Error!!!"
420
+ mask[ii][indices[ii, :keep]] = 0
421
+
422
+ mask = mask.reshape(b, 1, h, w)
423
+ edge = edge * (1 - mask)
424
+ line = line * (1 - mask)
425
+
426
+ edge, line = edge.to(torch.float32), line.to(torch.float32)
427
+ return edge, line
lama_cleaner/model_manager.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lama_cleaner.model.fcf import FcF
2
+ from lama_cleaner.model.lama import LaMa
3
+ from lama_cleaner.model.ldm import LDM
4
+ from lama_cleaner.model.mat import MAT
5
+ from lama_cleaner.model.sd import SD14
6
+ from lama_cleaner.model.zits import ZITS
7
+ from lama_cleaner.model.opencv2 import OpenCV2
8
+ from lama_cleaner.schema import Config
9
+
10
+ models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.4": SD14, "cv2": OpenCV2}
11
+
12
+
13
+ class ModelManager:
14
+ def __init__(self, name: str, device, **kwargs):
15
+ self.name = name
16
+ self.device = device
17
+ self.kwargs = kwargs
18
+ self.model = self.init_model(name, device, **kwargs)
19
+
20
+ def init_model(self, name: str, device, **kwargs):
21
+ if name in models:
22
+ model = models[name](device, **kwargs)
23
+ else:
24
+ raise NotImplementedError(f"Not supported model: {name}")
25
+ return model
26
+
27
+ def is_downloaded(self, name: str) -> bool:
28
+ if name in models:
29
+ return models[name].is_downloaded()
30
+ else:
31
+ raise NotImplementedError(f"Not supported model: {name}")
32
+
33
+ def __call__(self, image, mask, config: Config):
34
+ return self.model(image, mask, config)
35
+
36
+ def switch(self, new_name: str):
37
+ if new_name == self.name:
38
+ return
39
+ try:
40
+ self.model = self.init_model(new_name, self.device, **self.kwargs)
41
+ self.name = new_name
42
+ except NotImplementedError as e:
43
+ raise e
lama_cleaner/schema.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class HDStrategy(str, Enum):
7
+ ORIGINAL = "Original"
8
+ RESIZE = "Resize"
9
+ CROP = "Crop"
10
+
11
+
12
+ class LDMSampler(str, Enum):
13
+ ddim = "ddim"
14
+ plms = "plms"
15
+
16
+
17
+ class SDSampler(str, Enum):
18
+ ddim = "ddim"
19
+ pndm = "pndm"
20
+
21
+
22
+ class Config(BaseModel):
23
+ ldm_steps: int
24
+ ldm_sampler: str = LDMSampler.plms
25
+ zits_wireframe: bool = True
26
+ hd_strategy: str
27
+ hd_strategy_crop_margin: int
28
+ hd_strategy_crop_trigger_size: int
29
+ hd_strategy_resize_limit: int
30
+
31
+ prompt: str = ""
32
+ # 始终是在原图尺度上的值
33
+ use_croper: bool = False
34
+ croper_x: int = None
35
+ croper_y: int = None
36
+ croper_height: int = None
37
+ croper_width: int = None
38
+
39
+ # sd
40
+ sd_mask_blur: int = 0
41
+ sd_strength: float = 0.75
42
+ sd_steps: int = 50
43
+ sd_guidance_scale: float = 7.5
44
+ sd_sampler: str = SDSampler.ddim
45
+ # -1 mean random seed
46
+ sd_seed: int = 42
47
+
48
+ # cv2
49
+ cv2_flag: str = 'INPAINT_NS'
50
+ cv2_radius: int = 4
lama_cleaner/settings.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Django settings for lama_cleaner project.
3
+
4
+ Generated by 'django-admin startproject' using Django 4.1.2.
5
+
6
+ For more information on this file, see
7
+ https://docs.djangoproject.com/en/4.1/topics/settings/
8
+
9
+ For the full list of settings and their values, see
10
+ https://docs.djangoproject.com/en/4.1/ref/settings/
11
+ """
12
+
13
+ from pathlib import Path
14
+
15
+ # Build paths inside the project like this: BASE_DIR / 'subdir'.
16
+ BASE_DIR = Path(__file__).resolve().parent.parent
17
+
18
+
19
+ # Quick-start development settings - unsuitable for production
20
+ # See https://docs.djangoproject.com/en/4.1/howto/deployment/checklist/
21
+
22
+ # SECURITY WARNING: keep the secret key used in production secret!
23
+ SECRET_KEY = 'django-insecure-=x2n@zasb2nkq$)frp(&h*tsozyka+jb5(&3^7@u5@ven@-sdu'
24
+
25
+ # SECURITY WARNING: don't run with debug turned on in production!
26
+ DEBUG = True
27
+
28
+ ALLOWED_HOSTS = []
29
+
30
+
31
+ # Application definition
32
+
33
+ INSTALLED_APPS = [
34
+ 'django.contrib.admin',
35
+ 'django.contrib.auth',
36
+ 'django.contrib.contenttypes',
37
+ 'django.contrib.sessions',
38
+ 'django.contrib.messages',
39
+ 'django.contrib.staticfiles',
40
+ 'inpainting',
41
+ ]
42
+
43
+ MIDDLEWARE = [
44
+ 'django.middleware.security.SecurityMiddleware',
45
+ 'django.contrib.sessions.middleware.SessionMiddleware',
46
+ 'django.middleware.common.CommonMiddleware',
47
+ 'django.middleware.csrf.CsrfViewMiddleware',
48
+ 'django.contrib.auth.middleware.AuthenticationMiddleware',
49
+ 'django.contrib.messages.middleware.MessageMiddleware',
50
+ 'django.middleware.clickjacking.XFrameOptionsMiddleware',
51
+ ]
52
+
53
+ ROOT_URLCONF = 'lama_cleaner.urls'
54
+
55
+ TEMPLATES = [
56
+ {
57
+ 'BACKEND': 'django.template.backends.django.DjangoTemplates',
58
+ 'DIRS': [],
59
+ 'APP_DIRS': True,
60
+ 'OPTIONS': {
61
+ 'context_processors': [
62
+ 'django.template.context_processors.debug',
63
+ 'django.template.context_processors.request',
64
+ 'django.contrib.auth.context_processors.auth',
65
+ 'django.contrib.messages.context_processors.messages',
66
+ ],
67
+ },
68
+ },
69
+ ]
70
+
71
+ WSGI_APPLICATION = 'lama_cleaner.wsgi.application'
72
+
73
+
74
+ # Database
75
+ # https://docs.djangoproject.com/en/4.1/ref/settings/#databases
76
+
77
+ DATABASES = {
78
+ 'default': {
79
+ 'ENGINE': 'django.db.backends.sqlite3',
80
+ 'NAME': BASE_DIR / 'db.sqlite3',
81
+ }
82
+ }
83
+
84
+
85
+ # Password validation
86
+ # https://docs.djangoproject.com/en/4.1/ref/settings/#auth-password-validators
87
+
88
+ AUTH_PASSWORD_VALIDATORS = [
89
+ {
90
+ 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
91
+ },
92
+ {
93
+ 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
94
+ },
95
+ {
96
+ 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
97
+ },
98
+ {
99
+ 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
100
+ },
101
+ ]
102
+
103
+
104
+ # Internationalization
105
+ # https://docs.djangoproject.com/en/4.1/topics/i18n/
106
+
107
+ LANGUAGE_CODE = 'en-us'
108
+
109
+ TIME_ZONE = 'UTC'
110
+
111
+ USE_I18N = True
112
+
113
+ USE_TZ = True
114
+
115
+
116
+ # Static files (CSS, JavaScript, Images)
117
+ # https://docs.djangoproject.com/en/4.1/howto/static-files/
118
+
119
+ STATIC_URL = 'static/'
120
+
121
+ # Default primary key field type
122
+ # https://docs.djangoproject.com/en/4.1/ref/settings/#default-auto-field
123
+
124
+ DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
lama_cleaner/urls.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """lama_cleaner URL Configuration
2
+
3
+ The `urlpatterns` list routes URLs to views. For more information please see:
4
+ https://docs.djangoproject.com/en/4.1/topics/http/urls/
5
+ Examples:
6
+ Function views
7
+ 1. Add an import: from my_app import views
8
+ 2. Add a URL to urlpatterns: path('', views.home, name='home')
9
+ Class-based views
10
+ 1. Add an import: from other_app.views import Home
11
+ 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
12
+ Including another URLconf
13
+ 1. Import the include() function: from django.urls import include, path
14
+ 2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
15
+ """
16
+ from django.contrib import admin
17
+ from django.urls import path,include
18
+
19
+ urlpatterns = [
20
+ path('admin/', admin.site.urls),
21
+ path('inpainting/',include('inpainting.urls')),
22
+ ]
lama_cleaner/wsgi.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WSGI config for lama_cleaner project.
3
+
4
+ It exposes the WSGI callable as a module-level variable named ``application``.
5
+
6
+ For more information on this file, see
7
+ https://docs.djangoproject.com/en/4.1/howto/deployment/wsgi/
8
+ """
9
+
10
+ import os
11
+
12
+ from django.core.wsgi import get_wsgi_application
13
+
14
+ os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'lama_cleaner.settings')
15
+
16
+ application = get_wsgi_application()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ opencv-python==4.6.0.66
2
+ pytest==7.1.3
3
+ torch==2.2.0
4
+ pydantic==1.10.2
5
+ loguru==0.6.0
6
+ tqdm==4.64.1
7
+ Pillow==9.2.0
8
+ diffusers==0.4.2
9
+ transformers
10
+ scikit-image==0.19.3
11
+ gradio
12
+ timm