vae-roundtrip / app.py
gustproof's picture
Duplicate from gustproof/sd-vae-roundtrip
b54be75
import gradio as gr
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
from PIL import Image
import PIL
import torch
import numpy as np
model_path = "Linaqruf/anything-v3.0"
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae")
print(f"vae loaded from {model_path}")
def snap(w, h, d=64, area=640 * 640):
s = min(1.0, (area / w / h) ** 0.5)
err = lambda a, b: 1 - min(a, b) / max(a, b)
sw, sh = map(lambda x: int((x * s) // d * d), (w, h))
return min(
(
(ww, hh)
for ww, hh in [(sw, sh), (sw, sh + d), (sw + d, sh), (sw + d, sh + d)]
if ww * hh <= area
),
key=lambda wh: err(w / h, wh[0] / wh[1]),
)
def center_crop_image(image, hx, wx):
# Get the original image dimensions (HxW)
original_width, original_height = image.size
# Calculate the coordinates for center cropping
if original_width / original_height > wx / hx:
ww = original_height * wx / hx
left, right, top, bottom = (
(original_width - ww) / 2,
(original_width + ww) / 2,
0,
original_height,
)
else:
hh = original_width * hx / wx
left, right, top, bottom = (
0,
original_width,
(original_height - hh) / 2,
(original_height + hh) / 2,
)
# Crop the image
cropped_image = image.crop((left, top, right, bottom))
# Resize the cropped image to the target size (hxw)
cropped_image = cropped_image.resize((wx, hx), Image.Resampling.LANCZOS)
return cropped_image
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
image = [np.array(i)[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def postprocess_image(sample: torch.FloatTensor, output_type: str = "pil"):
if output_type not in ["pt", "np", "pil"]:
raise ValueError(
f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
)
# Equivalent to diffusers.VaeImageProcessor.denormalize
sample = (sample / 2 + 0.5).clamp(0, 1)
if output_type == "pt":
return sample
# Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "np":
return sample
# Output_type must be 'pil'
sample = numpy_to_pil(sample)
return sample
def vae_roundtrip(image, max_resolution: int):
w, h = image.size
ww, hh = snap(w, h, area=max_resolution**2)
cropped = center_crop_image(image, hh, ww)
image = preprocess(cropped)
with torch.no_grad():
dist = vae.encode(image)[0]
res = vae.decode(dist.mean, return_dict=False)[0]
return cropped, postprocess_image(res)[0]
iface = gr.Interface(
fn=vae_roundtrip,
inputs=[gr.Image(type="pil"), gr.Slider(384, 1024, step=64, value=640)],
outputs=[gr.Image(label="center cropped"), gr.Image(label="after roundtrip")],
allow_flagging="never",
)
iface.launch()