Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,957 Bytes
07d760c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import torch
import torch.nn.functional as F
def crop(image, i, j, h, w):
"""
Args:
image (torch.tensor): Image to be cropped. Size is (C, H, W)
"""
if len(image.size()) != 3:
raise ValueError("image should be a 3D tensor")
return image[..., i : i + h, j : j + w]
def resize(image, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
return F.interpolate(image.unsqueeze(0), size=target_size, mode=interpolation_mode, align_corners=False).squeeze(0)
def resize_scale(image, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
H, W = image.size(-2), image.size(-1)
scale_ = target_size[0] / min(H, W)
return F.interpolate(image.unsqueeze(0), scale_factor=scale_, mode=interpolation_mode, align_corners=False).squeeze(0)
def resized_crop(image, i, j, h, w, size, interpolation_mode="bilinear"):
"""
Do spatial cropping and resizing to the image
Args:
image (torch.tensor): Image to be cropped. Size is (C, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized image
Returns:
image (torch.tensor): Resized and cropped image. Size is (C, H, W)
"""
if len(image.size()) != 3:
raise ValueError("image should be a 3D torch.tensor")
image = crop(image, i, j, h, w)
image = resize(image, size, interpolation_mode)
return image
def center_crop(image, crop_size):
if len(image.size()) != 3:
raise ValueError("image should be a 3D torch.tensor")
h, w = image.size(-2), image.size(-1)
th, tw = crop_size
if h < th or w < tw:
raise ValueError("height and width must be no smaller than crop_size")
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(image, i, j, th, tw)
def center_crop_using_short_edge(image):
if len(image.size()) != 3:
raise ValueError("image should be a 3D torch.tensor")
h, w = image.size(-2), image.size(-1)
if h < w:
th, tw = h, h
i = 0
j = int(round((w - tw) / 2.0))
else:
th, tw = w, w
i = int(round((h - th) / 2.0))
j = 0
return crop(image, i, j, th, tw)
class CenterCropResizeImage:
"""
Resize the image while maintaining aspect ratio, and then crop it to the desired size.
The resizing is done such that the area of padding/cropping is minimized.
"""
def __init__(self, size, interpolation_mode="bilinear"):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"Size should be a tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, image):
"""
Args:
image (torch.Tensor): Image to be resized and cropped. Size is (C, H, W)
Returns:
torch.Tensor: Resized and cropped image. Size is (C, target_height, target_width)
"""
target_height, target_width = self.size
target_aspect = target_width / target_height
# Get current image shape and aspect ratio
_, height, width = image.shape
height, width = float(height), float(width)
current_aspect = width / height
# Calculate crop dimensions
if current_aspect > target_aspect:
# Image is wider than target, crop width
crop_height = height
crop_width = height * target_aspect
else:
# Image is taller than target, crop height
crop_height = width / target_aspect
crop_width = width
# Calculate crop coordinates (center crop)
y1 = (height - crop_height) / 2
x1 = (width - crop_width) / 2
# Perform the crop
cropped_image = crop(image, int(y1), int(x1), int(crop_height), int(crop_width))
# Resize the cropped image to the target size
resized_image = resize(cropped_image, self.size, self.interpolation_mode)
return resized_image
# Example usage
if __name__ == "__main__":
# Create a sample image tensor
sample_image = torch.rand(3, 480, 640) # (C, H, W)
# Initialize the transform
transform = CenterCropResizeImage(size=(224, 224), interpolation_mode="bilinear")
# Apply the transform
transformed_image = transform(sample_image)
print(f"Original image shape: {sample_image.shape}")
print(f"Transformed image shape: {transformed_image.shape}") |