Spaces:
Running
Running
File size: 4,595 Bytes
4dfb78b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import collections.abc as collections
from pathlib import Path
from typing import Optional, Tuple
import cv2
import kornia
import numpy as np
import torch
from omegaconf import OmegaConf
class ImagePreprocessor:
default_conf = {
"resize": None, # target edge length, None for no resizing
"edge_divisible_by": None,
"side": "long",
"interpolation": "bilinear",
"align_corners": None,
"antialias": True,
"square_pad": False,
"add_padding_mask": False,
}
def __init__(self, conf) -> None:
super().__init__()
default_conf = OmegaConf.create(self.default_conf)
OmegaConf.set_struct(default_conf, True)
self.conf = OmegaConf.merge(default_conf, conf)
def __call__(self, img: torch.Tensor, interpolation: Optional[str] = None) -> dict:
"""Resize and preprocess an image, return image and resize scale"""
h, w = img.shape[-2:]
size = h, w
if self.conf.resize is not None:
if interpolation is None:
interpolation = self.conf.interpolation
size = self.get_new_image_size(h, w)
img = kornia.geometry.transform.resize(
img,
size,
side=self.conf.side,
antialias=self.conf.antialias,
align_corners=self.conf.align_corners,
interpolation=interpolation,
)
scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
T = np.diag([scale[0], scale[1], 1])
data = {
"scales": scale,
"image_size": np.array(size[::-1]),
"transform": T,
"original_image_size": np.array([w, h]),
}
if self.conf.square_pad:
sl = max(img.shape[-2:])
data["image"] = torch.zeros(
*img.shape[:-2], sl, sl, device=img.device, dtype=img.dtype
)
data["image"][:, : img.shape[-2], : img.shape[-1]] = img
if self.conf.add_padding_mask:
data["padding_mask"] = torch.zeros(
*img.shape[:-3], 1, sl, sl, device=img.device, dtype=torch.bool
)
data["padding_mask"][:, : img.shape[-2], : img.shape[-1]] = True
else:
data["image"] = img
return data
def load_image(self, image_path: Path) -> dict:
return self(load_image(image_path))
def get_new_image_size(
self,
h: int,
w: int,
) -> Tuple[int, int]:
side = self.conf.side
if isinstance(self.conf.resize, collections.Iterable):
assert len(self.conf.resize) == 2
return tuple(self.conf.resize)
side_size = self.conf.resize
aspect_ratio = w / h
if side not in ("short", "long", "vert", "horz"):
raise ValueError(
f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{side}'"
)
if side == "vert":
size = side_size, int(side_size * aspect_ratio)
elif side == "horz":
size = int(side_size / aspect_ratio), side_size
elif (side == "short") ^ (aspect_ratio < 1.0):
size = side_size, int(side_size * aspect_ratio)
else:
size = int(side_size / aspect_ratio), side_size
if self.conf.edge_divisible_by is not None:
df = self.conf.edge_divisible_by
size = list(map(lambda x: int(x // df * df), size))
return size
def read_image(path: Path, grayscale: bool = False) -> np.ndarray:
"""Read an image from path as RGB or grayscale"""
if not Path(path).exists():
raise FileNotFoundError(f"No image at path {path}.")
mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
image = cv2.imread(str(path), mode)
if image is None:
raise IOError(f"Could not read image at {path}.")
if not grayscale:
image = image[..., ::-1]
return image
def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
"""Normalize the image tensor and reorder the dimensions."""
if image.ndim == 3:
image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
elif image.ndim == 2:
image = image[None] # add channel axis
else:
raise ValueError(f"Not an image: {image.shape}")
return torch.tensor(image / 255.0, dtype=torch.float)
def load_image(path: Path, grayscale=False) -> torch.Tensor:
image = read_image(path, grayscale=grayscale)
return numpy_image_to_torch(image)
|