jhj0517 commited on
Commit
2c719e3
·
1 Parent(s): 60434a4

integrate the function

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. modules/sam_inference.py +75 -30
app.py CHANGED
@@ -73,14 +73,14 @@ class App:
73
  output_file = gr.File(label="Generated psd file", scale=9)
74
  btn_open_folder = gr.Button("📁\nOpen PSD folder", scale=1)
75
 
76
- sources = [img_input]
77
  model_params = [dd_models]
78
  auto_mask_hparams = [nb_points_per_side, nb_points_per_batch, sld_pred_iou_thresh,
79
  sld_stability_score_thresh, sld_stability_score_offset, nb_crop_n_layers,
80
  sld_box_nms_thresh, nb_crop_n_points_downscale_factor, nb_min_mask_region_area,
81
  cb_use_m2m]
82
 
83
- btn_generate.click(fn=self.sam_inf.generate_mask_app,
84
  inputs=sources + model_params + auto_mask_hparams, outputs=[gallery_output, output_file])
85
  btn_open_folder.click(fn=lambda: open_folder(os.path.join(OUTPUT_DIR)),
86
  inputs=None, outputs=None)
 
73
  output_file = gr.File(label="Generated psd file", scale=9)
74
  btn_open_folder = gr.Button("📁\nOpen PSD folder", scale=1)
75
 
76
+ sources = [img_input, img_input_prompter, dd_input_modes]
77
  model_params = [dd_models]
78
  auto_mask_hparams = [nb_points_per_side, nb_points_per_batch, sld_pred_iou_thresh,
79
  sld_stability_score_thresh, sld_stability_score_offset, nb_crop_n_layers,
80
  sld_box_nms_thresh, nb_crop_n_points_downscale_factor, nb_min_mask_region_area,
81
  cb_use_m2m]
82
 
83
+ btn_generate.click(fn=self.sam_inf.divide_layer,
84
  inputs=sources + model_params + auto_mask_hparams, outputs=[gallery_output, output_file])
85
  btn_open_folder.click(fn=lambda: open_folder(os.path.join(OUTPUT_DIR)),
86
  inputs=None, outputs=None)
modules/sam_inference.py CHANGED
@@ -1,6 +1,7 @@
1
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
2
  from sam2.build_sam import build_sam2
3
  from sam2.sam2_image_predictor import SAM2ImagePredictor
 
4
  import torch
5
  import os
6
  from datetime import datetime
@@ -12,6 +13,7 @@ from modules.model_downloader import (
12
  download_sam_model_url
13
  )
14
  from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR
 
