jhj0517 commited on
Commit
60434a4
1 Parent(s): 97f1bae

Remove meaningless attribute

Browse files
Files changed (1) hide show
  1. modules/sam_inference.py +8 -30
modules/sam_inference.py CHANGED
@@ -40,20 +40,7 @@ class SamInference:
40
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
  self.mask_generator = None
42
  self.image_predictor = None
43
-
44
- # Tunable Parameters , All default values by https://github.com/facebookresearch/segment-anything-2/blob/main/notebooks/automatic_mask_generator_example.ipynb
45
- self.maskgen_hparams = {
46
- "points_per_side": 64,
47
- "points_per_batch": 128,
48
- "pred_iou_thresh": 0.7,
49
- "stability_score_thresh": 0.92,
50
- "stability_score_offset": 0.7,
51
- "crop_n_layers": 1,
52
- "box_nms_thresh": 0.7,
53
- "crop_n_points_downscale_factor": 2,
54
- "min_mask_region_area": 25.0,
55
- "use_m2m": True,
56
- }
57
 
58
  def load_model(self):
59
  config = CONFIGS[self.model_type]
@@ -61,9 +48,9 @@ class SamInference:
61
  model_path = os.path.join(self.model_dir, filename)
62
 
63
  if not is_sam_exist(self.model_type):
64
- print(f"\nLayer Divider Extension : No SAM2 model found, downloading {self.model_type} model...")
65
  download_sam_model_url(self.model_type)
66
- print("\nLayer Divider Extension : applying configs to model..")
67
 
68
  try:
69
  self.model = build_sam2(
@@ -72,17 +59,7 @@ class SamInference:
72
  device=self.device
73
  )
74
  except Exception as e:
75
- print(f"Layer Divider Extension : Error while Loading SAM2 model! {e}")
76
-
77
- def set_predictors(self):
78
- if self.model is None:
79
- self.load_model()
80
-
81
- self.image_predictor = SAM2ImagePredictor(sam_model=self.model)
82
- self.mask_generator = SAM2AutomaticMaskGenerator(
83
- model=self.model,
84
- **self.maskgen_hparams
85
- )
86
 
87
  def generate_mask(self,
88
  image: np.ndarray):
@@ -113,9 +90,10 @@ class SamInference:
113
  self.model_type = model_type
114
  self.load_model()
115
 
116
- if self.mask_generator is None or self.maskgen_hparams != maskgen_hparams:
117
- self.maskgen_hparams = maskgen_hparams
118
- self.set_predictors()
 
119
 
120
  masks = self.mask_generator.generate(image)
121
 
 
40
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
  self.mask_generator = None
42
  self.image_predictor = None
43
+ self.video_predictor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def load_model(self):
46
  config = CONFIGS[self.model_type]
 
48
  model_path = os.path.join(self.model_dir, filename)
49
 
50
  if not is_sam_exist(self.model_type):
51
+ print(f"\nNo SAM2 model found, downloading {self.model_type} model...")
52
  download_sam_model_url(self.model_type)
53
+ print("\nApplying configs to model..")
54
 
55
  try:
56
  self.model = build_sam2(
 
59
  device=self.device
60
  )
61
  except Exception as e:
62
+ print(f"Error while Loading SAM2 model! {e}")
 
 
 
 
 
 
 
 
 
 
63
 
64
  def generate_mask(self,
65
  image: np.ndarray):
 
90
  self.model_type = model_type
91
  self.load_model()
92
 
93
+ self.mask_generator = SAM2AutomaticMaskGenerator(
94
+ model=self.model,
95
+ **maskgen_hparams
96
+ )
97
 
98
  masks = self.mask_generator.generate(image)
99