Pierre Chapuis commited on
Commit
ae1f8f9
1 Parent(s): f3f480b

clean up ESRGAN code

Browse files
Files changed (2) hide show
  1. src/enhancer.py +0 -1
  2. src/esrgan_model.py +118 -881
src/enhancer.py CHANGED
@@ -26,7 +26,6 @@ class ESRGANUpscaler(MultiUpscaler):
26
  ) -> None:
27
  super().__init__(checkpoints=checkpoints, device=device, dtype=dtype)
28
  self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype)
29
- self.esrgan.to(device=device, dtype=dtype)
30
 
31
  def to(self, device: torch.device, dtype: torch.dtype):
32
  self.esrgan.to(device=device, dtype=dtype)
 
26
  ) -> None:
27
  super().__init__(checkpoints=checkpoints, device=device, dtype=dtype)
28
  self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype)
 
29
 
30
  def to(self, device: torch.device, dtype: torch.dtype):
31
  self.esrgan.to(device=device, dtype=dtype)
src/esrgan_model.py CHANGED
@@ -1,4 +1,3 @@
1
- # type: ignore
2
  """
3
  Modified from https://github.com/philz1337x/clarity-upscaler
4
  which is a copy of https://github.com/AUTOMATIC1111/stable-diffusion-webui
@@ -7,215 +6,21 @@ which is a copy of https://github.com/xinntao/ESRGAN
7
  """
8
 
9
  import math
10
- import os
11
- from collections import OrderedDict, namedtuple
12
  from pathlib import Path
 
13
 
14
  import numpy as np
 
15
  import torch
16
  import torch.nn as nn
17
- import torch.nn.functional as F
18
  from PIL import Image
19
 
20
- ####################
21
- # RRDBNet Generator
22
- ####################
23
 
24
-
25
- class RRDBNet(nn.Module):
26
- def __init__(
27
- self,
28
- in_nc,
29
- out_nc,
30
- nf,
31
- nb,
32
- nr=3,
33
- gc=32,
34
- upscale=4,
35
- norm_type=None,
36
- act_type="leakyrelu",
37
- mode="CNA",
38
- upsample_mode="upconv",
39
- convtype="Conv2D",
40
- finalact=None,
41
- gaussian_noise=False,
42
- plus=False,
43
- ):
44
- super(RRDBNet, self).__init__()
45
- n_upscale = int(math.log(upscale, 2))
46
- if upscale == 3:
47
- n_upscale = 1
48
-
49
- self.resrgan_scale = 0
50
- if in_nc % 16 == 0:
51
- self.resrgan_scale = 1
52
- elif in_nc != 4 and in_nc % 4 == 0:
53
- self.resrgan_scale = 2
54
-
55
- fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
56
- rb_blocks = [
57
- RRDB(
58
- nf,
59
- nr,
60
- kernel_size=3,
61
- gc=32,
62
- stride=1,
63
- bias=1,
64
- pad_type="zero",
65
- norm_type=norm_type,
66
- act_type=act_type,
67
- mode="CNA",
68
- convtype=convtype,
69
- gaussian_noise=gaussian_noise,
70
- plus=plus,
71
- )
72
- for _ in range(nb)
73
- ]
74
- LR_conv = conv_block(
75
- nf,
76
- nf,
77
- kernel_size=3,
78
- norm_type=norm_type,
79
- act_type=None,
80
- mode=mode,
81
- convtype=convtype,
82
- )
83
-
84
- if upsample_mode == "upconv":
85
- upsample_block = upconv_block
86
- elif upsample_mode == "pixelshuffle":
87
- upsample_block = pixelshuffle_block
88
- else:
89
- raise NotImplementedError(f"upsample mode [{upsample_mode}] is not found")
90
- if upscale == 3:
91
- upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
92
- else:
93
- upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
94
- HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
95
- HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
96
-
97
- outact = act(finalact) if finalact else None
98
-
99
- self.model = sequential(
100
- fea_conv,
101
- ShortcutBlock(sequential(*rb_blocks, LR_conv)),
102
- *upsampler,
103
- HR_conv0,
104
- HR_conv1,
105
- outact,
106
- )
107
-
108
- def forward(self, x, outm=None):
109
- if self.resrgan_scale == 1:
110
- feat = pixel_unshuffle(x, scale=4)
111
- elif self.resrgan_scale == 2:
112
- feat = pixel_unshuffle(x, scale=2)
113
- else:
114
- feat = x
115
-
116
- return self.model(feat)
117
-
118
-
119
- class RRDB(nn.Module):
120
- """
121
- Residual in Residual Dense Block
122
- (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
123
- """
124
-
125
- def __init__(
126
- self,
127
- nf,
128
- nr=3,
129
- kernel_size=3,
130
- gc=32,
131
- stride=1,
132
- bias=1,
133
- pad_type="zero",
134
- norm_type=None,
135
- act_type="leakyrelu",
136
- mode="CNA",
137
- convtype="Conv2D",
138
- spectral_norm=False,
139
- gaussian_noise=False,
140
- plus=False,
141
- ):
142
- super(RRDB, self).__init__()
143
- # This is for backwards compatibility with existing models
144
- if nr == 3:
145
- self.RDB1 = ResidualDenseBlock_5C(
146
- nf,
147
- kernel_size,
148
- gc,
149
- stride,
150
- bias,
151
- pad_type,
152
- norm_type,
153
- act_type,
154
- mode,
155
- convtype,
156
- spectral_norm=spectral_norm,
157
- gaussian_noise=gaussian_noise,
158
- plus=plus,
159
- )
160
- self.RDB2 = ResidualDenseBlock_5C(
161
- nf,
162
- kernel_size,
163
- gc,
164
- stride,
165
- bias,
166
- pad_type,
167
- norm_type,
168
- act_type,
169
- mode,
170
- convtype,
171
- spectral_norm=spectral_norm,
172
- gaussian_noise=gaussian_noise,
173
- plus=plus,
174
- )
175
- self.RDB3 = ResidualDenseBlock_5C(
176
- nf,
177
- kernel_size,
178
- gc,
179
- stride,
180
- bias,
181
- pad_type,
182
- norm_type,
183
- act_type,
184
- mode,
185
- convtype,
186
- spectral_norm=spectral_norm,
187
- gaussian_noise=gaussian_noise,
188
- plus=plus,
189
- )
190
- else:
191
- RDB_list = [
192
- ResidualDenseBlock_5C(
193
- nf,
194
- kernel_size,
195
- gc,
196
- stride,
197
- bias,
198
- pad_type,
199
- norm_type,
200
- act_type,
201
- mode,
202
- convtype,
203
- spectral_norm=spectral_norm,
204
- gaussian_noise=gaussian_noise,
205
- plus=plus,
206
- )
207
- for _ in range(nr)
208
- ]
209
- self.RDBs = nn.Sequential(*RDB_list)
210
-
211
- def forward(self, x):
212
- if hasattr(self, "RDB1"):
213
- out = self.RDB1(x)
214
- out = self.RDB2(out)
215
- out = self.RDB3(out)
216
- else:
217
- out = self.RDBs(x)
218
- return out * 0.2 + x
219
 
