Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -67,15 +67,25 @@ output_dir = Path('./output/gradio')
|
|
67 |
setup_eval_logging()
|
68 |
net, feature_utils, seq_cfg = get_model()
|
69 |
|
70 |
-
|
71 |
-
@torch.inference_mode()
|
72 |
def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
73 |
seed: int = -1, num_steps: int = 15,
|
74 |
-
cfg_strength: float = 4.0, target_duration: float =
|
75 |
try:
|
76 |
logger.info("Starting audio generation process")
|
77 |
torch.cuda.empty_cache()
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
rng = torch.Generator(device=device)
|
80 |
if seed >= 0:
|
81 |
rng.manual_seed(seed)
|
@@ -84,8 +94,8 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
|
84 |
|
85 |
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
86 |
|
87 |
-
# load_video
|
88 |
-
video_info = load_video(video_path, duration_sec=target_duration)
|
89 |
|
90 |
if video_info is None:
|
91 |
logger.error("Failed to load video")
|
@@ -99,16 +109,20 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
|
99 |
logger.error("Failed to extract frames from video")
|
100 |
return video_path
|
101 |
|
102 |
-
#
|
103 |
clip_frames = clip_frames[:int(actual_duration * video_info.fps)]
|
104 |
sync_frames = sync_frames[:int(actual_duration * video_info.fps)]
|
105 |
|
106 |
clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16)
|
107 |
sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16)
|
108 |
|
|
|
109 |
seq_cfg.duration = actual_duration
|
110 |
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
111 |
|
|
|
|
|
|
|
112 |
logger.info("Generating audio...")
|
113 |
with torch.cuda.amp.autocast():
|
114 |
audios = generate(clip_frames,
|
@@ -356,6 +370,15 @@ def generate_video(image, prompt):
|
|
356 |
|
357 |
final_path = add_watermark(output_path)
|
358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
# 오디오 처리 추가
|
360 |
try:
|
361 |
logger.info("Starting audio generation process")
|
@@ -365,8 +388,8 @@ def generate_video(image, prompt):
|
|
365 |
negative_prompt="music",
|
366 |
seed=-1,
|
367 |
num_steps=20,
|
368 |
-
cfg_strength=4.5
|
369 |
-
target_duration
|
370 |
)
|
371 |
|
372 |
if final_path_with_audio != final_path:
|
|
|
67 |
setup_eval_logging()
|
68 |
net, feature_utils, seq_cfg = get_model()
|
69 |
|
70 |
+
|
|
|
71 |
def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
72 |
seed: int = -1, num_steps: int = 15,
|
73 |
+
cfg_strength: float = 4.0, target_duration: float = None): # target_duration을 선택적으로 변경
|
74 |
try:
|
75 |
logger.info("Starting audio generation process")
|
76 |
torch.cuda.empty_cache()
|
77 |
|
78 |
+
# 비디오 길이 확인
|
79 |
+
cap = cv2.VideoCapture(video_path)
|
80 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
81 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
82 |
+
video_duration = total_frames / fps
|
83 |
+
cap.release()
|
84 |
+
|
85 |
+
# 실제 비디오 길이를 target_duration으로 사용
|
86 |
+
target_duration = video_duration
|
87 |
+
logger.info(f"Video duration: {target_duration} seconds")
|
88 |
+
|
89 |
rng = torch.Generator(device=device)
|
90 |
if seed >= 0:
|
91 |
rng.manual_seed(seed)
|
|
|
94 |
|
95 |
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
96 |
|
97 |
+
# 비디오 길이에 맞춰 load_video 호출
|
98 |
+
video_info = load_video(video_path, duration_sec=target_duration)
|
99 |
|
100 |
if video_info is None:
|
101 |
logger.error("Failed to load video")
|
|
|
109 |
logger.error("Failed to extract frames from video")
|
110 |
return video_path
|
111 |
|
112 |
+
# 실제 비디오 프레임 수에 맞춰 조정
|
113 |
clip_frames = clip_frames[:int(actual_duration * video_info.fps)]
|
114 |
sync_frames = sync_frames[:int(actual_duration * video_info.fps)]
|
115 |
|
116 |
clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16)
|
117 |
sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16)
|
118 |
|
119 |
+
# sequence config 업데이트
|
120 |
seq_cfg.duration = actual_duration
|
121 |
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
122 |
|
123 |
+
logger.info(f"Generating audio for {actual_duration} seconds...")
|
124 |
+
|
125 |
+
|
126 |
logger.info("Generating audio...")
|
127 |
with torch.cuda.amp.autocast():
|
128 |
audios = generate(clip_frames,
|
|
|
370 |
|
371 |
final_path = add_watermark(output_path)
|
372 |
|
373 |
+
# 비디오 길이 확인
|
374 |
+
cap = cv2.VideoCapture(final_path)
|
375 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
376 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
377 |
+
video_duration = total_frames / fps
|
378 |
+
cap.release()
|
379 |
+
|
380 |
+
logger.info(f"Original video duration: {video_duration} seconds")
|
381 |
+
|
382 |
# 오디오 처리 추가
|
383 |
try:
|
384 |
logger.info("Starting audio generation process")
|
|
|
388 |
negative_prompt="music",
|
389 |
seed=-1,
|
390 |
num_steps=20,
|
391 |
+
cfg_strength=4.5
|
392 |
+
# target_duration 제거 - 자동으로 비디오 길이 사용
|
393 |
)
|
394 |
|
395 |
if final_path_with_audio != final_path:
|