Spaces:
Runtime error
Runtime error
jhj0517
commited on
Commit
•
2c719e3
1
Parent(s):
60434a4
integrate the function
Browse files- app.py +2 -2
- modules/sam_inference.py +75 -30
app.py
CHANGED
@@ -73,14 +73,14 @@ class App:
|
|
73 |
output_file = gr.File(label="Generated psd file", scale=9)
|
74 |
btn_open_folder = gr.Button("📁\nOpen PSD folder", scale=1)
|
75 |
|
76 |
-
sources = [img_input]
|
77 |
model_params = [dd_models]
|
78 |
auto_mask_hparams = [nb_points_per_side, nb_points_per_batch, sld_pred_iou_thresh,
|
79 |
sld_stability_score_thresh, sld_stability_score_offset, nb_crop_n_layers,
|
80 |
sld_box_nms_thresh, nb_crop_n_points_downscale_factor, nb_min_mask_region_area,
|
81 |
cb_use_m2m]
|
82 |
|
83 |
-
btn_generate.click(fn=self.sam_inf.
|
84 |
inputs=sources + model_params + auto_mask_hparams, outputs=[gallery_output, output_file])
|
85 |
btn_open_folder.click(fn=lambda: open_folder(os.path.join(OUTPUT_DIR)),
|
86 |
inputs=None, outputs=None)
|
|
|
73 |
output_file = gr.File(label="Generated psd file", scale=9)
|
74 |
btn_open_folder = gr.Button("📁\nOpen PSD folder", scale=1)
|
75 |
|
76 |
+
sources = [img_input, img_input_prompter, dd_input_modes]
|
77 |
model_params = [dd_models]
|
78 |
auto_mask_hparams = [nb_points_per_side, nb_points_per_batch, sld_pred_iou_thresh,
|
79 |
sld_stability_score_thresh, sld_stability_score_offset, nb_crop_n_layers,
|
80 |
sld_box_nms_thresh, nb_crop_n_points_downscale_factor, nb_min_mask_region_area,
|
81 |
cb_use_m2m]
|
82 |
|
83 |
+
btn_generate.click(fn=self.sam_inf.divide_layer,
|
84 |
inputs=sources + model_params + auto_mask_hparams, outputs=[gallery_output, output_file])
|
85 |
btn_open_folder.click(fn=lambda: open_folder(os.path.join(OUTPUT_DIR)),
|
86 |
inputs=None, outputs=None)
|
modules/sam_inference.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
2 |
from sam2.build_sam import build_sam2
|
3 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
|
4 |
import torch
|
5 |
import os
|
6 |
from datetime import datetime
|
@@ -12,6 +13,7 @@ from modules.model_downloader import (
|
|
12 |
download_sam_model_url
|
13 |
)
|
14 |
from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR
|
|
|
15 |
from modules.mask_utils import (
|
16 |
save_psd_with_masks,
|
17 |
create_mask_combined_images,
|
@@ -62,42 +64,85 @@ class SamInference:
|
|
62 |
print(f"Error while Loading SAM2 model! {e}")
|
63 |
|
64 |
def generate_mask(self,
|
65 |
-
image: np.ndarray
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
return self.mask_generator.generate(image)
|
67 |
|
68 |
-
def
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
maskgen_hparams = {
|
74 |
-
'points_per_side': int(params[0]),
|
75 |
-
'points_per_batch': int(params[1]),
|
76 |
-
'pred_iou_thresh': float(params[2]),
|
77 |
-
'stability_score_thresh': float(params[3]),
|
78 |
-
'stability_score_offset': float(params[4]),
|
79 |
-
'crop_n_layers': int(params[5]),
|
80 |
-
'box_nms_thresh': float(params[6]),
|
81 |
-
'crop_n_points_downscale_factor': int(params[7]),
|
82 |
-
'min_mask_region_area': int(params[8]),
|
83 |
-
'use_m2m': bool(params[9])
|
84 |
-
}
|
85 |
-
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
86 |
-
output_file_name = f"result-{timestamp}.psd"
|
87 |
-
output_path = os.path.join(self.output_dir, "psd", output_file_name)
|
88 |
-
|
89 |
if self.model is None or self.model_type != model_type:
|
90 |
self.model_type = model_type
|
91 |
self.load_model()
|
|
|
|
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
-
|
|
|
|
|
99 |
|
100 |
-
|
101 |
-
combined_image = create_mask_combined_images(image, masks)
|
102 |
-
gallery = create_mask_gallery(image, masks)
|
103 |
-
return [combined_image] + gallery, output_path
|
|
|
1 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
2 |
from sam2.build_sam import build_sam2
|
3 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
4 |
+
from typing import Dict, List
|
5 |
import torch
|
6 |
import os
|
7 |
from datetime import datetime
|
|
|
13 |
download_sam_model_url
|
14 |
)
|
15 |
from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR
|
16 |
+
from modules.constants import BOX_PROMPT_MODE, AUTOMATIC_MODE
|
17 |
from modules.mask_utils import (
|
18 |
save_psd_with_masks,
|
19 |
create_mask_combined_images,
|
|
|
64 |
print(f"Error while Loading SAM2 model! {e}")
|
65 |
|
66 |
def generate_mask(self,
|
67 |
+
image: np.ndarray,
|
68 |
+
model_type: str,
|
69 |
+
**params):
|
70 |
+
if self.model is None or self.model_type != model_type:
|
71 |
+
self.model_type = model_type
|
72 |
+
self.load_model()
|
73 |
+
self.mask_generator = SAM2AutomaticMaskGenerator(
|
74 |
+
model=self.model,
|
75 |
+
**params
|
76 |
+
)
|
77 |
return self.mask_generator.generate(image)
|
78 |
|
79 |
+
def predict_image(self,
|
80 |
+
image: np.ndarray,
|
81 |
+
model_type: str,
|
82 |
+
box: np.ndarray,
|
83 |
+
**params):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
if self.model is None or self.model_type != model_type:
|
85 |
self.model_type = model_type
|
86 |
self.load_model()
|
87 |
+
self.image_predictor = SAM2ImagePredictor(sam_model=self.model)
|
88 |
+
self.image_predictor.set_image(image)
|
89 |
|
90 |
+
masks, scores, logits = self.image_predictor.predict(
|
91 |
+
box=box,
|
92 |
+
multimask_output=params["multimask_output"],
|
93 |
)
|
94 |
+
print(f"masks: {masks}")
|
95 |
+
print(f"scores: {scores}")
|
96 |
+
print(f"logits: {logits}")
|
97 |
+
return masks, scores, logits
|
98 |
+
|
99 |
+
def divide_layer(self,
|
100 |
+
image_input: np.ndarray,
|
101 |
+
image_prompt_input_data: Dict,
|
102 |
+
input_mode: str,
|
103 |
+
model_type: str,
|
104 |
+
*params):
|
105 |
+
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
106 |
+
output_file_name = f"result-{timestamp}.psd"
|
107 |
+
output_path = os.path.join(self.output_dir, "psd", output_file_name)
|
108 |
+
|
109 |
+
if input_mode == AUTOMATIC_MODE:
|
110 |
+
image = image_input
|
111 |
+
maskgen_hparams = {
|
112 |
+
'points_per_side': int(params[0]),
|
113 |
+
'points_per_batch': int(params[1]),
|
114 |
+
'pred_iou_thresh': float(params[2]),
|
115 |
+
'stability_score_thresh': float(params[3]),
|
116 |
+
'stability_score_offset': float(params[4]),
|
117 |
+
'crop_n_layers': int(params[5]),
|
118 |
+
'box_nms_thresh': float(params[6]),
|
119 |
+
'crop_n_points_downscale_factor': int(params[7]),
|
120 |
+
'min_mask_region_area': int(params[8]),
|
121 |
+
'use_m2m': bool(params[9])
|
122 |
+
}
|
123 |
+
|
124 |
+
generated_masks = self.generate_mask(
|
125 |
+
image=image,
|
126 |
+
model_type=model_type,
|
127 |
+
**maskgen_hparams
|
128 |
+
)
|
129 |
+
|
130 |
+
elif input_mode == BOX_PROMPT_MODE:
|
131 |
+
image = image_prompt_input_data["image"]
|
132 |
+
box = image_prompt_input_data["points"]
|
133 |
+
predict_image_hparams = {
|
134 |
+
"multimask_output": params[0]
|
135 |
+
}
|
136 |
+
|
137 |
+
generated_masks, scores, logits = self.predict_image(
|
138 |
+
image=image,
|
139 |
+
model_type=model_type,
|
140 |
+
box=box,
|
141 |
+
**predict_image_hparams
|
142 |
+
)
|
143 |
|
144 |
+
save_psd_with_masks(image, generated_masks, output_path)
|
145 |
+
mask_combined_image = create_mask_combined_images(image, generated_masks)
|
146 |
+
gallery = create_mask_gallery(image, generated_masks)
|
147 |
|
148 |
+
return [mask_combined_image] + gallery, output_path
|
|
|
|
|
|