220
 
221
  class ResidualDenseBlock_5C(nn.Module):
@@ -229,642 +34,100 @@ class ResidualDenseBlock_5C(nn.Module):
229
  {Rakotonirina} and A. {Rasoanaivo}
230
  """
231
 
232
- def __init__(
233
- self,
234
- nf=64,
235
- kernel_size=3,
236
- gc=32,
237
- stride=1,
238
- bias=1,
239
- pad_type="zero",
240
- norm_type=None,
241
- act_type="leakyrelu",
242
- mode="CNA",
243
- convtype="Conv2D",
244
- spectral_norm=False,
245
- gaussian_noise=False,
246
- plus=False,
247
- ):
248
- super(ResidualDenseBlock_5C, self).__init__()
249
 
250
- self.noise = GaussianNoise() if gaussian_noise else None
251
- self.conv1x1 = conv1x1(nf, gc) if plus else None
 
 
 
 
252
 
253
- self.conv1 = conv_block(
254
- nf,
255
- gc,
256
- kernel_size,
257
- stride,
258
- bias=bias,
259
- pad_type=pad_type,
260
- norm_type=norm_type,
261
- act_type=act_type,
262
- mode=mode,
263
- convtype=convtype,
264
- spectral_norm=spectral_norm,
265
- )
266
- self.conv2 = conv_block(
267
- nf + gc,
268
- gc,
269
- kernel_size,
270
- stride,
271
- bias=bias,
272
- pad_type=pad_type,
273
- norm_type=norm_type,
274
- act_type=act_type,
275
- mode=mode,
276
- convtype=convtype,
277
- spectral_norm=spectral_norm,
278
- )
279
- self.conv3 = conv_block(
280
- nf + 2 * gc,
281
- gc,
282
- kernel_size,
283
- stride,
284
- bias=bias,
285
- pad_type=pad_type,
286
- norm_type=norm_type,
287
- act_type=act_type,
288
- mode=mode,
289
- convtype=convtype,
290
- spectral_norm=spectral_norm,
291
- )
292
- self.conv4 = conv_block(
293
- nf + 3 * gc,
294
- gc,
295
- kernel_size,
296
- stride,
297
- bias=bias,
298
- pad_type=pad_type,
299
- norm_type=norm_type,
300
- act_type=act_type,
301
- mode=mode,
302
- convtype=convtype,
303
- spectral_norm=spectral_norm,
304
- )
305
- if mode == "CNA":
306
- last_act = None
307
- else:
308
- last_act = act_type
309
- self.conv5 = conv_block(
310
- nf + 4 * gc,
311
- nf,
312
- 3,
313
- stride,
314
- bias=bias,
315
- pad_type=pad_type,
316
- norm_type=norm_type,
317
- act_type=last_act,
318
- mode=mode,
319
- convtype=convtype,
320
- spectral_norm=spectral_norm,
321
- )
322
-
323
- def forward(self, x):
324
  x1 = self.conv1(x)
325
  x2 = self.conv2(torch.cat((x, x1), 1))
326
- if self.conv1x1:
327
- x2 = x2 + self.conv1x1(x)
328
  x3 = self.conv3(torch.cat((x, x1, x2), 1))
329
  x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
330
- if self.conv1x1:
331
- x4 = x4 + x2
332
  x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
333
- if self.noise:
334
- return self.noise(x5.mul(0.2) + x)
335
- else:
336
- return x5 * 0.2 + x
337
-
338
-
339
- ####################
340
- # ESRGANplus
341
- ####################
342
-
343
-
344
- class GaussianNoise(nn.Module):
345
- def __init__(self, sigma=0.1, is_relative_detach=False):
346
- super().__init__()
347
- self.sigma = sigma
348
- self.is_relative_detach = is_relative_detach
349
- self.noise = torch.tensor(0, dtype=torch.float)
350
-
351
- def forward(self, x):
352
- if self.training and self.sigma != 0:
353
- self.noise = self.noise.to(device=x.device, dtype=x.device)
354
- scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
355
- sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
356
- x = x + sampled_noise
357
- return x
358
 
359
 
360
- def conv1x1(in_planes, out_planes, stride=1):
361
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
362
-
363
-
364
- ####################
365
- # SRVGGNetCompact
366
- ####################
367
-
368
-
369
- class SRVGGNetCompact(nn.Module):
370
- """A compact VGG-style network structure for super-resolution.
371
- This class is copied from https://github.com/xinntao/Real-ESRGAN
372
- """
373
-
374
- def __init__(
375
- self,
376
- num_in_ch=3,
377
- num_out_ch=3,
378
- num_feat=64,
379
- num_conv=16,
380
- upscale=4,
381
- act_type="prelu",
382
- ):
383
- super(SRVGGNetCompact, self).__init__()
384
- self.num_in_ch = num_in_ch
385
- self.num_out_ch = num_out_ch
386
- self.num_feat = num_feat
387
- self.num_conv = num_conv
388
- self.upscale = upscale
389
- self.act_type = act_type
390
-
391
- self.body = nn.ModuleList()
392
- # the first conv
393
- self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
394
- # the first activation
395
- if act_type == "relu":
396
- activation = nn.ReLU(inplace=True)
397
- elif act_type == "prelu":
398
- activation = nn.PReLU(num_parameters=num_feat)
399
- elif act_type == "leakyrelu":
400
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
401
- self.body.append(activation)
402
-
403
- # the body structure
404
- for _ in range(num_conv):
405
- self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
406
- # activation
407
- if act_type == "relu":
408
- activation = nn.ReLU(inplace=True)
409
- elif act_type == "prelu":
410
- activation = nn.PReLU(num_parameters=num_feat)
411
- elif act_type == "leakyrelu":
412
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
413
- self.body.append(activation)
414
-
415
- # the last conv
416
- self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
417
- # upsample
418
- self.upsampler = nn.PixelShuffle(upscale)
419
-
420
- def forward(self, x):
421
- out = x
422
- for i in range(0, len(self.body)):
423
- out = self.body[i](out)
424
-
425
- out = self.upsampler(out)
426
- # add the nearest upsampled image, so that the network learns the residual
427
- base = F.interpolate(x, scale_factor=self.upscale, mode="nearest")
428
- out += base
429
- return out
430
-
431
-
432
- ####################
433
- # Upsampler
434
- ####################
435
-
436
-
437
- class Upsample(nn.Module):
438
- r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
439
- The input data is assumed to be of the form
440
- `minibatch x channels x [optional depth] x [optional height] x width`.
441
- """
442
-
443
- def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
444
- super(Upsample, self).__init__()
445
- if isinstance(scale_factor, tuple):
446
- self.scale_factor = tuple(float(factor) for factor in scale_factor)
447
- else:
448
- self.scale_factor = float(scale_factor) if scale_factor else None
449
- self.mode = mode
450
- self.size = size
451
- self.align_corners = align_corners
452
-
453
- def forward(self, x):
454
- return nn.functional.interpolate(
455
- x,
456
- size=self.size,
457
- scale_factor=self.scale_factor,
458
- mode=self.mode,
459
- align_corners=self.align_corners,
460
- )
461
-
462
- def extra_repr(self):
463
- if self.scale_factor is not None:
464
- info = f"scale_factor={self.scale_factor}"
465
- else:
466
- info = f"size={self.size}"
467
- info += f", mode={self.mode}"
468
- return info
469
-
470
-
471
- def pixel_unshuffle(x, scale):
472
- """Pixel unshuffle.
473
- Args:
474
- x (Tensor): Input feature with shape (b, c, hh, hw).
475
- scale (int): Downsample ratio.
476
- Returns:
477
- Tensor: the pixel unshuffled feature.
478
- """
479
- b, c, hh, hw = x.size()
480
- out_channel = c * (scale**2)
481
- assert hh % scale == 0 and hw % scale == 0
482
- h = hh // scale
483
- w = hw // scale
484
- x_view = x.view(b, c, h, scale, w, scale)
485
- return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
486
-
487
-
488
- def pixelshuffle_block(
489
- in_nc,
490
- out_nc,
491
- upscale_factor=2,
492
- kernel_size=3,
493
- stride=1,
494
- bias=True,
495
- pad_type="zero",
496
- norm_type=None,
497
- act_type="relu",
498
- convtype="Conv2D",
499
- ):
500
- """
501
- Pixel shuffle layer
502
- (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
503
- Neural Network, CVPR17)
504
  """
