peterhartwigCF commited on
Commit
5c26863
1 Parent(s): 5a159c7

Upload 6 files

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. app.py +111 -0
  3. blue_shirt.png +3 -0
  4. dress.jpg +0 -0
  5. polo.png +0 -0
  6. requirements.txt +6 -0
  7. sam_vit_h_4b8939.pth +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ blue_shirt.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ import gradio as gr
5
+ from segment_anything import sam_model_registry, SamPredictor
6
+ from PIL import Image
7
+ from sklearn.cluster import KMeans
8
+
9
+ # Load SAM model
10
+ sam_checkpoint = "sam_vit_h_4b8939.pth"
11
+ model_type = "vit_h"
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
14
+ sam.to(device=device)
15
+ predictor = SamPredictor(sam)
16
+
17
+ def extract_dominant_color(img):
18
+ pixels = img.reshape(-1, 3)
19
+ kmeans = KMeans(n_clusters=1, n_init=10)
20
+ kmeans.fit(pixels)
21
+ return kmeans.cluster_centers_[0]
22
+
23
+ def create_circular_swatch(color, size=200):
24
+ swatch = np.zeros((size, size, 3), dtype=np.uint8)
25
+ color_tuple = tuple(map(int, color))
26
+ cv2.circle(swatch, (size//2, size//2), size//2, color_tuple, -1)
27
+ return swatch
28
+
29
+ def find_clean_area(img, mask, min_area_size=100*100):
30
+ # Apply mask to image
31
+ masked_img = cv2.bitwise_and(img, img, mask=mask)
32
+
33
+ gray = cv2.cvtColor(masked_img, cv2.COLOR_RGB2GRAY)
34
+ edges = cv2.Canny(gray, 30, 100)
35
+
36
+ kernel = np.ones((5,5), np.uint8)
37
+ dilated_edges = cv2.dilate(edges, kernel, iterations=1)
38
+
39
+ clean_mask = (dilated_edges == 0).astype(np.uint8) & mask
40
+
41
+ contours, _ = cv2.findContours(clean_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
42
+
43
+ valid_contours = [c for c in contours if cv2.contourArea(c) >= min_area_size]
44
+
45
+ if not valid_contours:
46
+ return None
47
+
48
+ largest_contour = max(valid_contours, key=cv2.contourArea)
49
+ x, y, w, h = cv2.boundingRect(largest_contour)
50
+
51
+ return (x, y, w, h)
52
+
53
+ def process_image(image, input_point):
54
+ predictor.set_image(image)
55
+ input_label = np.array([1])
56
+ masks, _, _ = predictor.predict(
57
+ point_coords=input_point,
58
+ point_labels=input_label,
59
+ multimask_output=False,
60
+ )
61
+ mask = masks[0]
62
+
63
+ # Extract dominant color and create swatch
64
+ masked_image = cv2.bitwise_and(image, image, mask=mask.astype(np.uint8))
65
+ dominant_color = extract_dominant_color(masked_image)
66
+ swatch = create_circular_swatch(dominant_color)
67
+
68
+ # Find clean area and crop
69
+ clean_area = find_clean_area(image, mask.astype(np.uint8))
70
+ if clean_area is not None:
71
+ x, y, w, h = clean_area
72
+ detailed_fabric = image[y:y+h, x:x+w]
73
+ detailed_fabric = cv2.resize(detailed_fabric, (400, 600))
74
+ else:
75
+ detailed_fabric = np.zeros((400, 600, 3), dtype=np.uint8)
76
+
77
+ # Visualize mask on image
78
+ visualization = image.copy()
79
+ visualization[mask] = visualization[mask] * 0.5 + np.array([0, 0, 255]) * 0.5
80
+
81
+ return visualization, swatch, detailed_fabric
82
+
83
+ def gradio_interface(input_image, click_x, click_y):
84
+ input_point = np.array([[click_x, click_y]])
85
+ visualization, swatch, detailed_fabric = process_image(input_image, input_point)
86
+ return visualization, swatch, detailed_fabric
87
+
88
+ # Create Gradio interface
89
+ iface = gr.Interface(
90
+ fn=gradio_interface,
91
+ inputs=[
92
+ gr.Image(type="numpy"),
93
+ gr.Slider(0, 1000, label="Click X"),
94
+ gr.Slider(0, 1000, label="Click Y")
95
+ ],
96
+ outputs=[
97
+ gr.Image(type="numpy", label="Segmentation"),
98
+ gr.Image(type="numpy", label="Color Swatch"),
99
+ gr.Image(type="numpy", label="Detailed Fabric")
100
+ ],
101
+ title="Fabric Analyzer",
102
+ description="Upload an image or choose from examples. Click on the garment to analyze fabric and color.",
103
+ examples=[
104
+ ["blue_shirt.png", 400, 500],
105
+ ["polo.png", 400, 500],
106
+ ["dress.jpg", 400, 500]
107
+ ]
108
+ )
109
+
110
+ if __name__ == "__main__":
111
+ iface.launch()
blue_shirt.png ADDED

Git LFS Details

  • SHA256: 70c785c26cf4131e2e11d8c405e0167ee53d3e32d9893456723b64e4b0176bb4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.42 MB
dress.jpg ADDED
polo.png ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ opencv-python-headless
4
+ scikit-learn
5
+ gradio
6
+ segment-anything
sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879