jhj0517 commited on
Commit
0f36b51
1 Parent(s): 8ab6ed9

Add loading video predictor

Browse files
Files changed (1) hide show
  1. modules/sam_inference.py +25 -7
modules/sam_inference.py CHANGED
@@ -1,5 +1,5 @@
1
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
2
- from sam2.build_sam import build_sam2
3
  from sam2.sam2_image_predictor import SAM2ImagePredictor
4
  from typing import Dict, List, Optional
5
  import torch
@@ -46,7 +46,8 @@ class SamInference:
46
  self.image_predictor = None
47
  self.video_predictor = None
48
 
49
- def load_model(self):
 
50
  config = CONFIGS[self.model_type]
51
  filename, url = AVAILABLE_MODELS[self.model_type]
52
  model_path = os.path.join(self.model_dir, filename)
@@ -56,6 +57,17 @@ class SamInference:
56
  download_sam_model_url(self.model_type)
57
  logger.info(f"Applying configs to model..")
58
 
 
 
 
 
 
 
 
 
 
 
 
59
  try:
60
  self.model = build_sam2(
61
  config_file=config,
@@ -63,8 +75,8 @@ class SamInference:
63
  device=self.device
64
  )
65
  except Exception as e:
66
- logger.exception("Error while auto generating masks")
67
- raise f"Error while Loading SAM2 model! {e}"
68
 
69
  def generate_mask(self,
70
  image: np.ndarray,
@@ -81,7 +93,7 @@ class SamInference:
81
  generated_masks = self.mask_generator.generate(image)
82
  except Exception as e:
83
  logger.exception("Error while auto generating masks")
84
- raise f"Error while auto generating masks: {e}"
85
  return generated_masks
86
 
87
  def predict_image(self,
@@ -106,9 +118,13 @@ class SamInference:
106
  )
107
  except Exception as e:
108
  logger.exception("Error while predicting image with prompt")
109
- raise f"Error while predicting image with prompt: {e}"
110
  return masks, scores, logits
111
 
 
 
 
 
112
  def divide_layer(self,
113
  image_input: np.ndarray,
114
  image_prompt_input_data: Dict,
@@ -119,6 +135,7 @@ class SamInference:
119
  output_file_name = f"result-{timestamp}.psd"
120
  output_path = os.path.join(self.output_dir, "psd", output_file_name)
121
 
 
122
  hparams = {
123
  'points_per_side': int(params[0]),
124
  'points_per_batch': int(params[1]),
@@ -171,8 +188,9 @@ class SamInference:
171
  save_psd_with_masks(image, generated_masks, output_path)
172
  mask_combined_image = create_mask_combined_images(image, generated_masks)
173
  gallery = create_mask_gallery(image, generated_masks)
 
174
 
175
- return [mask_combined_image] + gallery, output_path
176
 
177
  @staticmethod
178
  def format_to_auto_result(
 
1
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
2
+ from sam2.build_sam import build_sam2, build_sam2_video_predictor
3
  from sam2.sam2_image_predictor import SAM2ImagePredictor
4
  from typing import Dict, List, Optional
5
  import torch
 
46
  self.image_predictor = None
47
  self.video_predictor = None
48
 
49
+ def load_model(self,
50
+ load_video_predictor: bool = False):
51
  config = CONFIGS[self.model_type]
52
  filename, url = AVAILABLE_MODELS[self.model_type]
53
  model_path = os.path.join(self.model_dir, filename)
 
57
  download_sam_model_url(self.model_type)
58
  logger.info(f"Applying configs to model..")
59
 
60
+ if load_video_predictor:
61
+ try:
62
+ self.model = build_sam2_video_predictor(
63
+ config_file=config,
64
+ ckpt_path=model_path,
65
+ device=self.device
66
+ )
67
+ except Exception as e:
68
+ logger.exception("Error while loading SAM2 model for video predictor")
69
+ raise f"Error while loading SAM2 model for video predictor!: {e}"
70
+
71
  try:
72
  self.model = build_sam2(
73
  config_file=config,
 
75
  device=self.device
76
  )
77
  except Exception as e:
78
+ logger.exception("Error while loading SAM2 model")
79
+ raise f"Error while loading SAM2 model!: {e}"
80
 
81
  def generate_mask(self,
82
  image: np.ndarray,
 
93
  generated_masks = self.mask_generator.generate(image)
94
  except Exception as e:
95
  logger.exception("Error while auto generating masks")
96
+ raise f"Error while auto generating masks: str({e})"
97
  return generated_masks
98
 
99
  def predict_image(self,
 
118
  )
119
  except Exception as e:
120
  logger.exception("Error while predicting image with prompt")
121
+ raise f"Error while predicting image with prompt: {str(e)}"
122
  return masks, scores, logits
123
 
124
+ def predict_video(self,
125
+ video_input):
126
+ pass
127
+
128
  def divide_layer(self,
129
  image_input: np.ndarray,
130
  image_prompt_input_data: Dict,
 
135
  output_file_name = f"result-{timestamp}.psd"
136
  output_path = os.path.join(self.output_dir, "psd", output_file_name)
137
 
138
+ # Pre-processed gradio components
139
  hparams = {
140
  'points_per_side': int(params[0]),
141
  'points_per_batch': int(params[1]),
 
188
  save_psd_with_masks(image, generated_masks, output_path)
189
  mask_combined_image = create_mask_combined_images(image, generated_masks)
190
  gallery = create_mask_gallery(image, generated_masks)
191
+ gallery = [mask_combined_image] + gallery
192
 
193
+ return gallery, output_path
194
 
195
  @staticmethod
196
  def format_to_auto_result(