jhj0517 commited on
Commit
002d880
·
1 Parent(s): 7036f3f

Add `cb_multimask_output` and fix result type

Browse files
Files changed (1) hide show
  1. modules/sam_inference.py +19 -18
modules/sam_inference.py CHANGED
@@ -103,25 +103,27 @@ class SamInference:
103
  output_file_name = f"result-{timestamp}.psd"
104
  output_path = os.path.join(self.output_dir, "psd", output_file_name)
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  if input_mode == AUTOMATIC_MODE:
107
  image = image_input
108
- maskgen_hparams = {
109
- 'points_per_side': int(params[0]),
110
- 'points_per_batch': int(params[1]),
111
- 'pred_iou_thresh': float(params[2]),
112
- 'stability_score_thresh': float(params[3]),
113
- 'stability_score_offset': float(params[4]),
114
- 'crop_n_layers': int(params[5]),
115
- 'box_nms_thresh': float(params[6]),
116
- 'crop_n_points_downscale_factor': int(params[7]),
117
- 'min_mask_region_area': int(params[8]),
118
- 'use_m2m': bool(params[9])
119
- }
120
 
121
  generated_masks = self.generate_mask(
122
  image=image,
123
  model_type=model_type,
124
- **maskgen_hparams
125
  )
126
 
127
  elif input_mode == BOX_PROMPT_MODE:
@@ -129,15 +131,12 @@ class SamInference:
129
  image = np.array(image.convert("RGB"))
130
  box = image_prompt_input_data["points"]
131
  box = np.array([[x1, y1, x2, y2] for x1, y1, _, x2, y2, _ in box])
132
- predict_image_hparams = {
133
- "multimask_output": params[0]
134
- }
135
 
136
  predicted_masks, scores, logits = self.predict_image(
137
  image=image,
138
  model_type=model_type,
139
  box=box,
140
- **predict_image_hparams
141
  )
142
  generated_masks = self.format_to_auto_result(predicted_masks)
143
 
@@ -152,5 +151,7 @@ class SamInference:
152
  masks: np.ndarray
153
  ):
154
  place_holder = 0
155
- result = [{"segmentation": mask, "area": place_holder} for mask in masks]
 
 
156
  return result
 
103
  output_file_name = f"result-{timestamp}.psd"
104
  output_path = os.path.join(self.output_dir, "psd", output_file_name)
105
 
106
+ hparams = {
107
+ 'points_per_side': int(params[0]),
108
+ 'points_per_batch': int(params[1]),
109
+ 'pred_iou_thresh': float(params[2]),
110
+ 'stability_score_thresh': float(params[3]),
111
+ 'stability_score_offset': float(params[4]),
112
+ 'crop_n_layers': int(params[5]),
113
+ 'box_nms_thresh': float(params[6]),
114
+ 'crop_n_points_downscale_factor': int(params[7]),
115
+ 'min_mask_region_area': int(params[8]),
116
+ 'use_m2m': bool(params[9]),
117
+ 'multimask_output': bool(params[10])
118
+ }
119
+
120
  if input_mode == AUTOMATIC_MODE:
121
  image = image_input
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  generated_masks = self.generate_mask(
124
  image=image,
125
  model_type=model_type,
126
+ **hparams
127
  )
128
 
129
  elif input_mode == BOX_PROMPT_MODE:
 
131
  image = np.array(image.convert("RGB"))
132
  box = image_prompt_input_data["points"]
133
  box = np.array([[x1, y1, x2, y2] for x1, y1, _, x2, y2, _ in box])
 
 
 
134
 
135
  predicted_masks, scores, logits = self.predict_image(
136
  image=image,
137
  model_type=model_type,
138
  box=box,
139
+ multimask_output=hparams["multimask_output"]
140
  )
141
  generated_masks = self.format_to_auto_result(predicted_masks)
142
 
 
151
  masks: np.ndarray
152
  ):
153
  place_holder = 0
154
+ if len(masks) == 1:
155
+ return [{"segmentation": mask, "area": place_holder} for mask in masks]
156
+ result = [{"segmentation": mask[0], "area": place_holder} for mask in masks]
157
  return result