File size: 6,140 Bytes
b45ac3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
from PIL import Image 
from cv2 import imread, cvtColor, COLOR_BGR2GRAY, COLOR_BGR2BGRA, COLOR_BGRA2RGB, threshold, THRESH_BINARY_INV, findContours, RETR_EXTERNAL, CHAIN_APPROX_SIMPLE, contourArea, minEnclosingCircle
import numpy as np
import torch
import matplotlib.pyplot as plt

def convert_images_to_grayscale(folder_path):
    # Check if the folder exists
    if not os.path.isdir(folder_path):
        print(f"The folder path {folder_path} does not exist.")
        return
    
    # Iterate over all files in the folder
    for filename in os.listdir(folder_path):
        if filename.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
            image_path = os.path.join(folder_path, filename)
            
            # Open an image file
            with Image.open(image_path) as img:
                # Convert image to grayscale
                grayscale_img = img.convert('L').convert('RGB')
                grayscale_img.save(os.path.join(folder_path, filename))

def crop_center_largest_contour(folder_path):
    for each_image in os.listdir(folder_path):
        image_path = os.path.join(folder_path, each_image)
        image = imread(image_path)
        gray_image = cvtColor(image, COLOR_BGR2GRAY)

        # Threshold the image to get the non-white pixels
        _, binary_mask = threshold(gray_image, 254, 255, THRESH_BINARY_INV)

        # Find the largest contour
        contours, _ = findContours(binary_mask, RETR_EXTERNAL, CHAIN_APPROX_SIMPLE)
        largest_contour = max(contours, key=contourArea)

        # Get the minimum enclosing circle
        (x, y), radius = minEnclosingCircle(largest_contour)
        center = (int(x), int(y))
        radius = int(radius/3) # Divide by three (arbitrary) to make shape better

        # Crop the image to the bounding box of the circle
        x_min = max(0, center[0] - radius)
        x_max = min(image.shape[1], center[0] + radius)
        y_min = max(0, center[1] - radius)
        y_max = min(image.shape[0], center[1] + radius)
        cropped_image = image[y_min:y_max, x_min:x_max]
        cropped_image_rgba = cvtColor(cropped_image, COLOR_BGR2BGRA)
        cropped_pil_image = Image.fromarray(cvtColor(cropped_image_rgba, COLOR_BGRA2RGB))
        cropped_pil_image.save(image_path)

def calculate_variance(patch):
    # Convert patch to numpy array
    patch_array = np.array(patch)
    # Calculate the variance
    variance = np.var(patch_array)
    return variance

def crop_least_variant_patch(folder_path):
    for each_image in os.listdir(folder_path):
        image_path = os.path.join(folder_path, each_image)
        image = Image.open(image_path)
        # define window size
        width, height = image.size
        window_size = round(height * .2)
        stride = round(window_size * .2)
        min_variance = float('inf')
        best_patch = None
        # slide window across image
        for x in range(0, width - window_size + 1, stride):
            for y in range(0, height - window_size + 1, stride):
                patch = image.crop((x,y,x + window_size, y + window_size))
                patch_w, patch_h = patch.size
                total_pixels = patch_w * patch_h
                white_pixels = np.sum(np.all(np.array(patch) == [255, 255, 255], axis=2))
                if white_pixels < (total_pixels / 2):
                    # calculate variance / standard deviation
                    variance = calculate_variance(patch)
                    if variance < min_variance:
                        # update minimum var / sd
                        min_variance = variance
                        best_patch = patch
        try:
            best_patch.save(image_path)
        except AttributeError as e:
            print("No good homogenous patch to save.")

def extract_embeddings(transformation_chain, model: torch.nn.Module):
    """Utility to compute embeddings."""
    device = model.device

    def pp(batch):
        images = batch["image"]
        image_batch_transformed = torch.stack(
            [transformation_chain(image) for image in images]
        )
        new_batch = {"pixel_values": image_batch_transformed.to(device)}
        with torch.no_grad():
            embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
        return {"embeddings": embeddings}

    return pp

def compute_scores(emb_one, emb_two):
    """Computes cosine similarity between two vectors."""
    scores = torch.nn.functional.cosine_similarity(emb_one, emb_two)
    return scores.numpy().tolist()


def fetch_similar(image, transformation_chain, device, model, all_candidate_embeddings, candidate_ids, top_k=3):
    """Fetches the `top_k` similar images with `image` as the query."""
    # Prepare the input query image for embedding computation.
    image_transformed = transformation_chain(image).unsqueeze(0)
    new_batch = {"pixel_values": image_transformed.to(device)}

    # Compute the embedding.
    with torch.no_grad():
        query_embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()

    # Compute similarity scores with all the candidate images at one go.
    # We also create a mapping between the candidate image identifiers
    # and their similarity scores with the query image.
    sim_scores = compute_scores(all_candidate_embeddings, query_embeddings)
    similarity_mapping = dict(zip(candidate_ids, sim_scores))
 
    # Sort the mapping dictionary and return `top_k` candidates.
    similarity_mapping_sorted = dict(
        sorted(similarity_mapping.items(), key=lambda x: x[1], reverse=True)
    )
    id_entries = list(similarity_mapping_sorted.keys())[:top_k]

    ids = list(map(lambda x: int(x.split("_")[0]), id_entries))
    return ids 

def plot_images(images):

    plt.figure(figsize=(20, 10))
    columns = 6
    for (i, image) in enumerate(images):
        ax = plt.subplot(int(len(images) / columns + 1), columns, i + 1)
        if i == 0:
            ax.set_title("Query Image\n")
        else:
            ax.set_title(
                "Similar Image # " + str(i) 
            )
        plt.imshow(np.array(image).astype("int"))
        plt.axis("off")