505
- conv = conv_block(
506
- in_nc,
507
- out_nc * (upscale_factor**2),
508
- kernel_size,
509
- stride,
510
- bias=bias,
511
- pad_type=pad_type,
512
- norm_type=None,
513
- act_type=None,
514
- convtype=convtype,
515
- )
516
- pixel_shuffle = nn.PixelShuffle(upscale_factor)
517
-
518
- n = norm(norm_type, out_nc) if norm_type else None
519
- a = act(act_type) if act_type else None
520
- return sequential(conv, pixel_shuffle, n, a)
521
-
522
-
523
- def upconv_block(
524
- in_nc,
525
- out_nc,
526
- upscale_factor=2,
527
- kernel_size=3,
528
- stride=1,
529
- bias=True,
530
- pad_type="zero",
531
- norm_type=None,
532
- act_type="relu",
533
- mode="nearest",
534
- convtype="Conv2D",
535
- ):
536
- """Upconv layer"""
537
- upscale_factor = (1, upscale_factor, upscale_factor) if convtype == "Conv3D" else upscale_factor
538
- upsample = Upsample(scale_factor=upscale_factor, mode=mode)
539
- conv = conv_block(
540
- in_nc,
541
- out_nc,
542
- kernel_size,
543
- stride,
544
- bias=bias,
545
- pad_type=pad_type,
546
- norm_type=norm_type,
547
- act_type=act_type,
548
- convtype=convtype,
549
- )
550
- return sequential(upsample, conv)
551
-
552
-
553
- ####################
554
- # Basic blocks
555
- ####################
556
-
557
-
558
- def make_layer(basic_block, num_basic_block, **kwarg):
559
- """Make layers by stacking the same blocks.
560
- Args:
561
- basic_block (nn.module): nn.module class for basic block. (block)
562
- num_basic_block (int): number of blocks. (n_layers)
563
- Returns:
564
- nn.Sequential: Stacked blocks in nn.Sequential.
565
  """
