Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -25,25 +25,10 @@ from mmaudio.model.flow_matching import FlowMatching
|
|
25 |
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
26 |
from mmaudio.model.sequence_config import SequenceConfig
|
27 |
from mmaudio.model.utils.features_utils import FeaturesUtils
|
28 |
-
# ์ค๋์ค ๋ชจ๋ธ ์ค์
|
29 |
-
device = 'cuda'
|
30 |
-
dtype = torch.bfloat16
|
31 |
|
32 |
-
model: ModelConfig = all_model_cfg['large_44k_v2']
|
33 |
-
model.download_if_needed()
|
34 |
-
output_dir = Path('./output/gradio')
|
35 |
|
36 |
-
setup_eval_logging()
|
37 |
-
net, feature_utils, seq_cfg = get_model() # get_model ํจ์๋ ์ด์ ์ ์ ๊ณต๋ ์ฝ๋ ์ฌ์ฉ
|
38 |
-
|
39 |
-
# ๋ก๊น
์ค์
|
40 |
-
logging.basicConfig(level=logging.INFO)
|
41 |
-
logger = logging.getLogger(__name__)
|
42 |
-
|
43 |
-
# API ์ค์
|
44 |
-
CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
|
45 |
-
REPLICATE_API_TOKEN = os.getenv("API_KEY")
|
46 |
|
|
|
47 |
def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
|
48 |
seq_cfg = model.seq_cfg
|
49 |
|
@@ -61,6 +46,28 @@ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
|
|
61 |
|
62 |
return net, feature_utils, seq_cfg
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
@spaces.GPU(duration=120)
|
65 |
@torch.inference_mode()
|
66 |
def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
|
|
25 |
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
26 |
from mmaudio.model.sequence_config import SequenceConfig
|
27 |
from mmaudio.model.utils.features_utils import FeaturesUtils
|
|
|
|
|
|
|
28 |
|
|
|
|
|
|
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
# ๋จผ์ get_model ํจ์ ์ ์
|
32 |
def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
|
33 |
seq_cfg = model.seq_cfg
|
34 |
|
|
|
46 |
|
47 |
return net, feature_utils, seq_cfg
|
48 |
|
49 |
+
# ๊ทธ ๋ค์ ๋ชจ๋ธ ์ค์ ๋ฐ ์ด๊ธฐํ
|
50 |
+
device = 'cuda'
|
51 |
+
dtype = torch.bfloat16
|
52 |
+
|
53 |
+
model: ModelConfig = all_model_cfg['large_44k_v2']
|
54 |
+
model.download_if_needed()
|
55 |
+
output_dir = Path('./output/gradio')
|
56 |
+
|
57 |
+
setup_eval_logging()
|
58 |
+
net, feature_utils, seq_cfg = get_model()
|
59 |
+
|
60 |
+
|
61 |
+
# ๋ก๊น
์ค์
|
62 |
+
logging.basicConfig(level=logging.INFO)
|
63 |
+
logger = logging.getLogger(__name__)
|
64 |
+
|
65 |
+
# API ์ค์
|
66 |
+
CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
|
67 |
+
REPLICATE_API_TOKEN = os.getenv("API_KEY")
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
@spaces.GPU(duration=120)
|
72 |
@torch.inference_mode()
|
73 |
def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|