Spaces:
Runtime error
Runtime error
Rex Cheng
commited on
Commit
•
b0ec3f5
1
Parent(s):
164c335
test
Browse files- app.py +5 -6
- demo.py +9 -9
- mmaudio/eval_utils.py +20 -58
app.py
CHANGED
@@ -67,7 +67,10 @@ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int
|
|
67 |
rng.manual_seed(seed)
|
68 |
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
69 |
|
70 |
-
|
|
|
|
|
|
|
71 |
clip_frames = clip_frames.unsqueeze(0)
|
72 |
sync_frames = sync_frames.unsqueeze(0)
|
73 |
seq_cfg.duration = duration
|
@@ -87,11 +90,7 @@ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int
|
|
87 |
video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
|
88 |
# output_dir.mkdir(exist_ok=True, parents=True)
|
89 |
# video_save_path = output_dir / f'{current_time_string}.mp4'
|
90 |
-
make_video(
|
91 |
-
video_save_path,
|
92 |
-
audio,
|
93 |
-
sampling_rate=seq_cfg.sampling_rate,
|
94 |
-
duration_sec=seq_cfg.duration)
|
95 |
log.info(f'Saved video to {video_save_path}')
|
96 |
return video_save_path
|
97 |
|
|
|
67 |
rng.manual_seed(seed)
|
68 |
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
69 |
|
70 |
+
video_info = load_video(video, duration)
|
71 |
+
clip_frames = video_info.clip_frames
|
72 |
+
sync_frames = video_info.sync_frames
|
73 |
+
duration = video_info.duration_sec
|
74 |
clip_frames = clip_frames.unsqueeze(0)
|
75 |
sync_frames = sync_frames.unsqueeze(0)
|
76 |
seq_cfg.duration = duration
|
|
|
90 |
video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
|
91 |
# output_dir.mkdir(exist_ok=True, parents=True)
|
92 |
# video_save_path = output_dir / f'{current_time_string}.mp4'
|
93 |
+
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
|
|
|
|
|
|
|
|
94 |
log.info(f'Saved video to {video_save_path}')
|
95 |
return video_save_path
|
96 |
|
demo.py
CHANGED
@@ -5,8 +5,8 @@ from pathlib import Path
|
|
5 |
import torch
|
6 |
import torchaudio
|
7 |
|
8 |
-
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate,
|
9 |
-
|
10 |
from mmaudio.model.flow_matching import FlowMatching
|
11 |
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
12 |
from mmaudio.model.utils.features_utils import FeaturesUtils
|
@@ -81,12 +81,16 @@ def main():
|
|
81 |
synchformer_ckpt=model.synchformer_ckpt,
|
82 |
enable_conditions=True,
|
83 |
mode=model.mode,
|
84 |
-
bigvgan_vocoder_ckpt=model.bigvgan_16k_path
|
|
|
85 |
feature_utils = feature_utils.to(device, dtype).eval()
|
86 |
|
87 |
if video_path is not None:
|
88 |
log.info(f'Using video {video_path}')
|
89 |
-
|
|
|
|
|
|
|
90 |
if mask_away_clip:
|
91 |
clip_frames = None
|
92 |
else:
|
@@ -121,11 +125,7 @@ def main():
|
|
121 |
log.info(f'Audio saved to {save_path}')
|
122 |
if video_path is not None and not skip_video_composite:
|
123 |
video_save_path = output_dir / f'{video_path.stem}.mp4'
|
124 |
-
make_video(
|
125 |
-
video_save_path,
|
126 |
-
audio,
|
127 |
-
sampling_rate=seq_cfg.sampling_rate,
|
128 |
-
duration_sec=seq_cfg.duration)
|
129 |
log.info(f'Video saved to {output_dir / video_save_path}')
|
130 |
|
131 |
log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
|
|
|
5 |
import torch
|
6 |
import torchaudio
|
7 |
|
8 |
+
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
|
9 |
+
setup_eval_logging)
|
10 |
from mmaudio.model.flow_matching import FlowMatching
|
11 |
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
12 |
from mmaudio.model.utils.features_utils import FeaturesUtils
|
|
|
81 |
synchformer_ckpt=model.synchformer_ckpt,
|
82 |
enable_conditions=True,
|
83 |
mode=model.mode,
|
84 |
+
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
|
85 |
+
need_vae_encoder=False)
|
86 |
feature_utils = feature_utils.to(device, dtype).eval()
|
87 |
|
88 |
if video_path is not None:
|
89 |
log.info(f'Using video {video_path}')
|
90 |
+
video_info = load_video(video_path, duration)
|
91 |
+
clip_frames = video_info.clip_frames
|
92 |
+
sync_frames = video_info.sync_frames
|
93 |
+
duration = video_info.duration_sec
|
94 |
if mask_away_clip:
|
95 |
clip_frames = None
|
96 |
else:
|
|
|
125 |
log.info(f'Audio saved to {save_path}')
|
126 |
if video_path is not None and not skip_video_composite:
|
127 |
video_save_path = output_dir / f'{video_path.stem}.mp4'
|
128 |
+
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
|
|
|
|
|
|
|
|
129 |
log.info(f'Video saved to {output_dir / video_save_path}')
|
130 |
|
131 |
log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
|
mmaudio/eval_utils.py
CHANGED
@@ -3,12 +3,11 @@ import logging
|
|
3 |
from pathlib import Path
|
4 |
from typing import Optional
|
5 |
|
6 |
-
import av
|
7 |
import torch
|
8 |
from colorlog import ColoredFormatter
|
9 |
from torchvision.transforms import v2
|
10 |
-
from torio.io import StreamingMediaDecoder, StreamingMediaEncoder
|
11 |
|
|
|
12 |
from mmaudio.model.flow_matching import FlowMatching
|
13 |
from mmaudio.model.networks import MMAudio
|
14 |
from mmaudio.model.sequence_config import (CONFIG_16K, CONFIG_44K, SequenceConfig)
|
@@ -154,7 +153,7 @@ def setup_eval_logging(log_level: int = logging.INFO):
|
|
154 |
log.addHandler(stream)
|
155 |
|
156 |
|
157 |
-
def load_video(video_path: Path, duration_sec: float) ->
|
158 |
_CLIP_SIZE = 384
|
159 |
_CLIP_FPS = 8.0
|
160 |
|
@@ -175,26 +174,15 @@ def load_video(video_path: Path, duration_sec: float) -> tuple[torch.Tensor, tor
|
|
175 |
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
176 |
])
|
177 |
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
format='rgb24',
|
184 |
-
)
|
185 |
-
reader.add_basic_video_stream(
|
186 |
-
frames_per_chunk=int(_SYNC_FPS * duration_sec),
|
187 |
-
buffer_chunk_size=-1,
|
188 |
-
frame_rate=_SYNC_FPS,
|
189 |
-
format='rgb24',
|
190 |
-
)
|
191 |
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
sync_chunk = data_chunk[1]
|
196 |
-
assert clip_chunk is not None
|
197 |
-
assert sync_chunk is not None
|
198 |
|
199 |
clip_frames = clip_transform(clip_chunk)
|
200 |
sync_frames = sync_transform(sync_chunk)
|
@@ -215,41 +203,15 @@ def load_video(video_path: Path, duration_sec: float) -> tuple[torch.Tensor, tor
|
|
215 |
clip_frames = clip_frames[:int(_CLIP_FPS * duration_sec)]
|
216 |
sync_frames = sync_frames[:int(_SYNC_FPS * duration_sec)]
|
217 |
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
|
|
|
|
223 |
|
224 |
-
av_video = av.open(video_path)
|
225 |
-
frame_rate = av_video.streams.video[0].guessed_rate
|
226 |
|
227 |
-
|
228 |
-
|
229 |
-
reader.add_basic_video_stream(
|
230 |
-
frames_per_chunk=approx_max_length,
|
231 |
-
buffer_chunk_size=-1,
|
232 |
-
format='rgb24',
|
233 |
-
)
|
234 |
-
reader.fill_buffer()
|
235 |
-
video_chunk = reader.pop_chunks()[0]
|
236 |
-
assert video_chunk is not None
|
237 |
-
|
238 |
-
h, w = video_chunk.shape[-2:]
|
239 |
-
video_chunk = video_chunk[:int(frame_rate * duration_sec)]
|
240 |
-
|
241 |
-
writer = StreamingMediaEncoder(output_path)
|
242 |
-
writer.add_audio_stream(
|
243 |
-
sample_rate=sampling_rate,
|
244 |
-
num_channels=audio.shape[0],
|
245 |
-
encoder='aac', # 'flac' does not work for some reason?
|
246 |
-
)
|
247 |
-
writer.add_video_stream(frame_rate=frame_rate,
|
248 |
-
width=w,
|
249 |
-
height=h,
|
250 |
-
format='rgb24',
|
251 |
-
encoder='libx264',
|
252 |
-
encoder_format='yuv420p')
|
253 |
-
with writer.open():
|
254 |
-
writer.write_audio_chunk(0, audio.float().transpose(0, 1))
|
255 |
-
writer.write_video_chunk(1, video_chunk)
|
|
|
3 |
from pathlib import Path
|
4 |
from typing import Optional
|
5 |
|
|
|
6 |
import torch
|
7 |
from colorlog import ColoredFormatter
|
8 |
from torchvision.transforms import v2
|
|
|
9 |
|
10 |
+
from mmaudio.data.av_utils import VideoInfo, read_frames, reencode_with_audio
|
11 |
from mmaudio.model.flow_matching import FlowMatching
|
12 |
from mmaudio.model.networks import MMAudio
|
13 |
from mmaudio.model.sequence_config import (CONFIG_16K, CONFIG_44K, SequenceConfig)
|
|
|
153 |
log.addHandler(stream)
|
154 |
|
155 |
|
156 |
+
def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
|
157 |
_CLIP_SIZE = 384
|
158 |
_CLIP_FPS = 8.0
|
159 |
|
|
|
174 |
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
175 |
])
|
176 |
|
177 |
+
output_frames, all_frames, orig_fps = read_frames(video_path,
|
178 |
+
list_of_fps=[_CLIP_FPS, _SYNC_FPS],
|
179 |
+
start_sec=0,
|
180 |
+
end_sec=duration_sec,
|
181 |
+
need_all_frames=load_all_frames)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
+
clip_chunk, sync_chunk = output_frames
|
184 |
+
clip_chunk = torch.from_numpy(clip_chunk).permute(0, 3, 1, 2)
|
185 |
+
sync_chunk = torch.from_numpy(sync_chunk).permute(0, 3, 1, 2)
|
|
|
|
|
|
|
186 |
|
187 |
clip_frames = clip_transform(clip_chunk)
|
188 |
sync_frames = sync_transform(sync_chunk)
|
|
|
203 |
clip_frames = clip_frames[:int(_CLIP_FPS * duration_sec)]
|
204 |
sync_frames = sync_frames[:int(_SYNC_FPS * duration_sec)]
|
205 |
|
206 |
+
video_info = VideoInfo(
|
207 |
+
duration_sec=duration_sec,
|
208 |
+
fps=orig_fps,
|
209 |
+
clip_frames=clip_frames,
|
210 |
+
sync_frames=sync_frames,
|
211 |
+
all_frames=all_frames if load_all_frames else None,
|
212 |
+
)
|
213 |
+
return video_info
|
214 |
|
|
|
|
|
215 |
|
216 |
+
def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
|
217 |
+
reencode_with_audio(video_info, output_path, audio, sampling_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|