import spaces import logging from datetime import datetime from pathlib import Path import gradio as gr import torch import torchaudio import os from transformers import pipeline from pixabay import Image, Video import tempfile # 기본 설정 try: import mmaudio except ImportError: os.system("pip install -e .") import mmaudio from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video, setup_eval_logging) from mmaudio.model.flow_matching import FlowMatching from mmaudio.model.networks import MMAudio, get_my_mmaudio from mmaudio.model.sequence_config import SequenceConfig from mmaudio.model.utils.features_utils import FeaturesUtils # CUDA 설정 torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # 로깅 설정 log = logging.getLogger() # 장치 및 데이터 타입 설정 device = 'cuda' dtype = torch.bfloat16 # 모델 설정 model: ModelConfig = all_model_cfg['large_44k_v2'] model.download_if_needed() output_dir = Path('./output/gradio') setup_eval_logging() # 번역기 및 Pixabay API 설정 translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") PIXABAY_API_KEY = "33492762-a28a596ec4f286f84cd328b17" pixabay_video = Video(PIXABAY_API_KEY) # CSS 스타일 정의 custom_css = """ .gradio-container { background: linear-gradient(45deg, #1a1a1a, #2a2a2a); border-radius: 15px; box-shadow: 0 8px 32px rgba(0,0,0,0.3); } .input-container, .output-container { background: rgba(255,255,255,0.1); backdrop-filter: blur(10px); border-radius: 10px; padding: 20px; transform-style: preserve-3d; transition: transform 0.3s ease; } .input-container:hover { transform: translateZ(20px); } .gallery-item { transition: transform 0.3s ease; border-radius: 8px; overflow: hidden; } .gallery-item:hover { transform: scale(1.05); box-shadow: 0 4px 15px rgba(0,0,0,0.2); } .tabs { background: rgba(255,255,255,0.05); border-radius: 10px; padding: 10px; } button { background: linear-gradient(45deg, #4a90e2, #357abd); border: none; border-radius: 5px; transition: all 0.3s ease; } button:hover { transform: translateY(-2px); box-shadow: 0 4px 15px rgba(74,144,226,0.3); } """ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]: seq_cfg = model.seq_cfg net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval() net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True)) log.info(f'Loaded weights from {model.model_path}') feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path, synchformer_ckpt=model.synchformer_ckpt, enable_conditions=True, mode=model.mode, bigvgan_vocoder_ckpt=model.bigvgan_16k_path, need_vae_encoder=False) feature_utils = feature_utils.to(device, dtype).eval() return net, feature_utils, seq_cfg net, feature_utils, seq_cfg = get_model() def translate_prompt(text): if text and any(ord(char) >= 0x3131 and ord(char) <= 0xD7A3 for char in text): translation = translator(text)[0]['translation_text'] return translation return text def search_videos(query): query = translate_prompt(query) videos = pixabay_video.search(q=query, per_page=80) return [video.video_large for video in videos['hits']] @spaces.GPU @torch.inference_mode() def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float, duration: float): prompt = translate_prompt(prompt) negative_prompt = translate_prompt(negative_prompt) rng = torch.Generator(device=device) rng.manual_seed(seed) fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps) clip_frames, sync_frames, duration = load_video(video, duration) clip_frames = clip_frames.unsqueeze(0) sync_frames = sync_frames.unsqueeze(0) seq_cfg.duration = duration net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) audios = generate(clip_frames, sync_frames, [prompt], negative_text=[negative_prompt], feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=cfg_strength) audio = audios.float().cpu()[0] video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name make_video(video, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate, duration_sec=seq_cfg.duration) return video_save_path @spaces.GPU @torch.inference_mode() def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float, duration: float): prompt = translate_prompt(prompt) negative_prompt = translate_prompt(negative_prompt) rng = torch.Generator(device=device) rng.manual_seed(seed) fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps) clip_frames = sync_frames = None seq_cfg.duration = duration net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) audios = generate(clip_frames, sync_frames, [prompt], negative_text=[negative_prompt], feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=cfg_strength) audio = audios.float().cpu()[0] audio_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.flac').name torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate) return audio_save_path # 인터페이스 정의 video_search_tab = gr.Interface( fn=search_videos, inputs=gr.Textbox(label="검색어 입력"), outputs=gr.Gallery(label="검색 결과", columns=4, rows=20), css=custom_css ) video_to_audio_tab = gr.Interface( fn=video_to_audio, inputs=[ gr.Video(label="비디오"), gr.Textbox(label="프롬프트"), gr.Textbox(label="네거티브 프롬프트", value="music"), gr.Number(label="시드", value=0), gr.Number(label="스텝 수", value=25), gr.Number(label="가이드 강도", value=4.5), gr.Number(label="길이(초)", value=8), ], outputs="playable_video", css=custom_css ) text_to_audio_tab = gr.Interface( fn=text_to_audio, inputs=[ gr.Textbox(label="프롬프트"), gr.Textbox(label="네거티브 프롬프트"), gr.Number(label="시드", value=0), gr.Number(label="스텝 수", value=25), gr.Number(label="가이드 강도", value=4.5), gr.Number(label="길이(초)", value=8), ], outputs="audio", css=custom_css ) # 메인 실행 if __name__ == "__main__": gr.TabbedInterface( [video_search_tab, video_to_audio_tab, text_to_audio_tab], ["비디오 검색", "비디오-오디오 변환", "텍스트-오디오 변환"], css=custom_css ).launch(allowed_paths=[output_dir])