566
- layers = []
567
- for _ in range(num_basic_block):
568
- layers.append(basic_block(**kwarg))
569
- return nn.Sequential(*layers)
570
-
571
-
572
- def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
573
- """activation helper"""
574
- act_type = act_type.lower()
575
- if act_type == "relu":
576
- layer = nn.ReLU(inplace)
577
- elif act_type in ("leakyrelu", "lrelu"):
578
- layer = nn.LeakyReLU(neg_slope, inplace)
579
- elif act_type == "prelu":
580
- layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
581
- elif act_type == "tanh": # [-1, 1] range output
582
- layer = nn.Tanh()
583
- elif act_type == "sigmoid": # [0, 1] range output
584
- layer = nn.Sigmoid()
585
- else:
586
- raise NotImplementedError(f"activation layer [{act_type}] is not found")
587
- return layer
588
-
589
-
590
- class Identity(nn.Module):
591
- def __init__(self, *kwargs):
592
- super(Identity, self).__init__()
593
 
594
- def forward(self, x, *kwargs):
595
- return x
 
 
 
596
 
597
-
598
- def norm(norm_type, nc):
599
- """Return a normalization layer"""
600
- norm_type = norm_type.lower()
601
- if norm_type == "batch":
602
- layer = nn.BatchNorm2d(nc, affine=True)
603
- elif norm_type == "instance":
604
- layer = nn.InstanceNorm2d(nc, affine=False)
605
- elif norm_type == "none":
606
-
607
- def norm_layer(x):
608
- return Identity()
609
- else:
610
- raise NotImplementedError(f"normalization layer [{norm_type}] is not found")
611
- return layer
612
 
613
 
614
- def pad(pad_type, padding):
615
- """padding layer helper"""
616
- pad_type = pad_type.lower()
617
- if padding == 0:
618
- return None
619
- if pad_type == "reflect":
620
- layer = nn.ReflectionPad2d(padding)
621
- elif pad_type == "replicate":
622
- layer = nn.ReplicationPad2d(padding)
623
- elif pad_type == "zero":
624
- layer = nn.ZeroPad2d(padding)
625
- else:
626
- raise NotImplementedError(f"padding layer [{pad_type}] is not implemented")
627
- return layer
628
 
 
 
629
 
630
- def get_valid_padding(kernel_size, dilation):
631
- kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
632
- padding = (kernel_size - 1) // 2
633
- return padding
634
 
635
 
636
  class ShortcutBlock(nn.Module):
637
  """Elementwise sum the output of a submodule to its input"""
638
 
639
- def __init__(self, submodule):
640
- super(ShortcutBlock, self).__init__()
641
  self.sub = submodule
642
 
643
- def forward(self, x):
644
- output = x + self.sub(x)
645
- return output
646
-
647
- def __repr__(self):
648
- return "Identity + \n|" + self.sub.__repr__().replace("\n", "\n|")
649
-
650
-
651
- def sequential(*args):
652
- """Flatten Sequential. It unwraps nn.Sequential."""
653
- if len(args) == 1:
654
- if isinstance(args[0], OrderedDict):
655
- raise NotImplementedError("sequential does not support OrderedDict input.")
656
- return args[0] # No sequential is needed.
657
- modules = []
658
- for module in args:
659
- if isinstance(module, nn.Sequential):
660
- for submodule in module.children():
661
- modules.append(submodule)
662
- elif isinstance(module, nn.Module):
663
- modules.append(module)
664
- return nn.Sequential(*modules)
665
-
666
 
667
- def conv_block(
668
- in_nc,
669
- out_nc,
670
- kernel_size,
671
- stride=1,
672
- dilation=1,
673
- groups=1,
674
- bias=True,
675
- pad_type="zero",
676
- norm_type=None,
677
- act_type="relu",
678
- mode="CNA",
679
- convtype="Conv2D",
680
- spectral_norm=False,
681
- ):
682
- """Conv layer with padding, normalization, activation"""
683
- assert mode in ["CNA", "NAC", "CNAC"], f"Wrong conv mode [{mode}]"
684
- padding = get_valid_padding(kernel_size, dilation)
685
- p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
686
- padding = padding if pad_type == "zero" else 0
687
 
