rizavelioglu commited on
Commit
835cd00
1 Parent(s): 80cc34d

enable upscaler

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +68 -38
  3. esrgan_model.py +307 -0
README.md CHANGED
@@ -2,7 +2,7 @@
2
  title: TryOffDiff
3
  emoji: 🔥
4
  colorFrom: yellow
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.7.0
8
  app_file: app.py
 
2
  title: TryOffDiff
3
  emoji: 🔥
4
  colorFrom: yellow
5
+ colorTo: orange
6
  sdk: gradio
7
  sdk_version: 5.7.0
8
  app_file: app.py
app.py CHANGED
@@ -1,14 +1,17 @@
1
  import os
 
 
 
 
2
  import gradio as gr
 
3
  import spaces
4
  import torch
5
  import torch.nn as nn
6
- from diffusers import EulerDiscreteScheduler, AutoencoderKL, UNet2DConditionModel
7
- from huggingface_hub import hf_hub_download
8
- from transformers import SiglipImageProcessor, SiglipVisionModel
9
  from torchvision.io import read_image
10
  import torchvision.transforms.v2 as transforms
11
  from torchvision.utils import make_grid
 
12
 
13
 
14
  class TryOffDiff(nn.Module):
@@ -44,29 +47,26 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
44
 
45
  # Initialize Image Encoder
46
  img_processor = SiglipImageProcessor.from_pretrained(
47
- "google/siglip-base-patch16-512",
48
- do_resize=False,
49
- do_rescale=False,
50
- do_normalize=False
51
  )
52
  img_enc = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-512").eval().to(device)
53
- img_enc_transform = transforms.Compose([
54
- PadToSquare(), # Custom transform to pad the image to a square
55
- transforms.Resize((512, 512)),
56
- transforms.ToDtype(torch.float32, scale=True),
57
- transforms.Normalize(mean=[0.5], std=[0.5]),
58
- ])
 
 
59
 
60
  # Load TryOffDiff Model
61
  path_model = hf_hub_download(
62
  repo_id="rizavelioglu/tryoffdiff",
63
  filename="tryoffdiff.pth", # or one of ["ldm-1", "ldm-2", "ldm-3", ...],
64
- force_download=False
65
  )
66
  path_scheduler = hf_hub_download(
67
- repo_id="rizavelioglu/tryoffdiff",
68
- filename="scheduler/scheduler_config.json",
69
- force_download=False
70
  )
71
  net = TryOffDiff()
72
  net.load_state_dict(torch.load(path_model, weights_only=False))
@@ -74,19 +74,35 @@ net.eval().to(device)
74
 
75
  # Initialize VAE (only Decoder will be used)
76
  vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").eval().to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  torch.cuda.empty_cache()
78
 
79
 
80
  # Define image generation function
81
  @spaces.GPU(duration=10)
82
  @torch.no_grad()
83
- def generate_image(input_image, seed=42, guidance_scale=2.0, num_inference_steps=50, is_upscale=False):
 
 
84
  # Configure scheduler
85
  scheduler = EulerDiscreteScheduler.from_pretrained(path_scheduler)
86
  scheduler.is_scale_input_called = True # suppress warning
87
  scheduler.set_timesteps(num_inference_steps)
88
 
89
- # Set random seed
90
  generator = torch.Generator(device=device).manual_seed(seed)
91
  x = torch.randn(1, 4, 64, 64, generator=generator, device=device)
92
 
