jhj0517 commited on
Commit
3423be2
·
1 Parent(s): 62a455e

Add create_filtered_video() and move logic to the inference class

Browse files
Files changed (1) hide show
  1. modules/sam_inference.py +26 -15
modules/sam_inference.py CHANGED
@@ -23,7 +23,8 @@ from modules.mask_utils import (
23
  create_mask_pixelized_image,
24
  create_solid_color_mask_image
25
  )
26
- from modules.video_utils import get_frames_from_dir
 
27
  from modules.utils import save_image
28
  from modules.logger_util import get_logger
29
 
@@ -53,6 +54,7 @@ class SamInference:
53
  self.image_predictor = None
54
  self.video_predictor = None
55
  self.video_inference_state = None
 
56
 
57
  def load_model(self,
58
  model_type: Optional[str] = None,
@@ -79,7 +81,6 @@ class SamInference:
79
  )
80
  except Exception as e:
81
  logger.exception("Error while loading SAM2 model for video predictor")
82
- raise f"Error while loading SAM2 model for video predictor!: {e}"
83
 
84
  try:
85
  self.model = build_sam2(
@@ -89,7 +90,6 @@ class SamInference:
89
  )
90
  except Exception as e:
91
  logger.exception("Error while loading SAM2 model")
92
- raise f"Error while loading SAM2 model!: {e}"
93
 
94
  def init_video_inference_state(self,
95
  vid_input: str,
@@ -101,11 +101,18 @@ class SamInference:
101
  self.current_model_type = model_type
102
  self.load_model(model_type=model_type, load_video_predictor=True)
103
 
 
 
 
 
 
 
 
104
  if self.video_inference_state is not None:
105
  self.video_predictor.reset_state(self.video_inference_state)
106
  self.video_inference_state = None
107
 
108
- self.video_inference_state = self.video_predictor.init_state(video_path=vid_input)
109
 
110
  def generate_mask(self,
111
  image: np.ndarray,
@@ -147,7 +154,6 @@ class SamInference:
147
  )
148
  except Exception as e:
149
  logger.exception(f"Error while predicting image with prompt: {str(e)}")
150
- raise RuntimeError(f"Error while predicting image with prompt: {str(e)}") from e
151
  return masks, scores, logits
152
 
153
  def add_prediction_to_frame(self,
@@ -160,7 +166,6 @@ class SamInference:
160
  if (self.video_predictor is None or
161
  inference_state is None and self.video_inference_state is None):
162
  logger.exception("Error while predicting frame from video, load video predictor first")
163
- raise f"Error while predicting frame from video"
164
 
165
  if inference_state is None:
166
  inference_state = self.video_inference_state
@@ -184,7 +189,6 @@ class SamInference:
184
  inference_state: Optional[Dict] = None,):
185
  if inference_state is None and self.video_inference_state is None:
186
  logger.exception("Error while propagating in video, load video predictor first")
187
- raise f"Error while propagating in video"
188
 
189
  if inference_state is None:
190
  inference_state = self.video_inference_state
@@ -196,7 +200,6 @@ class SamInference:
196
  inference_state=inference_state,
197
  start_frame_idx=0
198
  )
199
- cached_images = inference_state["images"]
200
  images = get_frames_from_dir(vid_dir=TEMP_DIR, as_numpy=True)
201
 
202
  with torch.autocast(device_type=self.device, dtype=torch.float16):
@@ -208,7 +211,6 @@ class SamInference:
208
  }
209
  except Exception as e:
210
  logger.exception(f"Error while propagating in video: {str(e)}")
211
- raise RuntimeError(f"Failed to propagate in video: {str(e)}") from e
212
 
213
  return video_segments
214
 
@@ -255,12 +257,13 @@ class SamInference:
255
 
256
  return image
257
 
258
- def add_filter_to_video(self,
259
- image_prompt_input_data: Dict,
260
- filter_mode: str,
261
- frame_idx: int,
262
- pixel_size: Optional[int] = None,
263
- color_hex: Optional[str] = None,):
 
264
  if self.video_predictor is None or self.video_inference_state is None:
265
  logger.exception("Error while adding filter to preview, load video predictor first")
266
  raise f"Error while adding filter to preview"
