File size: 4,042 Bytes
b54be75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr

from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
from PIL import Image
import PIL
import torch
import numpy as np

model_path = "Linaqruf/anything-v3.0"
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae")
print(f"vae loaded from {model_path}")


def snap(w, h, d=64, area=640 * 640):
    s = min(1.0, (area / w / h) ** 0.5)
    err = lambda a, b: 1 - min(a, b) / max(a, b)
    sw, sh = map(lambda x: int((x * s) // d * d), (w, h))
    return min(
        (
            (ww, hh)
            for ww, hh in [(sw, sh), (sw, sh + d), (sw + d, sh), (sw + d, sh + d)]
            if ww * hh <= area
        ),
        key=lambda wh: err(w / h, wh[0] / wh[1]),
    )


def center_crop_image(image, hx, wx):
    # Get the original image dimensions (HxW)
    original_width, original_height = image.size

    # Calculate the coordinates for center cropping
    if original_width / original_height > wx / hx:
        ww = original_height * wx / hx
        left, right, top, bottom = (
            (original_width - ww) / 2,
            (original_width + ww) / 2,
            0,
            original_height,
        )
    else:
        hh = original_width * hx / wx
        left, right, top, bottom = (
            0,
            original_width,
            (original_height - hh) / 2,
            (original_height + hh) / 2,
        )
    # Crop the image
    cropped_image = image.crop((left, top, right, bottom))

    # Resize the cropped image to the target size (hxw)
    cropped_image = cropped_image.resize((wx, hx), Image.Resampling.LANCZOS)

    return cropped_image


def preprocess(image):
    if isinstance(image, torch.Tensor):
        return image
    elif isinstance(image, PIL.Image.Image):
        image = [image]

    if isinstance(image[0], PIL.Image.Image):
        image = [np.array(i)[None, :] for i in image]
        image = np.concatenate(image, axis=0)
        image = np.array(image).astype(np.float32) / 255.0
        image = image.transpose(0, 3, 1, 2)
        image = 2.0 * image - 1.0
        image = torch.from_numpy(image)
    elif isinstance(image[0], torch.Tensor):
        image = torch.cat(image, dim=0)
    return image


def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    if images.shape[-1] == 1:
        # special case for grayscale (single channel) images
        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
    else:
        pil_images = [Image.fromarray(image) for image in images]

    return pil_images


def postprocess_image(sample: torch.FloatTensor, output_type: str = "pil"):
    if output_type not in ["pt", "np", "pil"]:
        raise ValueError(
            f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
        )

    # Equivalent to diffusers.VaeImageProcessor.denormalize
    sample = (sample / 2 + 0.5).clamp(0, 1)
    if output_type == "pt":
        return sample

    # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
    sample = sample.cpu().permute(0, 2, 3, 1).numpy()
    if output_type == "np":
        return sample

    # Output_type must be 'pil'
    sample = numpy_to_pil(sample)
    return sample


def vae_roundtrip(image, max_resolution: int):
    w, h = image.size
    ww, hh = snap(w, h, area=max_resolution**2)
    cropped = center_crop_image(image, hh, ww)
    image = preprocess(cropped)
    with torch.no_grad():
        dist = vae.encode(image)[0]
        res = vae.decode(dist.mean, return_dict=False)[0]
    return cropped, postprocess_image(res)[0]


iface = gr.Interface(
    fn=vae_roundtrip,
    inputs=[gr.Image(type="pil"), gr.Slider(384, 1024, step=64, value=640)],
    outputs=[gr.Image(label="center cropped"), gr.Image(label="after roundtrip")],
    allow_flagging="never",
)
iface.launch()