jhj0517 commited on
Commit
2878798
1 Parent(s): a6933f9

Add video predictor initialization

Browse files
Files changed (1) hide show
  1. modules/sam_inference.py +42 -1
modules/sam_inference.py CHANGED
@@ -45,6 +45,7 @@ class SamInference:
45
  self.mask_generator = None
46
  self.image_predictor = None
47
  self.video_predictor = None
 
48
 
49
  def load_model(self,
50
  load_video_predictor: bool = False):
@@ -59,7 +60,8 @@ class SamInference:
59
 
60
  if load_video_predictor:
61
  try:
62
- self.model = build_sam2_video_predictor(
 
63
  config_file=config,
64
  ckpt_path=model_path,
65
  device=self.device
@@ -78,6 +80,16 @@ class SamInference:
78
  logger.exception("Error while loading SAM2 model")
79
  raise f"Error while loading SAM2 model!: {e}"
80
 
 
 
 
 
 
 
 
 
 
 
81
  def generate_mask(self,
82
  image: np.ndarray,
83
  model_type: str,
@@ -121,10 +133,39 @@ class SamInference:
121
  raise f"Error while predicting image with prompt: {str(e)}"
122
  return masks, scores, logits
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def predict_video(self,
125
  video_input):
126
  pass
127
 
 
 
 
 
 
128
  def divide_layer(self,
129
  image_input: np.ndarray,
130
  image_prompt_input_data: Dict,
 
45
  self.mask_generator = None
46
  self.image_predictor = None
47
  self.video_predictor = None
48
+ self.video_inference_state = None
49
 
50
  def load_model(self,
51
  load_video_predictor: bool = False):
 
60
 
61
  if load_video_predictor:
62
  try:
63
+ self.model = None
64
+ self.video_predictor = build_sam2_video_predictor(
65
  config_file=config,
66
  ckpt_path=model_path,
67
  device=self.device
 
80
  logger.exception("Error while loading SAM2 model")
81
  raise f"Error while loading SAM2 model!: {e}"
82
 
83
+ def init_video_inference_state(self,
84
+ vid_input: str):
85
+ if self.video_predictor is None:
86
+ self.load_model(load_video_predictor=True)
87
+
88
+ if self.video_inference_state is not None:
89
+ self.video_predictor.reset_state(self.video_inference_state)
90
+
91
+ self.video_predictor.init_state(video_path=vid_input)
92
+
93
  def generate_mask(self,
94
  image: np.ndarray,
95
  model_type: str,
 
133
  raise f"Error while predicting image with prompt: {str(e)}"
134
  return masks, scores, logits
135
 
136
+ def predict_frame(self,
137
+ frame_idx: int,
138
+ obj_id: int,
139
+ inference_state: Dict,
140
+ points: np.ndarray,
141
+ labels: np.ndarray):
142
+ if self.video_inference_state is None:
143
+ logger.exception("Error while predicting frame from video, load video predictor first")
144
+ raise f"Error while predicting frame from video"
145
+
146
+ try:
147
+ out_masks, out_obj_ids, out_mask_logits = self.video_predictor.add_new_points_or_box(
148
+ inference_state=inference_state,
149
+ frame_idx=frame_idx,
150
+ obj_id=obj_id,
151
+ points=points,
152
+ labels=labels,
153
+ )
154
+ except Exception as e:
155
+ logger.exception("Error while predicting frame with prompt")
156
+ raise f"Error while predicting frame with prompt: {str(e)}"
157
+
158
+ return out_masks, out_obj_ids, out_mask_logits
159
+
160
  def predict_video(self,
161
  video_input):
162
  pass
163
 
164
+ def add_filter_to_preview(self,
165
+ image: np.ndarray,
166
+ ):
167
+ pass
168
+
169
  def divide_layer(self,
170
  image_input: np.ndarray,
171
  image_prompt_input_data: Dict,