Spaces:
Runtime error
Runtime error
jhj0517
commited on
Commit
•
cfa5142
1
Parent(s):
baa2a55
Add inference script
Browse files- modules/sam_inference.py +119 -0
modules/sam_inference.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
2 |
+
from sam2.build_sam import build_sam2
|
3 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
from datetime import datetime
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from modules.model_downloader import (
|
10 |
+
AVAILABLE_MODELS,
|
11 |
+
DEFAULT_MODEL_TYPE,
|
12 |
+
OUTPUT_DIR,
|
13 |
+
is_sam_exist,
|
14 |
+
download_sam_model_url
|
15 |
+
)
|
16 |
+
from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR
|
17 |
+
from modules.mask_utils import (
|
18 |
+
save_psd_with_masks,
|
19 |
+
create_mask_combined_images,
|
20 |
+
create_mask_gallery
|
21 |
+
)
|
22 |
+
|
23 |
+
CONFIGS = {
|
24 |
+
"sam2_hiera_tiny": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_t.yaml"),
|
25 |
+
"sam2_hiera_small": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_s.yaml"),
|
26 |
+
"sam2_hiera_base_plus": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_b+.yaml"),
|
27 |
+
"sam2_hiera_large": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_l.yaml"),
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
class SamInference:
|
32 |
+
def __init__(self,
|
33 |
+
model_dir: str = MODELS_DIR,
|
34 |
+
output_dir: str = OUTPUT_DIR
|
35 |
+
):
|
36 |
+
self.model = None
|
37 |
+
self.available_models = list(AVAILABLE_MODELS.keys())
|
38 |
+
self.model_type = DEFAULT_MODEL_TYPE
|
39 |
+
self.model_dir = model_dir
|
40 |
+
self.output_dir = output_dir
|
41 |
+
self.model_path = os.path.join(self.model_dir, AVAILABLE_MODELS[DEFAULT_MODEL_TYPE][0])
|
42 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
43 |
+
self.mask_generator = None
|
44 |
+
self.image_predictor = None
|
45 |
+
|
46 |
+
# Tunable Parameters , All default values by https://github.com/facebookresearch/segment-anything-2/blob/main/notebooks/automatic_mask_generator_example.ipynb
|
47 |
+
self.maskgen_hparams = {
|
48 |
+
"points_per_side": 64,
|
49 |
+
"points_per_batch": 128,
|
50 |
+
"pred_iou_thresh": 0.7,
|
51 |
+
"stability_score_thresh": 0.92,
|
52 |
+
"stability_score_offset": 0.7,
|
53 |
+
"crop_n_layers": 1,
|
54 |
+
"box_nms_thresh": 0.7,
|
55 |
+
"crop_n_points_downscale_factor": 2,
|
56 |
+
"min_mask_region_area": 25.0,
|
57 |
+
"use_m2m": True,
|
58 |
+
}
|
59 |
+
|
60 |
+
def load_model(self):
|
61 |
+
config = CONFIGS[self.model_type]
|
62 |
+
filename, url = AVAILABLE_MODELS[self.model_type]
|
63 |
+
model_path = os.path.join(self.model_dir, filename)
|
64 |
+
|
65 |
+
if not is_sam_exist(self.model_type):
|
66 |
+
print(f"\nLayer Divider Extension : No SAM2 model found, downloading {self.model_type} model...")
|
67 |
+
download_sam_model_url(self.model_type)
|
68 |
+
print("\nLayer Divider Extension : applying configs to model..")
|
69 |
+
|
70 |
+
try:
|
71 |
+
self.model = build_sam2(
|
72 |
+
config_file=config,
|
73 |
+
ckpt_path=model_path,
|
74 |
+
device=self.device
|
75 |
+
)
|
76 |
+
self.image_predictor = SAM2ImagePredictor(sam_model=self.model)
|
77 |
+
self.mask_generator = SAM2AutomaticMaskGenerator(
|
78 |
+
model=self.model,
|
79 |
+
**self.maskgen_hparams
|
80 |
+
)
|
81 |
+
except Exception as e:
|
82 |
+
print(f"Layer Divider Extension : Error while Loading SAM2 model! {e}")
|
83 |
+
|
84 |
+
def generate_mask(self,
|
85 |
+
image: np.ndarray):
|
86 |
+
return self.mask_generator.generate(image)
|
87 |
+
|
88 |
+
def generate_mask_app(self,
|
89 |
+
image: np.ndarray,
|
90 |
+
model_type: str,
|
91 |
+
*params
|
92 |
+
):
|
93 |
+
maskgen_hparams = {
|
94 |
+
'points_per_side': int(params[0]),
|
95 |
+
'points_per_batch': int(params[1]),
|
96 |
+
'pred_iou_thresh': float(params[2]),
|
97 |
+
'stability_score_thresh': float(params[3]),
|
98 |
+
'stability_score_offset': float(params[4]),
|
99 |
+
'crop_n_layers': int(params[5]),
|
100 |
+
'box_nms_thresh': float(params[6]),
|
101 |
+
'crop_n_points_downscale_factor': int(params[7]),
|
102 |
+
'min_mask_region_area': int(params[8]),
|
103 |
+
'use_m2m': bool(params[9])
|
104 |
+
}
|
105 |
+
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
106 |
+
output_file_name = f"result-{timestamp}.psd"
|
107 |
+
output_path = os.path.join(self.output_dir, "psd", output_file_name)
|
108 |
+
|
109 |
+
if self.model is None or self.mask_generator is None or self.model_type != model_type or self.maskgen_hparams != maskgen_hparams:
|
110 |
+
self.model_type = model_type
|
111 |
+
self.maskgen_hparams = maskgen_hparams
|
112 |
+
self.load_model()
|
113 |
+
|
114 |
+
masks = self.mask_generator.generate(image)
|
115 |
+
|
116 |
+
save_psd_with_masks(image, masks, output_path)
|
117 |
+
combined_image = create_mask_combined_images(image, masks)
|
118 |
+
gallery = create_mask_gallery(image, masks)
|
119 |
+
return [combined_image] + gallery, output_path
|