Update app.py
Browse files
app.py
CHANGED
@@ -44,6 +44,59 @@ logger = logging.getLogger(__name__)
|
|
44 |
CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
|
45 |
REPLICATE_API_TOKEN = os.getenv("API_KEY")
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
def upload_to_catbox(file_path):
|
48 |
"""catbox.moe API를 사용하여 파일 업로드"""
|
49 |
try:
|
@@ -287,7 +340,7 @@ footer {display: none}
|
|
287 |
"""
|
288 |
|
289 |
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
|
290 |
-
|
291 |
|
292 |
with gr.Row():
|
293 |
with gr.Column(scale=3):
|
|
|
44 |
CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
|
45 |
REPLICATE_API_TOKEN = os.getenv("API_KEY")
|
46 |
|
47 |
+
def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
|
48 |
+
seq_cfg = model.seq_cfg
|
49 |
+
|
50 |
+
net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
|
51 |
+
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
|
52 |
+
logger.info(f'Loaded weights from {model.model_path}')
|
53 |
+
|
54 |
+
feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
|
55 |
+
synchformer_ckpt=model.synchformer_ckpt,
|
56 |
+
enable_conditions=True,
|
57 |
+
mode=model.mode,
|
58 |
+
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
|
59 |
+
need_vae_encoder=False)
|
60 |
+
feature_utils = feature_utils.to(device, dtype).eval()
|
61 |
+
|
62 |
+
return net, feature_utils, seq_cfg
|
63 |
+
|
64 |
+
@spaces.GPU(duration=120)
|
65 |
+
@torch.inference_mode()
|
66 |
+
def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
67 |
+
seed: int = -1, num_steps: int = 25,
|
68 |
+
cfg_strength: float = 4.5, duration: float = 8):
|
69 |
+
rng = torch.Generator(device=device)
|
70 |
+
if seed >= 0:
|
71 |
+
rng.manual_seed(seed)
|
72 |
+
else:
|
73 |
+
rng.seed()
|
74 |
+
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
75 |
+
|
76 |
+
video_info = load_video(video_path, duration)
|
77 |
+
clip_frames = video_info.clip_frames
|
78 |
+
sync_frames = video_info.sync_frames
|
79 |
+
duration = video_info.duration_sec
|
80 |
+
clip_frames = clip_frames.unsqueeze(0)
|
81 |
+
sync_frames = sync_frames.unsqueeze(0)
|
82 |
+
seq_cfg.duration = duration
|
83 |
+
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
84 |
+
|
85 |
+
audios = generate(clip_frames,
|
86 |
+
sync_frames, [prompt],
|
87 |
+
negative_text=[negative_prompt],
|
88 |
+
feature_utils=feature_utils,
|
89 |
+
net=net,
|
90 |
+
fm=fm,
|
91 |
+
rng=rng,
|
92 |
+
cfg_strength=cfg_strength)
|
93 |
+
audio = audios.float().cpu()[0]
|
94 |
+
|
95 |
+
video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
|
96 |
+
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
97 |
+
logger.info(f'Saved video with audio to {video_save_path}')
|
98 |
+
return video_save_path
|
99 |
+
|
100 |
def upload_to_catbox(file_path):
|
101 |
"""catbox.moe API를 사용하여 파일 업로드"""
|
102 |
try:
|
|
|
340 |
"""
|
341 |
|
342 |
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
|
343 |
+
|
344 |
|
345 |
with gr.Row():
|
346 |
with gr.Column(scale=3):
|