File size: 4,957 Bytes
07d760c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch
import torch.nn.functional as F

def crop(image, i, j, h, w):
    """
    Args:
        image (torch.tensor): Image to be cropped. Size is (C, H, W)
    """
    if len(image.size()) != 3:
        raise ValueError("image should be a 3D tensor")
    return image[..., i : i + h, j : j + w]

def resize(image, target_size, interpolation_mode):
    if len(target_size) != 2:
        raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
    return F.interpolate(image.unsqueeze(0), size=target_size, mode=interpolation_mode, align_corners=False).squeeze(0)

def resize_scale(image, target_size, interpolation_mode):
    if len(target_size) != 2:
        raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
    H, W = image.size(-2), image.size(-1)
    scale_ = target_size[0] / min(H, W)
    return F.interpolate(image.unsqueeze(0), scale_factor=scale_, mode=interpolation_mode, align_corners=False).squeeze(0)

def resized_crop(image, i, j, h, w, size, interpolation_mode="bilinear"):
    """
    Do spatial cropping and resizing to the image
    Args:
        image (torch.tensor): Image to be cropped. Size is (C, H, W)
        i (int): i in (i,j) i.e coordinates of the upper left corner.
        j (int): j in (i,j) i.e coordinates of the upper left corner.
        h (int): Height of the cropped region.
        w (int): Width of the cropped region.
        size (tuple(int, int)): height and width of resized image
    Returns:
        image (torch.tensor): Resized and cropped image. Size is (C, H, W)
    """
    if len(image.size()) != 3:
        raise ValueError("image should be a 3D torch.tensor")
    image = crop(image, i, j, h, w)
    image = resize(image, size, interpolation_mode)
    return image

def center_crop(image, crop_size):
    if len(image.size()) != 3:
        raise ValueError("image should be a 3D torch.tensor")
    h, w = image.size(-2), image.size(-1)
    th, tw = crop_size
    if h < th or w < tw:
        raise ValueError("height and width must be no smaller than crop_size")
    i = int(round((h - th) / 2.0))
    j = int(round((w - tw) / 2.0))
    return crop(image, i, j, th, tw)

def center_crop_using_short_edge(image):
    if len(image.size()) != 3:
        raise ValueError("image should be a 3D torch.tensor")
    h, w = image.size(-2), image.size(-1)
    if h < w:
        th, tw = h, h
        i = 0
        j = int(round((w - tw) / 2.0))
    else:
        th, tw = w, w
        i = int(round((h - th) / 2.0))
        j = 0
    return crop(image, i, j, th, tw)

class CenterCropResizeImage:
    """
    Resize the image while maintaining aspect ratio, and then crop it to the desired size.
    The resizing is done such that the area of padding/cropping is minimized.
    """
    def __init__(self, size, interpolation_mode="bilinear"):
        if isinstance(size, tuple):
            if len(size) != 2:
                raise ValueError(f"Size should be a tuple (height, width), instead got {size}")
            self.size = size
        else:
            self.size = (size, size)
        self.interpolation_mode = interpolation_mode

    def __call__(self, image):
        """
        Args:
            image (torch.Tensor): Image to be resized and cropped. Size is (C, H, W)
        
        Returns:
            torch.Tensor: Resized and cropped image. Size is (C, target_height, target_width)
        """
        target_height, target_width = self.size
        target_aspect = target_width / target_height

        # Get current image shape and aspect ratio
        _, height, width = image.shape
        height, width = float(height), float(width)
        current_aspect = width / height

        # Calculate crop dimensions
        if current_aspect > target_aspect:
            # Image is wider than target, crop width
            crop_height = height
            crop_width = height * target_aspect
        else:
            # Image is taller than target, crop height
            crop_height = width / target_aspect
            crop_width = width

        # Calculate crop coordinates (center crop)
        y1 = (height - crop_height) / 2
        x1 = (width - crop_width) / 2

        # Perform the crop
        cropped_image = crop(image, int(y1), int(x1), int(crop_height), int(crop_width))

        # Resize the cropped image to the target size
        resized_image = resize(cropped_image, self.size, self.interpolation_mode)

        return resized_image
    
# Example usage
if __name__ == "__main__":
    # Create a sample image tensor
    sample_image = torch.rand(3, 480, 640)  # (C, H, W)

    # Initialize the transform
    transform = CenterCropResizeImage(size=(224, 224), interpolation_mode="bilinear")

    # Apply the transform
    transformed_image = transform(sample_image)

    print(f"Original image shape: {sample_image.shape}")
    print(f"Transformed image shape: {transformed_image.shape}")