688
- if convtype == "PartialConv2D":
689
- # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
690
- from torchvision.ops import PartialConv2d
691
-
692
- c = PartialConv2d(
693
- in_nc,
694
- out_nc,
695
- kernel_size=kernel_size,
696
- stride=stride,
697
- padding=padding,
698
- dilation=dilation,
699
- bias=bias,
700
- groups=groups,
701
- )
702
- elif convtype == "DeformConv2D":
703
- from torchvision.ops import DeformConv2d # not tested
704
-
705
- c = DeformConv2d(
706
- in_nc,
707
- out_nc,
708
- kernel_size=kernel_size,
709
- stride=stride,
710
- padding=padding,
711
- dilation=dilation,
712
- bias=bias,
713
- groups=groups,
714
- )
715
- elif convtype == "Conv3D":
716
- c = nn.Conv3d(
717
- in_nc,
718
- out_nc,
719
- kernel_size=kernel_size,
720
- stride=stride,
721
- padding=padding,
722
- dilation=dilation,
723
- bias=bias,
724
- groups=groups,
725
- )
726
- else:
727
- c = nn.Conv2d(
728
- in_nc,
729
- out_nc,
730
- kernel_size=kernel_size,
731
- stride=stride,
732
- padding=padding,
733
- dilation=dilation,
734
- bias=bias,
735
- groups=groups,
736
  )
737
 
738
- if spectral_norm:
739
- c = nn.utils.spectral_norm(c)
740
-
741
- a = act(act_type) if act_type else None
742
- if "CNA" in mode:
743
- n = norm(norm_type, out_nc) if norm_type else None
744
- return sequential(p, c, n, a)
745
- elif mode == "NAC":
746
- if norm_type is None and act_type is not None:
747
- a = act(act_type, inplace=False)
748
- n = norm(norm_type, in_nc) if norm_type else None
749
- return sequential(n, a, p, c)
750
-
751
-
752
- def load_models(
753
- model_path: Path,
754
- command_path: str = None,
755
- ) -> list:
756
- """
757
- A one-and done loader to try finding the desired models in specified directories.
758
-
759
- @param download_name: Specify to download from model_url immediately.
760
- @param model_url: If no other models are found, this will be downloaded on upscale.
761
- @param model_path: The location to store/find models in.
762
- @param command_path: A command-line argument to search for models in first.
763
- @param ext_filter: An optional list of filename extensions to filter by
764
- @return: A list of paths containing the desired model(s)
765
- """
766
- output = []
767
-
768
- try:
769
- places = []
770
- if command_path is not None and command_path != model_path:
771
- pretrained_path = os.path.join(command_path, "experiments/pretrained_models")
772
- if os.path.exists(pretrained_path):
773
- print(f"Appending path: {pretrained_path}")
774
- places.append(pretrained_path)
775
- elif os.path.exists(command_path):
776
- places.append(command_path)
777
-
778
- places.append(model_path)
779
-
780
- except Exception:
781
- pass
782
 
783
- return output
784
 
785
-
786
- def mod2normal(state_dict):
787
- # this code is copied from https://github.com/victorca25/iNNfer
788
- if "conv_first.weight" in state_dict:
789
- crt_net = {}
790
- items = list(state_dict)
791
-
792
- crt_net["model.0.weight"] = state_dict["conv_first.weight"]
793
- crt_net["model.0.bias"] = state_dict["conv_first.bias"]
794
-
795
- for k in items.copy():
796
- if "RDB" in k:
797
- ori_k = k.replace("RRDB_trunk.", "model.1.sub.")
798
- if ".weight" in k:
799
- ori_k = ori_k.replace(".weight", ".0.weight")
800
- elif ".bias" in k:
801
- ori_k = ori_k.replace(".bias", ".0.bias")
802
- crt_net[ori_k] = state_dict[k]
803
- items.remove(k)
804
-
805
- crt_net["model.1.sub.23.weight"] = state_dict["trunk_conv.weight"]
806
- crt_net["model.1.sub.23.bias"] = state_dict["trunk_conv.bias"]
807
- crt_net["model.3.weight"] = state_dict["upconv1.weight"]
808
- crt_net["model.3.bias"] = state_dict["upconv1.bias"]
809
- crt_net["model.6.weight"] = state_dict["upconv2.weight"]
810
- crt_net["model.6.bias"] = state_dict["upconv2.bias"]
811
- crt_net["model.8.weight"] = state_dict["HRconv.weight"]
812
- crt_net["model.8.bias"] = state_dict["HRconv.bias"]
813
- crt_net["model.10.weight"] = state_dict["conv_last.weight"]
814
- crt_net["model.10.bias"] = state_dict["conv_last.bias"]
815
- state_dict = crt_net
816
- return state_dict
817
-
818
-
819
- def resrgan2normal(state_dict, nb=23):
820
- # this code is copied from https://github.com/victorca25/iNNfer
821
- if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
822
- re8x = 0
823
- crt_net = {}
824
- items = list(state_dict)
825
-
826
- crt_net["model.0.weight"] = state_dict["conv_first.weight"]
827
- crt_net["model.0.bias"] = state_dict["conv_first.bias"]
828
-
829
- for k in items.copy():
830
- if "rdb" in k:
831
- ori_k = k.replace("body.", "model.1.sub.")
832
- ori_k = ori_k.replace(".rdb", ".RDB")
833
- if ".weight" in k:
834
- ori_k = ori_k.replace(".weight", ".0.weight")
835
- elif ".bias" in k:
836
- ori_k = ori_k.replace(".bias", ".0.bias")
837
- crt_net[ori_k] = state_dict[k]
838
- items.remove(k)
839
-
840
- crt_net[f"model.1.sub.{nb}.weight"] = state_dict["conv_body.weight"]
841
- crt_net[f"model.1.sub.{nb}.bias"] = state_dict["conv_body.bias"]
842
- crt_net["model.3.weight"] = state_dict["conv_up1.weight"]
843
- crt_net["model.3.bias"] = state_dict["conv_up1.bias"]
844
- crt_net["model.6.weight"] = state_dict["conv_up2.weight"]
845
- crt_net["model.6.bias"] = state_dict["conv_up2.bias"]
846
-
847
- if "conv_up3.weight" in state_dict:
848
- # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
849
- re8x = 3
850
- crt_net["model.9.weight"] = state_dict["conv_up3.weight"]
851
- crt_net["model.9.bias"] = state_dict["conv_up3.bias"]
852
-
853
- crt_net[f"model.{8+re8x}.weight"] = state_dict["conv_hr.weight"]
854
- crt_net[f"model.{8+re8x}.bias"] = state_dict["conv_hr.bias"]
855
- crt_net[f"model.{10+re8x}.weight"] = state_dict["conv_last.weight"]
856
- crt_net[f"model.{10+re8x}.bias"] = state_dict["conv_last.bias"]
857
-
858
- state_dict = crt_net
859
- return state_dict
860
-
861
-
862
- def infer_params(state_dict):
863
- # this code is copied from https://github.com/victorca25/iNNfer
864
  scale2x = 0
