wenmengzhou commited on
Commit
65882a7
·
verified ·
1 Parent(s): 1932d9e

split video_generation to two function

Browse files
Files changed (1) hide show
  1. webgui.py +104 -35
webgui.py CHANGED
@@ -160,9 +160,8 @@ def select_face(det_bboxes, probs):
160
  return sorted_bboxes[0]
161
 
162
  lmk_extractor = LMKExtractor()
163
- @spaces.GPU
164
- def process_video(uploaded_img, uploaded_audio, width, height, length, facemask_dilation_ratio, facecrop_dilation_ratio, context_frames, context_overlap, cfg, steps, sample_rate, fps, device):
165
- #### face musk prepare
166
  face_img = cv2.imread(uploaded_img)
167
  if face_img is None:
168
  raise gr.Error("input image should be uploaded or selected.")
@@ -178,8 +177,7 @@ def process_video(uploaded_img, uploaded_audio, width, height, length, facemask_
178
  r_pad = int((re - rb) * facemask_dilation_ratio)
179
  c_pad = int((ce - cb) * facemask_dilation_ratio)
180
  face_mask[rb - r_pad : re + r_pad, cb - c_pad : ce + c_pad] = 255
181
-
182
- #### face crop
183
  r_pad_crop = int((re - rb) * facecrop_dilation_ratio)
184
  c_pad_crop = int((ce - cb) * facecrop_dilation_ratio)
185
  crop_rect = [max(0, cb - c_pad_crop), max(0, rb - r_pad_crop), min(ce + c_pad_crop, face_img.shape[1]), min(re + r_pad_crop, face_img.shape[0])]
@@ -187,39 +185,14 @@ def process_video(uploaded_img, uploaded_audio, width, height, length, facemask_
187
  face_mask = crop_and_pad(face_mask, crop_rect)
188
  face_img = cv2.resize(face_img, (width, height))
189
  face_mask = cv2.resize(face_mask, (width, height))
 
190
  print('face detect done.')
191
- # ==================== face_locator =====================
192
- '''
193
- driver_video = "./assets/driven_videos/c.mp4"
194
 
195
- input_frames_cv2 = [cv2.resize(center_crop_cv2(pil_to_cv2(i)), (512, 512)) for i in pils_from_video(driver_video)]
196
- ref_det = lmk_extractor(face_img)
197
-
198
- visualizer = FaceMeshVisualizer(draw_iris=False, draw_mouse=False)
199
-
200
- pose_list = []
201
- sequence_driver_det = []
202
- try:
203
- for frame in input_frames_cv2:
204
- result = lmk_extractor(frame)
205
- assert result is not None, "{}, bad video, face not detected".format(driver_video)
206
- sequence_driver_det.append(result)
207
- except:
208
- print("face detection failed")
209
- exit()
210
-
211
- sequence_det_ms = motion_sync(sequence_driver_det, ref_det)
212
- for p in sequence_det_ms:
213
- tgt_musk = visualizer.draw_landmarks((width, height), p)
214
- tgt_musk_pil = Image.fromarray(np.array(tgt_musk).astype(np.uint8)).convert('RGB')
215
- pose_list.append(torch.Tensor(np.array(tgt_musk_pil)).to(dtype=weight_dtype, device="cuda").permute(2,0,1) / 255.0)
216
- '''
217
- # face_mask_tensor = torch.stack(pose_list, dim=1).unsqueeze(0)
218
  face_mask_tensor = torch.Tensor(face_mask).to(dtype=weight_dtype, device="cuda").unsqueeze(0).unsqueeze(0).unsqueeze(0) / 255.0
219
-
220
  ref_image_pil = Image.fromarray(face_img[:, :, [2, 1, 0]])
221
-
222
- #del pose_list, sequence_det_ms, sequence_driver_det, input_frames_cv2
223
 
224
  video = pipe(
225
  ref_image_pil,
@@ -230,7 +203,6 @@ def process_video(uploaded_img, uploaded_audio, width, height, length, facemask_
230
  length,
231
  steps,
232
  cfg,
233
- #generator=generator,
234
  audio_sample_rate=sample_rate,
235
  context_frames=context_frames,
236
  fps=fps,
@@ -250,6 +222,103 @@ def process_video(uploaded_img, uploaded_audio, width, height, length, facemask_
250
  video_clip.write_videofile(str(final_output_path), codec="libx264", audio_codec="aac")
251
 
252
  return final_output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  with gr.Blocks() as demo:
255
  gr.Markdown('# EchoMimic')
 
160
  return sorted_bboxes[0]
161
 
162
  lmk_extractor = LMKExtractor()
163
+
164
+ def face_detection(uploaded_img, facemask_dilation_ratio, facecrop_dilation_ratio, width, height):
 
165
  face_img = cv2.imread(uploaded_img)
166
  if face_img is None:
167
  raise gr.Error("input image should be uploaded or selected.")
 
177
  r_pad = int((re - rb) * facemask_dilation_ratio)
178
  c_pad = int((ce - cb) * facemask_dilation_ratio)
179
  face_mask[rb - r_pad : re + r_pad, cb - c_pad : ce + c_pad] = 255
180
+
 
181
  r_pad_crop = int((re - rb) * facecrop_dilation_ratio)
182
  c_pad_crop = int((ce - cb) * facecrop_dilation_ratio)
183
  crop_rect = [max(0, cb - c_pad_crop), max(0, rb - r_pad_crop), min(ce + c_pad_crop, face_img.shape[1]), min(re + r_pad_crop, face_img.shape[0])]
 
185
  face_mask = crop_and_pad(face_mask, crop_rect)
186
  face_img = cv2.resize(face_img, (width, height))
187
  face_mask = cv2.resize(face_mask, (width, height))
188
+
189
  print('face detect done.')
190
+ return face_img, face_mask
 
 
191
 
192
+ @spaces.GPU
193
+ def video_pipe(face_img, face_mask, uploaded_audio, width, height, length, context_frames, context_overlap, cfg, steps, sample_rate, fps, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  face_mask_tensor = torch.Tensor(face_mask).to(dtype=weight_dtype, device="cuda").unsqueeze(0).unsqueeze(0).unsqueeze(0) / 255.0
 
195
  ref_image_pil = Image.fromarray(face_img[:, :, [2, 1, 0]])
 
 
196
 
197
  video = pipe(
198
  ref_image_pil,
 
203
  length,
204
  steps,
205
  cfg,
 
206
  audio_sample_rate=sample_rate,
207
  context_frames=context_frames,
208
  fps=fps,
 
222
  video_clip.write_videofile(str(final_output_path), codec="libx264", audio_codec="aac")
223
 
224
  return final_output_path
225
+
226
+ def process_video(uploaded_img, uploaded_audio, width, height, length, facemask_dilation_ratio, facecrop_dilation_ratio, context_frames, context_overlap, cfg, steps, sample_rate, fps, device):
227
+ face_img, face_mask = face_detection(uploaded_img, facemask_dilation_ratio, facecrop_dilation_ratio, width, height)
228
+ final_output_path = video_pipe(face_img, face_mask, uploaded_audio, width, height, length, context_frames, context_overlap, cfg, steps, sample_rate, fps, device)
229
+ return final_output_path
230
+
231
+
232
+ # @spaces.GPU
233
+ # def process_video(uploaded_img, uploaded_audio, width, height, length, facemask_dilation_ratio, facecrop_dilation_ratio, context_frames, context_overlap, cfg, steps, sample_rate, fps, device):
234
+ # #### face musk prepare
235
+ # face_img = cv2.imread(uploaded_img)
236
+ # if face_img is None:
237
+ # raise gr.Error("input image should be uploaded or selected.")
238
+ # face_mask = np.zeros((face_img.shape[0], face_img.shape[1])).astype('uint8')
239
+ # det_bboxes, probs = face_detector.detect(face_img)
240
+ # select_bbox = select_face(det_bboxes, probs)
241
+ # if select_bbox is None:
242
+ # face_mask[:, :] = 255
243
+ # else:
244
+ # xyxy = select_bbox[:4]
245
+ # xyxy = np.round(xyxy).astype('int')
246
+ # rb, re, cb, ce = xyxy[1], xyxy[3], xyxy[0], xyxy[2]
247
+ # r_pad = int((re - rb) * facemask_dilation_ratio)
248
+ # c_pad = int((ce - cb) * facemask_dilation_ratio)
249
+ # face_mask[rb - r_pad : re + r_pad, cb - c_pad : ce + c_pad] = 255
250
+
251
+ # #### face crop
252
+ # r_pad_crop = int((re - rb) * facecrop_dilation_ratio)
253
+ # c_pad_crop = int((ce - cb) * facecrop_dilation_ratio)
254
+ # crop_rect = [max(0, cb - c_pad_crop), max(0, rb - r_pad_crop), min(ce + c_pad_crop, face_img.shape[1]), min(re + r_pad_crop, face_img.shape[0])]
255
+ # face_img = crop_and_pad(face_img, crop_rect)
256
+ # face_mask = crop_and_pad(face_mask, crop_rect)
257
+ # face_img = cv2.resize(face_img, (width, height))
258
+ # face_mask = cv2.resize(face_mask, (width, height))
259
+ # print('face detect done.')
260
+ # # ==================== face_locator =====================
261
+ # '''
262
+ # driver_video = "./assets/driven_videos/c.mp4"
263
+
264
+ # input_frames_cv2 = [cv2.resize(center_crop_cv2(pil_to_cv2(i)), (512, 512)) for i in pils_from_video(driver_video)]
265
+ # ref_det = lmk_extractor(face_img)
266
+
267
+ # visualizer = FaceMeshVisualizer(draw_iris=False, draw_mouse=False)
268
+
269
+ # pose_list = []
270
+ # sequence_driver_det = []
271
+ # try:
272
+ # for frame in input_frames_cv2:
273
+ # result = lmk_extractor(frame)
274
+ # assert result is not None, "{}, bad video, face not detected".format(driver_video)
275
+ # sequence_driver_det.append(result)
276
+ # except:
277
+ # print("face detection failed")
278
+ # exit()
279
+
280
+ # sequence_det_ms = motion_sync(sequence_driver_det, ref_det)
281
+ # for p in sequence_det_ms:
282
+ # tgt_musk = visualizer.draw_landmarks((width, height), p)
283
+ # tgt_musk_pil = Image.fromarray(np.array(tgt_musk).astype(np.uint8)).convert('RGB')
284
+ # pose_list.append(torch.Tensor(np.array(tgt_musk_pil)).to(dtype=weight_dtype, device="cuda").permute(2,0,1) / 255.0)
285
+ # '''
286
+ # # face_mask_tensor = torch.stack(pose_list, dim=1).unsqueeze(0)
287
+ # face_mask_tensor = torch.Tensor(face_mask).to(dtype=weight_dtype, device="cuda").unsqueeze(0).unsqueeze(0).unsqueeze(0) / 255.0
288
+
289
+ # ref_image_pil = Image.fromarray(face_img[:, :, [2, 1, 0]])
290
+
291
+ # #del pose_list, sequence_det_ms, sequence_driver_det, input_frames_cv2
292
+
293
+ # video = pipe(
294
+ # ref_image_pil,
295
+ # uploaded_audio,
296
+ # face_mask_tensor,
297
+ # width,
298
+ # height,
299
+ # length,
300
+ # steps,
301
+ # cfg,
302
+ # #generator=generator,
303
+ # audio_sample_rate=sample_rate,
304
+ # context_frames=context_frames,
305
+ # fps=fps,
306
+ # context_overlap=context_overlap
307
+ # ).videos
308
+ # print('video pipe done.')
309
+
310
+ # save_dir = Path("output/tmp")
311
+ # save_dir.mkdir(exist_ok=True, parents=True)
312
+ # output_video_path = save_dir / "output_video.mp4"
313
+ # save_videos_grid(video, str(output_video_path), n_rows=1, fps=fps)
314
+
315
+ # video_clip = VideoFileClip(str(output_video_path))
316
+ # audio_clip = AudioFileClip(uploaded_audio)
317
+ # final_output_path = save_dir / "output_video_with_audio.mp4"
318
+ # video_clip = video_clip.set_audio(audio_clip)
319
+ # video_clip.write_videofile(str(final_output_path), codec="libx264", audio_codec="aac")
320
+
321
+ # return final_output_path
322
 
323
  with gr.Blocks() as demo:
324
  gr.Markdown('# EchoMimic')