Vincentqyw
fix: roma
358ab8f
raw
history blame
19 kB
# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use
import pdb
import numpy as np
from PIL import Image, ImageOps
import torchvision.transforms as tvf
import random
from math import ceil
from . import transforms_tools as F
"""
Example command to try out some transformation chain:
python -m tools.transforms --trfs "Scale(384), ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1), RandomRotation(10), RandomTilting(0.5, 'all'), RandomScale(240,320), RandomCrop(224)"
"""
def instanciate_transformation(cmd_line):
"""Create a sequence of transformations.
cmd_line: (str)
Comma-separated list of transformations.
Ex: "Rotate(10), Scale(256)"
"""
if not isinstance(cmd_line, str):
return cmd_line # already instanciated
cmd_line = "tvf.Compose([%s])" % cmd_line
try:
return eval(cmd_line)
except Exception as e:
print("Cannot interpret this transform list: %s\nReason: %s" % (cmd_line, e))
class Scale(object):
"""Rescale the input PIL.Image to a given size.
Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
The smallest dimension of the resulting image will be = size.
if largest == True: same behaviour for the largest dimension.
if not can_upscale: don't upscale
if not can_downscale: don't downscale
"""
def __init__(
self,
size,
interpolation=Image.BILINEAR,
largest=False,
can_upscale=True,
can_downscale=True,
):
assert isinstance(size, int) or (len(size) == 2)
self.size = size
self.interpolation = interpolation
self.largest = largest
self.can_upscale = can_upscale
self.can_downscale = can_downscale
def __repr__(self):
fmt_str = "RandomScale(%s" % str(self.size)
if self.largest:
fmt_str += ", largest=True"
if not self.can_upscale:
fmt_str += ", can_upscale=False"
if not self.can_downscale:
fmt_str += ", can_downscale=False"
return fmt_str + ")"
def get_params(self, imsize):
w, h = imsize
if isinstance(self.size, int):
cmp = lambda a, b: (a >= b) if self.largest else (a <= b)
if (cmp(w, h) and w == self.size) or (cmp(h, w) and h == self.size):
ow, oh = w, h
elif cmp(w, h):
ow = self.size
oh = int(self.size * h / w)
else:
oh = self.size
ow = int(self.size * w / h)
else:
ow, oh = self.size
return ow, oh
def __call__(self, inp):
img = F.grab_img(inp)
w, h = img.size
size2 = ow, oh = self.get_params(img.size)
if size2 != img.size:
a1, a2 = img.size, size2
if (self.can_upscale and min(a1) < min(a2)) or (
self.can_downscale and min(a1) > min(a2)
):
img = img.resize(size2, self.interpolation)
return F.update_img_and_labels(
inp, img, persp=(ow / w, 0, 0, 0, oh / h, 0, 0, 0)
)
class RandomScale(Scale):
"""Rescale the input PIL.Image to a random size.
Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
Args:
min_size (int): min size of the smaller edge of the picture.
max_size (int): max size of the smaller edge of the picture.
ar (float or tuple):
max change of aspect ratio (width/height).
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
"""
def __init__(
self,
min_size,
max_size,
ar=1,
can_upscale=False,
can_downscale=True,
interpolation=Image.BILINEAR,
):
Scale.__init__(
self,
0,
can_upscale=can_upscale,
can_downscale=can_downscale,
interpolation=interpolation,
)
assert type(min_size) == type(
max_size
), "min_size and max_size can only be 2 ints or 2 floats"
assert (
isinstance(min_size, int)
and min_size >= 1
or isinstance(min_size, float)
and min_size > 0
)
assert isinstance(max_size, (int, float)) and min_size <= max_size
self.min_size = min_size
self.max_size = max_size
if type(ar) in (float, int):
ar = (min(1 / ar, ar), max(1 / ar, ar))
assert 0.2 < ar[0] <= ar[1] < 5
self.ar = ar
def get_params(self, imsize):
w, h = imsize
if isinstance(self.min_size, float):
min_size = int(self.min_size * min(w, h) + 0.5)
if isinstance(self.max_size, float):
max_size = int(self.max_size * min(w, h) + 0.5)
if isinstance(self.min_size, int):
min_size = self.min_size
if isinstance(self.max_size, int):
max_size = self.max_size
if not self.can_upscale:
max_size = min(max_size, min(w, h))
size = int(0.5 + F.rand_log_uniform(min_size, max_size))
ar = F.rand_log_uniform(*self.ar) # change of aspect ratio
if w < h: # image is taller
ow = size
oh = int(0.5 + size * h / w / ar)
if oh < min_size:
ow, oh = int(0.5 + ow * float(min_size) / oh), min_size
else: # image is wider
oh = size
ow = int(0.5 + size * w / h * ar)
if ow < min_size:
ow, oh = min_size, int(0.5 + oh * float(min_size) / ow)
assert ow >= min_size, "image too small (width=%d < min_size=%d)" % (
ow,
min_size,
)
assert oh >= min_size, "image too small (height=%d < min_size=%d)" % (
oh,
min_size,
)
return ow, oh
class RandomCrop(object):
"""Crop the given PIL Image at a random location.
Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
padding (int or sequence, optional): Optional padding on each border
of the image. Default is 0, i.e no padding. If a sequence of length
4 is provided, it is used to pad left, top, right, bottom borders
respectively.
"""
def __init__(self, size, padding=0):
if isinstance(size, int):
self.size = (int(size), int(size))
else:
self.size = size
self.padding = padding
def __repr__(self):
return "RandomCrop(%s)" % str(self.size)
@staticmethod
def get_params(img, output_size):
w, h = img.size
th, tw = output_size
assert h >= th and w >= tw, "Image of %dx%d is too small for crop %dx%d" % (
w,
h,
tw,
th,
)
y = np.random.randint(0, h - th) if h > th else 0
x = np.random.randint(0, w - tw) if w > tw else 0
return x, y, tw, th
def __call__(self, inp):
img = F.grab_img(inp)
padl = padt = 0
if self.padding:
if F.is_pil_image(img):
img = ImageOps.expand(img, border=self.padding, fill=0)
else:
assert isinstance(img, F.DummyImg)
img = img.expand(border=self.padding)
if isinstance(self.padding, int):
padl = padt = self.padding
else:
padl, padt = self.padding[0:2]
i, j, tw, th = self.get_params(img, self.size)
img = img.crop((i, j, i + tw, j + th))
return F.update_img_and_labels(
inp, img, persp=(1, 0, padl - i, 0, 1, padt - j, 0, 0)
)
class CenterCrop(RandomCrop):
"""Crops the given PIL Image at the center.
Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
"""
@staticmethod
def get_params(img, output_size):
w, h = img.size
th, tw = output_size
y = int(0.5 + ((h - th) / 2.0))
x = int(0.5 + ((w - tw) / 2.0))
return x, y, tw, th
class RandomRotation(object):
"""Rescale the input PIL.Image to a random size.
Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
Args:
degrees (float):
rotation angle.
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
"""
def __init__(self, degrees, interpolation=Image.BILINEAR):
self.degrees = degrees
self.interpolation = interpolation
def __call__(self, inp):
img = F.grab_img(inp)
w, h = img.size
angle = np.random.uniform(-self.degrees, self.degrees)
img = img.rotate(angle, resample=self.interpolation)
w2, h2 = img.size
trf = F.translate(-w / 2, -h / 2)
trf = F.persp_mul(trf, F.rotate(-angle * np.pi / 180))
trf = F.persp_mul(trf, F.translate(w2 / 2, h2 / 2))
return F.update_img_and_labels(inp, img, persp=trf)
class RandomTilting(object):
"""Apply a random tilting (left, right, up, down) to the input PIL.Image
Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
Args:
maginitude (float):
maximum magnitude of the random skew (value between 0 and 1)
directions (string):
tilting directions allowed (all, left, right, up, down)
examples: "all", "left,right", "up-down-right"
"""
def __init__(self, magnitude, directions="all"):
self.magnitude = magnitude
self.directions = directions.lower().replace(",", " ").replace("-", " ")
def __repr__(self):
return "RandomTilt(%g, '%s')" % (self.magnitude, self.directions)
def __call__(self, inp):
img = F.grab_img(inp)
w, h = img.size
x1, y1, x2, y2 = 0, 0, h, w
original_plane = [(y1, x1), (y2, x1), (y2, x2), (y1, x2)]
max_skew_amount = max(w, h)
max_skew_amount = int(ceil(max_skew_amount * self.magnitude))
skew_amount = random.randint(1, max_skew_amount)
if self.directions == "all":
choices = [0, 1, 2, 3]
else:
dirs = ["left", "right", "up", "down"]
choices = []
for d in self.directions.split():
try:
choices.append(dirs.index(d))
except:
raise ValueError("Tilting direction %s not recognized" % d)
skew_direction = random.choice(choices)
# print('randomtitlting: ', skew_amount, skew_direction) # to debug random
if skew_direction == 0:
# Left Tilt
new_plane = [
(y1, x1 - skew_amount), # Top Left
(y2, x1), # Top Right
(y2, x2), # Bottom Right
(y1, x2 + skew_amount),
] # Bottom Left
elif skew_direction == 1:
# Right Tilt
new_plane = [
(y1, x1), # Top Left
(y2, x1 - skew_amount), # Top Right
(y2, x2 + skew_amount), # Bottom Right
(y1, x2),
] # Bottom Left
elif skew_direction == 2:
# Forward Tilt
new_plane = [
(y1 - skew_amount, x1), # Top Left
(y2 + skew_amount, x1), # Top Right
(y2, x2), # Bottom Right
(y1, x2),
] # Bottom Left
elif skew_direction == 3:
# Backward Tilt
new_plane = [
(y1, x1), # Top Left
(y2, x1), # Top Right
(y2 + skew_amount, x2), # Bottom Right
(y1 - skew_amount, x2),
] # Bottom Left
# To calculate the coefficients required by PIL for the perspective skew,
# see the following Stack Overflow discussion: https://goo.gl/sSgJdj
matrix = []
for p1, p2 in zip(new_plane, original_plane):
matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
A = np.matrix(matrix, dtype=np.float)
B = np.array(original_plane).reshape(8)
homography = np.dot(np.linalg.pinv(A), B)
homography = tuple(np.array(homography).reshape(8))
# print(homography)
img = img.transform(
img.size, Image.PERSPECTIVE, homography, resample=Image.BICUBIC
)
homography = np.linalg.pinv(
np.float32(homography + (1,)).reshape(3, 3)
).ravel()[:8]
return F.update_img_and_labels(inp, img, persp=tuple(homography))
RandomTilt = RandomTilting # redefinition
class Tilt(object):
"""Apply a known tilting to an image"""
def __init__(self, *homography):
assert len(homography) == 8
self.homography = homography
def __call__(self, inp):
img = F.grab_img(inp)
homography = self.homography
# print(homography)
img = img.transform(
img.size, Image.PERSPECTIVE, homography, resample=Image.BICUBIC
)
homography = np.linalg.pinv(
np.float32(homography + (1,)).reshape(3, 3)
).ravel()[:8]
return F.update_img_and_labels(inp, img, persp=tuple(homography))
class StillTransform(object):
"""Takes and return an image, without changing its shape or geometry."""
def _transform(self, img):
raise NotImplementedError()
def __call__(self, inp):
img = F.grab_img(inp)
# transform the image (size should not change)
try:
img = self._transform(img)
except TypeError:
pass
return F.update_img_and_labels(inp, img, persp=(1, 0, 0, 0, 1, 0, 0, 0))
class PixelNoise(StillTransform):
"""Takes an image, and add random white noise."""
def __init__(self, ampl=20):
StillTransform.__init__(self)
assert 0 <= ampl < 255
self.ampl = ampl
def __repr__(self):
return "PixelNoise(%g)" % self.ampl
def _transform(self, img):
img = np.float32(img)
img += np.random.uniform(
0.5 - self.ampl / 2, 0.5 + self.ampl / 2, size=img.shape
)
return Image.fromarray(np.uint8(img.clip(0, 255)))
class ColorJitter(StillTransform):
"""Randomly change the brightness, contrast and saturation of an image.
Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
Args:
brightness (float): How much to jitter brightness. brightness_factor
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
contrast (float): How much to jitter contrast. contrast_factor
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
saturation (float): How much to jitter saturation. saturation_factor
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
[-hue, hue]. Should be >=0 and <= 0.5.
"""
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
def __repr__(self):
return "ColorJitter(%g,%g,%g,%g)" % (
self.brightness,
self.contrast,
self.saturation,
self.hue,
)
@staticmethod
def get_params(brightness, contrast, saturation, hue):
"""Get a randomized transform to be applied on image.
Arguments are same as that of __init__.
Returns:
Transform which randomly adjusts brightness, contrast and
saturation in a random order.
"""
transforms = []
if brightness > 0:
brightness_factor = np.random.uniform(
max(0, 1 - brightness), 1 + brightness
)
transforms.append(
tvf.Lambda(lambda img: F.adjust_brightness(img, brightness_factor))
)
if contrast > 0:
contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
transforms.append(
tvf.Lambda(lambda img: F.adjust_contrast(img, contrast_factor))
)
if saturation > 0:
saturation_factor = np.random.uniform(
max(0, 1 - saturation), 1 + saturation
)
transforms.append(
tvf.Lambda(lambda img: F.adjust_saturation(img, saturation_factor))
)
if hue > 0:
hue_factor = np.random.uniform(-hue, hue)
transforms.append(tvf.Lambda(lambda img: F.adjust_hue(img, hue_factor)))
# print('colorjitter: ', brightness_factor, contrast_factor, saturation_factor, hue_factor) # to debug random seed
np.random.shuffle(transforms)
transform = tvf.Compose(transforms)
return transform
def _transform(self, img):
transform = self.get_params(
self.brightness, self.contrast, self.saturation, self.hue
)
return transform(img)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Script to try out and visualize transformations")
parser.add_argument("--img", type=str, default="imgs/test.png", help="input image")
parser.add_argument(
"--trfs", type=str, required=True, help="list of transformations"
)
parser.add_argument(
"--layout", type=int, nargs=2, default=(3, 3), help="nb of rows,cols"
)
args = parser.parse_args()
import os
args.img = args.img.replace("$HERE", os.path.dirname(__file__))
img = Image.open(args.img)
img = dict(img=img)
trfs = instanciate_transformation(args.trfs)
from matplotlib import pyplot as pl
pl.ion()
pl.subplots_adjust(0, 0, 1, 1)
nr, nc = args.layout
while True:
for j in range(nr):
for i in range(nc):
pl.subplot(nr, nc, i + j * nc + 1)
if i == j == 0:
img2 = img
else:
img2 = trfs(img.copy())
if isinstance(img2, dict):
img2 = img2["img"]
pl.imshow(img2)
pl.xlabel("%d x %d" % img2.size)
pl.xticks(())
pl.yticks(())
pdb.set_trace()