865
  scalemin = 6
866
  n_uplayer = 0
867
- plus = False
 
868
 
869
  for block in list(state_dict):
870
  parts = block.split(".")
@@ -878,65 +141,66 @@ def infer_params(state_dict):
878
  if part_num > n_uplayer:
879
  n_uplayer = part_num
880
  out_nc = state_dict[block].shape[0]
881
- if not plus and "conv1x1" in block:
882
- plus = True
883
 
884
  nf = state_dict["model.0.weight"].shape[0]
885
  in_nc = state_dict["model.0.weight"].shape[1]
886
- out_nc = out_nc
887
  scale = 2**scale2x
888
 
889
- return in_nc, out_nc, nf, nb, plus, scale
 
 
 
 
 
 
 
890
 
891
 
892
  # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L64
893
- Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
 
 
 
 
 
 
894
 
895
 
896
- # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L67
897
- def split_grid(image, tile_w=512, tile_h=512, overlap=64):
898
  w = image.width
899
  h = image.height
900
 
901
  non_overlap_width = tile_w - overlap
902
  non_overlap_height = tile_h - overlap
903
 
904
- cols = math.ceil((w - overlap) / non_overlap_width)
905
- rows = math.ceil((h - overlap) / non_overlap_height)
906
 
907
  dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
908
  dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
909
 
910
  grid = Grid([], tile_w, tile_h, w, h, overlap)
911
  for row in range(rows):
912
- row_images = []
913
-
914
- y = int(row * dy)
915
-
916
- if y + tile_h >= h:
917
- y = h - tile_h
918
-
919
  for col in range(cols):
920
- x = int(col * dx)
921
-
922
- if x + tile_w >= w:
923
- x = w - tile_w
924
-
925
- tile = image.crop((x, y, x + tile_w, y + tile_h))
926
-
927
- row_images.append([x, tile_w, tile])
928
-
929
- grid.tiles.append([y, tile_h, row_images])
930
 
931
  return grid
932
 
933
 
934
  # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L104
935
- def combine_grid(grid):
936
- def make_mask_image(r):
937
  r = r * 255 / grid.overlap
938
- r = r.astype(np.uint8)
939
- return Image.fromarray(r, "L")
940
 
941
  mask_w = make_mask_image(
942
  np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)
@@ -975,10 +239,10 @@ def combine_grid(grid):
975
 
976
  class UpscalerESRGAN:
977
  def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
978
- self.device = device
979
- self.dtype = dtype
980
  self.model_path = model_path
 
981
  self.model = self.load_model(model_path)
 
982
 
983
  def __call__(self, img: Image.Image) -> Image.Image:
984
  return self.upscale_without_tiling(img)
@@ -988,51 +252,25 @@ class UpscalerESRGAN:
988
  self.dtype = dtype
989
  self.model.to(device=device, dtype=dtype)
990
 
991
- def load_model(self, path: Path) -> SRVGGNetCompact | RRDBNet:
992
  filename = path
993
- state_dict = torch.load(filename, weights_only=True, map_location=self.device)
994
-
995
- if "params_ema" in state_dict:
996
- state_dict = state_dict["params_ema"]
997
- elif "params" in state_dict:
998
- state_dict = state_dict["params"]
999
- num_conv = 16 if "realesr-animevideov3" in filename else 32
1000
- model = SRVGGNetCompact(
1001
- num_in_ch=3,
1002
- num_out_ch=3,
1003
- num_feat=64,
1004
- num_conv=num_conv,
1005
- upscale=4,
1006
- act_type="prelu",
1007
- )
1008
- model.load_state_dict(state_dict)
1009
- model.eval()
1010
- return model
1011
-
1012
- if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
1013
- nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
1014
- state_dict = resrgan2normal(state_dict, nb)
1015
- elif "conv_first.weight" in state_dict:
1016
- state_dict = mod2normal(state_dict)
1017
- elif "model.0.weight" not in state_dict:
1018
- raise Exception("The file is not a recognized ESRGAN model.")
1019
-
1020
- in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
1021
-
1022
- model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
1023
  model.load_state_dict(state_dict)