@@ -95,46 +111,50 @@ def generate_image(input_image, seed=42, guidance_scale=2.0, num_inference_steps
95
  inputs = {k: v.to(img_enc.device) for k, v in img_processor(images=cond_image, return_tensors="pt").items()}
96
  cond_emb = img_enc(**inputs).last_hidden_state.to(device)
97
 
98
- # Prepare unconditioned embeddings
99
  uncond_emb = torch.zeros_like(cond_emb) if guidance_scale > 1 else None
100
 
101
- # Denoising loop with mixed precision
102
  with torch.autocast(device):
103
  for t in scheduler.timesteps:
104
  if guidance_scale > 1:
105
- noise_pred = net(
106
- torch.cat([x] * 2), t, torch.cat([uncond_emb, cond_emb])
107
- ).chunk(2)
108
  noise_pred = noise_pred[0] + guidance_scale * (noise_pred[1] - noise_pred[0])
109
  else:
 
110
  noise_pred = net(x, t, cond_emb)
111
 
 
112
  scheduler_output = scheduler.step(noise_pred, t, x)
113
  x = scheduler_output.prev_sample
114
 
115
- # Decode preds
116
  decoded = vae.decode(1 / 0.18215 * scheduler_output.pred_original_sample).sample
117
- images = (decoded / 2 + 0.5).cpu()
118
-
119
  # Create grid
120
  grid = make_grid(images, nrow=len(images), normalize=True, scale_each=True)
 
 
 
121
  if is_upscale:
122
- pass
123
- else:
124
- return transforms.ToPILImage()(grid)
125
 
126
 
127
  title = "Virtual Try-Off Generator"
128
  description = r"""
129
- This is the demo of the paper <a href="https://arxiv.org/abs/2411.18350">TryOffDiff: Virtual-Try-Off via High-Fidelity Garment Reconstruction using Diffusion Models</a>.
130
- <br>Upload an image of a clothed individual to generate a standardized garment image using TryOffDiff.
131
  <br> Check out the <a href="https://rizavelioglu.github.io/tryoffdiff/">project page</a> for more information.
132
  """
133
  article = r"""
134
  Example images are sampled from the `VITON-HD-test` set, which the models did not see during training.
135
 
136
- <br>**Citation** <br>If you find our work useful in your research, please consider giving a star ⭐ and
137
- a citation:
138
  ```
139
  @article{velioglu2024tryoffdiff,
140
  title = {TryOffDiff: Virtual-Try-Off via High-Fidelity Garment Reconstruction using Diffusion Models},
@@ -151,18 +171,28 @@ examples = [[f"examples/{img_filename}", 42, 2.0, 20, False] for img_filename in
151
  demo = gr.Interface(
152
  fn=generate_image,
153
  inputs=[
154
- gr.Image(type="filepath", label="Reference Image"),
155
  gr.Slider(value=42, minimum=0, maximum=1e6, step=1, label="Seed"),
156
- gr.Slider(value=2.0, minimum=1, maximum=5, step=0.5, label="Guidance Scale(s)", info="No guidance applied at s=1, hence faster inference."),
 
 
 
 
 
 
 
157
  gr.Slider(value=20, minimum=0, maximum=1000, step=10, label="# of Inference Steps"),
158
- gr.Checkbox(value=False, label="Upscale Output")
 
 
159
  ],
160
- outputs=gr.Image(type="pil", label="Generated Garment", height=512, width=512),
161
  title=title,
162
  description=description,
163
  article=article,
164
  examples=examples,
165
  examples_per_page=4,
 
166
  )
167
 
168
  demo.launch()
 
1
  import os
2
+ from pathlib import Path
3
+
4
+ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
5
+ from esrgan_model import UpscalerESRGAN
6
  import gradio as gr
7
+ from huggingface_hub import hf_hub_download
8
  import spaces
9
  import torch
10
  import torch.nn as nn
 
 
 
11
  from torchvision.io import read_image
12
  import torchvision.transforms.v2 as transforms
13
  from torchvision.utils import make_grid
14
+ from transformers import SiglipImageProcessor, SiglipVisionModel
15
 
16
 
17
  class TryOffDiff(nn.Module):
 
47
 
48
  # Initialize Image Encoder
49
  img_processor = SiglipImageProcessor.from_pretrained(
50
+ "google/siglip-base-patch16-512", do_resize=False, do_rescale=False, do_normalize=False
 
 
 
51
  )
52
  img_enc = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-512").eval().to(device)
53
+ img_enc_transform = transforms.Compose(
54
+ [
55
+ PadToSquare(), # Custom transform to pad the image to a square
56
+ transforms.Resize((512, 512)),
57
+ transforms.ToDtype(torch.float32, scale=True),
58
+ transforms.Normalize(mean=[0.5], std=[0.5]),
59
+ ]
60
+ )
61
 
62
  # Load TryOffDiff Model
63
  path_model = hf_hub_download(
64
  repo_id="rizavelioglu/tryoffdiff",
65
  filename="tryoffdiff.pth", # or one of ["ldm-1", "ldm-2", "ldm-3", ...],
66
+ force_download=False,
67
  )
68
  path_scheduler = hf_hub_download(
69
+ repo_id="rizavelioglu/tryoffdiff", filename="scheduler/scheduler_config.json", force_download=False
 
 
70
  )
71
  net = TryOffDiff()
72
  net.load_state_dict(torch.load(path_model, weights_only=False))
 
74
 
75
  # Initialize VAE (only Decoder will be used)
76
  vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").eval().to(device)
77
+
78
+ # Initialize the upscaler
79
+ upscaler = UpscalerESRGAN(
80
+ model_path=Path(
81
+ hf_hub_download(
82
+ repo_id="philz1337x/upscaler",
83
+ filename="4x-UltraSharp.pth",
84
+ # revision="011deacac8270114eb7d2eeff4fe6fa9a837be70",
85
+ )
86
+ ),
87
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
88
+ dtype=torch.float32,
89
+ )
90
+
91
  torch.cuda.empty_cache()
92
 
93
 
94
  # Define image generation function
95
  @spaces.GPU(duration=10)
96
  @torch.no_grad()
97
+ def generate_image(
98
+ input_image, seed: int = 42, guidance_scale: float = 2.0, num_inference_steps: int = 50, is_upscale: bool = False
99
+ ):
100
  # Configure scheduler
101
  scheduler = EulerDiscreteScheduler.from_pretrained(path_scheduler)
102
  scheduler.is_scale_input_called = True # suppress warning
103
  scheduler.set_timesteps(num_inference_steps)
104
 
105
+ # Set seed for reproducibility
106
  generator = torch.Generator(device=device).manual_seed(seed)
107
  x = torch.randn(1, 4, 64, 64, generator=generator, device=device)
108
 
 
111
  inputs = {k: v.to(img_enc.device) for k, v in img_processor(images=cond_image, return_tensors="pt").items()}
112
  cond_emb = img_enc(**inputs).last_hidden_state.to(device)
113
 
114
+ # Prepare unconditioned embeddings (only if guidance is enabled)
115
  uncond_emb = torch.zeros_like(cond_emb) if guidance_scale > 1 else None
116
 
117
+ # Diffusion denoising loop with mixed precision for efficiency
118
  with torch.autocast(device):
119
  for t in scheduler.timesteps:
120
  if guidance_scale > 1:
121
+ # Classifier-Free Guidance (CFG)
122
+ noise_pred = net(torch.cat([x] * 2), t, torch.cat([uncond_emb, cond_emb])).chunk(2)
 
123
  noise_pred = noise_pred[0] + guidance_scale * (noise_pred[1] - noise_pred[0])
124
  else:
125
+ # Standard prediction
126
  noise_pred = net(x, t, cond_emb)
127
 
128
+ # Scheduler step
129
  scheduler_output = scheduler.step(noise_pred, t, x)
130
  x = scheduler_output.prev_sample
131
 
132
+ # Decode predictions from latent space
133
  decoded = vae.decode(1 / 0.18215 * scheduler_output.pred_original_sample).sample
134
+ images = (decoded / 2 + 0.5).cpu()
135
+
136
  # Create grid
137
  grid = make_grid(images, nrow=len(images), normalize=True, scale_each=True)
138
+ output_image = transforms.ToPILImage()(grid)
139
+
140
+ # Optionally upscale the output image
141
  if is_upscale:
142
+ output_image = upscaler(output_image)
143
+
144
+ return output_image
145
 
146
 
147
  title = "Virtual Try-Off Generator"
148
  description = r"""
149
+ This is the demo of the paper <a href="https://arxiv.org/abs/2411.18350">TryOffDiff: Virtual-Try-Off via High-Fidelity Garment Reconstruction using Diffusion Models</a>.
150
+ <br>Upload an image of a clothed individual to generate a standardized garment image using TryOffDiff.
151
  <br> Check out the <a href="https://rizavelioglu.github.io/tryoffdiff/">project page</a> for more information.
152
  """
153
  article = r"""
154
  Example images are sampled from the `VITON-HD-test` set, which the models did not see during training.
155
 
156
+ <br>**Citation** <br>If you find our work useful in your research, please consider giving a star ⭐ and
157
+ a citation:
158
  ```
159
  @article{velioglu2024tryoffdiff,
160
  title = {TryOffDiff: Virtual-Try-Off via High-Fidelity Garment Reconstruction using Diffusion Models},
 
171
  demo = gr.Interface(
172
  fn=generate_image,
173
  inputs=[
174
+ gr.Image(type="filepath", label="Reference Image", height=1024, width=1024),
175
  gr.Slider(value=42, minimum=0, maximum=1e6, step=1, label="Seed"),
176
+ gr.Slider(
177
+ value=2.0,
178
+ minimum=1,
179
+ maximum=5,
180
+ step=0.5,
181
+ label="Guidance Scale(s)",
182
+ info="No guidance applied at s=1, hence faster inference.",
183
+ ),
184
  gr.Slider(value=20, minimum=0, maximum=1000, step=10, label="# of Inference Steps"),
185
+ gr.Checkbox(
186
+ value=False, label="Upscale Output", info="Upscale output by 4x (2048x2048) using an off-the-shelf model."
187
+ ),
188
  ],
189
+ outputs=gr.Image(type="pil", label="Generated Garment", height=1024, width=1024),
190
  title=title,
191
  description=description,
192
  article=article,
193
  examples=examples,
194
  examples_per_page=4,
195
+ submit_btn="Generate",
196
  )
197
 
198
  demo.launch()
esrgan_model.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from https://github.com/finegrain-ai/refiners
3
+ Modified from https://github.com/philz1337x/clarity-upscaler
4
+ which is a copy of https://github.com/AUTOMATIC1111/stable-diffusion-webui
5
+ which is a copy of https://github.com/victorca25/iNNfer
6
+ which is a copy of https://github.com/xinntao/ESRGAN
7
+ """
8
+
9
+ import math
10
+ from pathlib import Path
11
+ from typing import NamedTuple
12
+
13
+ import numpy as np
14
+ import numpy.typing as npt
15
+ import torch
16
+ import torch.nn as nn
17
+ from PIL import Image
18
+ from huggingface_hub import hf_hub_download
19
+
20
+
21
+ def conv_block(in_nc: int, out_nc: int) -> nn.Sequential:
22
+ return nn.Sequential(
23
+ nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1),
24
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
25
+ )
26
+
27
+
28
+ class ResidualDenseBlock_5C(nn.Module):
29
+ """
30
+ Residual Dense Block
31
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
32
+ Modified options that can be used:
33
+ - "Partial Convolution based Padding" arXiv:1811.11718
34
+ - "Spectral normalization" arXiv:1802.05957
35
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
36
+ {Rakotonirina} and A. {Rasoanaivo}
37
+ """
38
+
39
+ def __init__(self, nf: int = 64, gc: int = 32) -> None:
40
+ super().__init__() # type: ignore[reportUnknownMemberType]
41
+
42
+ self.conv1 = conv_block(nf, gc)
43
+ self.conv2 = conv_block(nf + gc, gc)
44
+ self.conv3 = conv_block(nf + 2 * gc, gc)
45
+ self.conv4 = conv_block(nf + 3 * gc, gc)
46
+ # Wrapped in Sequential because of key in state dict.
47
+ self.conv5 = nn.Sequential(nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1))
48
+
49
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
50
+ x1 = self.conv1(x)
51
+ x2 = self.conv2(torch.cat((x, x1), 1))
52
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
53
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
54
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
55
+ return x5 * 0.2 + x
56
+
57
+
58
+ class RRDB(nn.Module):
59
+ """
60
+ Residual in Residual Dense Block
61
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
62
+ """
63
+
64
+ def __init__(self, nf: int) -> None:
65
+ super().__init__() # type: ignore[reportUnknownMemberType]
66
+ self.RDB1 = ResidualDenseBlock_5C(nf)
67
+ self.RDB2 = ResidualDenseBlock_5C(nf)
68
+ self.RDB3 = ResidualDenseBlock_5C(nf)
69
+
70
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
71
+ out = self.RDB1(x)
72
+ out = self.RDB2(out)
73
+ out = self.RDB3(out)
74
+ return out * 0.2 + x
75
+
76
+
77
+ class Upsample2x(nn.Module):
78
+ """Upsample 2x."""
79
+
80
+ def __init__(self) -> None:
81
+ super().__init__() # type: ignore[reportUnknownMemberType]
82
+
83
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
84
+ return nn.functional.interpolate(x, scale_factor=2.0) # type: ignore
85
+
86
+
87
+ class ShortcutBlock(nn.Module):
88
+ """Elementwise sum the output of a submodule to its input"""
89
+
90
+ def __init__(self, submodule: nn.Module) -> None:
91
+ super().__init__() # type: ignore[reportUnknownMemberType]
92
+ self.sub = submodule
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ return x + self.sub(x)
96
+
97
+
98
+ class RRDBNet(nn.Module):
99
+ def __init__(self, in_nc: int, out_nc: int, nf: int, nb: int) -> None:
100
+ super().__init__() # type: ignore[reportUnknownMemberType]
101
+ assert in_nc % 4 != 0 # in_nc is 3
102
+
103
+ self.model = nn.Sequential(
104
+ nn.Conv2d(in_nc, nf, kernel_size=3, padding=1),
105
+ ShortcutBlock(
106
+ nn.Sequential(
107
+ *(RRDB(nf) for _ in range(nb)),
108
+ nn.Conv2d(nf, nf, kernel_size=3, padding=1),
109
+ )
110
+ ),
111
+ Upsample2x(),
112
+ nn.Conv2d(nf, nf, kernel_size=3, padding=1),
113
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
114
+ Upsample2x(),
115
+ nn.Conv2d(nf, nf, kernel_size=3, padding=1),
116
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
117
+ nn.Conv2d(nf, nf, kernel_size=3, padding=1),
118
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
119
+ nn.Conv2d(nf, out_nc, kernel_size=3, padding=1),
120
+ )
121
+
122
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
123
+ return self.model(x)
124
+
125
+
126
+ def infer_params(state_dict: dict[str, torch.Tensor]) -> tuple[int, int, int, int, int]:
127
+ # this code is adapted from https://github.com/victorca25/iNNfer
128
+ scale2x = 0
129
+ scalemin = 6
130
+ n_uplayer = 0
131
+ out_nc = 0
132
+ nb = 0
133
+
134
+ for block in list(state_dict):
135
+ parts = block.split(".")
136
+ n_parts = len(parts)
137
+ if n_parts == 5 and parts[2] == "sub":
138
+ nb = int(parts[3])
139
+ elif n_parts == 3:
140
+ part_num = int(parts[1])
141
+ if part_num > scalemin and parts[0] == "model" and parts[2] == "weight":
142
+ scale2x += 1
143
+ if part_num > n_uplayer:
144
+ n_uplayer = part_num
145
+ out_nc = state_dict[block].shape[0]
146
+ assert "conv1x1" not in block # no ESRGANPlus
147
+
148
+ nf = state_dict["model.0.weight"].shape[0]
149
+ in_nc = state_dict["model.0.weight"].shape[1]
150
+ scale = 2**scale2x
151
+
152
+ assert out_nc > 0
153
+ assert nb > 0
154
+
155
+ return in_nc, out_nc, nf, nb, scale # 3, 3, 64, 23, 4
156
+
157
+
158
+ Tile = tuple[int, int, Image.Image]
159
+ Tiles = list[tuple[int, int, list[Tile]]]
160
+
161
+
162
+ # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L64
163
+ class Grid(NamedTuple):
164
+ tiles: Tiles
165
+ tile_w: int
166
+ tile_h: int
167
+ image_w: int
168
+ image_h: int
169
+ overlap: int
170
+
171
+
172
+ # adapted from https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L67
173
+ def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
174
+ w = image.width
175
+ h = image.height
176
+
177
+ non_overlap_width = tile_w - overlap
178
+ non_overlap_height = tile_h - overlap
179
+
180
+ cols = max(1, math.ceil((w - overlap) / non_overlap_width))
181
+ rows = max(1, math.ceil((h - overlap) / non_overlap_height))
182
+
183
+ dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
184
+ dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
185
+
186
+ grid = Grid([], tile_w, tile_h, w, h, overlap)
187
+ for row in range(rows):
188
+ row_images: list[Tile] = []
189
+ y1 = max(min(int(row * dy), h - tile_h), 0)
190
+ y2 = min(y1 + tile_h, h)
191
+ for col in range(cols):
192
+ x1 = max(min(int(col * dx), w - tile_w), 0)
193
+ x2 = min(x1 + tile_w, w)
194
+ tile = image.crop((x1, y1, x2, y2))
195
+ row_images.append((x1, tile_w, tile))
196
+ grid.tiles.append((y1, tile_h, row_images))
197
+
198
+ return grid
199
+
200
+
201
+ # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L104
202
+ def combine_grid(grid: Grid):
203
+ def make_mask_image(r: npt.NDArray[np.float32]) -> Image.Image:
204
+ r = r * 255 / grid.overlap
205
+ return Image.fromarray(r.astype(np.uint8), "L")
206
+
207
+ mask_w = make_mask_image(
208
+ np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)
209
+ )
210
+ mask_h = make_mask_image(
211
+ np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1)
212
+ )
213
+
214
+ combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
215
+ for y, h, row in grid.tiles:
216
+ combined_row = Image.new("RGB", (grid.image_w, h))
217
+ for x, w, tile in row:
218
+ if x == 0:
219
+ combined_row.paste(tile, (0, 0))
220
+ continue
221
+
222
+ combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
223
+ combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
224
+
225
+ if y == 0:
226
+ combined_image.paste(combined_row, (0, 0))
227
+ continue
228
+
229
+ combined_image.paste(
230
+ combined_row.crop((0, 0, combined_row.width, grid.overlap)),
231
+ (0, y),
232
+ mask=mask_h,
233
+ )
234
+ combined_image.paste(
235
+ combined_row.crop((0, grid.overlap, combined_row.width, h)),
236
+ (0, y + grid.overlap),
237
+ )
238
+
239
+ return combined_image
240
+
241
+
242
+ class UpscalerESRGAN:
243
+ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
244
+ self.model_path = model_path
245
+ self.device = device
246
+ self.model = self.load_model(model_path)
247
+ self.to(device, dtype)
248
+
249
+ def __call__(self, img: Image.Image) -> Image.Image:
250
+ return self.upscale_without_tiling(img)
251
+
252
+ def to(self, device: torch.device, dtype: torch.dtype):
253
+ self.device = device
254
+ self.dtype = dtype
255
+ self.model.to(device=device, dtype=dtype)
256
+
257
+ def load_model(self, path: Path) -> RRDBNet:
258
+ filename = path
259
+ state_dict: dict[str, torch.Tensor] = torch.load(filename, weights_only=True, map_location=self.device) # type: ignore
260
+ in_nc, out_nc, nf, nb, upscale = infer_params(state_dict)
261
+ assert upscale == 4, "Only 4x upscaling is supported"
262
+ model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb)
263
+ model.load_state_dict(state_dict)
264
+ model.eval()
265
+
266
+ return model
267
+
268
+ def upscale_without_tiling(self, img: Image.Image) -> Image.Image:
269
+ img_np = np.array(img)
270
+ img_np = img_np[:, :, ::-1]
271
+ img_np = np.ascontiguousarray(np.transpose(img_np, (2, 0, 1))) / 255
272
+ img_t = torch.from_numpy(img_np).float() # type: ignore
273
+ img_t = img_t.unsqueeze(0).to(device=self.device, dtype=self.dtype)
274
+ with torch.no_grad():
275
+ output = self.model(img_t)
276
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
277
+ output = 255.0 * np.moveaxis(output, 0, 2)
278
+ output = output.astype(np.uint8)
279
+ output = output[:, :, ::-1]
280
+ return Image.fromarray(output, "RGB")
281
+
282
+ # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/esrgan_model.py#L208
283
+ def upscale_with_tiling(self, img: Image.Image) -> Image.Image:
284
+ img = img.convert("RGB")
285
+ grid = split_grid(img)
286
+ newtiles: Tiles = []
287
+ scale_factor: int = 1
288
+
289
+ for y, h, row in grid.tiles:
290
+ newrow: list[Tile] = []
291
+ for tiledata in row:
292
+ x, w, tile = tiledata
293
+ output = self.upscale_without_tiling(tile)
294
+ scale_factor = output.width // tile.width
295
+ newrow.append((x * scale_factor, w * scale_factor, output))
296
+ newtiles.append((y * scale_factor, h * scale_factor, newrow))
297
+
298
+ newgrid = Grid(
299
+ newtiles,
300
+ grid.tile_w * scale_factor,
301
+ grid.tile_h * scale_factor,
302
+ grid.image_w * scale_factor,
303
+ grid.image_h * scale_factor,
304
+ grid.overlap * scale_factor,
305
+ )
306
+ output = combine_grid(newgrid)
307
+ return output