jhj0517 commited on
Commit
029295b
1 Parent(s): 71c08fe

Enable image restoration in video creation

Browse files
modules/live_portrait/live_portrait_inferencer.py CHANGED
@@ -257,6 +257,7 @@ class LivePortraitInferencer:
257
  crop_factor: float = 2.3,
258
  src_image: Optional[str] = None,
259
  driving_vid_path: Optional[str] = None,
 
260
  progress: gr.Progress = gr.Progress()
261
  ):
262
  if self.pipeline is None or model_type != self.model_type:
@@ -328,7 +329,10 @@ class LivePortraitInferencer:
328
  np.uint8)
329
 
330
  out_frame_path = get_auto_incremental_file_path(os.path.join(self.output_dir, "temp", "video_frames", "out"), "png")
331
- save_image(out, out_frame_path)
 
 
 
332
 
333
  progress(i/total_length, desc=f"Generating frames {i}/{total_length} ..")
334
 
 
257
  crop_factor: float = 2.3,
258
  src_image: Optional[str] = None,
259
  driving_vid_path: Optional[str] = None,
260
+ enable_image_restoration: bool = False,
261
  progress: gr.Progress = gr.Progress()
262
  ):
263
  if self.pipeline is None or model_type != self.model_type:
 
329
  np.uint8)
330
 
331
  out_frame_path = get_auto_incremental_file_path(os.path.join(self.output_dir, "temp", "video_frames", "out"), "png")
332
+ out_frame_path = save_image(out, out_frame_path)
333
+
334
+ if enable_image_restoration:
335
+ out_frame_path = self.resrgan_inferencer.restore_image(out_frame_path)
336
 
337
  progress(i/total_length, desc=f"Generating frames {i}/{total_length} ..")
338