swatchme / app.py
peterhartwigCF's picture
Upload 6 files
5c26863 verified
raw
history blame contribute delete
No virus
3.61 kB
import cv2
import numpy as np
import torch
import gradio as gr
from segment_anything import sam_model_registry, SamPredictor
from PIL import Image
from sklearn.cluster import KMeans
# Load SAM model
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
def extract_dominant_color(img):
pixels = img.reshape(-1, 3)
kmeans = KMeans(n_clusters=1, n_init=10)
kmeans.fit(pixels)
return kmeans.cluster_centers_[0]
def create_circular_swatch(color, size=200):
swatch = np.zeros((size, size, 3), dtype=np.uint8)
color_tuple = tuple(map(int, color))
cv2.circle(swatch, (size//2, size//2), size//2, color_tuple, -1)
return swatch
def find_clean_area(img, mask, min_area_size=100*100):
# Apply mask to image
masked_img = cv2.bitwise_and(img, img, mask=mask)
gray = cv2.cvtColor(masked_img, cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 30, 100)
kernel = np.ones((5,5), np.uint8)
dilated_edges = cv2.dilate(edges, kernel, iterations=1)
clean_mask = (dilated_edges == 0).astype(np.uint8) & mask
contours, _ = cv2.findContours(clean_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
valid_contours = [c for c in contours if cv2.contourArea(c) >= min_area_size]
if not valid_contours:
return None
largest_contour = max(valid_contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(largest_contour)
return (x, y, w, h)
def process_image(image, input_point):
predictor.set_image(image)
input_label = np.array([1])
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)
mask = masks[0]
# Extract dominant color and create swatch
masked_image = cv2.bitwise_and(image, image, mask=mask.astype(np.uint8))
dominant_color = extract_dominant_color(masked_image)
swatch = create_circular_swatch(dominant_color)
# Find clean area and crop
clean_area = find_clean_area(image, mask.astype(np.uint8))
if clean_area is not None:
x, y, w, h = clean_area
detailed_fabric = image[y:y+h, x:x+w]
detailed_fabric = cv2.resize(detailed_fabric, (400, 600))
else:
detailed_fabric = np.zeros((400, 600, 3), dtype=np.uint8)
# Visualize mask on image
visualization = image.copy()
visualization[mask] = visualization[mask] * 0.5 + np.array([0, 0, 255]) * 0.5
return visualization, swatch, detailed_fabric
def gradio_interface(input_image, click_x, click_y):
input_point = np.array([[click_x, click_y]])
visualization, swatch, detailed_fabric = process_image(input_image, input_point)
return visualization, swatch, detailed_fabric
# Create Gradio interface
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Image(type="numpy"),
gr.Slider(0, 1000, label="Click X"),
gr.Slider(0, 1000, label="Click Y")
],
outputs=[
gr.Image(type="numpy", label="Segmentation"),
gr.Image(type="numpy", label="Color Swatch"),
gr.Image(type="numpy", label="Detailed Fabric")
],
title="Fabric Analyzer",
description="Upload an image or choose from examples. Click on the garment to analyze fabric and color.",
examples=[
["blue_shirt.png", 400, 500],
["polo.png", 400, 500],
["dress.jpg", 400, 500]
]
)
if __name__ == "__main__":
iface.launch()