aiqcamp commited on
Commit
f6c9d00
ยท
verified ยท
1 Parent(s): 46cfad8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -13
app.py CHANGED
@@ -8,7 +8,17 @@ import os
8
  import time
9
  from datetime import datetime
10
  import gradio as gr
 
 
11
  import torch
 
 
 
 
 
 
 
 
12
  import requests
13
  from pathlib import Path
14
  import cv2
@@ -37,25 +47,34 @@ translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
37
  CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
38
  REPLICATE_API_TOKEN = os.getenv("API_KEY")
39
 
 
40
  # 4. ์˜ค๋””์˜ค ๋ชจ๋ธ ์„ค์ •
41
- device = 'cuda'
42
- dtype = torch.bfloat16
43
 
44
  # 5. get_model ํ•จ์ˆ˜ ์ •์˜
45
  def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
46
  seq_cfg = model.seq_cfg
47
 
48
- net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
 
 
 
 
49
  net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
50
  logger.info(f'Loaded weights from {model.model_path}')
51
 
52
- feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
53
- synchformer_ckpt=model.synchformer_ckpt,
54
- enable_conditions=True,
55
- mode=model.mode,
56
- bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
57
- need_vae_encoder=False)
58
- feature_utils = feature_utils.to(device, dtype).eval()
 
 
 
 
 
59
 
60
  return net, feature_utils, seq_cfg
61
 
@@ -67,13 +86,16 @@ output_dir = Path('./output/gradio')
67
  setup_eval_logging()
68
  net, feature_utils, seq_cfg = get_model()
69
 
70
-
 
71
  def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
72
  seed: int = -1, num_steps: int = 15,
73
- cfg_strength: float = 4.0, target_duration: float = None): # target_duration์„ ์„ ํƒ์ ์œผ๋กœ ๋ณ€๊ฒฝ
74
  try:
75
  logger.info("Starting audio generation process")
76
- torch.cuda.empty_cache()
 
 
77
 
78
  # ๋น„๋””์˜ค ๊ธธ์ด ํ™•์ธ
79
  cap = cv2.VideoCapture(video_path)
@@ -493,4 +515,10 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
493
  )
494
 
495
  if __name__ == "__main__":
 
 
 
 
 
 
496
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
8
  import time
9
  from datetime import datetime
10
  import gradio as gr
11
+
12
+ # GPU ์ดˆ๊ธฐํ™” ์„ค์ •
13
  import torch
14
+ if torch.cuda.is_available():
15
+ torch.cuda.init()
16
+ device = torch.device('cuda')
17
+ logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
18
+ else:
19
+ device = torch.device('cpu')
20
+ logger.warning("GPU not available, using CPU")
21
+
22
  import requests
23
  from pathlib import Path
24
  import cv2
 
47
  CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
48
  REPLICATE_API_TOKEN = os.getenv("API_KEY")
49
 
50
+
51
  # 4. ์˜ค๋””์˜ค ๋ชจ๋ธ ์„ค์ •
52
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
 
53
 
54
  # 5. get_model ํ•จ์ˆ˜ ์ •์˜
55
  def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
56
  seq_cfg = model.seq_cfg
57
 
58
+ net: MMAudio = get_my_mmaudio(model.model_name).to(device)
59
+ if torch.cuda.is_available():
60
+ net = net.to(dtype)
61
+ net.eval()
62
+
63
  net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
64
  logger.info(f'Loaded weights from {model.model_path}')
65
 
66
+ feature_utils = FeaturesUtils(
67
+ tod_vae_ckpt=model.vae_path,
68
+ synchformer_ckpt=model.synchformer_ckpt,
69
+ enable_conditions=True,
70
+ mode=model.mode,
71
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
72
+ need_vae_encoder=False
73
+ ).to(device)
74
+
75
+ if torch.cuda.is_available():
76
+ feature_utils = feature_utils.to(dtype)
77
+ feature_utils.eval()
78
 
79
  return net, feature_utils, seq_cfg
80
 
 
86
  setup_eval_logging()
87
  net, feature_utils, seq_cfg = get_model()
88
 
89
+ @spaces.GPU(duration=30)
90
+ @torch.inference_mode()
91
  def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
92
  seed: int = -1, num_steps: int = 15,
93
+ cfg_strength: float = 4.0, target_duration: float = None):
94
  try:
95
  logger.info("Starting audio generation process")
96
+ if torch.cuda.is_available():
97
+ torch.cuda.empty_cache()
98
+
99
 
100
  # ๋น„๋””์˜ค ๊ธธ์ด ํ™•์ธ
101
  cap = cv2.VideoCapture(video_path)
 
515
  )
516
 
517
  if __name__ == "__main__":
518
+ # GPU ์ดˆ๊ธฐํ™” ํ™•์ธ
519
+ if torch.cuda.is_available():
520
+ logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
521
+ else:
522
+ logger.warning("GPU not available, using CPU")
523
+
524
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)