yerang commited on
Commit
cb12676
1 Parent(s): 36cb39e

Upload stf_utils.py

Browse files
Files changed (1) hide show
  1. stf_utils.py +125 -0
stf_utils.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from pydub import AudioSegment
5
+ import cv2
6
+ from pathlib import Path
7
+ import subprocess
8
+ from pathlib import Path
9
+ import av
10
+ import imageio
11
+ import numpy as np
12
+ from rich.progress import track
13
+ from tqdm import tqdm
14
+
15
+ import stf_alternative
16
+
17
+
18
+
19
+ def exec_cmd(cmd):
20
+ subprocess.run(
21
+ cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
22
+ )
23
+
24
+
25
+ def images2video(images, wfp, **kwargs):
26
+ fps = kwargs.get("fps", 24)
27
+ video_format = kwargs.get("format", "mp4") # default is mp4 format
28
+ codec = kwargs.get("codec", "libx264") # default is libx264 encoding
29
+ quality = kwargs.get("quality") # video quality
30
+ pixelformat = kwargs.get("pixelformat", "yuv420p") # video pixel format
31
+ image_mode = kwargs.get("image_mode", "rgb")
32
+ macro_block_size = kwargs.get("macro_block_size", 2)
33
+ ffmpeg_params = ["-crf", str(kwargs.get("crf", 18))]
34
+
35
+ writer = imageio.get_writer(
36
+ wfp,
37
+ fps=fps,
38
+ format=video_format,
39
+ codec=codec,
40
+ quality=quality,
41
+ ffmpeg_params=ffmpeg_params,
42
+ pixelformat=pixelformat,
43
+ macro_block_size=macro_block_size,
44
+ )
45
+
46
+ n = len(images)
47
+ for i in track(range(n), description="writing", transient=True):
48
+ if image_mode.lower() == "bgr":
49
+ writer.append_data(images[i][..., ::-1])
50
+ else:
51
+ writer.append_data(images[i])
52
+
53
+ writer.close()
54
+
55
+ # print(f':smiley: Dump to {wfp}\n', style="bold green")
56
+ print(f"Dump to {wfp}\n")
57
+
58
+
59
+ def merge_audio_video(video_fp, audio_fp, wfp):
60
+ if osp.exists(video_fp) and osp.exists(audio_fp):
61
+ cmd = f"ffmpeg -i {video_fp} -i {audio_fp} -c:v copy -c:a aac {wfp} -y"
62
+ exec_cmd(cmd)
63
+ print(f"merge {video_fp} and {audio_fp} to {wfp}")
64
+ else:
65
+ print(f"video_fp: {video_fp} or audio_fp: {audio_fp} not exists!")
66
+
67
+
68
+
69
+
70
+ class STFPipeline:
71
+ def __init__(self,
72
+ stf_path: str = "../stf/",
73
+ device: str = "cuda:0",
74
+ template_video_path: str = "templates/front_one_piece_dress_nodded_cut.webm",
75
+ config_path: str = "front_config.json",
76
+ checkpoint_path: str = "089.pth",
77
+ root_path: str = "works"
78
+
79
+ ):
80
+
81
+ config_path = os.path.join(stf_path, config_path)
82
+ checkpoint_path = os.path.join(stf_path, checkpoint_path)
83
+ work_root_path = os.path.join(stf_path, root_path)
84
+
85
+ model = stf_alternative.create_model(
86
+ config_path=config_path,
87
+ checkpoint_path=checkpoint_path,
88
+ work_root_path=work_root_path,
89
+ device=device,
90
+ wavlm_path="microsoft/wavlm-large",
91
+ )
92
+ self.template = stf_alternative.Template(
93
+ model=model,
94
+ config_path=config_path,
95
+ template_video_path=template_video_path,
96
+ )
97
+
98
+
99
+ def execute(self, audio: str):
100
+ Path("dubbing").mkdir(exist_ok=True)
101
+ save_path = os.path.join("dubbing", Path(audio).stem+"--lip.mp4")
102
+ reader = iter(self.template._get_reader(num_skip_frames=0))
103
+ audio_segment = AudioSegment.from_file(audio)
104
+ pivot = 0
105
+ results = []
106
+ with ThreadPoolExecutor(4) as p:
107
+ try:
108
+
109
+ gen_infer = self.template.gen_infer_concurrent(
110
+ p,
111
+ audio_segment,
112
+ pivot,
113
+ )
114
+ for idx, (it, chunk) in enumerate(gen_infer, pivot):
115
+ frame = next(reader)
116
+ composed = self.template.compose(idx, frame, it)
117
+ frame_name = f"{idx}".zfill(5)+".jpg"
118
+ results.append(it['pred'])
119
+ pivot = idx + 1
120
+ except StopIteration as e:
121
+ pass
122
+
123
+ images2video(results, save_path)
124
+
125
+ return save_path