15
  from modules.mask_utils import (
16
  save_psd_with_masks,
17
  create_mask_combined_images,
@@ -62,42 +64,85 @@ class SamInference:
62
  print(f"Error while Loading SAM2 model! {e}")
63
 
64
  def generate_mask(self,
65
- image: np.ndarray):
 
 
 
 
 
 
 
 
 
66
  return self.mask_generator.generate(image)
67
 
68
- def generate_mask_app(self,
69
- image: np.ndarray,
70
- model_type: str,
71
- *params
72
- ):
73
- maskgen_hparams = {
74
- 'points_per_side': int(params[0]),
75
- 'points_per_batch': int(params[1]),
76
- 'pred_iou_thresh': float(params[2]),
77
- 'stability_score_thresh': float(params[3]),
78
- 'stability_score_offset': float(params[4]),
79
- 'crop_n_layers': int(params[5]),
80
- 'box_nms_thresh': float(params[6]),
81
- 'crop_n_points_downscale_factor': int(params[7]),
82
- 'min_mask_region_area': int(params[8]),
83
- 'use_m2m': bool(params[9])
84
- }
85
- timestamp = datetime.now().strftime("%m%d%H%M%S")
86
- output_file_name = f"result-{timestamp}.psd"
87
- output_path = os.path.join(self.output_dir, "psd", output_file_name)
88
-
89
  if self.model is None or self.model_type != model_type:
90
  self.model_type = model_type
91
  self.load_model()
 
 
92
 
93
- self.mask_generator = SAM2AutomaticMaskGenerator(
94
- model=self.model,
95
- **maskgen_hparams
96
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- masks = self.mask_generator.generate(image)
 
 
99
 
100
- save_psd_with_masks(image, masks, output_path)
101
- combined_image = create_mask_combined_images(image, masks)
102
- gallery = create_mask_gallery(image, masks)
103
- return [combined_image] + gallery, output_path
 
1
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
2
  from sam2.build_sam import build_sam2
3
  from sam2.sam2_image_predictor import SAM2ImagePredictor
4
+ from typing import Dict, List
5
  import torch
6
  import os
7
  from datetime import datetime
 
13
  download_sam_model_url
14
  )
15
  from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR
16
+ from modules.constants import BOX_PROMPT_MODE, AUTOMATIC_MODE
17
  from modules.mask_utils import (
18
  save_psd_with_masks,
19
  create_mask_combined_images,
 
64
  print(f"Error while Loading SAM2 model! {e}")
65
 
66
  def generate_mask(self,
67
+ image: np.ndarray,
68
+ model_type: str,
69
+ **params):
70
+ if self.model is None or self.model_type != model_type:
71
+ self.model_type = model_type
72
+ self.load_model()
73
+ self.mask_generator = SAM2AutomaticMaskGenerator(
74
+ model=self.model,
75
+ **params
76
+ )
77
  return self.mask_generator.generate(image)
78
 
79
+ def predict_image(self,
80
+ image: np.ndarray,
81
+ model_type: str,
82
+ box: np.ndarray,
83
+ **params):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  if self.model is None or self.model_type != model_type:
85
  self.model_type = model_type
86
  self.load_model()
87
+ self.image_predictor = SAM2ImagePredictor(sam_model=self.model)
88
+ self.image_predictor.set_image(image)
89
 
90
+ masks, scores, logits = self.image_predictor.predict(
91
+ box=box,
92
+ multimask_output=params["multimask_output"],
93
  )
94
+ print(f"masks: {masks}")
95
+ print(f"scores: {scores}")
96
+ print(f"logits: {logits}")
97
+ return masks, scores, logits
98
+
99
+ def divide_layer(self,
100
+ image_input: np.ndarray,
101
+ image_prompt_input_data: Dict,
102
+ input_mode: str,
103
+ model_type: str,
104
+ *params):
105
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
106
+ output_file_name = f"result-{timestamp}.psd"
107
+ output_path = os.path.join(self.output_dir, "psd", output_file_name)
108
+
109
+ if input_mode == AUTOMATIC_MODE:
110
+ image = image_input
111
+ maskgen_hparams = {
112
+ 'points_per_side': int(params[0]),
113
+ 'points_per_batch': int(params[1]),
114
+ 'pred_iou_thresh': float(params[2]),
115
+ 'stability_score_thresh': float(params[3]),
116
+ 'stability_score_offset': float(params[4]),
117
+ 'crop_n_layers': int(params[5]),
118
+ 'box_nms_thresh': float(params[6]),
119
+ 'crop_n_points_downscale_factor': int(params[7]),
120
+ 'min_mask_region_area': int(params[8]),
121
+ 'use_m2m': bool(params[9])
122
+ }
123
+
124
+ generated_masks = self.generate_mask(
125
+ image=image,
126
+ model_type=model_type,
127
+ **maskgen_hparams
128
+ )
129
+
130
+ elif input_mode == BOX_PROMPT_MODE:
131
+ image = image_prompt_input_data["image"]
132
+ box = image_prompt_input_data["points"]
133
+ predict_image_hparams = {
134
+ "multimask_output": params[0]
135
+ }
136
+
137
+ generated_masks, scores, logits = self.predict_image(
138
+ image=image,
139
+ model_type=model_type,
140
+ box=box,
141
+ **predict_image_hparams
142
+ )
143
 
144
+ save_psd_with_masks(image, generated_masks, output_path)
145
+ mask_combined_image = create_mask_combined_images(image, generated_masks)
146
+ gallery = create_mask_gallery(image, generated_masks)
147
 
148
+ return [mask_combined_image] + gallery, output_path