aiqcamp commited on
Commit
7a8cebd
·
verified ·
1 Parent(s): e76dd8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -17
app.py CHANGED
@@ -65,11 +65,11 @@ output_dir = Path('./output/gradio')
65
  setup_eval_logging()
66
  net, feature_utils, seq_cfg = get_model()
67
 
68
- @spaces.GPU(duration=60)
69
  @torch.inference_mode()
70
  def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
71
- seed: int = -1, num_steps: int = 20,
72
- cfg_strength: float = 4.5, target_duration: float = 6.0):
73
  try:
74
  logger.info("Starting audio generation process")
75
  torch.cuda.empty_cache()
@@ -83,16 +83,12 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
83
  fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
84
 
85
  # load_video 함수 호출 수정
86
- video_info = load_video(video_path) # duration 파라미터 제거
87
 
88
  if video_info is None:
89
  logger.error("Failed to load video")
90
  return video_path
91
 
92
- # 비디오 길이 조정이 필요한 경우 여기서 처리
93
- if hasattr(video_info, 'set_duration'):
94
- video_info.set_duration(target_duration)
95
-
96
  clip_frames = video_info.clip_frames
97
  sync_frames = video_info.sync_frames
98
  actual_duration = video_info.duration_sec
@@ -101,6 +97,10 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
101
  logger.error("Failed to extract frames from video")
102
  return video_path
103
 
 
 
 
 
104
  clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16)
105
  sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16)
106
 
@@ -108,15 +108,16 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
108
  net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
109
 
110
  logger.info("Generating audio...")
111
- audios = generate(clip_frames,
112
- sync_frames,
113
- [prompt],
114
- negative_text=[negative_prompt],
115
- feature_utils=feature_utils,
116
- net=net,
117
- fm=fm,
118
- rng=rng,
119
- cfg_strength=cfg_strength)
 
120
 
121
  if audios is None:
122
  logger.error("Failed to generate audio")
 
65
  setup_eval_logging()
66
  net, feature_utils, seq_cfg = get_model()
67
 
68
+ @spaces.GPU(duration=30) # 30초로 제한
69
  @torch.inference_mode()
70
  def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
71
+ seed: int = -1, num_steps: int = 15,
72
+ cfg_strength: float = 4.0, target_duration: float = 4.0):
73
  try:
74
  logger.info("Starting audio generation process")
75
  torch.cuda.empty_cache()
 
83
  fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
84
 
85
  # load_video 함수 호출 수정
86
+ video_info = load_video(video_path, duration_sec=target_duration) # duration_sec 파라미터로 변경
87
 
88
  if video_info is None:
89
  logger.error("Failed to load video")
90
  return video_path
91
 
 
 
 
 
92
  clip_frames = video_info.clip_frames
93
  sync_frames = video_info.sync_frames
94
  actual_duration = video_info.duration_sec
 
97
  logger.error("Failed to extract frames from video")
98
  return video_path
99
 
100
+ # 메모리 최적화
101
+ clip_frames = clip_frames[:int(actual_duration * video_info.fps)]
102
+ sync_frames = sync_frames[:int(actual_duration * video_info.fps)]
103
+
104
  clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16)
105
  sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16)
106
 
 
108
  net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
109
 
110
  logger.info("Generating audio...")
111
+ with torch.cuda.amp.autocast():
112
+ audios = generate(clip_frames,
113
+ sync_frames,
114
+ [prompt],
115
+ negative_text=[negative_prompt],
116
+ feature_utils=feature_utils,
117
+ net=net,
118
+ fm=fm,
119
+ rng=rng,
120
+ cfg_strength=cfg_strength)
121
 
122
  if audios is None:
123
  logger.error("Failed to generate audio")