|
from typing import List |
|
|
|
import PIL.Image |
|
import PIL.ImageOps |
|
from packaging import version |
|
from PIL import Image |
|
|
|
|
|
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): |
|
PIL_INTERPOLATION = { |
|
"linear": PIL.Image.Resampling.BILINEAR, |
|
"bilinear": PIL.Image.Resampling.BILINEAR, |
|
"bicubic": PIL.Image.Resampling.BICUBIC, |
|
"lanczos": PIL.Image.Resampling.LANCZOS, |
|
"nearest": PIL.Image.Resampling.NEAREST, |
|
} |
|
else: |
|
PIL_INTERPOLATION = { |
|
"linear": PIL.Image.LINEAR, |
|
"bilinear": PIL.Image.BILINEAR, |
|
"bicubic": PIL.Image.BICUBIC, |
|
"lanczos": PIL.Image.LANCZOS, |
|
"nearest": PIL.Image.NEAREST, |
|
} |
|
|
|
|
|
def pt_to_pil(images): |
|
""" |
|
Convert a torch image to a PIL image. |
|
""" |
|
images = (images / 2 + 0.5).clamp(0, 1) |
|
images = images.cpu().permute(0, 2, 3, 1).float().numpy() |
|
images = numpy_to_pil(images) |
|
return images |
|
|
|
|
|
def numpy_to_pil(images): |
|
""" |
|
Convert a numpy image or a batch of images to a PIL image. |
|
""" |
|
if images.ndim == 3: |
|
images = images[None, ...] |
|
images = (images * 255).round().astype("uint8") |
|
if images.shape[-1] == 1: |
|
|
|
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] |
|
else: |
|
pil_images = [Image.fromarray(image) for image in images] |
|
|
|
return pil_images |
|
|
|
|
|
def make_image_grid(images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None) -> PIL.Image.Image: |
|
""" |
|
Prepares a single grid of images. Useful for visualization purposes. |
|
""" |
|
assert len(images) == rows * cols |
|
|
|
if resize is not None: |
|
images = [img.resize((resize, resize)) for img in images] |
|
|
|
w, h = images[0].size |
|
grid = Image.new("RGB", size=(cols * w, rows * h)) |
|
|
|
for i, img in enumerate(images): |
|
grid.paste(img, box=(i % cols * w, i // cols * h)) |
|
return grid |
|
|