Spaces:
Paused
Paused
rizavelioglu
commited on
Commit
•
835cd00
1
Parent(s):
80cc34d
enable upscaler
Browse files- README.md +1 -1
- app.py +68 -38
- esrgan_model.py +307 -0
README.md
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
title: TryOffDiff
|
3 |
emoji: 🔥
|
4 |
colorFrom: yellow
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.7.0
|
8 |
app_file: app.py
|
|
|
2 |
title: TryOffDiff
|
3 |
emoji: 🔥
|
4 |
colorFrom: yellow
|
5 |
+
colorTo: orange
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.7.0
|
8 |
app_file: app.py
|
app.py
CHANGED
@@ -1,14 +1,17 @@
|
|
1 |
import os
|
|
|
|
|
|
|
|
|
2 |
import gradio as gr
|
|
|
3 |
import spaces
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
-
from diffusers import EulerDiscreteScheduler, AutoencoderKL, UNet2DConditionModel
|
7 |
-
from huggingface_hub import hf_hub_download
|
8 |
-
from transformers import SiglipImageProcessor, SiglipVisionModel
|
9 |
from torchvision.io import read_image
|
10 |
import torchvision.transforms.v2 as transforms
|
11 |
from torchvision.utils import make_grid
|
|
|
12 |
|
13 |
|
14 |
class TryOffDiff(nn.Module):
|
@@ -44,29 +47,26 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
44 |
|
45 |
# Initialize Image Encoder
|
46 |
img_processor = SiglipImageProcessor.from_pretrained(
|
47 |
-
"google/siglip-base-patch16-512",
|
48 |
-
do_resize=False,
|
49 |
-
do_rescale=False,
|
50 |
-
do_normalize=False
|
51 |
)
|
52 |
img_enc = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-512").eval().to(device)
|
53 |
-
img_enc_transform = transforms.Compose(
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
])
|
|
|
|
|
59 |
|
60 |
# Load TryOffDiff Model
|
61 |
path_model = hf_hub_download(
|
62 |
repo_id="rizavelioglu/tryoffdiff",
|
63 |
filename="tryoffdiff.pth", # or one of ["ldm-1", "ldm-2", "ldm-3", ...],
|
64 |
-
force_download=False
|
65 |
)
|
66 |
path_scheduler = hf_hub_download(
|
67 |
-
repo_id="rizavelioglu/tryoffdiff",
|
68 |
-
filename="scheduler/scheduler_config.json",
|
69 |
-
force_download=False
|
70 |
)
|
71 |
net = TryOffDiff()
|
72 |
net.load_state_dict(torch.load(path_model, weights_only=False))
|
@@ -74,19 +74,35 @@ net.eval().to(device)
|
|
74 |
|
75 |
# Initialize VAE (only Decoder will be used)
|
76 |
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").eval().to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
torch.cuda.empty_cache()
|
78 |
|
79 |
|
80 |
# Define image generation function
|
81 |
@spaces.GPU(duration=10)
|
82 |
@torch.no_grad()
|
83 |
-
def generate_image(
|
|
|
|
|
84 |
# Configure scheduler
|
85 |
scheduler = EulerDiscreteScheduler.from_pretrained(path_scheduler)
|
86 |
scheduler.is_scale_input_called = True # suppress warning
|
87 |
scheduler.set_timesteps(num_inference_steps)
|
88 |
|
89 |
-
# Set
|
90 |
generator = torch.Generator(device=device).manual_seed(seed)
|
91 |
x = torch.randn(1, 4, 64, 64, generator=generator, device=device)
|
92 |
|
@@ -95,46 +111,50 @@ def generate_image(input_image, seed=42, guidance_scale=2.0, num_inference_steps
|
|
95 |
inputs = {k: v.to(img_enc.device) for k, v in img_processor(images=cond_image, return_tensors="pt").items()}
|
96 |
cond_emb = img_enc(**inputs).last_hidden_state.to(device)
|
97 |
|
98 |
-
# Prepare unconditioned embeddings
|
99 |
uncond_emb = torch.zeros_like(cond_emb) if guidance_scale > 1 else None
|
100 |
|
101 |
-
#
|
102 |
with torch.autocast(device):
|
103 |
for t in scheduler.timesteps:
|
104 |
if guidance_scale > 1:
|
105 |
-
|
106 |
-
|
107 |
-
).chunk(2)
|
108 |
noise_pred = noise_pred[0] + guidance_scale * (noise_pred[1] - noise_pred[0])
|
109 |
else:
|
|
|
110 |
noise_pred = net(x, t, cond_emb)
|
111 |
|
|
|
112 |
scheduler_output = scheduler.step(noise_pred, t, x)
|
113 |
x = scheduler_output.prev_sample
|
114 |
|
115 |
-
# Decode
|
116 |
decoded = vae.decode(1 / 0.18215 * scheduler_output.pred_original_sample).sample
|
117 |
-
images = (decoded / 2 + 0.5).cpu()
|
118 |
-
|
119 |
# Create grid
|
120 |
grid = make_grid(images, nrow=len(images), normalize=True, scale_each=True)
|
|
|
|
|
|
|
121 |
if is_upscale:
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
|
126 |
|
127 |
title = "Virtual Try-Off Generator"
|
128 |
description = r"""
|
129 |
-
This is the demo of the paper <a href="https://arxiv.org/abs/2411.18350">TryOffDiff: Virtual-Try-Off via High-Fidelity Garment Reconstruction using Diffusion Models</a>.
|
130 |
-
<br>Upload an image of a clothed individual to generate a standardized garment image using TryOffDiff.
|
131 |
<br> Check out the <a href="https://rizavelioglu.github.io/tryoffdiff/">project page</a> for more information.
|
132 |
"""
|
133 |
article = r"""
|
134 |
Example images are sampled from the `VITON-HD-test` set, which the models did not see during training.
|
135 |
|
136 |
-
<br>**Citation** <br>If you find our work useful in your research, please consider giving a star ⭐ and
|
137 |
-
a citation:
|
138 |
```
|
139 |
@article{velioglu2024tryoffdiff,
|
140 |
title = {TryOffDiff: Virtual-Try-Off via High-Fidelity Garment Reconstruction using Diffusion Models},
|
@@ -151,18 +171,28 @@ examples = [[f"examples/{img_filename}", 42, 2.0, 20, False] for img_filename in
|
|
151 |
demo = gr.Interface(
|
152 |
fn=generate_image,
|
153 |
inputs=[
|
154 |
-
gr.Image(type="filepath", label="Reference Image"),
|
155 |
gr.Slider(value=42, minimum=0, maximum=1e6, step=1, label="Seed"),
|
156 |
-
gr.Slider(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
gr.Slider(value=20, minimum=0, maximum=1000, step=10, label="# of Inference Steps"),
|
158 |
-
gr.Checkbox(
|
|
|
|
|
159 |
],
|
160 |
-
outputs=gr.Image(type="pil", label="Generated Garment", height=
|
161 |
title=title,
|
162 |
description=description,
|
163 |
article=article,
|
164 |
examples=examples,
|
165 |
examples_per_page=4,
|
|
|
166 |
)
|
167 |
|
168 |
demo.launch()
|
|
|
1 |
import os
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
|
5 |
+
from esrgan_model import UpscalerESRGAN
|
6 |
import gradio as gr
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
import spaces
|
9 |
import torch
|
10 |
import torch.nn as nn
|
|
|
|
|
|
|
11 |
from torchvision.io import read_image
|
12 |
import torchvision.transforms.v2 as transforms
|
13 |
from torchvision.utils import make_grid
|
14 |
+
from transformers import SiglipImageProcessor, SiglipVisionModel
|
15 |
|
16 |
|
17 |
class TryOffDiff(nn.Module):
|
|
|
47 |
|
48 |
# Initialize Image Encoder
|
49 |
img_processor = SiglipImageProcessor.from_pretrained(
|
50 |
+
"google/siglip-base-patch16-512", do_resize=False, do_rescale=False, do_normalize=False
|
|
|
|
|
|
|
51 |
)
|
52 |
img_enc = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-512").eval().to(device)
|
53 |
+
img_enc_transform = transforms.Compose(
|
54 |
+
[
|
55 |
+
PadToSquare(), # Custom transform to pad the image to a square
|
56 |
+
transforms.Resize((512, 512)),
|
57 |
+
transforms.ToDtype(torch.float32, scale=True),
|
58 |
+
transforms.Normalize(mean=[0.5], std=[0.5]),
|
59 |
+
]
|
60 |
+
)
|
61 |
|
62 |
# Load TryOffDiff Model
|
63 |
path_model = hf_hub_download(
|
64 |
repo_id="rizavelioglu/tryoffdiff",
|
65 |
filename="tryoffdiff.pth", # or one of ["ldm-1", "ldm-2", "ldm-3", ...],
|
66 |
+
force_download=False,
|
67 |
)
|
68 |
path_scheduler = hf_hub_download(
|
69 |
+
repo_id="rizavelioglu/tryoffdiff", filename="scheduler/scheduler_config.json", force_download=False
|
|
|
|
|
70 |
)
|
71 |
net = TryOffDiff()
|
72 |
net.load_state_dict(torch.load(path_model, weights_only=False))
|
|
|
74 |
|
75 |
# Initialize VAE (only Decoder will be used)
|
76 |
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").eval().to(device)
|
77 |
+
|
78 |
+
# Initialize the upscaler
|
79 |
+
upscaler = UpscalerESRGAN(
|
80 |
+
model_path=Path(
|
81 |
+
hf_hub_download(
|
82 |
+
repo_id="philz1337x/upscaler",
|
83 |
+
filename="4x-UltraSharp.pth",
|
84 |
+
# revision="011deacac8270114eb7d2eeff4fe6fa9a837be70",
|
85 |
+
)
|
86 |
+
),
|
87 |
+
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
88 |
+
dtype=torch.float32,
|
89 |
+
)
|
90 |
+
|
91 |
torch.cuda.empty_cache()
|
92 |
|
93 |
|
94 |
# Define image generation function
|
95 |
@spaces.GPU(duration=10)
|
96 |
@torch.no_grad()
|
97 |
+
def generate_image(
|
98 |
+
input_image, seed: int = 42, guidance_scale: float = 2.0, num_inference_steps: int = 50, is_upscale: bool = False
|
99 |
+
):
|
100 |
# Configure scheduler
|
101 |
scheduler = EulerDiscreteScheduler.from_pretrained(path_scheduler)
|
102 |
scheduler.is_scale_input_called = True # suppress warning
|
103 |
scheduler.set_timesteps(num_inference_steps)
|
104 |
|
105 |
+
# Set seed for reproducibility
|
106 |
generator = torch.Generator(device=device).manual_seed(seed)
|
107 |
x = torch.randn(1, 4, 64, 64, generator=generator, device=device)
|
108 |
|
|
|
111 |
inputs = {k: v.to(img_enc.device) for k, v in img_processor(images=cond_image, return_tensors="pt").items()}
|
112 |
cond_emb = img_enc(**inputs).last_hidden_state.to(device)
|
113 |
|
114 |
+
# Prepare unconditioned embeddings (only if guidance is enabled)
|
115 |
uncond_emb = torch.zeros_like(cond_emb) if guidance_scale > 1 else None
|
116 |
|
117 |
+
# Diffusion denoising loop with mixed precision for efficiency
|
118 |
with torch.autocast(device):
|
119 |
for t in scheduler.timesteps:
|
120 |
if guidance_scale > 1:
|
121 |
+
# Classifier-Free Guidance (CFG)
|
122 |
+
noise_pred = net(torch.cat([x] * 2), t, torch.cat([uncond_emb, cond_emb])).chunk(2)
|
|
|
123 |
noise_pred = noise_pred[0] + guidance_scale * (noise_pred[1] - noise_pred[0])
|
124 |
else:
|
125 |
+
# Standard prediction
|
126 |
noise_pred = net(x, t, cond_emb)
|
127 |
|
128 |
+
# Scheduler step
|
129 |
scheduler_output = scheduler.step(noise_pred, t, x)
|
130 |
x = scheduler_output.prev_sample
|
131 |
|
132 |
+
# Decode predictions from latent space
|
133 |
decoded = vae.decode(1 / 0.18215 * scheduler_output.pred_original_sample).sample
|
134 |
+
images = (decoded / 2 + 0.5).cpu()
|
135 |
+
|
136 |
# Create grid
|
137 |
grid = make_grid(images, nrow=len(images), normalize=True, scale_each=True)
|
138 |
+
output_image = transforms.ToPILImage()(grid)
|
139 |
+
|
140 |
+
# Optionally upscale the output image
|
141 |
if is_upscale:
|
142 |
+
output_image = upscaler(output_image)
|
143 |
+
|
144 |
+
return output_image
|
145 |
|
146 |
|
147 |
title = "Virtual Try-Off Generator"
|
148 |
description = r"""
|
149 |
+
This is the demo of the paper <a href="https://arxiv.org/abs/2411.18350">TryOffDiff: Virtual-Try-Off via High-Fidelity Garment Reconstruction using Diffusion Models</a>.
|
150 |
+
<br>Upload an image of a clothed individual to generate a standardized garment image using TryOffDiff.
|
151 |
<br> Check out the <a href="https://rizavelioglu.github.io/tryoffdiff/">project page</a> for more information.
|
152 |
"""
|
153 |
article = r"""
|
154 |
Example images are sampled from the `VITON-HD-test` set, which the models did not see during training.
|
155 |
|
156 |
+
<br>**Citation** <br>If you find our work useful in your research, please consider giving a star ⭐ and
|
157 |
+
a citation:
|
158 |
```
|
159 |
@article{velioglu2024tryoffdiff,
|
160 |
title = {TryOffDiff: Virtual-Try-Off via High-Fidelity Garment Reconstruction using Diffusion Models},
|
|
|
171 |
demo = gr.Interface(
|
172 |
fn=generate_image,
|
173 |
inputs=[
|
174 |
+
gr.Image(type="filepath", label="Reference Image", height=1024, width=1024),
|
175 |
gr.Slider(value=42, minimum=0, maximum=1e6, step=1, label="Seed"),
|
176 |
+
gr.Slider(
|
177 |
+
value=2.0,
|
178 |
+
minimum=1,
|
179 |
+
maximum=5,
|
180 |
+
step=0.5,
|
181 |
+
label="Guidance Scale(s)",
|
182 |
+
info="No guidance applied at s=1, hence faster inference.",
|
183 |
+
),
|
184 |
gr.Slider(value=20, minimum=0, maximum=1000, step=10, label="# of Inference Steps"),
|
185 |
+
gr.Checkbox(
|
186 |
+
value=False, label="Upscale Output", info="Upscale output by 4x (2048x2048) using an off-the-shelf model."
|
187 |
+
),
|
188 |
],
|
189 |
+
outputs=gr.Image(type="pil", label="Generated Garment", height=1024, width=1024),
|
190 |
title=title,
|
191 |
description=description,
|
192 |
article=article,
|
193 |
examples=examples,
|
194 |
examples_per_page=4,
|
195 |
+
submit_btn="Generate",
|
196 |
)
|
197 |
|
198 |
demo.launch()
|
esrgan_model.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Taken from https://github.com/finegrain-ai/refiners
|
3 |
+
Modified from https://github.com/philz1337x/clarity-upscaler
|
4 |
+
which is a copy of https://github.com/AUTOMATIC1111/stable-diffusion-webui
|
5 |
+
which is a copy of https://github.com/victorca25/iNNfer
|
6 |
+
which is a copy of https://github.com/xinntao/ESRGAN
|
7 |
+
"""
|
8 |
+
|
9 |
+
import math
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import NamedTuple
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import numpy.typing as npt
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
from PIL import Image
|
18 |
+
from huggingface_hub import hf_hub_download
|
19 |
+
|
20 |
+
|
21 |
+
def conv_block(in_nc: int, out_nc: int) -> nn.Sequential:
|
22 |
+
return nn.Sequential(
|
23 |
+
nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1),
|
24 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
class ResidualDenseBlock_5C(nn.Module):
|
29 |
+
"""
|
30 |
+
Residual Dense Block
|
31 |
+
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
32 |
+
Modified options that can be used:
|
33 |
+
- "Partial Convolution based Padding" arXiv:1811.11718
|
34 |
+
- "Spectral normalization" arXiv:1802.05957
|
35 |
+
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
36 |
+
{Rakotonirina} and A. {Rasoanaivo}
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, nf: int = 64, gc: int = 32) -> None:
|
40 |
+
super().__init__() # type: ignore[reportUnknownMemberType]
|
41 |
+
|
42 |
+
self.conv1 = conv_block(nf, gc)
|
43 |
+
self.conv2 = conv_block(nf + gc, gc)
|
44 |
+
self.conv3 = conv_block(nf + 2 * gc, gc)
|
45 |
+
self.conv4 = conv_block(nf + 3 * gc, gc)
|
46 |
+
# Wrapped in Sequential because of key in state dict.
|
47 |
+
self.conv5 = nn.Sequential(nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1))
|
48 |
+
|
49 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
50 |
+
x1 = self.conv1(x)
|
51 |
+
x2 = self.conv2(torch.cat((x, x1), 1))
|
52 |
+
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
53 |
+
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
54 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
55 |
+
return x5 * 0.2 + x
|
56 |
+
|
57 |
+
|
58 |
+
class RRDB(nn.Module):
|
59 |
+
"""
|
60 |
+
Residual in Residual Dense Block
|
61 |
+
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(self, nf: int) -> None:
|
65 |
+
super().__init__() # type: ignore[reportUnknownMemberType]
|
66 |
+
self.RDB1 = ResidualDenseBlock_5C(nf)
|
67 |
+
self.RDB2 = ResidualDenseBlock_5C(nf)
|
68 |
+
self.RDB3 = ResidualDenseBlock_5C(nf)
|
69 |
+
|
70 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
71 |
+
out = self.RDB1(x)
|
72 |
+
out = self.RDB2(out)
|
73 |
+
out = self.RDB3(out)
|
74 |
+
return out * 0.2 + x
|
75 |
+
|
76 |
+
|
77 |
+
class Upsample2x(nn.Module):
|
78 |
+
"""Upsample 2x."""
|
79 |
+
|
80 |
+
def __init__(self) -> None:
|
81 |
+
super().__init__() # type: ignore[reportUnknownMemberType]
|
82 |
+
|
83 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
84 |
+
return nn.functional.interpolate(x, scale_factor=2.0) # type: ignore
|
85 |
+
|
86 |
+
|
87 |
+
class ShortcutBlock(nn.Module):
|
88 |
+
"""Elementwise sum the output of a submodule to its input"""
|
89 |
+
|
90 |
+
def __init__(self, submodule: nn.Module) -> None:
|
91 |
+
super().__init__() # type: ignore[reportUnknownMemberType]
|
92 |
+
self.sub = submodule
|
93 |
+
|
94 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
95 |
+
return x + self.sub(x)
|
96 |
+
|
97 |
+
|
98 |
+
class RRDBNet(nn.Module):
|
99 |
+
def __init__(self, in_nc: int, out_nc: int, nf: int, nb: int) -> None:
|
100 |
+
super().__init__() # type: ignore[reportUnknownMemberType]
|
101 |
+
assert in_nc % 4 != 0 # in_nc is 3
|
102 |
+
|
103 |
+
self.model = nn.Sequential(
|
104 |
+
nn.Conv2d(in_nc, nf, kernel_size=3, padding=1),
|
105 |
+
ShortcutBlock(
|
106 |
+
nn.Sequential(
|
107 |
+
*(RRDB(nf) for _ in range(nb)),
|
108 |
+
nn.Conv2d(nf, nf, kernel_size=3, padding=1),
|
109 |
+
)
|
110 |
+
),
|
111 |
+
Upsample2x(),
|
112 |
+
nn.Conv2d(nf, nf, kernel_size=3, padding=1),
|
113 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
114 |
+
Upsample2x(),
|
115 |
+
nn.Conv2d(nf, nf, kernel_size=3, padding=1),
|
116 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
117 |
+
nn.Conv2d(nf, nf, kernel_size=3, padding=1),
|
118 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
119 |
+
nn.Conv2d(nf, out_nc, kernel_size=3, padding=1),
|
120 |
+
)
|
121 |
+
|
122 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
123 |
+
return self.model(x)
|
124 |
+
|
125 |
+
|
126 |
+
def infer_params(state_dict: dict[str, torch.Tensor]) -> tuple[int, int, int, int, int]:
|
127 |
+
# this code is adapted from https://github.com/victorca25/iNNfer
|
128 |
+
scale2x = 0
|
129 |
+
scalemin = 6
|
130 |
+
n_uplayer = 0
|
131 |
+
out_nc = 0
|
132 |
+
nb = 0
|
133 |
+
|
134 |
+
for block in list(state_dict):
|
135 |
+
parts = block.split(".")
|
136 |
+
n_parts = len(parts)
|
137 |
+
if n_parts == 5 and parts[2] == "sub":
|
138 |
+
nb = int(parts[3])
|
139 |
+
elif n_parts == 3:
|
140 |
+
part_num = int(parts[1])
|
141 |
+
if part_num > scalemin and parts[0] == "model" and parts[2] == "weight":
|
142 |
+
scale2x += 1
|
143 |
+
if part_num > n_uplayer:
|
144 |
+
n_uplayer = part_num
|
145 |
+
out_nc = state_dict[block].shape[0]
|
146 |
+
assert "conv1x1" not in block # no ESRGANPlus
|
147 |
+
|
148 |
+
nf = state_dict["model.0.weight"].shape[0]
|
149 |
+
in_nc = state_dict["model.0.weight"].shape[1]
|
150 |
+
scale = 2**scale2x
|
151 |
+
|
152 |
+
assert out_nc > 0
|
153 |
+
assert nb > 0
|
154 |
+
|
155 |
+
return in_nc, out_nc, nf, nb, scale # 3, 3, 64, 23, 4
|
156 |
+
|
157 |
+
|
158 |
+
Tile = tuple[int, int, Image.Image]
|
159 |
+
Tiles = list[tuple[int, int, list[Tile]]]
|
160 |
+
|
161 |
+
|
162 |
+
# https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L64
|
163 |
+
class Grid(NamedTuple):
|
164 |
+
tiles: Tiles
|
165 |
+
tile_w: int
|
166 |
+
tile_h: int
|
167 |
+
image_w: int
|
168 |
+
image_h: int
|
169 |
+
overlap: int
|
170 |
+
|
171 |
+
|
172 |
+
# adapted from https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L67
|
173 |
+
def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
|
174 |
+
w = image.width
|
175 |
+
h = image.height
|
176 |
+
|
177 |
+
non_overlap_width = tile_w - overlap
|
178 |
+
non_overlap_height = tile_h - overlap
|
179 |
+
|
180 |
+
cols = max(1, math.ceil((w - overlap) / non_overlap_width))
|
181 |
+
rows = max(1, math.ceil((h - overlap) / non_overlap_height))
|
182 |
+
|
183 |
+
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
|
184 |
+
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
|
185 |
+
|
186 |
+
grid = Grid([], tile_w, tile_h, w, h, overlap)
|
187 |
+
for row in range(rows):
|
188 |
+
row_images: list[Tile] = []
|
189 |
+
y1 = max(min(int(row * dy), h - tile_h), 0)
|
190 |
+
y2 = min(y1 + tile_h, h)
|
191 |
+
for col in range(cols):
|
192 |
+
x1 = max(min(int(col * dx), w - tile_w), 0)
|
193 |
+
x2 = min(x1 + tile_w, w)
|
194 |
+
tile = image.crop((x1, y1, x2, y2))
|
195 |
+
row_images.append((x1, tile_w, tile))
|
196 |
+
grid.tiles.append((y1, tile_h, row_images))
|
197 |
+
|
198 |
+
return grid
|
199 |
+
|
200 |
+
|
201 |
+
# https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L104
|
202 |
+
def combine_grid(grid: Grid):
|
203 |
+
def make_mask_image(r: npt.NDArray[np.float32]) -> Image.Image:
|
204 |
+
r = r * 255 / grid.overlap
|
205 |
+
return Image.fromarray(r.astype(np.uint8), "L")
|
206 |
+
|
207 |
+
mask_w = make_mask_image(
|
208 |
+
np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)
|
209 |
+
)
|
210 |
+
mask_h = make_mask_image(
|
211 |
+
np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1)
|
212 |
+
)
|
213 |
+
|
214 |
+
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
|
215 |
+
for y, h, row in grid.tiles:
|
216 |
+
combined_row = Image.new("RGB", (grid.image_w, h))
|
217 |
+
for x, w, tile in row:
|
218 |
+
if x == 0:
|
219 |
+
combined_row.paste(tile, (0, 0))
|
220 |
+
continue
|
221 |
+
|
222 |
+
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
|
223 |
+
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
|
224 |
+
|
225 |
+
if y == 0:
|
226 |
+
combined_image.paste(combined_row, (0, 0))
|
227 |
+
continue
|
228 |
+
|
229 |
+
combined_image.paste(
|
230 |
+
combined_row.crop((0, 0, combined_row.width, grid.overlap)),
|
231 |
+
(0, y),
|
232 |
+
mask=mask_h,
|
233 |
+
)
|
234 |
+
combined_image.paste(
|
235 |
+
combined_row.crop((0, grid.overlap, combined_row.width, h)),
|
236 |
+
(0, y + grid.overlap),
|
237 |
+
)
|
238 |
+
|
239 |
+
return combined_image
|
240 |
+
|
241 |
+
|
242 |
+
class UpscalerESRGAN:
|
243 |
+
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
|
244 |
+
self.model_path = model_path
|
245 |
+
self.device = device
|
246 |
+
self.model = self.load_model(model_path)
|
247 |
+
self.to(device, dtype)
|
248 |
+
|
249 |
+
def __call__(self, img: Image.Image) -> Image.Image:
|
250 |
+
return self.upscale_without_tiling(img)
|
251 |
+
|
252 |
+
def to(self, device: torch.device, dtype: torch.dtype):
|
253 |
+
self.device = device
|
254 |
+
self.dtype = dtype
|
255 |
+
self.model.to(device=device, dtype=dtype)
|
256 |
+
|
257 |
+
def load_model(self, path: Path) -> RRDBNet:
|
258 |
+
filename = path
|
259 |
+
state_dict: dict[str, torch.Tensor] = torch.load(filename, weights_only=True, map_location=self.device) # type: ignore
|
260 |
+
in_nc, out_nc, nf, nb, upscale = infer_params(state_dict)
|
261 |
+
assert upscale == 4, "Only 4x upscaling is supported"
|
262 |
+
model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb)
|
263 |
+
model.load_state_dict(state_dict)
|
264 |
+
model.eval()
|
265 |
+
|
266 |
+
return model
|
267 |
+
|
268 |
+
def upscale_without_tiling(self, img: Image.Image) -> Image.Image:
|
269 |
+
img_np = np.array(img)
|
270 |
+
img_np = img_np[:, :, ::-1]
|
271 |
+
img_np = np.ascontiguousarray(np.transpose(img_np, (2, 0, 1))) / 255
|
272 |
+
img_t = torch.from_numpy(img_np).float() # type: ignore
|
273 |
+
img_t = img_t.unsqueeze(0).to(device=self.device, dtype=self.dtype)
|
274 |
+
with torch.no_grad():
|
275 |
+
output = self.model(img_t)
|
276 |
+
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
277 |
+
output = 255.0 * np.moveaxis(output, 0, 2)
|
278 |
+
output = output.astype(np.uint8)
|
279 |
+
output = output[:, :, ::-1]
|
280 |
+
return Image.fromarray(output, "RGB")
|
281 |
+
|
282 |
+
# https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/esrgan_model.py#L208
|
283 |
+
def upscale_with_tiling(self, img: Image.Image) -> Image.Image:
|
284 |
+
img = img.convert("RGB")
|
285 |
+
grid = split_grid(img)
|
286 |
+
newtiles: Tiles = []
|
287 |
+
scale_factor: int = 1
|
288 |
+
|
289 |
+
for y, h, row in grid.tiles:
|
290 |
+
newrow: list[Tile] = []
|
291 |
+
for tiledata in row:
|
292 |
+
x, w, tile = tiledata
|
293 |
+
output = self.upscale_without_tiling(tile)
|
294 |
+
scale_factor = output.width // tile.width
|
295 |
+
newrow.append((x * scale_factor, w * scale_factor, output))
|
296 |
+
newtiles.append((y * scale_factor, h * scale_factor, newrow))
|
297 |
+
|
298 |
+
newgrid = Grid(
|
299 |
+
newtiles,
|
300 |
+
grid.tile_w * scale_factor,
|
301 |
+
grid.tile_h * scale_factor,
|
302 |
+
grid.image_w * scale_factor,
|
303 |
+
grid.image_h * scale_factor,
|
304 |
+
grid.overlap * scale_factor,
|
305 |
+
)
|
306 |
+
output = combine_grid(newgrid)
|
307 |
+
return output
|