import cv2 import numpy as np import torch import gradio as gr from segment_anything import sam_model_registry, SamPredictor from PIL import Image from sklearn.cluster import KMeans # Load SAM model sam_checkpoint = "sam_vit_h_4b8939.pth" model_type = "vit_h" device = "cuda" if torch.cuda.is_available() else "cpu" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) predictor = SamPredictor(sam) def extract_dominant_color(img): pixels = img.reshape(-1, 3) kmeans = KMeans(n_clusters=1, n_init=10) kmeans.fit(pixels) return kmeans.cluster_centers_[0] def create_circular_swatch(color, size=200): swatch = np.zeros((size, size, 3), dtype=np.uint8) color_tuple = tuple(map(int, color)) cv2.circle(swatch, (size//2, size//2), size//2, color_tuple, -1) return swatch def find_clean_area(img, mask, min_area_size=100*100): # Apply mask to image masked_img = cv2.bitwise_and(img, img, mask=mask) gray = cv2.cvtColor(masked_img, cv2.COLOR_RGB2GRAY) edges = cv2.Canny(gray, 30, 100) kernel = np.ones((5,5), np.uint8) dilated_edges = cv2.dilate(edges, kernel, iterations=1) clean_mask = (dilated_edges == 0).astype(np.uint8) & mask contours, _ = cv2.findContours(clean_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) valid_contours = [c for c in contours if cv2.contourArea(c) >= min_area_size] if not valid_contours: return None largest_contour = max(valid_contours, key=cv2.contourArea) x, y, w, h = cv2.boundingRect(largest_contour) return (x, y, w, h) def process_image(image, input_point): predictor.set_image(image) input_label = np.array([1]) masks, _, _ = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=False, ) mask = masks[0] # Extract dominant color and create swatch masked_image = cv2.bitwise_and(image, image, mask=mask.astype(np.uint8)) dominant_color = extract_dominant_color(masked_image) swatch = create_circular_swatch(dominant_color) # Find clean area and crop clean_area = find_clean_area(image, mask.astype(np.uint8)) if clean_area is not None: x, y, w, h = clean_area detailed_fabric = image[y:y+h, x:x+w] detailed_fabric = cv2.resize(detailed_fabric, (400, 600)) else: detailed_fabric = np.zeros((400, 600, 3), dtype=np.uint8) # Visualize mask on image visualization = image.copy() visualization[mask] = visualization[mask] * 0.5 + np.array([0, 0, 255]) * 0.5 return visualization, swatch, detailed_fabric def gradio_interface(input_image, click_x, click_y): input_point = np.array([[click_x, click_y]]) visualization, swatch, detailed_fabric = process_image(input_image, input_point) return visualization, swatch, detailed_fabric # Create Gradio interface iface = gr.Interface( fn=gradio_interface, inputs=[ gr.Image(type="numpy"), gr.Slider(0, 1000, label="Click X"), gr.Slider(0, 1000, label="Click Y") ], outputs=[ gr.Image(type="numpy", label="Segmentation"), gr.Image(type="numpy", label="Color Swatch"), gr.Image(type="numpy", label="Detailed Fabric") ], title="Fabric Analyzer", description="Upload an image or choose from examples. Click on the garment to analyze fabric and color.", examples=[ ["blue_shirt.png", 400, 500], ["polo.png", 400, 500], ["dress.jpg", 400, 500] ] ) if __name__ == "__main__": iface.launch()