1024
  model.eval()
1025
 
1026
  return model
1027
 
1028
  def upscale_without_tiling(self, img: Image.Image) -> Image.Image:
1029
- img = np.array(img)
1030
- img = img[:, :, ::-1]
1031
- img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
1032
- img = torch.from_numpy(img).float()
1033
- img = img.unsqueeze(0).to(device=self.device, dtype=self.dtype)
1034
  with torch.no_grad():
1035
- output = self.model(img)
1036
  output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
1037
  output = 255.0 * np.moveaxis(output, 0, 2)
1038
  output = output.astype(np.uint8)
@@ -1041,20 +279,19 @@ class UpscalerESRGAN:
1041
 
1042
  # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/esrgan_model.py#L208
1043
  def upscale_with_tiling(self, img: Image.Image) -> Image.Image:
 
1044
  grid = split_grid(img)
1045
- newtiles = []
1046
- scale_factor = 1
1047
 
1048
  for y, h, row in grid.tiles:
1049
- newrow = []
1050
  for tiledata in row:
1051
  x, w, tile = tiledata
1052
-
1053
  output = self.upscale_without_tiling(tile)
1054
  scale_factor = output.width // tile.width
1055
-
1056
- newrow.append([x * scale_factor, w * scale_factor, output])
1057
- newtiles.append([y * scale_factor, h * scale_factor, newrow])
1058
 
