vae-roundtrip / app.py
gustproof's picture
Duplicate from gustproof/sd-vae-roundtrip
b54be75
raw
history blame
4.04 kB
import gradio as gr
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
from PIL import Image
import PIL
import torch
import numpy as np
model_path = "Linaqruf/anything-v3.0"
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae")
print(f"vae loaded from {model_path}")
def snap(w, h, d=64, area=640 * 640):
s = min(1.0, (area / w / h) ** 0.5)
err = lambda a, b: 1 - min(a, b) / max(a, b)
sw, sh = map(lambda x: int((x * s) // d * d), (w, h))
return min(
(
(ww, hh)
for ww, hh in [(sw, sh), (sw, sh + d), (sw + d, sh), (sw + d, sh + d)]
if ww * hh <= area
),
key=lambda wh: err(w / h, wh[0] / wh[1]),
)
def center_crop_image(image, hx, wx):
# Get the original image dimensions (HxW)
original_width, original_height = image.size
# Calculate the coordinates for center cropping
if original_width / original_height > wx / hx:
ww = original_height * wx / hx
left, right, top, bottom = (
(original_width - ww) / 2,
(original_width + ww) / 2,
0,
original_height,
)
else:
hh = original_width * hx / wx
left, right, top, bottom = (
0,
original_width,
(original_height - hh) / 2,
(original_height + hh) / 2,
)
# Crop the image
cropped_image = image.crop((left, top, right, bottom))
# Resize the cropped image to the target size (hxw)
cropped_image = cropped_image.resize((wx, hx), Image.Resampling.LANCZOS)
return cropped_image
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
image = [np.array(i)[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def postprocess_image(sample: torch.FloatTensor, output_type: str = "pil"):
if output_type not in ["pt", "np", "pil"]:
raise ValueError(
f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
)
# Equivalent to diffusers.VaeImageProcessor.denormalize
sample = (sample / 2 + 0.5).clamp(0, 1)
if output_type == "pt":
return sample
# Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "np":
return sample
# Output_type must be 'pil'
sample = numpy_to_pil(sample)
return sample
def vae_roundtrip(image, max_resolution: int):
w, h = image.size
ww, hh = snap(w, h, area=max_resolution**2)
cropped = center_crop_image(image, hh, ww)
image = preprocess(cropped)
with torch.no_grad():
dist = vae.encode(image)[0]
res = vae.decode(dist.mean, return_dict=False)[0]
return cropped, postprocess_image(res)[0]
iface = gr.Interface(
fn=vae_roundtrip,
inputs=[gr.Image(type="pil"), gr.Slider(384, 1024, step=64, value=640)],
outputs=[gr.Image(label="center cropped"), gr.Image(label="after roundtrip")],
allow_flagging="never",
)
iface.launch()