aiqcamp commited on
Commit
ca607ce
·
verified ·
1 Parent(s): 9ae8acd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -1
app.py CHANGED
@@ -44,6 +44,59 @@ logger = logging.getLogger(__name__)
44
  CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
45
  REPLICATE_API_TOKEN = os.getenv("API_KEY")
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def upload_to_catbox(file_path):
48
  """catbox.moe API를 사용하여 파일 업로드"""
49
  try:
@@ -287,7 +340,7 @@ footer {display: none}
287
  """
288
 
289
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
290
- gr.HTML('<div style="text-align: center; font-size: 1.5em; margin: 10px 0;">🎥 Image to Video Generator</div>')
291
 
292
  with gr.Row():
293
  with gr.Column(scale=3):
 
44
  CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
45
  REPLICATE_API_TOKEN = os.getenv("API_KEY")
46
 
47
+ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
48
+ seq_cfg = model.seq_cfg
49
+
50
+ net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
51
+ net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
52
+ logger.info(f'Loaded weights from {model.model_path}')
53
+
54
+ feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
55
+ synchformer_ckpt=model.synchformer_ckpt,
56
+ enable_conditions=True,
57
+ mode=model.mode,
58
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
59
+ need_vae_encoder=False)
60
+ feature_utils = feature_utils.to(device, dtype).eval()
61
+
62
+ return net, feature_utils, seq_cfg
63
+
64
+ @spaces.GPU(duration=120)
65
+ @torch.inference_mode()
66
+ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
67
+ seed: int = -1, num_steps: int = 25,
68
+ cfg_strength: float = 4.5, duration: float = 8):
69
+ rng = torch.Generator(device=device)
70
+ if seed >= 0:
71
+ rng.manual_seed(seed)
72
+ else:
73
+ rng.seed()
74
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
75
+
76
+ video_info = load_video(video_path, duration)
77
+ clip_frames = video_info.clip_frames
78
+ sync_frames = video_info.sync_frames
79
+ duration = video_info.duration_sec
80
+ clip_frames = clip_frames.unsqueeze(0)
81
+ sync_frames = sync_frames.unsqueeze(0)
82
+ seq_cfg.duration = duration
83
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
84
+
85
+ audios = generate(clip_frames,
86
+ sync_frames, [prompt],
87
+ negative_text=[negative_prompt],
88
+ feature_utils=feature_utils,
89
+ net=net,
90
+ fm=fm,
91
+ rng=rng,
92
+ cfg_strength=cfg_strength)
93
+ audio = audios.float().cpu()[0]
94
+
95
+ video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
96
+ make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
97
+ logger.info(f'Saved video with audio to {video_save_path}')
98
+ return video_save_path
99
+
100
  def upload_to_catbox(file_path):
101
  """catbox.moe API를 사용하여 파일 업로드"""
102
  try:
 
340
  """
341
 
342
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
343
+
344
 
345
  with gr.Row():
346
  with gr.Column(scale=3):