jhj0517 commited on
Commit
cfa5142
1 Parent(s): baa2a55

Add inference script

Browse files
Files changed (1) hide show
  1. modules/sam_inference.py +119 -0
modules/sam_inference.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
2
+ from sam2.build_sam import build_sam2
3
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
4
+ import torch
5
+ import os
6
+ from datetime import datetime
7
+ import numpy as np
8
+
9
+ from modules.model_downloader import (
10
+ AVAILABLE_MODELS,
11
+ DEFAULT_MODEL_TYPE,
12
+ OUTPUT_DIR,
13
+ is_sam_exist,
14
+ download_sam_model_url
15
+ )
16
+ from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR
17
+ from modules.mask_utils import (
18
+ save_psd_with_masks,
19
+ create_mask_combined_images,
20
+ create_mask_gallery
21
+ )
22
+
23
+ CONFIGS = {
24
+ "sam2_hiera_tiny": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_t.yaml"),
25
+ "sam2_hiera_small": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_s.yaml"),
26
+ "sam2_hiera_base_plus": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_b+.yaml"),
27
+ "sam2_hiera_large": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_l.yaml"),
28
+ }
29
+
30
+
31
+ class SamInference:
32
+ def __init__(self,
33
+ model_dir: str = MODELS_DIR,
34
+ output_dir: str = OUTPUT_DIR
35
+ ):
36
+ self.model = None
37
+ self.available_models = list(AVAILABLE_MODELS.keys())
38
+ self.model_type = DEFAULT_MODEL_TYPE
39
+ self.model_dir = model_dir
40
+ self.output_dir = output_dir
41
+ self.model_path = os.path.join(self.model_dir, AVAILABLE_MODELS[DEFAULT_MODEL_TYPE][0])
42
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ self.mask_generator = None
44
+ self.image_predictor = None
45
+
46
+ # Tunable Parameters , All default values by https://github.com/facebookresearch/segment-anything-2/blob/main/notebooks/automatic_mask_generator_example.ipynb
47
+ self.maskgen_hparams = {
48
+ "points_per_side": 64,
49
+ "points_per_batch": 128,
50
+ "pred_iou_thresh": 0.7,
51
+ "stability_score_thresh": 0.92,
52
+ "stability_score_offset": 0.7,
53
+ "crop_n_layers": 1,
54
+ "box_nms_thresh": 0.7,
55
+ "crop_n_points_downscale_factor": 2,
56
+ "min_mask_region_area": 25.0,
57
+ "use_m2m": True,
58
+ }
59
+
60
+ def load_model(self):
61
+ config = CONFIGS[self.model_type]
62
+ filename, url = AVAILABLE_MODELS[self.model_type]
63
+ model_path = os.path.join(self.model_dir, filename)
64
+
65
+ if not is_sam_exist(self.model_type):
66
+ print(f"\nLayer Divider Extension : No SAM2 model found, downloading {self.model_type} model...")
67
+ download_sam_model_url(self.model_type)
68
+ print("\nLayer Divider Extension : applying configs to model..")
69
+
70
+ try:
71
+ self.model = build_sam2(
72
+ config_file=config,
73
+ ckpt_path=model_path,
74
+ device=self.device
75
+ )
76
+ self.image_predictor = SAM2ImagePredictor(sam_model=self.model)
77
+ self.mask_generator = SAM2AutomaticMaskGenerator(
78
+ model=self.model,
79
+ **self.maskgen_hparams
80
+ )
81
+ except Exception as e:
82
+ print(f"Layer Divider Extension : Error while Loading SAM2 model! {e}")
83
+
84
+ def generate_mask(self,
85
+ image: np.ndarray):
86
+ return self.mask_generator.generate(image)
87
+
88
+ def generate_mask_app(self,
89
+ image: np.ndarray,
90
+ model_type: str,
91
+ *params
92
+ ):
93
+ maskgen_hparams = {
94
+ 'points_per_side': int(params[0]),
95
+ 'points_per_batch': int(params[1]),
96
+ 'pred_iou_thresh': float(params[2]),
97
+ 'stability_score_thresh': float(params[3]),
98
+ 'stability_score_offset': float(params[4]),
99
+ 'crop_n_layers': int(params[5]),
100
+ 'box_nms_thresh': float(params[6]),
101
+ 'crop_n_points_downscale_factor': int(params[7]),
102
+ 'min_mask_region_area': int(params[8]),
103
+ 'use_m2m': bool(params[9])
104
+ }
105
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
106
+ output_file_name = f"result-{timestamp}.psd"
107
+ output_path = os.path.join(self.output_dir, "psd", output_file_name)
108
+
109
+ if self.model is None or self.mask_generator is None or self.model_type != model_type or self.maskgen_hparams != maskgen_hparams:
110
+ self.model_type = model_type
111
+ self.maskgen_hparams = maskgen_hparams
112
+ self.load_model()
113
+
114
+ masks = self.mask_generator.generate(image)
115
+
116
+ save_psd_with_masks(image, masks, output_path)
117
+ combined_image = create_mask_combined_images(image, masks)
118
+ gallery = create_mask_gallery(image, masks)
119
+ return [combined_image] + gallery, output_path