@@ -299,6 +302,14 @@ class SamInference:
299
 
300
  save_image(image=filtered_image, output_dir=TEMP_OUT_DIR)
301
 
 
 
 
 
 
 
 
 
302
  def divide_layer(self,
303
  image_input: np.ndarray,
304
  image_prompt_input_data: Dict,
 
23
  create_mask_pixelized_image,
24
  create_solid_color_mask_image
25
  )
26
+ from modules.video_utils import (get_frames_from_dir, create_video_from_frames, get_video_info, extract_frames,
27
+ extract_sound, clean_temp_dir)
28
  from modules.utils import save_image
29
  from modules.logger_util import get_logger
30
 
 
54
  self.image_predictor = None
55
  self.video_predictor = None
56
  self.video_inference_state = None
57
+ self.video_info = None
58
 
59
  def load_model(self,
60
  model_type: Optional[str] = None,
 
81
  )
82
  except Exception as e:
83
  logger.exception("Error while loading SAM2 model for video predictor")
 
84
 
85
  try:
86
  self.model = build_sam2(
 
90
  )
91
  except Exception as e:
92
  logger.exception("Error while loading SAM2 model")
 
93
 
94
  def init_video_inference_state(self,
95
  vid_input: str,
 
101
  self.current_model_type = model_type
102
  self.load_model(model_type=model_type, load_video_predictor=True)
103
 
104
+ self.video_info = get_video_info(vid_input)
105
+ frames_temp_dir = TEMP_DIR
106
+ clean_temp_dir(frames_temp_dir)
107
+ extract_frames(vid_input, frames_temp_dir)
108
+ if self.video_info.has_sound:
109
+ extract_sound(vid_input, frames_temp_dir)
110
+
111
  if self.video_inference_state is not None:
112
  self.video_predictor.reset_state(self.video_inference_state)
113
  self.video_inference_state = None
114
 
115
+ self.video_inference_state = self.video_predictor.init_state(video_path=frames_temp_dir)
116
 
117
  def generate_mask(self,
118
  image: np.ndarray,
 
154
  )
155
  except Exception as e:
156
  logger.exception(f"Error while predicting image with prompt: {str(e)}")
 
157
  return masks, scores, logits
158
 
159
  def add_prediction_to_frame(self,
 
166
  if (self.video_predictor is None or
167
  inference_state is None and self.video_inference_state is None):
168
  logger.exception("Error while predicting frame from video, load video predictor first")
 
169
 
170
  if inference_state is None:
171
  inference_state = self.video_inference_state
 
189
  inference_state: Optional[Dict] = None,):
190
  if inference_state is None and self.video_inference_state is None:
191
  logger.exception("Error while propagating in video, load video predictor first")
 
192
 
193
  if inference_state is None:
194
  inference_state = self.video_inference_state
 
200
  inference_state=inference_state,
201
  start_frame_idx=0
202
  )
 
203
  images = get_frames_from_dir(vid_dir=TEMP_DIR, as_numpy=True)
204
 
205
  with torch.autocast(device_type=self.device, dtype=torch.float16):
 
211
  }
212
  except Exception as e:
213
  logger.exception(f"Error while propagating in video: {str(e)}")
 
214
 
215
  return video_segments
216
 
 
257
 
258
  return image
259
 
260
+ def create_filtered_video(self,
261
+ image_prompt_input_data: Dict,
262
+ filter_mode: str,
263
+ frame_idx: int,
264
+ pixel_size: Optional[int] = None,
265
+ color_hex: Optional[str] = None
266
+ ):
267
  if self.video_predictor is None or self.video_inference_state is None:
268
  logger.exception("Error while adding filter to preview, load video predictor first")
269
  raise f"Error while adding filter to preview"
 
302
 
303
  save_image(image=filtered_image, output_dir=TEMP_OUT_DIR)
304
 
305
+ out_video = create_video_from_frames(
306
+ frames_dir=TEMP_DIR,
307
+ frame_rate=self.video_info.frame_rate,
308
+ output_dir=self.output_dir,
309
+ )
310
+
311
+ return out_video, out_video
312
+
313
  def divide_layer(self,
314
  image_input: np.ndarray,
315
  image_prompt_input_data: Dict,