aiqcamp commited on
Commit
46cfad8
·
verified ·
1 Parent(s): c388283

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -8
app.py CHANGED
@@ -67,15 +67,25 @@ output_dir = Path('./output/gradio')
67
  setup_eval_logging()
68
  net, feature_utils, seq_cfg = get_model()
69
 
70
- @spaces.GPU(duration=30) # 30초로 제한
71
- @torch.inference_mode()
72
  def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
73
  seed: int = -1, num_steps: int = 15,
74
- cfg_strength: float = 4.0, target_duration: float = 4.0):
75
  try:
76
  logger.info("Starting audio generation process")
77
  torch.cuda.empty_cache()
78
 
 
 
 
 
 
 
 
 
 
 
 
79
  rng = torch.Generator(device=device)
80
  if seed >= 0:
81
  rng.manual_seed(seed)
@@ -84,8 +94,8 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
84
 
85
  fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
86
 
87
- # load_video 함수 호출 수정
88
- video_info = load_video(video_path, duration_sec=target_duration) # duration_sec 파라미터로 변경
89
 
90
  if video_info is None:
91
  logger.error("Failed to load video")
@@ -99,16 +109,20 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
99
  logger.error("Failed to extract frames from video")
100
  return video_path
101
 
102
- # 메모리 최적화
103
  clip_frames = clip_frames[:int(actual_duration * video_info.fps)]
104
  sync_frames = sync_frames[:int(actual_duration * video_info.fps)]
105
 
106
  clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16)
107
  sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16)
108
 
 
109
  seq_cfg.duration = actual_duration
110
  net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
111
 
 
 
 
112
  logger.info("Generating audio...")
113
  with torch.cuda.amp.autocast():
114
  audios = generate(clip_frames,
@@ -356,6 +370,15 @@ def generate_video(image, prompt):
356
 
357
  final_path = add_watermark(output_path)
358
 
 
 
 
 
 
 
 
 
 
359
  # 오디오 처리 추가
360
  try:
361
  logger.info("Starting audio generation process")
@@ -365,8 +388,8 @@ def generate_video(image, prompt):
365
  negative_prompt="music",
366
  seed=-1,
367
  num_steps=20,
368
- cfg_strength=4.5,
369
- target_duration=6.0
370
  )
371
 
372
  if final_path_with_audio != final_path:
 
67
  setup_eval_logging()
68
  net, feature_utils, seq_cfg = get_model()
69
 
70
+
 
71
  def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
72
  seed: int = -1, num_steps: int = 15,
73
+ cfg_strength: float = 4.0, target_duration: float = None): # target_duration을 선택적으로 변경
74
  try:
75
  logger.info("Starting audio generation process")
76
  torch.cuda.empty_cache()
77
 
78
+ # 비디오 길이 확인
79
+ cap = cv2.VideoCapture(video_path)
80
+ fps = cap.get(cv2.CAP_PROP_FPS)
81
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
82
+ video_duration = total_frames / fps
83
+ cap.release()
84
+
85
+ # 실제 비디오 길이를 target_duration으로 사용
86
+ target_duration = video_duration
87
+ logger.info(f"Video duration: {target_duration} seconds")
88
+
89
  rng = torch.Generator(device=device)
90
  if seed >= 0:
91
  rng.manual_seed(seed)
 
94
 
95
  fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
96
 
97
+ # 비디오 길이에 맞춰 load_video 호출
98
+ video_info = load_video(video_path, duration_sec=target_duration)
99
 
100
  if video_info is None:
101
  logger.error("Failed to load video")
 
109
  logger.error("Failed to extract frames from video")
110
  return video_path
111
 
112
+ # 실제 비디오 프레임 수에 맞춰 조정
113
  clip_frames = clip_frames[:int(actual_duration * video_info.fps)]
114
  sync_frames = sync_frames[:int(actual_duration * video_info.fps)]
115
 
116
  clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16)
117
  sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16)
118
 
119
+ # sequence config 업데이트
120
  seq_cfg.duration = actual_duration
121
  net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
122
 
123
+ logger.info(f"Generating audio for {actual_duration} seconds...")
124
+
125
+
126
  logger.info("Generating audio...")
127
  with torch.cuda.amp.autocast():
128
  audios = generate(clip_frames,
 
370
 
371
  final_path = add_watermark(output_path)
372
 
373
+ # 비디오 길이 확인
374
+ cap = cv2.VideoCapture(final_path)
375
+ fps = cap.get(cv2.CAP_PROP_FPS)
376
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
377
+ video_duration = total_frames / fps
378
+ cap.release()
379
+
380
+ logger.info(f"Original video duration: {video_duration} seconds")
381
+
382
  # 오디오 처리 추가
383
  try:
384
  logger.info("Starting audio generation process")
 
388
  negative_prompt="music",
389
  seed=-1,
390
  num_steps=20,
391
+ cfg_strength=4.5
392
+ # target_duration 제거 - 자동으로 비디오 길이 사용
393
  )
394
 
395
  if final_path_with_audio != final_path: