import spaces import logging from datetime import datetime from pathlib import Path import gradio as gr import torch import torchaudio import os import requests from transformers import pipeline import tempfile # 기본 설정 try: import mmaudio except ImportError: os.system("pip install -e .") import mmaudio from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video, setup_eval_logging) from mmaudio.model.flow_matching import FlowMatching from mmaudio.model.networks import MMAudio, get_my_mmaudio from mmaudio.model.sequence_config import SequenceConfig from mmaudio.model.utils.features_utils import FeaturesUtils # CUDA 설정 torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # 로깅 설정 log = logging.getLogger() # 장치 및 데이터 타입 설정 device = 'cuda' dtype = torch.bfloat16 # 모델 설정 model: ModelConfig = all_model_cfg['large_44k_v2'] model.download_if_needed() output_dir = Path('./output/gradio') setup_eval_logging() # 번역기 및 Pixabay API 설정 translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") PIXABAY_API_KEY = "33492762-a28a596ec4f286f84cd328b17" def search_pixabay_videos(query, api_key): base_url = "https://pixabay.com/api/videos/" params = { "key": api_key, "q": query, "per_page": 80 } response = requests.get(base_url, params=params) if response.status_code == 200: data = response.json() return [video['videos']['large']['url'] for video in data.get('hits', [])] return [] # CSS 스타일 정의 custom_css = """ .gradio-container { background: linear-gradient(45deg, #1a1a1a, #2a2a2a); border-radius: 15px; box-shadow: 0 8px 32px rgba(0,0,0,0.3); } .input-container, .output-container { background: rgba(255,255,255,0.1); backdrop-filter: blur(10px); border-radius: 10px; padding: 20px; transform-style: preserve-3d; transition: transform 0.3s ease; } .input-container:hover { transform: translateZ(20px); } .gallery-item { transition: transform 0.3s ease; border-radius: 8px; overflow: hidden; } .gallery-item:hover { transform: scale(1.05); box-shadow: 0 4px 15px rgba(0,0,0,0.2); } .tabs { background: rgba(255,255,255,0.05); border-radius: 10px; padding: 10px; } button { background: linear-gradient(45deg, #4a90e2, #357abd); border: none; border-radius: 5px; transition: all 0.3s ease; } button:hover { transform: translateY(-2px); box-shadow: 0 4px 15px rgba(74,144,226,0.3); } """ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]: seq_cfg = model.seq_cfg net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval() net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True)) log.info(f'Loaded weights from {model.model_path}') feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path, synchformer_ckpt=model.synchformer_ckpt, enable_conditions=True, mode=model.mode, bigvgan_vocoder_ckpt=model.bigvgan_16k_path, need_vae_encoder=False) feature_utils = feature_utils.to(device, dtype).eval() return net, feature_utils, seq_cfg net, feature_utils, seq_cfg = get_model() def translate_prompt(text): if text and any(ord(char) >= 0x3131 and ord(char) <= 0xD7A3 for char in text): translation = translator(text)[0]['translation_text'] return translation return text def search_videos(query): query = translate_prompt(query) return search_pixabay_videos(query, PIXABAY_API_KEY) @spaces.GPU @torch.inference_mode() def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float, duration: float): prompt = translate_prompt(prompt) negative_prompt = translate_prompt(negative_prompt) rng = torch.Generator(device=device) rng.manual_seed(seed) fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps) clip_frames, sync_frames, duration = load_video(video, duration) clip_frames = clip_frames.unsqueeze(0) sync_frames = sync_frames.unsqueeze(0) seq_cfg.duration = duration net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) audios = generate(clip_frames, sync_frames, [prompt], negative_text=[negative_prompt], feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=cfg_strength) audio = audios.float().cpu()[0] video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name make_video(video, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate, duration_sec=seq_cfg.duration) return video_save_path @spaces.GPU @torch.inference_mode() def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float, duration: float): prompt = translate_prompt(prompt) negative_prompt = translate_prompt(negative_prompt) rng = torch.Generator(device=device) rng.manual_seed(seed) fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps) clip_frames = sync_frames = None seq_cfg.duration = duration net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) audios = generate(clip_frames, sync_frames, [prompt], negative_text=[negative_prompt], feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=cfg_strength) audio = audios.float().cpu()[0] audio_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.flac').name torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate) return audio_save_path # 인터페이스 정의 video_search_tab = gr.Interface( fn=search_videos, inputs=gr.Textbox(label="검색어 입력"), outputs=gr.Gallery(label="검색 결과", columns=4, rows=20), css=custom_css ) video_to_audio_tab = gr.Interface( fn=video_to_audio, inputs=[ gr.Video(label="비디오"), gr.Textbox(label="프롬프트"), gr.Textbox(label="네거티브 프롬프트", value="music"), gr.Number(label="시드", value=0), gr.Number(label="스텝 수", value=25), gr.Number(label="가이드 강도", value=4.5), gr.Number(label="길이(초)", value=8), ], outputs="playable_video", css=custom_css ) text_to_audio_tab = gr.Interface( fn=text_to_audio, inputs=[ gr.Textbox(label="프롬프트"), gr.Textbox(label="네거티브 프롬프트"), gr.Number(label="시드", value=0), gr.Number(label="스텝 수", value=25), gr.Number(label="가이드 강도", value=4.5), gr.Number(label="길이(초)", value=8), ], outputs="audio", css=custom_css ) # 메인 실행 if __name__ == "__main__": gr.TabbedInterface( [video_search_tab, video_to_audio_tab, text_to_audio_tab], ["비디오 검색", "비디오-오디오 변환", "텍스트-오디오 변환"], css=custom_css ).launch(allowed_paths=[output_dir])