aiqcamp commited on
Commit
decba1f
·
verified ·
1 Parent(s): d974483

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -31
app.py CHANGED
@@ -66,42 +66,48 @@ setup_eval_logging()
66
  net, feature_utils, seq_cfg = get_model()
67
 
68
 
69
-
70
  @spaces.GPU(duration=120)
71
  @torch.inference_mode()
72
  def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
73
  seed: int = -1, num_steps: int = 25,
74
  cfg_strength: float = 4.5, duration: float = 8):
75
- rng = torch.Generator(device=device)
76
- if seed >= 0:
77
- rng.manual_seed(seed)
78
- else:
79
- rng.seed()
80
- fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
81
-
82
- video_info = load_video(video_path, duration)
83
- clip_frames = video_info.clip_frames
84
- sync_frames = video_info.sync_frames
85
- duration = video_info.duration_sec
86
- clip_frames = clip_frames.unsqueeze(0)
87
- sync_frames = sync_frames.unsqueeze(0)
88
- seq_cfg.duration = duration
89
- net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
90
-
91
- audios = generate(clip_frames,
92
- sync_frames, [prompt],
93
- negative_text=[negative_prompt],
94
- feature_utils=feature_utils,
95
- net=net,
96
- fm=fm,
97
- rng=rng,
98
- cfg_strength=cfg_strength)
99
- audio = audios.float().cpu()[0]
100
-
101
- video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
102
- make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
103
- logger.info(f'Saved video with audio to {video_save_path}')
104
- return video_save_path
 
 
 
 
 
 
 
105
 
106
  def upload_to_catbox(file_path):
107
  """catbox.moe API를 사용하여 파일 업로드"""
 
66
  net, feature_utils, seq_cfg = get_model()
67
 
68
 
 
69
  @spaces.GPU(duration=120)
70
  @torch.inference_mode()
71
  def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
72
  seed: int = -1, num_steps: int = 25,
73
  cfg_strength: float = 4.5, duration: float = 8):
74
+ try:
75
+ rng = torch.Generator(device=device)
76
+ if seed >= 0:
77
+ rng.manual_seed(seed)
78
+ else:
79
+ rng.seed()
80
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
81
+
82
+ # duration 파라미터 전달 방식 수정
83
+ video_info = load_video(video_path, static_duration=duration) # static_duration으로 변경
84
+
85
+ clip_frames = video_info.clip_frames
86
+ sync_frames = video_info.sync_frames
87
+ actual_duration = video_info.duration_sec
88
+ clip_frames = clip_frames.unsqueeze(0)
89
+ sync_frames = sync_frames.unsqueeze(0)
90
+ seq_cfg.duration = actual_duration
91
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
92
+
93
+ audios = generate(clip_frames,
94
+ sync_frames, [prompt],
95
+ negative_text=[negative_prompt],
96
+ feature_utils=feature_utils,
97
+ net=net,
98
+ fm=fm,
99
+ rng=rng,
100
+ cfg_strength=cfg_strength)
101
+ audio = audios.float().cpu()[0]
102
+
103
+ video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
104
+ make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
105
+ logger.info(f'Saved video with audio to {video_save_path}')
106
+ return video_save_path
107
+
108
+ except Exception as e:
109
+ logger.error(f"Error in video_to_audio: {str(e)}")
110
+ return video_path # 오류 발생 시 원본 비디오 반환
111
 
112
  def upload_to_catbox(file_path):
113
  """catbox.moe API를 사용하여 파일 업로드"""