fffiloni commited on
Commit
6a8fc54
·
verified ·
1 Parent(s): d16d939

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +355 -277
app.py CHANGED
@@ -1,277 +1,355 @@
1
- import os
2
- import random
3
- from pathlib import Path
4
- import numpy as np
5
- import torch
6
- from diffusers import AutoencoderKL, DDIMScheduler
7
- from PIL import Image
8
- from src.models.unet_2d_condition import UNet2DConditionModel
9
- from src.models.unet_3d_emo import EMOUNet3DConditionModel
10
- from src.models.whisper.audio2feature import load_audio_model
11
- from src.pipelines.pipeline_echomimicv2 import EchoMimicV2Pipeline
12
- from src.utils.util import save_videos_grid
13
- from src.models.pose_encoder import PoseEncoder
14
- from src.utils.dwpose_util import draw_pose_select_v2
15
- from moviepy.editor import VideoFileClip, AudioFileClip
16
-
17
- import gradio as gr
18
- from datetime import datetime
19
- from torchao.quantization import quantize_, int8_weight_only
20
- import gc
21
-
22
- total_vram_in_gb = torch.cuda.get_device_properties(0).total_memory / 1073741824
23
- print(f'\033[32mCUDA版本:{torch.version.cuda}\033[0m')
24
- print(f'\033[32mPytorch版本:{torch.__version__}\033[0m')
25
- print(f'\033[32m显卡型号:{torch.cuda.get_device_name()}\033[0m')
26
- print(f'\033[32m显存大小:{total_vram_in_gb:.2f}GB\033[0m')
27
- print(f'\033[32m精度:float16\033[0m')
28
- dtype = torch.float16
29
- if torch.cuda.is_available():
30
- device = "cuda"
31
- else:
32
- print("cuda not available, using cpu")
33
- device = "cpu"
34
-
35
- ffmpeg_path = os.getenv('FFMPEG_PATH')
36
- if ffmpeg_path is None:
37
- print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=./ffmpeg-4.4-amd64-static")
38
- elif ffmpeg_path not in os.getenv('PATH'):
39
- print("add ffmpeg to path")
40
- os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
41
-
42
-
43
- def generate(image_input, audio_input, pose_input, width, height, length, steps, sample_rate, cfg, fps, context_frames, context_overlap, quantization_input, seed):
44
- gc.collect()
45
- torch.cuda.empty_cache()
46
- torch.cuda.ipc_collect()
47
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
48
- save_dir = Path("outputs")
49
- save_dir.mkdir(exist_ok=True, parents=True)
50
-
51
- ############# model_init started #############
52
- ## vae init
53
- vae = AutoencoderKL.from_pretrained("./pretrained_weights/sd-vae-ft-mse").to(device, dtype=dtype)
54
- if quantization_input:
55
- quantize_(vae, int8_weight_only())
56
- print("使用int8量化")
57
-
58
- ## reference net init
59
- reference_unet = UNet2DConditionModel.from_pretrained("./pretrained_weights/sd-image-variations-diffusers", subfolder="unet", use_safetensors=False).to(dtype=dtype, device=device)
60
- reference_unet.load_state_dict(torch.load("./pretrained_weights/reference_unet.pth", weights_only=True))
61
- if quantization_input:
62
- quantize_(reference_unet, int8_weight_only())
63
-
64
- ## denoising net init
65
- if os.path.exists("./pretrained_weights/motion_module.pth"):
66
- print('using motion module')
67
- else:
68
- exit("motion module not found")
69
- ### stage1 + stage2
70
- denoising_unet = EMOUNet3DConditionModel.from_pretrained_2d(
71
- "./pretrained_weights/sd-image-variations-diffusers",
72
- "./pretrained_weights/motion_module.pth",
73
- subfolder="unet",
74
- unet_additional_kwargs = {
75
- "use_inflated_groupnorm": True,
76
- "unet_use_cross_frame_attention": False,
77
- "unet_use_temporal_attention": False,
78
- "use_motion_module": True,
79
- "cross_attention_dim": 384,
80
- "motion_module_resolutions": [
81
- 1,
82
- 2,
83
- 4,
84
- 8
85
- ],
86
- "motion_module_mid_block": True ,
87
- "motion_module_decoder_only": False,
88
- "motion_module_type": "Vanilla",
89
- "motion_module_kwargs":{
90
- "num_attention_heads": 8,
91
- "num_transformer_block": 1,
92
- "attention_block_types": [
93
- 'Temporal_Self',
94
- 'Temporal_Self'
95
- ],
96
- "temporal_position_encoding": True,
97
- "temporal_position_encoding_max_len": 32,
98
- "temporal_attention_dim_div": 1,
99
- }
100
- },
101
- ).to(dtype=dtype, device=device)
102
- denoising_unet.load_state_dict(torch.load("./pretrained_weights/denoising_unet.pth", weights_only=True),strict=False)
103
-
104
- # pose net init
105
- pose_net = PoseEncoder(320, conditioning_channels=3, block_out_channels=(16, 32, 96, 256)).to(dtype=dtype, device=device)
106
- pose_net.load_state_dict(torch.load("./pretrained_weights/pose_encoder.pth", weights_only=True))
107
-
108
- ### load audio processor params
109
- audio_processor = load_audio_model(model_path="./pretrained_weights/audio_processor/tiny.pt", device=device)
110
-
111
- ############# model_init finished #############
112
- sched_kwargs = {
113
- "beta_start": 0.00085,
114
- "beta_end": 0.012,
115
- "beta_schedule": "linear",
116
- "clip_sample": False,
117
- "steps_offset": 1,
118
- "prediction_type": "v_prediction",
119
- "rescale_betas_zero_snr": True,
120
- "timestep_spacing": "trailing"
121
- }
122
- scheduler = DDIMScheduler(**sched_kwargs)
123
-
124
- pipe = EchoMimicV2Pipeline(
125
- vae=vae,
126
- reference_unet=reference_unet,
127
- denoising_unet=denoising_unet,
128
- audio_guider=audio_processor,
129
- pose_encoder=pose_net,
130
- scheduler=scheduler,
131
- )
132
-
133
- pipe = pipe.to(device, dtype=dtype)
134
-
135
- if seed is not None and seed > -1:
136
- generator = torch.manual_seed(seed)
137
- else:
138
- seed = random.randint(100, 1000000)
139
- generator = torch.manual_seed(seed)
140
-
141
- inputs_dict = {
142
- "refimg": image_input,
143
- "audio": audio_input,
144
- "pose": pose_input,
145
- }
146
-
147
- print('Pose:', inputs_dict['pose'])
148
- print('Reference:', inputs_dict['refimg'])
149
- print('Audio:', inputs_dict['audio'])
150
-
151
- save_name = f"{save_dir}/{timestamp}"
152
-
153
- ref_image_pil = Image.open(inputs_dict['refimg']).resize((width, height))
154
- audio_clip = AudioFileClip(inputs_dict['audio'])
155
-
156
- length = min(length, int(audio_clip.duration * fps), len(os.listdir(inputs_dict['pose'])))
157
-
158
- start_idx = 0
159
-
160
- pose_list = []
161
- for index in range(start_idx, start_idx + length):
162
- tgt_musk = np.zeros((width, height, 3)).astype('uint8')
163
- tgt_musk_path = os.path.join(inputs_dict['pose'], "{}.npy".format(index))
164
- detected_pose = np.load(tgt_musk_path, allow_pickle=True).tolist()
165
- imh_new, imw_new, rb, re, cb, ce = detected_pose['draw_pose_params']
166
- im = draw_pose_select_v2(detected_pose, imh_new, imw_new, ref_w=800)
167
- im = np.transpose(np.array(im),(1, 2, 0))
168
- tgt_musk[rb:re,cb:ce,:] = im
169
-
170
- tgt_musk_pil = Image.fromarray(np.array(tgt_musk)).convert('RGB')
171
- pose_list.append(torch.Tensor(np.array(tgt_musk_pil)).to(dtype=dtype, device=device).permute(2,0,1) / 255.0)
172
-
173
- poses_tensor = torch.stack(pose_list, dim=1).unsqueeze(0)
174
- audio_clip = AudioFileClip(inputs_dict['audio'])
175
-
176
- audio_clip = audio_clip.set_duration(length / fps)
177
- video = pipe(
178
- ref_image_pil,
179
- inputs_dict['audio'],
180
- poses_tensor[:,:,:length,...],
181
- width,
182
- height,
183
- length,
184
- steps,
185
- cfg,
186
- generator=generator,
187
- audio_sample_rate=sample_rate,
188
- context_frames=context_frames,
189
- fps=fps,
190
- context_overlap=context_overlap,
191
- start_idx=start_idx,
192
- ).videos
193
-
194
- final_length = min(video.shape[2], poses_tensor.shape[2], length)
195
- video_sig = video[:, :, :final_length, :, :]
196
-
197
- save_videos_grid(
198
- video_sig,
199
- save_name + "_woa_sig.mp4",
200
- n_rows=1,
201
- fps=fps,
202
- )
203
-
204
- video_clip_sig = VideoFileClip(save_name + "_woa_sig.mp4",)
205
- video_clip_sig = video_clip_sig.set_audio(audio_clip)
206
- video_clip_sig.write_videofile(save_name + "_sig.mp4", codec="libx264", audio_codec="aac", threads=2)
207
- video_output = save_name + "_sig.mp4"
208
- seed_text = gr.update(visible=True, value=seed)
209
- return video_output, seed_text
210
-
211
-
212
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
213
- gr.Markdown("""
214
- <div>
215
- <h2 style="font-size: 30px;text-align: center;">EchoMimicV2</h2>
216
- </div>
217
- <div style="text-align: center;">
218
- <a href="https://github.com/antgroup/echomimic_v2">🌐 Github</a> |
219
- <a href="https://arxiv.org/abs/2411.10061">📜 arXiv </a>
220
- </div>
221
- <div style="text-align: center; font-weight: bold; color: red;">
222
- ⚠️ 该演示仅供学术研究和体验使用。
223
- </div>
224
-
225
- """)
226
- with gr.Column():
227
- with gr.Row():
228
- with gr.Column():
229
- with gr.Group():
230
- image_input = gr.Image(label="图像输入(自动缩放)", type="filepath")
231
- audio_input = gr.Audio(label="音频输入", type="filepath")
232
- pose_input = gr.Textbox(label="姿态输入(目录地址)", placeholder="请输入姿态数据的目录地址", value="assets/halfbody_demo/pose/01")
233
- with gr.Group():
234
- with gr.Row():
235
- width = gr.Number(label="宽度(16的倍数,推荐768)", value=768)
236
- height = gr.Number(label="高度(16的倍数,推荐768)", value=768)
237
- length = gr.Number(label="视频长度,推荐240)", value=240)
238
- with gr.Row():
239
- steps = gr.Number(label="步骤(推荐30)", value=20)
240
- sample_rate = gr.Number(label="采样率(推荐16000)", value=16000)
241
- cfg = gr.Number(label="cfg(推荐2.5)", value=2.5, step=0.1)
242
- with gr.Row():
243
- fps = gr.Number(label="帧率(推荐24)", value=24)
244
- context_frames = gr.Number(label="上下文框架(推荐12)", value=12)
245
- context_overlap = gr.Number(label="上下文重叠(推荐3)", value=3)
246
- with gr.Row():
247
- quantization_input = gr.Checkbox(label="int8量化(推荐显存12G的用户开启,并使用不超过5秒的音频)", value=False)
248
- seed = gr.Number(label="种子(-1为随机)", value=-1)
249
- generate_button = gr.Button("🎬 生成视频")
250
- with gr.Column():
251
- video_output = gr.Video(label="输出视频")
252
- seed_text = gr.Textbox(label="种子", interactive=False, visible=False)
253
- gr.Examples(
254
- examples=[
255
- ["EMTD_dataset/ref_imgs_by_FLUX/man/0001.png", "assets/halfbody_demo/audio/chinese/echomimicv2_man.wav"],
256
- ["EMTD_dataset/ref_imgs_by_FLUX/woman/0077.png", "assets/halfbody_demo/audio/chinese/echomimicv2_woman.wav"],
257
- ["EMTD_dataset/ref_imgs_by_FLUX/man/0003.png", "assets/halfbody_demo/audio/chinese/fighting.wav"],
258
- ["EMTD_dataset/ref_imgs_by_FLUX/woman/0033.png", "assets/halfbody_demo/audio/chinese/good.wav"],
259
- ["EMTD_dataset/ref_imgs_by_FLUX/man/0010.png", "assets/halfbody_demo/audio/chinese/news.wav"],
260
- ["EMTD_dataset/ref_imgs_by_FLUX/man/1168.png", "assets/halfbody_demo/audio/chinese/no_smoking.wav"],
261
- ["EMTD_dataset/ref_imgs_by_FLUX/woman/0057.png", "assets/halfbody_demo/audio/chinese/ultraman.wav"]
262
- ],
263
- inputs=[image_input, audio_input],
264
- label="预设人物及音频",
265
- )
266
-
267
- generate_button.click(
268
- generate,
269
- inputs=[image_input, audio_input, pose_input, width, height, length, steps, sample_rate, cfg, fps, context_frames, context_overlap, quantization_input, seed],
270
- outputs=[video_output, seed_text],
271
- )
272
-
273
-
274
-
275
- if __name__ == "__main__":
276
- demo.queue()
277
- demo.launch(inbrowser=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from pathlib import Path
4
+ import numpy as np
5
+ import torch
6
+ from diffusers import AutoencoderKL, DDIMScheduler
7
+ from PIL import Image
8
+ from src.models.unet_2d_condition import UNet2DConditionModel
9
+ from src.models.unet_3d_emo import EMOUNet3DConditionModel
10
+ from src.models.whisper.audio2feature import load_audio_model
11
+ from src.pipelines.pipeline_echomimicv2 import EchoMimicV2Pipeline
12
+ from src.utils.util import save_videos_grid
13
+ from src.models.pose_encoder import PoseEncoder
14
+ from src.utils.dwpose_util import draw_pose_select_v2
15
+ from moviepy.editor import VideoFileClip, AudioFileClip
16
+
17
+ import gradio as gr
18
+ from datetime import datetime
19
+ from torchao.quantization import quantize_, int8_weight_only
20
+ import gc
21
+
22
+ import requests
23
+ import tarfile
24
+
25
+ def download_and_setup_ffmpeg():
26
+ url = "https://www.johnvansickle.com/ffmpeg/old-releases/ffmpeg-4.4-amd64-static.tar.xz"
27
+ download_path = "ffmpeg-4.4-amd64-static.tar.xz"
28
+ extract_dir = "ffmpeg-4.4-amd64-static"
29
+
30
+ try:
31
+ # Download the file
32
+ response = requests.get(url, stream=True)
33
+ response.raise_for_status() # Check for HTTP request errors
34
+ with open(download_path, "wb") as file:
35
+ for chunk in response.iter_content(chunk_size=8192):
36
+ file.write(chunk)
37
+
38
+ # Extract the tar.xz file
39
+ with tarfile.open(download_path, "r:xz") as tar:
40
+ tar.extractall(path=extract_dir)
41
+
42
+ # Set the FFMPEG_PATH environment variable
43
+ ffmpeg_binary_path = os.path.join(extract_dir, "ffmpeg-4.4-amd64-static", "ffmpeg")
44
+ os.environ["FFMPEG_PATH"] = ffmpeg_binary_path
45
+
46
+ return f"FFmpeg downloaded and setup successfully! Path: {ffmpeg_binary_path}"
47
+ except Exception as e:
48
+ return f"An error occurred: {str(e)}"
49
+
50
+ download_and_setup_ffmpeg()
51
+
52
+ from huggingface_hub import snapshot_download
53
+
54
+ # Create the main "pretrained_weights" folder
55
+ os.makedirs("pretrained_weights", exist_ok=True)
56
+
57
+ # List of subdirectories to create inside "pretrained_weights"
58
+ subfolders = [
59
+ "sd-vae-ft-mse",
60
+ "sd-image-variations-diffusers",
61
+ "audio_processor"
62
+ ]
63
+
64
+ # Create each subdirectory
65
+ for subfolder in subfolders:
66
+ os.makedirs(os.path.join("pretrained_weights", subfolder), exist_ok=True)
67
+
68
+ snapshot_download(
69
+ repo_id = "BadToBest/EchoMimicV2",
70
+ local_dir="./pretrained_weights"
71
+ )
72
+ snapshot_download(
73
+ repo_id = "stabilityai/sd-vae-ft-mse",
74
+ local_dir="./pretrained_weights/sd-vae-ft-mse"
75
+ )
76
+ snapshot_download(
77
+ repo_id = "lambdalabs/sd-image-variations-diffusers",
78
+ local_dir="./pretrained_weights/sd-image-variations-diffusers"
79
+ )
80
+
81
+ # Download and place the Whisper model in the "audio_processor" folder
82
+ def download_whisper_model():
83
+ url = "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt"
84
+ save_path = os.path.join("pretrained_weights", "audio_processor", "tiny.pt")
85
+
86
+ try:
87
+ # Download the file
88
+ response = requests.get(url, stream=True)
89
+ response.raise_for_status() # Check for HTTP request errors
90
+ with open(save_path, "wb") as file:
91
+ for chunk in response.iter_content(chunk_size=8192):
92
+ file.write(chunk)
93
+ print(f"Whisper model downloaded and saved to {save_path}")
94
+ except Exception as e:
95
+ print(f"An error occurred while downloading the model: {str(e)}")
96
+
97
+ # Download the Whisper model
98
+ download_whisper_model()
99
+
100
+ total_vram_in_gb = torch.cuda.get_device_properties(0).total_memory / 1073741824
101
+ print(f'\033[32mCUDA版本:{torch.version.cuda}\033[0m')
102
+ print(f'\033[32mPytorch版本:{torch.__version__}\033[0m')
103
+ print(f'\033[32m显卡型号:{torch.cuda.get_device_name()}\033[0m')
104
+ print(f'\033[32m显存大小:{total_vram_in_gb:.2f}GB\033[0m')
105
+ print(f'\033[32m精度:float16\033[0m')
106
+ dtype = torch.float16
107
+ if torch.cuda.is_available():
108
+ device = "cuda"
109
+ else:
110
+ print("cuda not available, using cpu")
111
+ device = "cpu"
112
+
113
+ ffmpeg_path = os.getenv('FFMPEG_PATH')
114
+ if ffmpeg_path is None:
115
+ print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=./ffmpeg-4.4-amd64-static")
116
+ elif ffmpeg_path not in os.getenv('PATH'):
117
+ print("add ffmpeg to path")
118
+ os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
119
+
120
+
121
+ def generate(image_input, audio_input, pose_input, width, height, length, steps, sample_rate, cfg, fps, context_frames, context_overlap, quantization_input, seed):
122
+ gc.collect()
123
+ torch.cuda.empty_cache()
124
+ torch.cuda.ipc_collect()
125
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
126
+ save_dir = Path("outputs")
127
+ save_dir.mkdir(exist_ok=True, parents=True)
128
+
129
+ ############# model_init started #############
130
+ ## vae init
131
+ vae = AutoencoderKL.from_pretrained("./pretrained_weights/sd-vae-ft-mse").to(device, dtype=dtype)
132
+ if quantization_input:
133
+ quantize_(vae, int8_weight_only())
134
+ print("使用int8量化")
135
+
136
+ ## reference net init
137
+ reference_unet = UNet2DConditionModel.from_pretrained("./pretrained_weights/sd-image-variations-diffusers", subfolder="unet", use_safetensors=False).to(dtype=dtype, device=device)
138
+ reference_unet.load_state_dict(torch.load("./pretrained_weights/reference_unet.pth", weights_only=True))
139
+ if quantization_input:
140
+ quantize_(reference_unet, int8_weight_only())
141
+
142
+ ## denoising net init
143
+ if os.path.exists("./pretrained_weights/motion_module.pth"):
144
+ print('using motion module')
145
+ else:
146
+ exit("motion module not found")
147
+ ### stage1 + stage2
148
+ denoising_unet = EMOUNet3DConditionModel.from_pretrained_2d(
149
+ "./pretrained_weights/sd-image-variations-diffusers",
150
+ "./pretrained_weights/motion_module.pth",
151
+ subfolder="unet",
152
+ unet_additional_kwargs = {
153
+ "use_inflated_groupnorm": True,
154
+ "unet_use_cross_frame_attention": False,
155
+ "unet_use_temporal_attention": False,
156
+ "use_motion_module": True,
157
+ "cross_attention_dim": 384,
158
+ "motion_module_resolutions": [
159
+ 1,
160
+ 2,
161
+ 4,
162
+ 8
163
+ ],
164
+ "motion_module_mid_block": True ,
165
+ "motion_module_decoder_only": False,
166
+ "motion_module_type": "Vanilla",
167
+ "motion_module_kwargs":{
168
+ "num_attention_heads": 8,
169
+ "num_transformer_block": 1,
170
+ "attention_block_types": [
171
+ 'Temporal_Self',
172
+ 'Temporal_Self'
173
+ ],
174
+ "temporal_position_encoding": True,
175
+ "temporal_position_encoding_max_len": 32,
176
+ "temporal_attention_dim_div": 1,
177
+ }
178
+ },
179
+ ).to(dtype=dtype, device=device)
180
+ denoising_unet.load_state_dict(torch.load("./pretrained_weights/denoising_unet.pth", weights_only=True),strict=False)
181
+
182
+ # pose net init
183
+ pose_net = PoseEncoder(320, conditioning_channels=3, block_out_channels=(16, 32, 96, 256)).to(dtype=dtype, device=device)
184
+ pose_net.load_state_dict(torch.load("./pretrained_weights/pose_encoder.pth", weights_only=True))
185
+
186
+ ### load audio processor params
187
+ audio_processor = load_audio_model(model_path="./pretrained_weights/audio_processor/tiny.pt", device=device)
188
+
189
+ ############# model_init finished #############
190
+ sched_kwargs = {
191
+ "beta_start": 0.00085,
192
+ "beta_end": 0.012,
193
+ "beta_schedule": "linear",
194
+ "clip_sample": False,
195
+ "steps_offset": 1,
196
+ "prediction_type": "v_prediction",
197
+ "rescale_betas_zero_snr": True,
198
+ "timestep_spacing": "trailing"
199
+ }
200
+ scheduler = DDIMScheduler(**sched_kwargs)
201
+
202
+ pipe = EchoMimicV2Pipeline(
203
+ vae=vae,
204
+ reference_unet=reference_unet,
205
+ denoising_unet=denoising_unet,
206
+ audio_guider=audio_processor,
207
+ pose_encoder=pose_net,
208
+ scheduler=scheduler,
209
+ )
210
+
211
+ pipe = pipe.to(device, dtype=dtype)
212
+
213
+ if seed is not None and seed > -1:
214
+ generator = torch.manual_seed(seed)
215
+ else:
216
+ seed = random.randint(100, 1000000)
217
+ generator = torch.manual_seed(seed)
218
+
219
+ inputs_dict = {
220
+ "refimg": image_input,
221
+ "audio": audio_input,
222
+ "pose": pose_input,
223
+ }
224
+
225
+ print('Pose:', inputs_dict['pose'])
226
+ print('Reference:', inputs_dict['refimg'])
227
+ print('Audio:', inputs_dict['audio'])
228
+
229
+ save_name = f"{save_dir}/{timestamp}"
230
+
231
+ ref_image_pil = Image.open(inputs_dict['refimg']).resize((width, height))
232
+ audio_clip = AudioFileClip(inputs_dict['audio'])
233
+
234
+ length = min(length, int(audio_clip.duration * fps), len(os.listdir(inputs_dict['pose'])))
235
+
236
+ start_idx = 0
237
+
238
+ pose_list = []
239
+ for index in range(start_idx, start_idx + length):
240
+ tgt_musk = np.zeros((width, height, 3)).astype('uint8')
241
+ tgt_musk_path = os.path.join(inputs_dict['pose'], "{}.npy".format(index))
242
+ detected_pose = np.load(tgt_musk_path, allow_pickle=True).tolist()
243
+ imh_new, imw_new, rb, re, cb, ce = detected_pose['draw_pose_params']
244
+ im = draw_pose_select_v2(detected_pose, imh_new, imw_new, ref_w=800)
245
+ im = np.transpose(np.array(im),(1, 2, 0))
246
+ tgt_musk[rb:re,cb:ce,:] = im
247
+
248
+ tgt_musk_pil = Image.fromarray(np.array(tgt_musk)).convert('RGB')
249
+ pose_list.append(torch.Tensor(np.array(tgt_musk_pil)).to(dtype=dtype, device=device).permute(2,0,1) / 255.0)
250
+
251
+ poses_tensor = torch.stack(pose_list, dim=1).unsqueeze(0)
252
+ audio_clip = AudioFileClip(inputs_dict['audio'])
253
+
254
+ audio_clip = audio_clip.set_duration(length / fps)
255
+ video = pipe(
256
+ ref_image_pil,
257
+ inputs_dict['audio'],
258
+ poses_tensor[:,:,:length,...],
259
+ width,
260
+ height,
261
+ length,
262
+ steps,
263
+ cfg,
264
+ generator=generator,
265
+ audio_sample_rate=sample_rate,
266
+ context_frames=context_frames,
267
+ fps=fps,
268
+ context_overlap=context_overlap,
269
+ start_idx=start_idx,
270
+ ).videos
271
+
272
+ final_length = min(video.shape[2], poses_tensor.shape[2], length)
273
+ video_sig = video[:, :, :final_length, :, :]
274
+
275
+ save_videos_grid(
276
+ video_sig,
277
+ save_name + "_woa_sig.mp4",
278
+ n_rows=1,
279
+ fps=fps,
280
+ )
281
+
282
+ video_clip_sig = VideoFileClip(save_name + "_woa_sig.mp4",)
283
+ video_clip_sig = video_clip_sig.set_audio(audio_clip)
284
+ video_clip_sig.write_videofile(save_name + "_sig.mp4", codec="libx264", audio_codec="aac", threads=2)
285
+ video_output = save_name + "_sig.mp4"
286
+ seed_text = gr.update(visible=True, value=seed)
287
+ return video_output, seed_text
288
+
289
+
290
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
291
+ gr.Markdown("""
292
+ <div>
293
+ <h2 style="font-size: 30px;text-align: center;">EchoMimicV2</h2>
294
+ </div>
295
+ <div style="text-align: center;">
296
+ <a href="https://github.com/antgroup/echomimic_v2">🌐 Github</a> |
297
+ <a href="https://arxiv.org/abs/2411.10061">📜 arXiv </a>
298
+ </div>
299
+ <div style="text-align: center; font-weight: bold; color: red;">
300
+ ⚠️ 该演示仅供学术研究和体验使用。
301
+ </div>
302
+
303
+ """)
304
+ with gr.Column():
305
+ with gr.Row():
306
+ with gr.Column():
307
+ with gr.Group():
308
+ image_input = gr.Image(label="图像输入(自动缩放)", type="filepath")
309
+ audio_input = gr.Audio(label="音频输入", type="filepath")
310
+ pose_input = gr.Textbox(label="姿态输入(目录地址)", placeholder="请输入姿态数据的目录地址", value="assets/halfbody_demo/pose/01")
311
+ with gr.Group():
312
+ with gr.Row():
313
+ width = gr.Number(label="宽度(16的倍数,推荐768)", value=768)
314
+ height = gr.Number(label="高度(16的倍数,推荐768)", value=768)
315
+ length = gr.Number(label="视频长度,推荐240)", value=240)
316
+ with gr.Row():
317
+ steps = gr.Number(label="步骤(推荐30)", value=20)
318
+ sample_rate = gr.Number(label="采样率(推荐16000)", value=16000)
319
+ cfg = gr.Number(label="cfg(推荐2.5)", value=2.5, step=0.1)
320
+ with gr.Row():
321
+ fps = gr.Number(label="帧率(推荐24)", value=24)
322
+ context_frames = gr.Number(label="上下文框架(推荐12)", value=12)
323
+ context_overlap = gr.Number(label="上下文重叠(推荐3)", value=3)
324
+ with gr.Row():
325
+ quantization_input = gr.Checkbox(label="int8量化(推荐显存12G的用户开启,并使用不超过5秒的音频)", value=False)
326
+ seed = gr.Number(label="种子(-1为随机)", value=-1)
327
+ generate_button = gr.Button("🎬 生成视频")
328
+ with gr.Column():
329
+ video_output = gr.Video(label="输出视频")
330
+ seed_text = gr.Textbox(label="种子", interactive=False, visible=False)
331
+ gr.Examples(
332
+ examples=[
333
+ ["EMTD_dataset/ref_imgs_by_FLUX/man/0001.png", "assets/halfbody_demo/audio/chinese/echomimicv2_man.wav"],
334
+ ["EMTD_dataset/ref_imgs_by_FLUX/woman/0077.png", "assets/halfbody_demo/audio/chinese/echomimicv2_woman.wav"],
335
+ ["EMTD_dataset/ref_imgs_by_FLUX/man/0003.png", "assets/halfbody_demo/audio/chinese/fighting.wav"],
336
+ ["EMTD_dataset/ref_imgs_by_FLUX/woman/0033.png", "assets/halfbody_demo/audio/chinese/good.wav"],
337
+ ["EMTD_dataset/ref_imgs_by_FLUX/man/0010.png", "assets/halfbody_demo/audio/chinese/news.wav"],
338
+ ["EMTD_dataset/ref_imgs_by_FLUX/man/1168.png", "assets/halfbody_demo/audio/chinese/no_smoking.wav"],
339
+ ["EMTD_dataset/ref_imgs_by_FLUX/woman/0057.png", "assets/halfbody_demo/audio/chinese/ultraman.wav"]
340
+ ],
341
+ inputs=[image_input, audio_input],
342
+ label="预设人物及音频",
343
+ )
344
+
345
+ generate_button.click(
346
+ generate,
347
+ inputs=[image_input, audio_input, pose_input, width, height, length, steps, sample_rate, cfg, fps, context_frames, context_overlap, quantization_input, seed],
348
+ outputs=[video_output, seed_text],
349
+ )
350
+
351
+
352
+
353
+ if __name__ == "__main__":
354
+ demo.queue()
355
+ demo.launch(inbrowser=True)