gustproof commited on
Commit
b54be75
·
0 Parent(s):

Duplicate from gustproof/sd-vae-roundtrip

Browse files
Files changed (4) hide show
  1. .gitattributes +34 -0
  2. README.md +13 -0
  3. app.py +131 -0
  4. requirements.txt +5 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Sd Vae Roundtrip
3
+ emoji: 📊
4
+ colorFrom: red
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 3.28.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: gustproof/sd-vae-roundtrip
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from transformers import CLIPTextModel, CLIPTokenizer
4
+ from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
5
+ from PIL import Image
6
+ import PIL
7
+ import torch
8
+ import numpy as np
9
+
10
+ model_path = "Linaqruf/anything-v3.0"
11
+ vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae")
12
+ print(f"vae loaded from {model_path}")
13
+
14
+
15
+ def snap(w, h, d=64, area=640 * 640):
16
+ s = min(1.0, (area / w / h) ** 0.5)
17
+ err = lambda a, b: 1 - min(a, b) / max(a, b)
18
+ sw, sh = map(lambda x: int((x * s) // d * d), (w, h))
19
+ return min(
20
+ (
21
+ (ww, hh)
22
+ for ww, hh in [(sw, sh), (sw, sh + d), (sw + d, sh), (sw + d, sh + d)]
23
+ if ww * hh <= area
24
+ ),
25
+ key=lambda wh: err(w / h, wh[0] / wh[1]),
26
+ )
27
+
28
+
29
+ def center_crop_image(image, hx, wx):
30
+ # Get the original image dimensions (HxW)
31
+ original_width, original_height = image.size
32
+
33
+ # Calculate the coordinates for center cropping
34
+ if original_width / original_height > wx / hx:
35
+ ww = original_height * wx / hx
36
+ left, right, top, bottom = (
37
+ (original_width - ww) / 2,
38
+ (original_width + ww) / 2,
39
+ 0,
40
+ original_height,
41
+ )
42
+ else:
43
+ hh = original_width * hx / wx
44
+ left, right, top, bottom = (
45
+ 0,
46
+ original_width,
47
+ (original_height - hh) / 2,
48
+ (original_height + hh) / 2,
49
+ )
50
+ # Crop the image
51
+ cropped_image = image.crop((left, top, right, bottom))
52
+
53
+ # Resize the cropped image to the target size (hxw)
54
+ cropped_image = cropped_image.resize((wx, hx), Image.Resampling.LANCZOS)
55
+
56
+ return cropped_image
57
+
58
+
59
+ def preprocess(image):
60
+ if isinstance(image, torch.Tensor):
61
+ return image
62
+ elif isinstance(image, PIL.Image.Image):
63
+ image = [image]
64
+
65
+ if isinstance(image[0], PIL.Image.Image):
66
+ image = [np.array(i)[None, :] for i in image]
67
+ image = np.concatenate(image, axis=0)
68
+ image = np.array(image).astype(np.float32) / 255.0
69
+ image = image.transpose(0, 3, 1, 2)
70
+ image = 2.0 * image - 1.0
71
+ image = torch.from_numpy(image)
72
+ elif isinstance(image[0], torch.Tensor):
73
+ image = torch.cat(image, dim=0)
74
+ return image
75
+
76
+
77
+ def numpy_to_pil(images):
78
+ """
79
+ Convert a numpy image or a batch of images to a PIL image.
80
+ """
81
+ if images.ndim == 3:
82
+ images = images[None, ...]
83
+ images = (images * 255).round().astype("uint8")
84
+ if images.shape[-1] == 1:
85
+ # special case for grayscale (single channel) images
86
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
87
+ else:
88
+ pil_images = [Image.fromarray(image) for image in images]
89
+
90
+ return pil_images
91
+
92
+
93
+ def postprocess_image(sample: torch.FloatTensor, output_type: str = "pil"):
94
+ if output_type not in ["pt", "np", "pil"]:
95
+ raise ValueError(
96
+ f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
97
+ )
98
+
99
+ # Equivalent to diffusers.VaeImageProcessor.denormalize
100
+ sample = (sample / 2 + 0.5).clamp(0, 1)
101
+ if output_type == "pt":
102
+ return sample
103
+
104
+ # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
105
+ sample = sample.cpu().permute(0, 2, 3, 1).numpy()
106
+ if output_type == "np":
107
+ return sample
108
+
109
+ # Output_type must be 'pil'
110
+ sample = numpy_to_pil(sample)
111
+ return sample
112
+
113
+
114
+ def vae_roundtrip(image, max_resolution: int):
115
+ w, h = image.size
116
+ ww, hh = snap(w, h, area=max_resolution**2)
117
+ cropped = center_crop_image(image, hh, ww)
118
+ image = preprocess(cropped)
119
+ with torch.no_grad():
120
+ dist = vae.encode(image)[0]
121
+ res = vae.decode(dist.mean, return_dict=False)[0]
122
+ return cropped, postprocess_image(res)[0]
123
+
124
+
125
+ iface = gr.Interface(
126
+ fn=vae_roundtrip,
127
+ inputs=[gr.Image(type="pil"), gr.Slider(384, 1024, step=64, value=640)],
128
+ outputs=[gr.Image(label="center cropped"), gr.Image(label="after roundtrip")],
129
+ allow_flagging="never",
130
+ )
131
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ diffusers
3
+ accelerate
4
+ torch
5
+ numpy