Spaces:
Runtime error
Runtime error
File size: 4,042 Bytes
b54be75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
from PIL import Image
import PIL
import torch
import numpy as np
model_path = "Linaqruf/anything-v3.0"
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae")
print(f"vae loaded from {model_path}")
def snap(w, h, d=64, area=640 * 640):
s = min(1.0, (area / w / h) ** 0.5)
err = lambda a, b: 1 - min(a, b) / max(a, b)
sw, sh = map(lambda x: int((x * s) // d * d), (w, h))
return min(
(
(ww, hh)
for ww, hh in [(sw, sh), (sw, sh + d), (sw + d, sh), (sw + d, sh + d)]
if ww * hh <= area
),
key=lambda wh: err(w / h, wh[0] / wh[1]),
)
def center_crop_image(image, hx, wx):
# Get the original image dimensions (HxW)
original_width, original_height = image.size
# Calculate the coordinates for center cropping
if original_width / original_height > wx / hx:
ww = original_height * wx / hx
left, right, top, bottom = (
(original_width - ww) / 2,
(original_width + ww) / 2,
0,
original_height,
)
else:
hh = original_width * hx / wx
left, right, top, bottom = (
0,
original_width,
(original_height - hh) / 2,
(original_height + hh) / 2,
)
# Crop the image
cropped_image = image.crop((left, top, right, bottom))
# Resize the cropped image to the target size (hxw)
cropped_image = cropped_image.resize((wx, hx), Image.Resampling.LANCZOS)
return cropped_image
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
image = [np.array(i)[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def postprocess_image(sample: torch.FloatTensor, output_type: str = "pil"):
if output_type not in ["pt", "np", "pil"]:
raise ValueError(
f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
)
# Equivalent to diffusers.VaeImageProcessor.denormalize
sample = (sample / 2 + 0.5).clamp(0, 1)
if output_type == "pt":
return sample
# Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "np":
return sample
# Output_type must be 'pil'
sample = numpy_to_pil(sample)
return sample
def vae_roundtrip(image, max_resolution: int):
w, h = image.size
ww, hh = snap(w, h, area=max_resolution**2)
cropped = center_crop_image(image, hh, ww)
image = preprocess(cropped)
with torch.no_grad():
dist = vae.encode(image)[0]
res = vae.decode(dist.mean, return_dict=False)[0]
return cropped, postprocess_image(res)[0]
iface = gr.Interface(
fn=vae_roundtrip,
inputs=[gr.Image(type="pil"), gr.Slider(384, 1024, step=64, value=640)],
outputs=[gr.Image(label="center cropped"), gr.Image(label="after roundtrip")],
allow_flagging="never",
)
iface.launch()
|