ArchitSharma commited on
Commit
c716076
1 Parent(s): 5a090d9

Upload 16 files

Browse files
src/__init__.py ADDED
File without changes
src/app_utils.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import random
4
+ import _thread as thread
5
+ from uuid import uuid4
6
+ import urllib
7
+
8
+ import numpy as np
9
+ import skimage
10
+ from skimage.filters import gaussian
11
+ from PIL import Image
12
+
13
+ def compress_image(image, path_original):
14
+ size = 1920, 1080
15
+ width = 1920
16
+ height = 1080
17
+
18
+ name = os.path.basename(path_original).split('.')
19
+ first_name = os.path.join(os.path.dirname(path_original), name[0] + '.jpg')
20
+
21
+ if image.size[0] > width and image.size[1] > height:
22
+ image.thumbnail(size, Image.ANTIALIAS)
23
+ image.save(first_name, quality=85)
24
+ elif image.size[0] > width:
25
+ wpercent = (width/float(image.size[0]))
26
+ height = int((float(image.size[1])*float(wpercent)))
27
+ image = image.resize((width,height), Image.ANTIALIAS)
28
+ image.save(first_name,quality=85)
29
+ elif image.size[1] > height:
30
+ wpercent = (height/float(image.size[1]))
31
+ width = int((float(image.size[0])*float(wpercent)))
32
+ image = image.resize((width,height), Image.ANTIALIAS)
33
+ image.save(first_name, quality=85)
34
+ else:
35
+ image.save(first_name, quality=85)
36
+
37
+
38
+ def convertToJPG(path_original):
39
+ img = Image.open(path_original)
40
+ name = os.path.basename(path_original).split('.')
41
+ first_name = os.path.join(os.path.dirname(path_original), name[0] + '.jpg')
42
+
43
+ if img.format == "JPEG":
44
+ image = img.convert('RGB')
45
+ compress_image(image, path_original)
46
+ img.close()
47
+
48
+ elif img.format == "GIF":
49
+ i = img.convert("RGBA")
50
+ bg = Image.new("RGBA", i.size)
51
+ image = Image.composite(i, bg, i)
52
+ compress_image(image, path_original)
53
+ img.close()
54
+
55
+ elif img.format == "PNG":
56
+ try:
57
+ image = Image.new("RGB", img.size, (255,255,255))
58
+ image.paste(img,img)
59
+ compress_image(image, path_original)
60
+ except ValueError:
61
+ image = img.convert('RGB')
62
+ compress_image(image, path_original)
63
+
64
+ img.close()
65
+
66
+ elif img.format == "BMP":
67
+ image = img.convert('RGB')
68
+ compress_image(image, path_original)
69
+ img.close()
70
+
71
+
72
+
73
+ def blur(image, x0, x1, y0, y1, sigma=1, multichannel=True):
74
+ y0, y1 = min(y0, y1), max(y0, y1)
75
+ x0, x1 = min(x0, x1), max(x0, x1)
76
+ im = image.copy()
77
+ sub_im = im[y0:y1,x0:x1].copy()
78
+ blur_sub_im = gaussian(sub_im, sigma=sigma, multichannel=multichannel)
79
+ blur_sub_im = np.round(255 * blur_sub_im)
80
+ im[y0:y1,x0:x1] = blur_sub_im
81
+ return im
82
+
83
+
84
+
85
+ def download(url, filename):
86
+ data = requests.get(url).content
87
+ with open(filename, 'wb') as handler:
88
+ handler.write(data)
89
+
90
+ return filename
91
+
92
+
93
+ def generate_random_filename(upload_directory, extension):
94
+ filename = str(uuid4())
95
+ filename = os.path.join(upload_directory, filename + "." + extension)
96
+ return filename
97
+
98
+
99
+ def clean_me(filename):
100
+ if os.path.exists(filename):
101
+ os.remove(filename)
102
+
103
+
104
+ def clean_all(files):
105
+ for me in files:
106
+ clean_me(me)
107
+
108
+
109
+ def create_directory(path):
110
+ os.makedirs(os.path.dirname(path), exist_ok=True)
111
+
112
+
113
+ def get_model_bin(url, output_path):
114
+ # print('Getting model dir: ', output_path)
115
+ if not os.path.exists(output_path):
116
+ create_directory(output_path)
117
+
118
+ urllib.request.urlretrieve(url, output_path)
119
+
120
+ # cmd = "wget -O %s %s" % (output_path, url)
121
+ # print(cmd)
122
+ # os.system(cmd)
123
+
124
+ return output_path
125
+
126
+
127
+ #model_list = [(url, output_path), (url, output_path)]
128
+ def get_multi_model_bin(model_list):
129
+ for m in model_list:
130
+ thread.start_new_thread(get_model_bin, m)
131
+
src/deoldify/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from src.deoldify._device import _Device
2
+
3
+ device = _Device()
src/deoldify/_device.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from enum import Enum
3
+ from .device_id import DeviceId
4
+
5
+ #NOTE: This must be called first before any torch imports in order to work properly!
6
+
7
+ class DeviceException(Exception):
8
+ pass
9
+
10
+ class _Device:
11
+ def __init__(self):
12
+ self.set(DeviceId.CPU)
13
+
14
+ def is_gpu(self):
15
+ ''' Returns `True` if the current device is GPU, `False` otherwise. '''
16
+ return self.current() is not DeviceId.CPU
17
+
18
+ def current(self):
19
+ return self._current_device
20
+
21
+ def set(self, device:DeviceId):
22
+ if device == DeviceId.CPU:
23
+ os.environ['CUDA_VISIBLE_DEVICES']=''
24
+ else:
25
+ os.environ['CUDA_VISIBLE_DEVICES']=str(device.value)
26
+ import torch
27
+ torch.backends.cudnn.benchmark=False
28
+
29
+ self._current_device = device
30
+ return device
src/deoldify/augs.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from fastai.vision.image import TfmPixel
4
+
5
+ # Contributed by Rani Horev. Thank you!
6
+ def _noisify(
7
+ x, pct_pixels_min: float = 0.001, pct_pixels_max: float = 0.4, noise_range: int = 30
8
+ ):
9
+ if noise_range > 255 or noise_range < 0:
10
+ raise Exception("noise_range must be between 0 and 255, inclusively.")
11
+
12
+ h, w = x.shape[1:]
13
+ img_size = h * w
14
+ mult = 10000.0
15
+ pct_pixels = (
16
+ random.randrange(int(pct_pixels_min * mult), int(pct_pixels_max * mult)) / mult
17
+ )
18
+ noise_count = int(img_size * pct_pixels)
19
+
20
+ for ii in range(noise_count):
21
+ yy = random.randrange(h)
22
+ xx = random.randrange(w)
23
+ noise = random.randrange(-noise_range, noise_range) / 255.0
24
+ x[:, yy, xx].add_(noise)
25
+
26
+ return x
27
+
28
+
29
+ noisify = TfmPixel(_noisify)
src/deoldify/critics.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.core import *
2
+ from fastai.torch_core import *
3
+ from fastai.vision import *
4
+ from fastai.vision.gan import AdaptiveLoss, accuracy_thresh_expand
5
+
6
+ _conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)
7
+
8
+
9
+ def _conv(ni: int, nf: int, ks: int = 3, stride: int = 1, **kwargs):
10
+ return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
11
+
12
+
13
+ def custom_gan_critic(
14
+ n_channels: int = 3, nf: int = 256, n_blocks: int = 3, p: int = 0.15
15
+ ):
16
+ "Critic to train a `GAN`."
17
+ layers = [_conv(n_channels, nf, ks=4, stride=2), nn.Dropout2d(p / 2)]
18
+ for i in range(n_blocks):
19
+ layers += [
20
+ _conv(nf, nf, ks=3, stride=1),
21
+ nn.Dropout2d(p),
22
+ _conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)),
23
+ ]
24
+ nf *= 2
25
+ layers += [
26
+ _conv(nf, nf, ks=3, stride=1),
27
+ _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
28
+ Flatten(),
29
+ ]
30
+ return nn.Sequential(*layers)
31
+
32
+
33
+ def colorize_crit_learner(
34
+ data: ImageDataBunch,
35
+ loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()),
36
+ nf: int = 256,
37
+ ) -> Learner:
38
+ return Learner(
39
+ data,
40
+ custom_gan_critic(nf=nf),
41
+ metrics=accuracy_thresh_expand,
42
+ loss_func=loss_critic,
43
+ wd=1e-3,
44
+ )
src/deoldify/dataset.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fastai
2
+ from fastai import *
3
+ from fastai.core import *
4
+ from fastai.vision.transform import get_transforms
5
+ from fastai.vision.data import ImageImageList, ImageDataBunch, imagenet_stats
6
+ from .augs import noisify
7
+
8
+
9
+ def get_colorize_data(
10
+ sz: int,
11
+ bs: int,
12
+ crappy_path: Path,
13
+ good_path: Path,
14
+ random_seed: int = None,
15
+ keep_pct: float = 1.0,
16
+ num_workers: int = 8,
17
+ stats: tuple = imagenet_stats,
18
+ xtra_tfms=[],
19
+ ) -> ImageDataBunch:
20
+
21
+ src = (
22
+ ImageImageList.from_folder(crappy_path, convert_mode='RGB')
23
+ .use_partial_data(sample_pct=keep_pct, seed=random_seed)
24
+ .split_by_rand_pct(0.1, seed=random_seed)
25
+ )
26
+
27
+ data = (
28
+ src.label_from_func(lambda x: good_path / x.relative_to(crappy_path))
29
+ .transform(
30
+ get_transforms(
31
+ max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms
32
+ ),
33
+ size=sz,
34
+ tfm_y=True,
35
+ )
36
+ .databunch(bs=bs, num_workers=num_workers, no_check=True)
37
+ .normalize(stats, do_y=True)
38
+ )
39
+
40
+ data.c = 3
41
+ return data
42
+
43
+
44
+ def get_dummy_databunch() -> ImageDataBunch:
45
+ path = Path('./assets/dummy/')
46
+ return get_colorize_data(
47
+ sz=1, bs=1, crappy_path=path, good_path=path, keep_pct=0.001
48
+ )
src/deoldify/device_id.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import IntEnum
2
+
3
+ class DeviceId(IntEnum):
4
+ GPU0 = 0,
5
+ GPU1 = 1,
6
+ GPU2 = 2,
7
+ GPU3 = 3,
8
+ GPU4 = 4,
9
+ GPU5 = 5,
10
+ GPU6 = 6,
11
+ GPU7 = 7,
12
+ CPU = 99
src/deoldify/filters.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numpy import ndarray
2
+ from abc import ABC, abstractmethod
3
+ from .critics import colorize_crit_learner
4
+ from fastai.core import *
5
+ from fastai.vision import *
6
+ from fastai.vision.image import *
7
+ from fastai.vision.data import *
8
+ from fastai import *
9
+ import math
10
+ from scipy import misc
11
+ import cv2
12
+ from PIL import Image as PilImage
13
+
14
+
15
+ class IFilter(ABC):
16
+ @abstractmethod
17
+ def filter(
18
+ self, orig_image: PilImage, filtered_image: PilImage, render_factor: int
19
+ ) -> PilImage:
20
+ pass
21
+
22
+
23
+ class BaseFilter(IFilter):
24
+ def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
25
+ super().__init__()
26
+ self.learn = learn
27
+ self.device = next(self.learn.model.parameters()).device
28
+ self.norm, self.denorm = normalize_funcs(*stats)
29
+
30
+ def _transform(self, image: PilImage) -> PilImage:
31
+ return image
32
+
33
+ def _scale_to_square(self, orig: PilImage, targ: int) -> PilImage:
34
+ # a simple stretch to fit a square really makes a big difference in rendering quality/consistency.
35
+ # I've tried padding to the square as well (reflect, symetric, constant, etc). Not as good!
36
+ targ_sz = (targ, targ)
37
+ return orig.resize(targ_sz, resample=PIL.Image.BILINEAR)
38
+
39
+ def _get_model_ready_image(self, orig: PilImage, sz: int) -> PilImage:
40
+ result = self._scale_to_square(orig, sz)
41
+ result = self._transform(result)
42
+ return result
43
+
44
+ def _model_process(self, orig: PilImage, sz: int) -> PilImage:
45
+ model_image = self._get_model_ready_image(orig, sz)
46
+ x = pil2tensor(model_image, np.float32)
47
+ x = x.to(self.device)
48
+ x.div_(255)
49
+ x, y = self.norm((x, x), do_x=True)
50
+
51
+ try:
52
+ result = self.learn.pred_batch(
53
+ ds_type=DatasetType.Valid, batch=(x[None], y[None]), reconstruct=True
54
+ )
55
+ except RuntimeError as rerr:
56
+ if 'memory' not in str(rerr):
57
+ raise rerr
58
+ print('Warning: render_factor was set too high, and out of memory error resulted. Returning original image.')
59
+ return model_image
60
+
61
+ out = result[0]
62
+ out = self.denorm(out.px, do_x=False)
63
+ out = image2np(out * 255).astype(np.uint8)
64
+ return PilImage.fromarray(out)
65
+
66
+ def _unsquare(self, image: PilImage, orig: PilImage) -> PilImage:
67
+ targ_sz = orig.size
68
+ image = image.resize(targ_sz, resample=PIL.Image.BILINEAR)
69
+ return image
70
+
71
+
72
+ class ColorizerFilter(BaseFilter):
73
+ def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
74
+ super().__init__(learn=learn, stats=stats)
75
+ self.render_base = 16
76
+
77
+ def filter(
78
+ self, orig_image: PilImage, filtered_image: PilImage, render_factor: int, post_process: bool = True) -> PilImage:
79
+ render_sz = render_factor * self.render_base
80
+ model_image = self._model_process(orig=filtered_image, sz=render_sz)
81
+ raw_color = self._unsquare(model_image, orig_image)
82
+
83
+ if post_process:
84
+ return self._post_process(raw_color, orig_image)
85
+ else:
86
+ return raw_color
87
+
88
+ def _transform(self, image: PilImage) -> PilImage:
89
+ return image.convert('LA').convert('RGB')
90
+
91
+ # This takes advantage of the fact that human eyes are much less sensitive to
92
+ # imperfections in chrominance compared to luminance. This means we can
93
+ # save a lot on memory and processing in the model, yet get a great high
94
+ # resolution result at the end. This is primarily intended just for
95
+ # inference
96
+ def _post_process(self, raw_color: PilImage, orig: PilImage) -> PilImage:
97
+ color_np = np.asarray(raw_color)
98
+ orig_np = np.asarray(orig)
99
+ color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV)
100
+ # do a black and white transform first to get better luminance values
101
+ orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_BGR2YUV)
102
+ hires = np.copy(orig_yuv)
103
+ hires[:, :, 1:3] = color_yuv[:, :, 1:3]
104
+ final = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)
105
+ final = PilImage.fromarray(final)
106
+ return final
107
+
108
+
109
+ class MasterFilter(BaseFilter):
110
+ def __init__(self, filters: [IFilter], render_factor: int):
111
+ self.filters = filters
112
+ self.render_factor = render_factor
113
+
114
+ def filter(
115
+ self, orig_image: PilImage, filtered_image: PilImage, render_factor: int = None, post_process: bool = True) -> PilImage:
116
+ render_factor = self.render_factor if render_factor is None else render_factor
117
+ for filter in self.filters:
118
+ filtered_image = filter.filter(orig_image, filtered_image, render_factor, post_process)
119
+
120
+ return filtered_image
src/deoldify/generators.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.vision import *
2
+ from fastai.vision.learner import cnn_config
3
+ from .unet import DynamicUnetWide, DynamicUnetDeep
4
+ from .loss import FeatureLoss
5
+ from .dataset import *
6
+
7
+ # Weights are implicitly read from ./models/ folder
8
+ def gen_inference_wide(
9
+ root_folder: Path, weights_name: str, nf_factor: int = 2, arch=models.resnet101) -> Learner:
10
+ data = get_dummy_databunch()
11
+ learn = gen_learner_wide(
12
+ data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch
13
+ )
14
+ learn.path = root_folder
15
+ learn.load(weights_name)
16
+ learn.model.eval()
17
+ return learn
18
+
19
+
20
+ def gen_learner_wide(
21
+ data: ImageDataBunch, gen_loss, arch=models.resnet101, nf_factor: int = 2
22
+ ) -> Learner:
23
+ return unet_learner_wide(
24
+ data,
25
+ arch=arch,
26
+ wd=1e-3,
27
+ blur=True,
28
+ norm_type=NormType.Spectral,
29
+ self_attention=True,
30
+ y_range=(-3.0, 3.0),
31
+ loss_func=gen_loss,
32
+ nf_factor=nf_factor,
33
+ )
34
+
35
+
36
+ # The code below is meant to be merged into fastaiv1 ideally
37
+ def unet_learner_wide(
38
+ data: DataBunch,
39
+ arch: Callable,
40
+ pretrained: bool = True,
41
+ blur_final: bool = True,
42
+ norm_type: Optional[NormType] = NormType,
43
+ split_on: Optional[SplitFuncOrIdxList] = None,
44
+ blur: bool = False,
45
+ self_attention: bool = False,
46
+ y_range: Optional[Tuple[float, float]] = None,
47
+ last_cross: bool = True,
48
+ bottle: bool = False,
49
+ nf_factor: int = 1,
50
+ **kwargs: Any
51
+ ) -> Learner:
52
+ "Build Unet learner from `data` and `arch`."
53
+ meta = cnn_config(arch)
54
+ body = create_body(arch, pretrained)
55
+ model = to_device(
56
+ DynamicUnetWide(
57
+ body,
58
+ n_classes=data.c,
59
+ blur=blur,
60
+ blur_final=blur_final,
61
+ self_attention=self_attention,
62
+ y_range=y_range,
63
+ norm_type=norm_type,
64
+ last_cross=last_cross,
65
+ bottle=bottle,
66
+ nf_factor=nf_factor,
67
+ ),
68
+ data.device,
69
+ )
70
+ learn = Learner(data, model, **kwargs)
71
+ learn.split(ifnone(split_on, meta['split']))
72
+ if pretrained:
73
+ learn.freeze()
74
+ apply_init(model[2], nn.init.kaiming_normal_)
75
+ return learn
76
+
77
+
78
+ # ----------------------------------------------------------------------
79
+
80
+ # Weights are implicitly read from ./models/ folder
81
+ def gen_inference_deep(
82
+ root_folder: Path, weights_name: str, arch=models.resnet34, nf_factor: float = 1.5) -> Learner:
83
+ data = get_dummy_databunch()
84
+ learn = gen_learner_deep(
85
+ data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor
86
+ )
87
+ learn.path = root_folder
88
+ learn.load(weights_name)
89
+ learn.model.eval()
90
+ return learn
91
+
92
+
93
+ def gen_learner_deep(
94
+ data: ImageDataBunch, gen_loss, arch=models.resnet34, nf_factor: float = 1.5
95
+ ) -> Learner:
96
+ return unet_learner_deep(
97
+ data,
98
+ arch,
99
+ wd=1e-3,
100
+ blur=True,
101
+ norm_type=NormType.Spectral,
102
+ self_attention=True,
103
+ y_range=(-3.0, 3.0),
104
+ loss_func=gen_loss,
105
+ nf_factor=nf_factor,
106
+ )
107
+
108
+
109
+ # The code below is meant to be merged into fastaiv1 ideally
110
+ def unet_learner_deep(
111
+ data: DataBunch,
112
+ arch: Callable,
113
+ pretrained: bool = True,
114
+ blur_final: bool = True,
115
+ norm_type: Optional[NormType] = NormType,
116
+ split_on: Optional[SplitFuncOrIdxList] = None,
117
+ blur: bool = False,
118
+ self_attention: bool = False,
119
+ y_range: Optional[Tuple[float, float]] = None,
120
+ last_cross: bool = True,
121
+ bottle: bool = False,
122
+ nf_factor: float = 1.5,
123
+ **kwargs: Any
124
+ ) -> Learner:
125
+ "Build Unet learner from `data` and `arch`."
126
+ meta = cnn_config(arch)
127
+ body = create_body(arch, pretrained)
128
+ model = to_device(
129
+ DynamicUnetDeep(
130
+ body,
131
+ n_classes=data.c,
132
+ blur=blur,
133
+ blur_final=blur_final,
134
+ self_attention=self_attention,
135
+ y_range=y_range,
136
+ norm_type=norm_type,
137
+ last_cross=last_cross,
138
+ bottle=bottle,
139
+ nf_factor=nf_factor,
140
+ ),
141
+ data.device,
142
+ )
143
+ learn = Learner(data, model, **kwargs)
144
+ learn.split(ifnone(split_on, meta['split']))
145
+ if pretrained:
146
+ learn.freeze()
147
+ apply_init(model[2], nn.init.kaiming_normal_)
148
+ return learn
149
+
150
+
151
+ # -----------------------------
src/deoldify/layers.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.layers import *
2
+ from fastai.torch_core import *
3
+ from torch.nn.parameter import Parameter
4
+ from torch.autograd import Variable
5
+
6
+
7
+ # The code below is meant to be merged into fastaiv1 ideally
8
+
9
+
10
+ def custom_conv_layer(
11
+ ni: int,
12
+ nf: int,
13
+ ks: int = 3,
14
+ stride: int = 1,
15
+ padding: int = None,
16
+ bias: bool = None,
17
+ is_1d: bool = False,
18
+ norm_type: Optional[NormType] = NormType.Batch,
19
+ use_activ: bool = True,
20
+ leaky: float = None,
21
+ transpose: bool = False,
22
+ init: Callable = nn.init.kaiming_normal_,
23
+ self_attention: bool = False,
24
+ extra_bn: bool = False,
25
+ ):
26
+ "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
27
+ if padding is None:
28
+ padding = (ks - 1) // 2 if not transpose else 0
29
+ bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn == True
30
+ if bias is None:
31
+ bias = not bn
32
+ conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
33
+ conv = init_default(
34
+ conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding),
35
+ init,
36
+ )
37
+ if norm_type == NormType.Weight:
38
+ conv = weight_norm(conv)
39
+ elif norm_type == NormType.Spectral:
40
+ conv = spectral_norm(conv)
41
+ layers = [conv]
42
+ if use_activ:
43
+ layers.append(relu(True, leaky=leaky))
44
+ if bn:
45
+ layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
46
+ if self_attention:
47
+ layers.append(SelfAttention(nf))
48
+ return nn.Sequential(*layers)
src/deoldify/loss.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai import *
2
+ from fastai.core import *
3
+ from fastai.torch_core import *
4
+ from fastai.callbacks import hook_outputs
5
+ import torchvision.models as models
6
+
7
+
8
+ class FeatureLoss(nn.Module):
9
+ def __init__(self, layer_wgts=[20, 70, 10]):
10
+ super().__init__()
11
+
12
+ self.m_feat = models.vgg16_bn(True).features.cuda().eval()
13
+ requires_grad(self.m_feat, False)
14
+ blocks = [
15
+ i - 1
16
+ for i, o in enumerate(children(self.m_feat))
17
+ if isinstance(o, nn.MaxPool2d)
18
+ ]
19
+ layer_ids = blocks[2:5]
20
+ self.loss_features = [self.m_feat[i] for i in layer_ids]
21
+ self.hooks = hook_outputs(self.loss_features, detach=False)
22
+ self.wgts = layer_wgts
23
+ self.metric_names = ['pixel'] + [f'feat_{i}' for i in range(len(layer_ids))]
24
+ self.base_loss = F.l1_loss
25
+
26
+ def _make_features(self, x, clone=False):
27
+ self.m_feat(x)
28
+ return [(o.clone() if clone else o) for o in self.hooks.stored]
29
+
30
+ def forward(self, input, target):
31
+ out_feat = self._make_features(target, clone=True)
32
+ in_feat = self._make_features(input)
33
+ self.feat_losses = [self.base_loss(input, target)]
34
+ self.feat_losses += [
35
+ self.base_loss(f_in, f_out) * w
36
+ for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
37
+ ]
38
+
39
+ self.metrics = dict(zip(self.metric_names, self.feat_losses))
40
+ return sum(self.feat_losses)
41
+
42
+ def __del__(self):
43
+ self.hooks.remove()
44
+
45
+
46
+ # Refactored code, originally from https://github.com/VinceMarron/style_transfer
47
+ class WassFeatureLoss(nn.Module):
48
+ def __init__(self, layer_wgts=[5, 15, 2], wass_wgts=[3.0, 0.7, 0.01]):
49
+ super().__init__()
50
+ self.m_feat = models.vgg16_bn(True).features.cuda().eval()
51
+ requires_grad(self.m_feat, False)
52
+ blocks = [
53
+ i - 1
54
+ for i, o in enumerate(children(self.m_feat))
55
+ if isinstance(o, nn.MaxPool2d)
56
+ ]
57
+ layer_ids = blocks[2:5]
58
+ self.loss_features = [self.m_feat[i] for i in layer_ids]
59
+ self.hooks = hook_outputs(self.loss_features, detach=False)
60
+ self.wgts = layer_wgts
61
+ self.wass_wgts = wass_wgts
62
+ self.metric_names = (
63
+ ['pixel']
64
+ + [f'feat_{i}' for i in range(len(layer_ids))]
65
+ + [f'wass_{i}' for i in range(len(layer_ids))]
66
+ )
67
+ self.base_loss = F.l1_loss
68
+
69
+ def _make_features(self, x, clone=False):
70
+ self.m_feat(x)
71
+ return [(o.clone() if clone else o) for o in self.hooks.stored]
72
+
73
+ def _calc_2_moments(self, tensor):
74
+ chans = tensor.shape[1]
75
+ tensor = tensor.view(1, chans, -1)
76
+ n = tensor.shape[2]
77
+ mu = tensor.mean(2)
78
+ tensor = (tensor - mu[:, :, None]).squeeze(0)
79
+ # Prevents nasty bug that happens very occassionally- divide by zero. Why such things happen?
80
+ if n == 0:
81
+ return None, None
82
+ cov = torch.mm(tensor, tensor.t()) / float(n)
83
+ return mu, cov
84
+
85
+ def _get_style_vals(self, tensor):
86
+ mean, cov = self._calc_2_moments(tensor)
87
+ if mean is None:
88
+ return None, None, None
89
+ eigvals, eigvects = torch.symeig(cov, eigenvectors=True)
90
+ eigroot_mat = torch.diag(torch.sqrt(eigvals.clamp(min=0)))
91
+ root_cov = torch.mm(torch.mm(eigvects, eigroot_mat), eigvects.t())
92
+ tr_cov = eigvals.clamp(min=0).sum()
93
+ return mean, tr_cov, root_cov
94
+
95
+ def _calc_l2wass_dist(
96
+ self, mean_stl, tr_cov_stl, root_cov_stl, mean_synth, cov_synth
97
+ ):
98
+ tr_cov_synth = torch.symeig(cov_synth, eigenvectors=True)[0].clamp(min=0).sum()
99
+ mean_diff_squared = (mean_stl - mean_synth).pow(2).sum()
100
+ cov_prod = torch.mm(torch.mm(root_cov_stl, cov_synth), root_cov_stl)
101
+ var_overlap = torch.sqrt(
102
+ torch.symeig(cov_prod, eigenvectors=True)[0].clamp(min=0) + 1e-8
103
+ ).sum()
104
+ dist = mean_diff_squared + tr_cov_stl + tr_cov_synth - 2 * var_overlap
105
+ return dist
106
+
107
+ def _single_wass_loss(self, pred, targ):
108
+ mean_test, tr_cov_test, root_cov_test = targ
109
+ mean_synth, cov_synth = self._calc_2_moments(pred)
110
+ loss = self._calc_l2wass_dist(
111
+ mean_test, tr_cov_test, root_cov_test, mean_synth, cov_synth
112
+ )
113
+ return loss
114
+
115
+ def forward(self, input, target):
116
+ out_feat = self._make_features(target, clone=True)
117
+ in_feat = self._make_features(input)
118
+ self.feat_losses = [self.base_loss(input, target)]
119
+ self.feat_losses += [
120
+ self.base_loss(f_in, f_out) * w
121
+ for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
122
+ ]
123
+
124
+ styles = [self._get_style_vals(i) for i in out_feat]
125
+
126
+ if styles[0][0] is not None:
127
+ self.feat_losses += [
128
+ self._single_wass_loss(f_pred, f_targ) * w
129
+ for f_pred, f_targ, w in zip(in_feat, styles, self.wass_wgts)
130
+ ]
131
+
132
+ self.metrics = dict(zip(self.metric_names, self.feat_losses))
133
+ return sum(self.feat_losses)
134
+
135
+ def __del__(self):
136
+ self.hooks.remove()
src/deoldify/save.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.basic_train import Learner, LearnerCallback
2
+ from fastai.vision.gan import GANLearner
3
+
4
+
5
+ class GANSaveCallback(LearnerCallback):
6
+ """A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."""
7
+
8
+ def __init__(
9
+ self,
10
+ learn: GANLearner,
11
+ learn_gen: Learner,
12
+ filename: str,
13
+ save_iters: int = 1000,
14
+ ):
15
+ super().__init__(learn)
16
+ self.learn_gen = learn_gen
17
+ self.filename = filename
18
+ self.save_iters = save_iters
19
+
20
+ def on_batch_end(self, iteration: int, epoch: int, **kwargs) -> None:
21
+ if iteration == 0:
22
+ return
23
+
24
+ if iteration % self.save_iters == 0:
25
+ self._save_gen_learner(iteration=iteration, epoch=epoch)
26
+
27
+ def _save_gen_learner(self, iteration: int, epoch: int):
28
+ filename = '{}_{}_{}'.format(self.filename, epoch, iteration)
29
+ self.learn_gen.save(filename)
src/deoldify/unet.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.layers import *
2
+ from .layers import *
3
+ from fastai.torch_core import *
4
+ from fastai.callbacks.hooks import *
5
+ from fastai.vision import *
6
+
7
+
8
+ # The code below is meant to be merged into fastaiv1 ideally
9
+
10
+ __all__ = ['DynamicUnetDeep', 'DynamicUnetWide']
11
+
12
+
13
+ def _get_sfs_idxs(sizes: Sizes) -> List[int]:
14
+ "Get the indexes of the layers where the size of the activation changes."
15
+ feature_szs = [size[-1] for size in sizes]
16
+ sfs_idxs = list(
17
+ np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0]
18
+ )
19
+ if feature_szs[0] != feature_szs[1]:
20
+ sfs_idxs = [0] + sfs_idxs
21
+ return sfs_idxs
22
+
23
+
24
+ class CustomPixelShuffle_ICNR(nn.Module):
25
+ "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
26
+
27
+ def __init__(
28
+ self,
29
+ ni: int,
30
+ nf: int = None,
31
+ scale: int = 2,
32
+ blur: bool = False,
33
+ leaky: float = None,
34
+ **kwargs
35
+ ):
36
+ super().__init__()
37
+ nf = ifnone(nf, ni)
38
+ self.conv = custom_conv_layer(
39
+ ni, nf * (scale ** 2), ks=1, use_activ=False, **kwargs
40
+ )
41
+ icnr(self.conv[0].weight)
42
+ self.shuf = nn.PixelShuffle(scale)
43
+ # Blurring over (h*w) kernel
44
+ # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
45
+ # - https://arxiv.org/abs/1806.02658
46
+ self.pad = nn.ReplicationPad2d((1, 0, 1, 0))
47
+ self.blur = nn.AvgPool2d(2, stride=1)
48
+ self.relu = relu(True, leaky=leaky)
49
+
50
+ def forward(self, x):
51
+ x = self.shuf(self.relu(self.conv(x)))
52
+ return self.blur(self.pad(x)) if self.blur else x
53
+
54
+
55
+ class UnetBlockDeep(nn.Module):
56
+ "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
57
+
58
+ def __init__(
59
+ self,
60
+ up_in_c: int,
61
+ x_in_c: int,
62
+ hook: Hook,
63
+ final_div: bool = True,
64
+ blur: bool = False,
65
+ leaky: float = None,
66
+ self_attention: bool = False,
67
+ nf_factor: float = 1.0,
68
+ **kwargs
69
+ ):
70
+ super().__init__()
71
+ self.hook = hook
72
+ self.shuf = CustomPixelShuffle_ICNR(
73
+ up_in_c, up_in_c // 2, blur=blur, leaky=leaky, **kwargs
74
+ )
75
+ self.bn = batchnorm_2d(x_in_c)
76
+ ni = up_in_c // 2 + x_in_c
77
+ nf = int((ni if final_div else ni // 2) * nf_factor)
78
+ self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs)
79
+ self.conv2 = custom_conv_layer(
80
+ nf, nf, leaky=leaky, self_attention=self_attention, **kwargs
81
+ )
82
+ self.relu = relu(leaky=leaky)
83
+
84
+ def forward(self, up_in: Tensor) -> Tensor:
85
+ s = self.hook.stored
86
+ up_out = self.shuf(up_in)
87
+ ssh = s.shape[-2:]
88
+ if ssh != up_out.shape[-2:]:
89
+ up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
90
+ cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
91
+ return self.conv2(self.conv1(cat_x))
92
+
93
+
94
+ class DynamicUnetDeep(SequentialEx):
95
+ "Create a U-Net from a given architecture."
96
+
97
+ def __init__(
98
+ self,
99
+ encoder: nn.Module,
100
+ n_classes: int,
101
+ blur: bool = False,
102
+ blur_final=True,
103
+ self_attention: bool = False,
104
+ y_range: Optional[Tuple[float, float]] = None,
105
+ last_cross: bool = True,
106
+ bottle: bool = False,
107
+ norm_type: Optional[NormType] = NormType.Batch,
108
+ nf_factor: float = 1.0,
109
+ **kwargs
110
+ ):
111
+ extra_bn = norm_type == NormType.Spectral
112
+ imsize = (256, 256)
113
+ sfs_szs = model_sizes(encoder, size=imsize)
114
+ sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
115
+ self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
116
+ x = dummy_eval(encoder, imsize).detach()
117
+
118
+ ni = sfs_szs[-1][1]
119
+ middle_conv = nn.Sequential(
120
+ custom_conv_layer(
121
+ ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
122
+ ),
123
+ custom_conv_layer(
124
+ ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
125
+ ),
126
+ ).eval()
127
+ x = middle_conv(x)
128
+ layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
129
+
130
+ for i, idx in enumerate(sfs_idxs):
131
+ not_final = i != len(sfs_idxs) - 1
132
+ up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
133
+ do_blur = blur and (not_final or blur_final)
134
+ sa = self_attention and (i == len(sfs_idxs) - 3)
135
+ unet_block = UnetBlockDeep(
136
+ up_in_c,
137
+ x_in_c,
138
+ self.sfs[i],
139
+ final_div=not_final,
140
+ blur=blur,
141
+ self_attention=sa,
142
+ norm_type=norm_type,
143
+ extra_bn=extra_bn,
144
+ nf_factor=nf_factor,
145
+ **kwargs
146
+ ).eval()
147
+ layers.append(unet_block)
148
+ x = unet_block(x)
149
+
150
+ ni = x.shape[1]
151
+ if imsize != sfs_szs[0][-2:]:
152
+ layers.append(PixelShuffle_ICNR(ni, **kwargs))
153
+ if last_cross:
154
+ layers.append(MergeLayer(dense=True))
155
+ ni += in_channels(encoder)
156
+ layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
157
+ layers += [
158
+ custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
159
+ ]
160
+ if y_range is not None:
161
+ layers.append(SigmoidRange(*y_range))
162
+ super().__init__(*layers)
163
+
164
+ def __del__(self):
165
+ if hasattr(self, "sfs"):
166
+ self.sfs.remove()
167
+
168
+
169
+ # ------------------------------------------------------
170
+ class UnetBlockWide(nn.Module):
171
+ "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
172
+
173
+ def __init__(
174
+ self,
175
+ up_in_c: int,
176
+ x_in_c: int,
177
+ n_out: int,
178
+ hook: Hook,
179
+ final_div: bool = True,
180
+ blur: bool = False,
181
+ leaky: float = None,
182
+ self_attention: bool = False,
183
+ **kwargs
184
+ ):
185
+ super().__init__()
186
+ self.hook = hook
187
+ up_out = x_out = n_out // 2
188
+ self.shuf = CustomPixelShuffle_ICNR(
189
+ up_in_c, up_out, blur=blur, leaky=leaky, **kwargs
190
+ )
191
+ self.bn = batchnorm_2d(x_in_c)
192
+ ni = up_out + x_in_c
193
+ self.conv = custom_conv_layer(
194
+ ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs
195
+ )
196
+ self.relu = relu(leaky=leaky)
197
+
198
+ def forward(self, up_in: Tensor) -> Tensor:
199
+ s = self.hook.stored
200
+ up_out = self.shuf(up_in)
201
+ ssh = s.shape[-2:]
202
+ if ssh != up_out.shape[-2:]:
203
+ up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
204
+ cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
205
+ return self.conv(cat_x)
206
+
207
+
208
+ class DynamicUnetWide(SequentialEx):
209
+ "Create a U-Net from a given architecture."
210
+
211
+ def __init__(
212
+ self,
213
+ encoder: nn.Module,
214
+ n_classes: int,
215
+ blur: bool = False,
216
+ blur_final=True,
217
+ self_attention: bool = False,
218
+ y_range: Optional[Tuple[float, float]] = None,
219
+ last_cross: bool = True,
220
+ bottle: bool = False,
221
+ norm_type: Optional[NormType] = NormType.Batch,
222
+ nf_factor: int = 1,
223
+ **kwargs
224
+ ):
225
+
226
+ nf = 512 * nf_factor
227
+ extra_bn = norm_type == NormType.Spectral
228
+ imsize = (256, 256)
229
+ sfs_szs = model_sizes(encoder, size=imsize)
230
+ sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
231
+ self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
232
+ x = dummy_eval(encoder, imsize).detach()
233
+
234
+ ni = sfs_szs[-1][1]
235
+ middle_conv = nn.Sequential(
236
+ custom_conv_layer(
237
+ ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
238
+ ),
239
+ custom_conv_layer(
240
+ ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
241
+ ),
242
+ ).eval()
243
+ x = middle_conv(x)
244
+ layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
245
+
246
+ for i, idx in enumerate(sfs_idxs):
247
+ not_final = i != len(sfs_idxs) - 1
248
+ up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
249
+ do_blur = blur and (not_final or blur_final)
250
+ sa = self_attention and (i == len(sfs_idxs) - 3)
251
+
252
+ n_out = nf if not_final else nf // 2
253
+
254
+ unet_block = UnetBlockWide(
255
+ up_in_c,
256
+ x_in_c,
257
+ n_out,
258
+ self.sfs[i],
259
+ final_div=not_final,
260
+ blur=blur,
261
+ self_attention=sa,
262
+ norm_type=norm_type,
263
+ extra_bn=extra_bn,
264
+ **kwargs
265
+ ).eval()
266
+ layers.append(unet_block)
267
+ x = unet_block(x)
268
+
269
+ ni = x.shape[1]
270
+ if imsize != sfs_szs[0][-2:]:
271
+ layers.append(PixelShuffle_ICNR(ni, **kwargs))
272
+ if last_cross:
273
+ layers.append(MergeLayer(dense=True))
274
+ ni += in_channels(encoder)
275
+ layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
276
+ layers += [
277
+ custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
278
+ ]
279
+ if y_range is not None:
280
+ layers.append(SigmoidRange(*y_range))
281
+ super().__init__(*layers)
282
+
283
+ def __del__(self):
284
+ if hasattr(self, "sfs"):
285
+ self.sfs.remove()
src/deoldify/visualize.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import gc
3
+ import requests
4
+ from io import BytesIO
5
+ import base64
6
+ from scipy import misc
7
+ from PIL import Image
8
+ from matplotlib.axes import Axes
9
+ from matplotlib.figure import Figure
10
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
11
+ from typing import Tuple
12
+
13
+ import torch
14
+ from fastai.core import *
15
+ from fastai.vision import *
16
+
17
+ from .filters import IFilter, MasterFilter, ColorizerFilter
18
+ from .generators import gen_inference_deep, gen_inference_wide
19
+
20
+
21
+
22
+ # class LoadedModel
23
+ class ModelImageVisualizer:
24
+ def __init__(self, filter: IFilter, results_dir: str = None):
25
+ self.filter = filter
26
+ self.results_dir = None if results_dir is None else Path(results_dir)
27
+ self.results_dir.mkdir(parents=True, exist_ok=True)
28
+
29
+ def _clean_mem(self):
30
+ torch.cuda.empty_cache()
31
+ # gc.collect()
32
+
33
+ def _open_pil_image(self, path: Path) -> Image:
34
+ return Image.open(path).convert('RGB')
35
+
36
+ def _get_image_from_url(self, url: str) -> Image:
37
+ response = requests.get(url, timeout=30, headers={'Accept': '*/*;q=0.8'})
38
+ img = Image.open(BytesIO(response.content)).convert('RGB')
39
+ return img
40
+
41
+ def plot_transformed_image_from_url(
42
+ self,
43
+ url: str,
44
+ path: str = 'test_images/image.png',
45
+ results_dir:Path = None,
46
+ figsize: Tuple[int, int] = (20, 20),
47
+ render_factor: int = None,
48
+
49
+ display_render_factor: bool = False,
50
+ compare: bool = False,
51
+ post_process: bool = True,
52
+ watermarked: bool = True,
53
+ ) -> Path:
54
+ img = self._get_image_from_url(url)
55
+ img.save(path)
56
+ return self.plot_transformed_image(
57
+ path=path,
58
+ results_dir=results_dir,
59
+ figsize=figsize,
60
+ render_factor=render_factor,
61
+ display_render_factor=display_render_factor,
62
+ compare=compare,
63
+ post_process = post_process,
64
+ watermarked=watermarked,
65
+ )
66
+
67
+ def plot_transformed_image(
68
+ self,
69
+ path: str,
70
+ results_dir:Path = None,
71
+ figsize: Tuple[int, int] = (20, 20),
72
+ render_factor: int = None,
73
+ display_render_factor: bool = False,
74
+ compare: bool = False,
75
+ post_process: bool = True,
76
+ watermarked: bool = True,
77
+ ) -> Path:
78
+ path = Path(path)
79
+ if results_dir is None:
80
+ results_dir = Path(self.results_dir)
81
+ result = self.get_transformed_image(
82
+ path, render_factor, post_process=post_process,watermarked=watermarked
83
+ )
84
+ orig = self._open_pil_image(path)
85
+ if compare:
86
+ self._plot_comparison(
87
+ figsize, render_factor, display_render_factor, orig, result
88
+ )
89
+ else:
90
+ self._plot_solo(figsize, render_factor, display_render_factor, result)
91
+
92
+ orig.close()
93
+ result_path = self._save_result_image(path, result, results_dir=results_dir)
94
+ result.close()
95
+ return result_path
96
+
97
+ def plot_transformed_pil_image(
98
+ self,
99
+ input_image: Image,
100
+ figsize: Tuple[int, int] = (20, 20),
101
+ render_factor: int = None,
102
+ display_render_factor: bool = False,
103
+ compare: bool = False,
104
+ post_process: bool = True,
105
+ ) -> Image:
106
+
107
+ result = self.get_transformed_pil_image(
108
+ input_image, render_factor, post_process=post_process
109
+ )
110
+
111
+ if compare:
112
+ self._plot_comparison(
113
+ figsize, render_factor, display_render_factor, input_image, result
114
+ )
115
+ else:
116
+ self._plot_solo(figsize, render_factor, display_render_factor, result)
117
+
118
+ return result
119
+
120
+ def _plot_comparison(
121
+ self,
122
+ figsize: Tuple[int, int],
123
+ render_factor: int,
124
+ display_render_factor: bool,
125
+ orig: Image,
126
+ result: Image,
127
+ ):
128
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
129
+ self._plot_image(
130
+ orig,
131
+ axes=axes[0],
132
+ figsize=figsize,
133
+ render_factor=render_factor,
134
+ display_render_factor=False,
135
+ )
136
+ self._plot_image(
137
+ result,
138
+ axes=axes[1],
139
+ figsize=figsize,
140
+ render_factor=render_factor,
141
+ display_render_factor=display_render_factor,
142
+ )
143
+
144
+ def _plot_solo(
145
+ self,
146
+ figsize: Tuple[int, int],
147
+ render_factor: int,
148
+ display_render_factor: bool,
149
+ result: Image,
150
+ ):
151
+ fig, axes = plt.subplots(1, 1, figsize=figsize)
152
+ self._plot_image(
153
+ result,
154
+ axes=axes,
155
+ figsize=figsize,
156
+ render_factor=render_factor,
157
+ display_render_factor=display_render_factor,
158
+ )
159
+
160
+ def _save_result_image(self, source_path: Path, image: Image, results_dir = None) -> Path:
161
+ if results_dir is None:
162
+ results_dir = Path(self.results_dir)
163
+ result_path = results_dir / source_path.name
164
+ image.save(result_path)
165
+ return result_path
166
+
167
+ def get_transformed_image(
168
+ self, path: Path, render_factor: int = None, post_process: bool = True,
169
+ watermarked: bool = True,
170
+ ) -> Image:
171
+ self._clean_mem()
172
+ orig_image = self._open_pil_image(path)
173
+ filtered_image = self.filter.filter(
174
+ orig_image, orig_image, render_factor=render_factor,post_process=post_process
175
+ )
176
+
177
+ return filtered_image
178
+
179
+ def get_transformed_pil_image(
180
+ self, input_image: Image, render_factor: int = None, post_process: bool = True,
181
+ ) -> Image:
182
+ self._clean_mem()
183
+ filtered_image = self.filter.filter(
184
+ input_image, input_image, render_factor=render_factor,post_process=post_process
185
+ )
186
+
187
+ return filtered_image
188
+
189
+ def _plot_image(
190
+ self,
191
+ image: Image,
192
+ render_factor: int,
193
+ axes: Axes = None,
194
+ figsize=(20, 20),
195
+ display_render_factor = False,
196
+ ):
197
+ if axes is None:
198
+ _, axes = plt.subplots(figsize=figsize)
199
+ axes.imshow(np.asarray(image) / 255)
200
+ axes.axis('off')
201
+ if render_factor is not None and display_render_factor:
202
+ plt.text(
203
+ 10,
204
+ 10,
205
+ 'render_factor: ' + str(render_factor),
206
+ color='white',
207
+ backgroundcolor='black',
208
+ )
209
+
210
+ def _get_num_rows_columns(self, num_images: int, max_columns: int) -> Tuple[int, int]:
211
+ columns = min(num_images, max_columns)
212
+ rows = num_images // columns
213
+ rows = rows if rows * columns == num_images else rows + 1
214
+ return rows, columns
215
+
216
+
217
+ def get_image_colorizer(
218
+ root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True
219
+ ) -> ModelImageVisualizer:
220
+ if artistic:
221
+ return get_artistic_image_colorizer(root_folder=root_folder, render_factor=render_factor)
222
+ else:
223
+ return get_stable_image_colorizer(root_folder=root_folder, render_factor=render_factor)
224
+
225
+
226
+ def get_stable_image_colorizer(
227
+ root_folder: Path = Path('./'),
228
+ weights_name: str = 'ColorizeStable_gen',
229
+ results_dir='output',
230
+ render_factor: int = 35
231
+ ) -> ModelImageVisualizer:
232
+ learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
233
+ filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
234
+ vis = ModelImageVisualizer(filtr, results_dir=results_dir)
235
+ return vis
236
+
237
+
238
+ def get_artistic_image_colorizer(
239
+ root_folder: Path = Path('./'),
240
+ weights_name: str = 'ColorizeArtistic_gen',
241
+ results_dir='output',
242
+ render_factor: int = 35
243
+ ) -> ModelImageVisualizer:
244
+ learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
245
+ filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
246
+ vis = ModelImageVisualizer(filtr, results_dir=results_dir)
247
+ return vis
src/st_style.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ button_style = """
2
+ <style>
3
+ div.stButton > button:first-child {
4
+ background-color: rgb(255, 75, 75);
5
+ color: rgb(255, 255, 255);
6
+ }
7
+ div.stButton > button:hover {
8
+ background-color: rgb(255, 75, 75);
9
+ color: rgb(255, 255, 255);
10
+ }
11
+ div.stButton > button:active {
12
+ background-color: rgb(255, 75, 75);
13
+ color: rgb(255, 255, 255);
14
+ }
15
+ div.stButton > button:focus {
16
+ background-color: rgb(255, 75, 75);
17
+ color: rgb(255, 255, 255);
18
+ }
19
+ .css-1cpxqw2:focus:not(:active) {
20
+ background-color: rgb(255, 75, 75);
21
+ border-color: rgb(255, 75, 75);
22
+ color: rgb(255, 255, 255);
23
+ }
24
+ """
25
+
26
+ style = """
27
+ <style>
28
+ #MainMenu {
29
+ visibility: hidden;
30
+ }
31
+ footer {
32
+ visibility: hidden;
33
+ }
34
+ header {
35
+ visibility: hidden;
36
+ }
37
+ </style>
38
+ """
39
+
40
+
41
+ def apply_prod_style(st):
42
+ return st.markdown(style, unsafe_allow_html=True)