leedoming commited on
Commit
fd87108
·
verified ·
1 Parent(s): 55609ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -24
app.py CHANGED
@@ -56,9 +56,10 @@ def extract_color_histogram(image, mask=None):
56
  try:
57
  img_array = np.array(image)
58
  if mask is not None:
59
- # Apply mask
60
- mask = np.expand_dims(mask, axis=2)
61
- img_array = img_array * mask
 
62
  # Only consider pixels that are part of the clothing item
63
  valid_pixels = img_array[mask[:,:,0] > 0]
64
  else:
@@ -66,24 +67,26 @@ def extract_color_histogram(image, mask=None):
66
 
67
  # Convert to HSV color space for better color representation
68
  if len(valid_pixels) > 0:
69
- img_hsv = Image.fromarray(valid_pixels.reshape(1, -1, 3).astype(np.uint8)).convert('HSV')
 
 
70
  hsv_pixels = np.array(img_hsv)
71
 
72
  # Calculate histogram for each HSV channel
73
- h_hist = np.histogram(hsv_pixels[:,:,0], bins=10, range=(0, 256))[0]
74
- s_hist = np.histogram(hsv_pixels[:,:,1], bins=10, range=(0, 256))[0]
75
- v_hist = np.histogram(hsv_pixels[:,:,2], bins=10, range=(0, 256))[0]
76
 
77
  # Normalize histograms
78
- h_hist = h_hist / h_hist.sum() if h_hist.sum() > 0 else h_hist
79
- s_hist = s_hist / s_hist.sum() if s_hist.sum() > 0 else s_hist
80
- v_hist = v_hist / v_hist.sum() if v_hist.sum() > 0 else v_hist
81
 
82
  return np.concatenate([h_hist, s_hist, v_hist])
83
- return np.zeros(30) # Return zero histogram if no valid pixels
84
  except Exception as e:
85
  logger.error(f"Color histogram extraction error: {e}")
86
- return np.zeros(30)
87
 
88
  def process_segmentation(image):
89
  """Segmentation processing"""
@@ -137,7 +140,7 @@ def extract_features(image, mask=None):
137
  # Extract CLIP features
138
  if mask is not None:
139
  img_array = np.array(image)
140
- mask = np.expand_dims(mask, axis=2)
141
  masked_img = img_array * mask
142
  masked_img[mask[:,:,0] == 0] = 255 # Set background to white
143
  image = Image.fromarray(masked_img.astype(np.uint8))
@@ -151,19 +154,24 @@ def extract_features(image, mask=None):
151
  # Extract color features
152
  color_features = extract_color_histogram(image, mask)
153
 
154
- # Combine features
155
- # Note: We normalize and weight the features to balance their influence
156
- clip_features_normalized = clip_features / np.linalg.norm(clip_features)
157
- color_features_normalized = color_features / np.linalg.norm(color_features)
 
 
 
 
 
 
 
158
 
159
- # Adjust these weights to control the influence of each feature type
160
- clip_weight = 0.7 # CLIP features weight
161
- color_weight = 0.3 # Color features weight
162
 
163
- combined_features = np.concatenate([
164
- clip_features_normalized * clip_weight,
165
- color_features_normalized * color_weight
166
- ])
167
 
168
  return combined_features
169
  except Exception as e:
 
56
  try:
57
  img_array = np.array(image)
58
  if mask is not None:
59
+ # Reshape mask to match image dimensions
60
+ mask = np.expand_dims(mask, axis=-1) # Add channel dimension
61
+ img_array = img_array * mask # Broadcasting will work correctly now
62
+
63
  # Only consider pixels that are part of the clothing item
64
  valid_pixels = img_array[mask[:,:,0] > 0]
65
  else:
 
67
 
68
  # Convert to HSV color space for better color representation
69
  if len(valid_pixels) > 0:
70
+ # Reshape to proper dimensions for PIL Image
71
+ valid_pixels = valid_pixels.reshape(-1, 3)
72
+ img_hsv = Image.fromarray(valid_pixels.astype(np.uint8)).convert('HSV')
73
  hsv_pixels = np.array(img_hsv)
74
 
75
  # Calculate histogram for each HSV channel
76
+ h_hist = np.histogram(hsv_pixels[:,0], bins=8, range=(0, 256))[0]
77
+ s_hist = np.histogram(hsv_pixels[:,1], bins=8, range=(0, 256))[0]
78
+ v_hist = np.histogram(hsv_pixels[:,2], bins=8, range=(0, 256))[0]
79
 
80
  # Normalize histograms
81
+ h_hist = h_hist / (h_hist.sum() + 1e-8) # Add small epsilon to avoid division by zero
82
+ s_hist = s_hist / (s_hist.sum() + 1e-8)
83
+ v_hist = v_hist / (v_hist.sum() + 1e-8)
84
 
85
  return np.concatenate([h_hist, s_hist, v_hist])
86
+ return np.zeros(24) # 8bins * 3channels = 24 features
87
  except Exception as e:
88
  logger.error(f"Color histogram extraction error: {e}")
89
+ return np.zeros(24)
90
 
91
  def process_segmentation(image):
92
  """Segmentation processing"""
 
140
  # Extract CLIP features
141
  if mask is not None:
142
  img_array = np.array(image)
143
+ mask = np.expand_dims(mask, axis=-1)
144
  masked_img = img_array * mask
145
  masked_img[mask[:,:,0] == 0] = 255 # Set background to white
146
  image = Image.fromarray(masked_img.astype(np.uint8))
 
154
  # Extract color features
155
  color_features = extract_color_histogram(image, mask)
156
 
157
+ # CLIP features are 768-dimensional, so we'll resize color features
158
+ # to maintain the same total dimensionality
159
+ clip_features = clip_features[:744] # Trim CLIP features to make room for color
160
+
161
+ # Normalize features
162
+ clip_features_normalized = clip_features / (np.linalg.norm(clip_features) + 1e-8)
163
+ color_features_normalized = color_features / (np.linalg.norm(color_features) + 1e-8)
164
+
165
+ # Adjust weights (total should be 768 to match collection dimensionality)
166
+ clip_weight = 0.7
167
+ color_weight = 0.3
168
 
169
+ combined_features = np.zeros(768) # Initialize with zeros
170
+ combined_features[:744] = clip_features_normalized * clip_weight # First 744 dimensions for CLIP
171
+ combined_features[744:] = color_features_normalized * color_weight # Last 24 dimensions for color
172
 
173
+ # Ensure final normalization
174
+ combined_features = combined_features / (np.linalg.norm(combined_features) + 1e-8)
 
 
175
 
176
  return combined_features
177
  except Exception as e: