aiqcamp commited on
Commit
d974483
·
verified ·
1 Parent(s): b1e6cb0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -15
app.py CHANGED
@@ -1,9 +1,14 @@
 
 
 
 
 
 
1
  import os
2
  import time
3
  from datetime import datetime
4
  import gradio as gr
5
  import torch
6
- import logging
7
  import requests
8
  from pathlib import Path
9
  import cv2
@@ -26,9 +31,15 @@ from mmaudio.model.networks import MMAudio, get_my_mmaudio
26
  from mmaudio.model.sequence_config import SequenceConfig
27
  from mmaudio.model.utils.features_utils import FeaturesUtils
28
 
 
 
 
29
 
 
 
 
30
 
31
- # 먼저 get_model 함수 정의
32
  def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
33
  seq_cfg = model.seq_cfg
34
 
@@ -46,10 +57,7 @@ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
46
 
47
  return net, feature_utils, seq_cfg
48
 
49
- # 다음 모델 설정 및 초기화
50
- device = 'cuda'
51
- dtype = torch.bfloat16
52
-
53
  model: ModelConfig = all_model_cfg['large_44k_v2']
54
  model.download_if_needed()
55
  output_dir = Path('./output/gradio')
@@ -58,15 +66,6 @@ setup_eval_logging()
58
  net, feature_utils, seq_cfg = get_model()
59
 
60
 
61
- # 로깅 설정
62
- logging.basicConfig(level=logging.INFO)
63
- logger = logging.getLogger(__name__)
64
-
65
- # API 설정
66
- CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
67
- REPLICATE_API_TOKEN = os.getenv("API_KEY")
68
-
69
-
70
 
71
  @spaces.GPU(duration=120)
72
  @torch.inference_mode()
 
1
+ # 1. 먼저 로깅 설정
2
+ import logging
3
+ logging.basicConfig(level=logging.INFO)
4
+ logger = logging.getLogger(__name__)
5
+
6
+ # 2. 나머지 imports
7
  import os
8
  import time
9
  from datetime import datetime
10
  import gradio as gr
11
  import torch
 
12
  import requests
13
  from pathlib import Path
14
  import cv2
 
31
  from mmaudio.model.sequence_config import SequenceConfig
32
  from mmaudio.model.utils.features_utils import FeaturesUtils
33
 
34
+ # 3. API 설정
35
+ CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
36
+ REPLICATE_API_TOKEN = os.getenv("API_KEY")
37
 
38
+ # 4. 오디오 모델 설정
39
+ device = 'cuda'
40
+ dtype = torch.bfloat16
41
 
42
+ # 5. get_model 함수 정의
43
  def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
44
  seq_cfg = model.seq_cfg
45
 
 
57
 
58
  return net, feature_utils, seq_cfg
59
 
60
+ # 6. 모델 초기화
 
 
 
61
  model: ModelConfig = all_model_cfg['large_44k_v2']
62
  model.download_if_needed()
63
  output_dir = Path('./output/gradio')
 
66
  net, feature_utils, seq_cfg = get_model()
67
 
68
 
 
 
 
 
 
 
 
 
 
69
 
70
  @spaces.GPU(duration=120)
71
  @torch.inference_mode()