Spaces:
Sleeping
Sleeping
peterhartwigCF
commited on
Commit
•
5c26863
1
Parent(s):
5a159c7
Upload 6 files
Browse files- .gitattributes +1 -0
- app.py +111 -0
- blue_shirt.png +3 -0
- dress.jpg +0 -0
- polo.png +0 -0
- requirements.txt +6 -0
- sam_vit_h_4b8939.pth +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
blue_shirt.png filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import gradio as gr
|
5 |
+
from segment_anything import sam_model_registry, SamPredictor
|
6 |
+
from PIL import Image
|
7 |
+
from sklearn.cluster import KMeans
|
8 |
+
|
9 |
+
# Load SAM model
|
10 |
+
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
11 |
+
model_type = "vit_h"
|
12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
14 |
+
sam.to(device=device)
|
15 |
+
predictor = SamPredictor(sam)
|
16 |
+
|
17 |
+
def extract_dominant_color(img):
|
18 |
+
pixels = img.reshape(-1, 3)
|
19 |
+
kmeans = KMeans(n_clusters=1, n_init=10)
|
20 |
+
kmeans.fit(pixels)
|
21 |
+
return kmeans.cluster_centers_[0]
|
22 |
+
|
23 |
+
def create_circular_swatch(color, size=200):
|
24 |
+
swatch = np.zeros((size, size, 3), dtype=np.uint8)
|
25 |
+
color_tuple = tuple(map(int, color))
|
26 |
+
cv2.circle(swatch, (size//2, size//2), size//2, color_tuple, -1)
|
27 |
+
return swatch
|
28 |
+
|
29 |
+
def find_clean_area(img, mask, min_area_size=100*100):
|
30 |
+
# Apply mask to image
|
31 |
+
masked_img = cv2.bitwise_and(img, img, mask=mask)
|
32 |
+
|
33 |
+
gray = cv2.cvtColor(masked_img, cv2.COLOR_RGB2GRAY)
|
34 |
+
edges = cv2.Canny(gray, 30, 100)
|
35 |
+
|
36 |
+
kernel = np.ones((5,5), np.uint8)
|
37 |
+
dilated_edges = cv2.dilate(edges, kernel, iterations=1)
|
38 |
+
|
39 |
+
clean_mask = (dilated_edges == 0).astype(np.uint8) & mask
|
40 |
+
|
41 |
+
contours, _ = cv2.findContours(clean_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
42 |
+
|
43 |
+
valid_contours = [c for c in contours if cv2.contourArea(c) >= min_area_size]
|
44 |
+
|
45 |
+
if not valid_contours:
|
46 |
+
return None
|
47 |
+
|
48 |
+
largest_contour = max(valid_contours, key=cv2.contourArea)
|
49 |
+
x, y, w, h = cv2.boundingRect(largest_contour)
|
50 |
+
|
51 |
+
return (x, y, w, h)
|
52 |
+
|
53 |
+
def process_image(image, input_point):
|
54 |
+
predictor.set_image(image)
|
55 |
+
input_label = np.array([1])
|
56 |
+
masks, _, _ = predictor.predict(
|
57 |
+
point_coords=input_point,
|
58 |
+
point_labels=input_label,
|
59 |
+
multimask_output=False,
|
60 |
+
)
|
61 |
+
mask = masks[0]
|
62 |
+
|
63 |
+
# Extract dominant color and create swatch
|
64 |
+
masked_image = cv2.bitwise_and(image, image, mask=mask.astype(np.uint8))
|
65 |
+
dominant_color = extract_dominant_color(masked_image)
|
66 |
+
swatch = create_circular_swatch(dominant_color)
|
67 |
+
|
68 |
+
# Find clean area and crop
|
69 |
+
clean_area = find_clean_area(image, mask.astype(np.uint8))
|
70 |
+
if clean_area is not None:
|
71 |
+
x, y, w, h = clean_area
|
72 |
+
detailed_fabric = image[y:y+h, x:x+w]
|
73 |
+
detailed_fabric = cv2.resize(detailed_fabric, (400, 600))
|
74 |
+
else:
|
75 |
+
detailed_fabric = np.zeros((400, 600, 3), dtype=np.uint8)
|
76 |
+
|
77 |
+
# Visualize mask on image
|
78 |
+
visualization = image.copy()
|
79 |
+
visualization[mask] = visualization[mask] * 0.5 + np.array([0, 0, 255]) * 0.5
|
80 |
+
|
81 |
+
return visualization, swatch, detailed_fabric
|
82 |
+
|
83 |
+
def gradio_interface(input_image, click_x, click_y):
|
84 |
+
input_point = np.array([[click_x, click_y]])
|
85 |
+
visualization, swatch, detailed_fabric = process_image(input_image, input_point)
|
86 |
+
return visualization, swatch, detailed_fabric
|
87 |
+
|
88 |
+
# Create Gradio interface
|
89 |
+
iface = gr.Interface(
|
90 |
+
fn=gradio_interface,
|
91 |
+
inputs=[
|
92 |
+
gr.Image(type="numpy"),
|
93 |
+
gr.Slider(0, 1000, label="Click X"),
|
94 |
+
gr.Slider(0, 1000, label="Click Y")
|
95 |
+
],
|
96 |
+
outputs=[
|
97 |
+
gr.Image(type="numpy", label="Segmentation"),
|
98 |
+
gr.Image(type="numpy", label="Color Swatch"),
|
99 |
+
gr.Image(type="numpy", label="Detailed Fabric")
|
100 |
+
],
|
101 |
+
title="Fabric Analyzer",
|
102 |
+
description="Upload an image or choose from examples. Click on the garment to analyze fabric and color.",
|
103 |
+
examples=[
|
104 |
+
["blue_shirt.png", 400, 500],
|
105 |
+
["polo.png", 400, 500],
|
106 |
+
["dress.jpg", 400, 500]
|
107 |
+
]
|
108 |
+
)
|
109 |
+
|
110 |
+
if __name__ == "__main__":
|
111 |
+
iface.launch()
|
blue_shirt.png
ADDED
Git LFS Details
|
dress.jpg
ADDED
polo.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
opencv-python-headless
|
4 |
+
scikit-learn
|
5 |
+
gradio
|
6 |
+
segment-anything
|
sam_vit_h_4b8939.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
|
3 |
+
size 2564550879
|