File size: 3,613 Bytes
5c26863
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import cv2
import numpy as np
import torch
import gradio as gr
from segment_anything import sam_model_registry, SamPredictor
from PIL import Image
from sklearn.cluster import KMeans

# Load SAM model
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)

def extract_dominant_color(img):
    pixels = img.reshape(-1, 3)
    kmeans = KMeans(n_clusters=1, n_init=10)
    kmeans.fit(pixels)
    return kmeans.cluster_centers_[0]

def create_circular_swatch(color, size=200):
    swatch = np.zeros((size, size, 3), dtype=np.uint8)
    color_tuple = tuple(map(int, color))
    cv2.circle(swatch, (size//2, size//2), size//2, color_tuple, -1)
    return swatch

def find_clean_area(img, mask, min_area_size=100*100):
    # Apply mask to image
    masked_img = cv2.bitwise_and(img, img, mask=mask)
    
    gray = cv2.cvtColor(masked_img, cv2.COLOR_RGB2GRAY)
    edges = cv2.Canny(gray, 30, 100)
    
    kernel = np.ones((5,5), np.uint8)
    dilated_edges = cv2.dilate(edges, kernel, iterations=1)
    
    clean_mask = (dilated_edges == 0).astype(np.uint8) & mask
    
    contours, _ = cv2.findContours(clean_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    valid_contours = [c for c in contours if cv2.contourArea(c) >= min_area_size]
    
    if not valid_contours:
        return None
    
    largest_contour = max(valid_contours, key=cv2.contourArea)
    x, y, w, h = cv2.boundingRect(largest_contour)
    
    return (x, y, w, h)

def process_image(image, input_point):
    predictor.set_image(image)
    input_label = np.array([1])
    masks, _, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=False,
    )
    mask = masks[0]

    # Extract dominant color and create swatch
    masked_image = cv2.bitwise_and(image, image, mask=mask.astype(np.uint8))
    dominant_color = extract_dominant_color(masked_image)
    swatch = create_circular_swatch(dominant_color)

    # Find clean area and crop
    clean_area = find_clean_area(image, mask.astype(np.uint8))
    if clean_area is not None:
        x, y, w, h = clean_area
        detailed_fabric = image[y:y+h, x:x+w]
        detailed_fabric = cv2.resize(detailed_fabric, (400, 600))
    else:
        detailed_fabric = np.zeros((400, 600, 3), dtype=np.uint8)

    # Visualize mask on image
    visualization = image.copy()
    visualization[mask] = visualization[mask] * 0.5 + np.array([0, 0, 255]) * 0.5

    return visualization, swatch, detailed_fabric

def gradio_interface(input_image, click_x, click_y):
    input_point = np.array([[click_x, click_y]])
    visualization, swatch, detailed_fabric = process_image(input_image, input_point)
    return visualization, swatch, detailed_fabric

# Create Gradio interface
iface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Image(type="numpy"),
        gr.Slider(0, 1000, label="Click X"),
        gr.Slider(0, 1000, label="Click Y")
    ],
    outputs=[
        gr.Image(type="numpy", label="Segmentation"),
        gr.Image(type="numpy", label="Color Swatch"),
        gr.Image(type="numpy", label="Detailed Fabric")
    ],
    title="Fabric Analyzer",
    description="Upload an image or choose from examples. Click on the garment to analyze fabric and color.",
    examples=[
        ["blue_shirt.png", 400, 500],
        ["polo.png", 400, 500],
        ["dress.jpg", 400, 500]
    ]
)

if __name__ == "__main__":
    iface.launch()