1059
  newgrid = Grid(
1060
  newtiles,
 
 
1
  """
2
  Modified from https://github.com/philz1337x/clarity-upscaler
3
  which is a copy of https://github.com/AUTOMATIC1111/stable-diffusion-webui
 
6
  """
7
 
8
  import math
 
 
9
  from pathlib import Path
10
+ from typing import NamedTuple
11
 
12
  import numpy as np
13
+ import numpy.typing as npt
14
  import torch
15
  import torch.nn as nn
 
16
  from PIL import Image
17
 
 
 
 
18
 
19
+ def conv_block(in_nc: int, out_nc: int) -> nn.Sequential:
20
+ return nn.Sequential(
21
+ nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1),
22
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
23
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  class ResidualDenseBlock_5C(nn.Module):
 
34
  {Rakotonirina} and A. {Rasoanaivo}
35
  """
36
 
37
+ def __init__(self, nf: int = 64, gc: int = 32) -> None:
38
+ super().__init__() # type: ignore[reportUnknownMemberType]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ self.conv1 = conv_block(nf, gc)
41
+ self.conv2 = conv_block(nf + gc, gc)
42
+ self.conv3 = conv_block(nf + 2 * gc, gc)
43
+ self.conv4 = conv_block(nf + 3 * gc, gc)
44
+ # Wrapped in Sequential because of key in state dict.
45
+ self.conv5 = nn.Sequential(nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1))
46
 
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  x1 = self.conv1(x)
49
  x2 = self.conv2(torch.cat((x, x1), 1))
 
 
50
  x3 = self.conv3(torch.cat((x, x1, x2), 1))
51
  x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
 
 
52
  x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
53
+ return x5 * 0.2 + x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
 
56
+ class RRDB(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  """
58
+ Residual in Residual Dense Block
59
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ def __init__(self, nf: int) -> None:
63
+ super().__init__() # type: ignore[reportUnknownMemberType]
64
+ self.RDB1 = ResidualDenseBlock_5C(nf)
65
+ self.RDB2 = ResidualDenseBlock_5C(nf)
66
+ self.RDB3 = ResidualDenseBlock_5C(nf)
67
 
68
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
+ out = self.RDB1(x)
70
+ out = self.RDB2(out)
71
+ out = self.RDB3(out)
72
+ return out * 0.2 + x
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
+ class Upsample2x(nn.Module):
76
+ """Upsample 2x."""
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ def __init__(self) -> None:
79
+ super().__init__() # type: ignore[reportUnknownMemberType]
80
 
81
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
82
+ return nn.functional.interpolate(x, scale_factor=2.0) # type: ignore
 
 
83
 
84
 
85
  class ShortcutBlock(nn.Module):
86
  """Elementwise sum the output of a submodule to its input"""
87
 
88
+ def __init__(self, submodule: nn.Module) -> None:
89
+ super().__init__() # type: ignore[reportUnknownMemberType]
90
  self.sub = submodule
91
 
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ return x + self.sub(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ class RRDBNet(nn.Module):
97
+ def __init__(self, in_nc: int, out_nc: int, nf: int, nb: int) -> None:
98
+ super().__init__() # type: ignore[reportUnknownMemberType]
99
+ assert in_nc % 4 != 0 # in_nc is 3
100
+
101
+ self.model = nn.Sequential(
102
+ nn.Conv2d(in_nc, nf, kernel_size=3, padding=1),
103
+ ShortcutBlock(
104
+ nn.Sequential(
105
+ *(RRDB(nf) for _ in range(nb)),
106
+ nn.Conv2d(nf, nf, kernel_size=3, padding=1),
107
+ )
108
+ ),
109
+ Upsample2x(),
110
+ nn.Conv2d(nf, nf, kernel_size=3, padding=1),
111
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
112
+ Upsample2x(),
113
+ nn.Conv2d(nf, nf, kernel_size=3, padding=1),
114
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
115
+ nn.Conv2d(nf, nf, kernel_size=3, padding=1),
116
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
117
+ nn.Conv2d(nf, out_nc, kernel_size=3, padding=1),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  )
119
 
120
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
121
+ return self.model(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
 
123
 
124
+ def infer_params(state_dict: dict[str, torch.Tensor]) -> tuple[int, int, int, int, int]:
125
+ # this code is adapted from https://github.com/victorca25/iNNfer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  scale2x = 0
127
  scalemin = 6
128
  n_uplayer = 0
129
+ out_nc = 0
130
+ nb = 0
131
 
132
  for block in list(state_dict):
133
  parts = block.split(".")
 
141
  if part_num > n_uplayer:
142
  n_uplayer = part_num
143
  out_nc = state_dict[block].shape[0]
144
+ assert "conv1x1" not in block # no ESRGANPlus
 
145
 
146
  nf = state_dict["model.0.weight"].shape[0]
147
  in_nc = state_dict["model.0.weight"].shape[1]
 
148
  scale = 2**scale2x
149
 
150
+ assert out_nc > 0
151
+ assert nb > 0
152
+
153
+ return in_nc, out_nc, nf, nb, scale # 3, 3, 64, 23, 4
154
+
155
+
156
+ Tile = tuple[int, int, Image.Image]
157
+ Tiles = list[tuple[int, int, list[Tile]]]
158
 
159
 
160
  # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L64
161
+ class Grid(NamedTuple):
162
+ tiles: Tiles
163
+ tile_w: int
164
+ tile_h: int
165
+ image_w: int
166
+ image_h: int
167
+ overlap: int
168
 
169
 
170
+ # adapted from https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L67
171
+ def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
172
  w = image.width
173
  h = image.height
174
 
175
  non_overlap_width = tile_w - overlap
176
  non_overlap_height = tile_h - overlap
177
 
178
+ cols = max(1, math.ceil((w - overlap) / non_overlap_width))
179
+ rows = max(1, math.ceil((h - overlap) / non_overlap_height))
180
 
181
  dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
182
  dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
183
 
184
  grid = Grid([], tile_w, tile_h, w, h, overlap)
185
  for row in range(rows):
186
+ row_images: list[Tile] = []
187
+ y1 = max(min(int(row * dy), h - tile_h), 0)
188
+ y2 = min(y1 + tile_h, h)
 
 
 
 
189
  for col in range(cols):
190
+ x1 = max(min(int(col * dx), w - tile_w), 0)
191
+ x2 = min(x1 + tile_w, w)
192
+ tile = image.crop((x1, y1, x2, y2))
193
+ row_images.append((x1, tile_w, tile))
194
+ grid.tiles.append((y1, tile_h, row_images))
 
 
 
 
 
195
 
196
  return grid
197
 
198
 
199
  # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L104
200
+ def combine_grid(grid: Grid):
201
+ def make_mask_image(r: npt.NDArray[np.float32]) -> Image.Image:
202
  r = r * 255 / grid.overlap
203
+ return Image.fromarray(r.astype(np.uint8), "L")
 
204
 
205
  mask_w = make_mask_image(
206
  np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)
 
239
 
240
  class UpscalerESRGAN:
241
  def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
 
 
242
  self.model_path = model_path
243
+ self.device = device
244
  self.model = self.load_model(model_path)
245
+ self.to(device, dtype)
246
 
247
  def __call__(self, img: Image.Image) -> Image.Image:
248
  return self.upscale_without_tiling(img)
 
252
  self.dtype = dtype
253
  self.model.to(device=device, dtype=dtype)
254
 
255
+ def load_model(self, path: Path) -> RRDBNet:
256
  filename = path
257
+ state_dict: dict[str, torch.Tensor] = torch.load(filename, weights_only=True, map_location=self.device) # type: ignore
258
+ in_nc, out_nc, nf, nb, upscale = infer_params(state_dict)
259
+ assert upscale == 4, "Only 4x upscaling is supported"
260
+ model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  model.load_state_dict(state_dict)
262
  model.eval()
263
 
264
  return model
265
 
266
  def upscale_without_tiling(self, img: Image.Image) -> Image.Image:
267
+ img_np = np.array(img)
268
+ img_np = img_np[:, :, ::-1]
269
+ img_np = np.ascontiguousarray(np.transpose(img_np, (2, 0, 1))) / 255
270
+ img_t = torch.from_numpy(img_np).float() # type: ignore
271
+ img_t = img_t.unsqueeze(0).to(device=self.device, dtype=self.dtype)
272
  with torch.no_grad():
273
+ output = self.model(img_t)
274
  output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
275
  output = 255.0 * np.moveaxis(output, 0, 2)
276
  output = output.astype(np.uint8)
 
279
 
280
  # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/esrgan_model.py#L208
281
  def upscale_with_tiling(self, img: Image.Image) -> Image.Image:
282
+ img = img.convert("RGB")
283
  grid = split_grid(img)
284
+ newtiles: Tiles = []
285
+ scale_factor: int = 1
286
 
287
  for y, h, row in grid.tiles:
288
+ newrow: list[Tile] = []
289
  for tiledata in row:
290
  x, w, tile = tiledata
 
291
  output = self.upscale_without_tiling(tile)
292
  scale_factor = output.width // tile.width
293
+ newrow.append((x * scale_factor, w * scale_factor, output))
294
+ newtiles.append((y * scale_factor, h * scale_factor, newrow))
 
295
 
296
  newgrid = Grid(
297
  newtiles,