Spaces:
Sleeping
Sleeping
import cv2 | |
import numpy as np | |
import torch | |
import gradio as gr | |
from segment_anything import sam_model_registry, SamPredictor | |
from PIL import Image | |
from sklearn.cluster import KMeans | |
# Load SAM model | |
sam_checkpoint = "sam_vit_h_4b8939.pth" | |
model_type = "vit_h" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) | |
sam.to(device=device) | |
predictor = SamPredictor(sam) | |
def extract_dominant_color(img): | |
pixels = img.reshape(-1, 3) | |
kmeans = KMeans(n_clusters=1, n_init=10) | |
kmeans.fit(pixels) | |
return kmeans.cluster_centers_[0] | |
def create_circular_swatch(color, size=200): | |
swatch = np.zeros((size, size, 3), dtype=np.uint8) | |
color_tuple = tuple(map(int, color)) | |
cv2.circle(swatch, (size//2, size//2), size//2, color_tuple, -1) | |
return swatch | |
def find_clean_area(img, mask, min_area_size=100*100): | |
# Apply mask to image | |
masked_img = cv2.bitwise_and(img, img, mask=mask) | |
gray = cv2.cvtColor(masked_img, cv2.COLOR_RGB2GRAY) | |
edges = cv2.Canny(gray, 30, 100) | |
kernel = np.ones((5,5), np.uint8) | |
dilated_edges = cv2.dilate(edges, kernel, iterations=1) | |
clean_mask = (dilated_edges == 0).astype(np.uint8) & mask | |
contours, _ = cv2.findContours(clean_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
valid_contours = [c for c in contours if cv2.contourArea(c) >= min_area_size] | |
if not valid_contours: | |
return None | |
largest_contour = max(valid_contours, key=cv2.contourArea) | |
x, y, w, h = cv2.boundingRect(largest_contour) | |
return (x, y, w, h) | |
def process_image(image, input_point): | |
predictor.set_image(image) | |
input_label = np.array([1]) | |
masks, _, _ = predictor.predict( | |
point_coords=input_point, | |
point_labels=input_label, | |
multimask_output=False, | |
) | |
mask = masks[0] | |
# Extract dominant color and create swatch | |
masked_image = cv2.bitwise_and(image, image, mask=mask.astype(np.uint8)) | |
dominant_color = extract_dominant_color(masked_image) | |
swatch = create_circular_swatch(dominant_color) | |
# Find clean area and crop | |
clean_area = find_clean_area(image, mask.astype(np.uint8)) | |
if clean_area is not None: | |
x, y, w, h = clean_area | |
detailed_fabric = image[y:y+h, x:x+w] | |
detailed_fabric = cv2.resize(detailed_fabric, (400, 600)) | |
else: | |
detailed_fabric = np.zeros((400, 600, 3), dtype=np.uint8) | |
# Visualize mask on image | |
visualization = image.copy() | |
visualization[mask] = visualization[mask] * 0.5 + np.array([0, 0, 255]) * 0.5 | |
return visualization, swatch, detailed_fabric | |
def gradio_interface(input_image, click_x, click_y): | |
input_point = np.array([[click_x, click_y]]) | |
visualization, swatch, detailed_fabric = process_image(input_image, input_point) | |
return visualization, swatch, detailed_fabric | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Image(type="numpy"), | |
gr.Slider(0, 1000, label="Click X"), | |
gr.Slider(0, 1000, label="Click Y") | |
], | |
outputs=[ | |
gr.Image(type="numpy", label="Segmentation"), | |
gr.Image(type="numpy", label="Color Swatch"), | |
gr.Image(type="numpy", label="Detailed Fabric") | |
], | |
title="Fabric Analyzer", | |
description="Upload an image or choose from examples. Click on the garment to analyze fabric and color.", | |
examples=[ | |
["blue_shirt.png", 400, 500], | |
["polo.png", 400, 500], | |
["dress.jpg", 400, 500] | |
] | |
) | |
if __name__ == "__main__": | |
iface.launch() |