aiqcamp commited on
Commit
b1e6cb0
ยท
verified ยท
1 Parent(s): 9002f84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -16
app.py CHANGED
@@ -25,25 +25,10 @@ from mmaudio.model.flow_matching import FlowMatching
25
  from mmaudio.model.networks import MMAudio, get_my_mmaudio
26
  from mmaudio.model.sequence_config import SequenceConfig
27
  from mmaudio.model.utils.features_utils import FeaturesUtils
28
- # ์˜ค๋””์˜ค ๋ชจ๋ธ ์„ค์ •
29
- device = 'cuda'
30
- dtype = torch.bfloat16
31
 
32
- model: ModelConfig = all_model_cfg['large_44k_v2']
33
- model.download_if_needed()
34
- output_dir = Path('./output/gradio')
35
 
36
- setup_eval_logging()
37
- net, feature_utils, seq_cfg = get_model() # get_model ํ•จ์ˆ˜๋Š” ์ด์ „์— ์ œ๊ณต๋œ ์ฝ”๋“œ ์‚ฌ์šฉ
38
-
39
- # ๋กœ๊น… ์„ค์ •
40
- logging.basicConfig(level=logging.INFO)
41
- logger = logging.getLogger(__name__)
42
-
43
- # API ์„ค์ •
44
- CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
45
- REPLICATE_API_TOKEN = os.getenv("API_KEY")
46
 
 
47
  def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
48
  seq_cfg = model.seq_cfg
49
 
@@ -61,6 +46,28 @@ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
61
 
62
  return net, feature_utils, seq_cfg
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  @spaces.GPU(duration=120)
65
  @torch.inference_mode()
66
  def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
 
25
  from mmaudio.model.networks import MMAudio, get_my_mmaudio
26
  from mmaudio.model.sequence_config import SequenceConfig
27
  from mmaudio.model.utils.features_utils import FeaturesUtils
 
 
 
28
 
 
 
 
29
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # ๋จผ์ € get_model ํ•จ์ˆ˜ ์ •์˜
32
  def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
33
  seq_cfg = model.seq_cfg
34
 
 
46
 
47
  return net, feature_utils, seq_cfg
48
 
49
+ # ๊ทธ ๋‹ค์Œ ๋ชจ๋ธ ์„ค์ • ๋ฐ ์ดˆ๊ธฐํ™”
50
+ device = 'cuda'
51
+ dtype = torch.bfloat16
52
+
53
+ model: ModelConfig = all_model_cfg['large_44k_v2']
54
+ model.download_if_needed()
55
+ output_dir = Path('./output/gradio')
56
+
57
+ setup_eval_logging()
58
+ net, feature_utils, seq_cfg = get_model()
59
+
60
+
61
+ # ๋กœ๊น… ์„ค์ •
62
+ logging.basicConfig(level=logging.INFO)
63
+ logger = logging.getLogger(__name__)
64
+
65
+ # API ์„ค์ •
66
+ CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
67
+ REPLICATE_API_TOKEN = os.getenv("API_KEY")
68
+
69
+
70
+
71
  @spaces.GPU(duration=120)
72
  @torch.inference_mode()
73
  def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",