Spaces:
Runtime error
Runtime error
gustproof
commited on
Commit
·
b54be75
0
Parent(s):
Duplicate from gustproof/sd-vae-roundtrip
Browse files- .gitattributes +34 -0
- README.md +13 -0
- app.py +131 -0
- requirements.txt +5 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Sd Vae Roundtrip
|
3 |
+
emoji: 📊
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.28.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
duplicated_from: gustproof/sd-vae-roundtrip
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
4 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
|
5 |
+
from PIL import Image
|
6 |
+
import PIL
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
model_path = "Linaqruf/anything-v3.0"
|
11 |
+
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae")
|
12 |
+
print(f"vae loaded from {model_path}")
|
13 |
+
|
14 |
+
|
15 |
+
def snap(w, h, d=64, area=640 * 640):
|
16 |
+
s = min(1.0, (area / w / h) ** 0.5)
|
17 |
+
err = lambda a, b: 1 - min(a, b) / max(a, b)
|
18 |
+
sw, sh = map(lambda x: int((x * s) // d * d), (w, h))
|
19 |
+
return min(
|
20 |
+
(
|
21 |
+
(ww, hh)
|
22 |
+
for ww, hh in [(sw, sh), (sw, sh + d), (sw + d, sh), (sw + d, sh + d)]
|
23 |
+
if ww * hh <= area
|
24 |
+
),
|
25 |
+
key=lambda wh: err(w / h, wh[0] / wh[1]),
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
def center_crop_image(image, hx, wx):
|
30 |
+
# Get the original image dimensions (HxW)
|
31 |
+
original_width, original_height = image.size
|
32 |
+
|
33 |
+
# Calculate the coordinates for center cropping
|
34 |
+
if original_width / original_height > wx / hx:
|
35 |
+
ww = original_height * wx / hx
|
36 |
+
left, right, top, bottom = (
|
37 |
+
(original_width - ww) / 2,
|
38 |
+
(original_width + ww) / 2,
|
39 |
+
0,
|
40 |
+
original_height,
|
41 |
+
)
|
42 |
+
else:
|
43 |
+
hh = original_width * hx / wx
|
44 |
+
left, right, top, bottom = (
|
45 |
+
0,
|
46 |
+
original_width,
|
47 |
+
(original_height - hh) / 2,
|
48 |
+
(original_height + hh) / 2,
|
49 |
+
)
|
50 |
+
# Crop the image
|
51 |
+
cropped_image = image.crop((left, top, right, bottom))
|
52 |
+
|
53 |
+
# Resize the cropped image to the target size (hxw)
|
54 |
+
cropped_image = cropped_image.resize((wx, hx), Image.Resampling.LANCZOS)
|
55 |
+
|
56 |
+
return cropped_image
|
57 |
+
|
58 |
+
|
59 |
+
def preprocess(image):
|
60 |
+
if isinstance(image, torch.Tensor):
|
61 |
+
return image
|
62 |
+
elif isinstance(image, PIL.Image.Image):
|
63 |
+
image = [image]
|
64 |
+
|
65 |
+
if isinstance(image[0], PIL.Image.Image):
|
66 |
+
image = [np.array(i)[None, :] for i in image]
|
67 |
+
image = np.concatenate(image, axis=0)
|
68 |
+
image = np.array(image).astype(np.float32) / 255.0
|
69 |
+
image = image.transpose(0, 3, 1, 2)
|
70 |
+
image = 2.0 * image - 1.0
|
71 |
+
image = torch.from_numpy(image)
|
72 |
+
elif isinstance(image[0], torch.Tensor):
|
73 |
+
image = torch.cat(image, dim=0)
|
74 |
+
return image
|
75 |
+
|
76 |
+
|
77 |
+
def numpy_to_pil(images):
|
78 |
+
"""
|
79 |
+
Convert a numpy image or a batch of images to a PIL image.
|
80 |
+
"""
|
81 |
+
if images.ndim == 3:
|
82 |
+
images = images[None, ...]
|
83 |
+
images = (images * 255).round().astype("uint8")
|
84 |
+
if images.shape[-1] == 1:
|
85 |
+
# special case for grayscale (single channel) images
|
86 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
87 |
+
else:
|
88 |
+
pil_images = [Image.fromarray(image) for image in images]
|
89 |
+
|
90 |
+
return pil_images
|
91 |
+
|
92 |
+
|
93 |
+
def postprocess_image(sample: torch.FloatTensor, output_type: str = "pil"):
|
94 |
+
if output_type not in ["pt", "np", "pil"]:
|
95 |
+
raise ValueError(
|
96 |
+
f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
|
97 |
+
)
|
98 |
+
|
99 |
+
# Equivalent to diffusers.VaeImageProcessor.denormalize
|
100 |
+
sample = (sample / 2 + 0.5).clamp(0, 1)
|
101 |
+
if output_type == "pt":
|
102 |
+
return sample
|
103 |
+
|
104 |
+
# Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
|
105 |
+
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
106 |
+
if output_type == "np":
|
107 |
+
return sample
|
108 |
+
|
109 |
+
# Output_type must be 'pil'
|
110 |
+
sample = numpy_to_pil(sample)
|
111 |
+
return sample
|
112 |
+
|
113 |
+
|
114 |
+
def vae_roundtrip(image, max_resolution: int):
|
115 |
+
w, h = image.size
|
116 |
+
ww, hh = snap(w, h, area=max_resolution**2)
|
117 |
+
cropped = center_crop_image(image, hh, ww)
|
118 |
+
image = preprocess(cropped)
|
119 |
+
with torch.no_grad():
|
120 |
+
dist = vae.encode(image)[0]
|
121 |
+
res = vae.decode(dist.mean, return_dict=False)[0]
|
122 |
+
return cropped, postprocess_image(res)[0]
|
123 |
+
|
124 |
+
|
125 |
+
iface = gr.Interface(
|
126 |
+
fn=vae_roundtrip,
|
127 |
+
inputs=[gr.Image(type="pil"), gr.Slider(384, 1024, step=64, value=640)],
|
128 |
+
outputs=[gr.Image(label="center cropped"), gr.Image(label="after roundtrip")],
|
129 |
+
allow_flagging="never",
|
130 |
+
)
|
131 |
+
iface.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
diffusers
|
3 |
+
accelerate
|
4 |
+
torch
|
5 |
+
numpy
|