# 1. 먼저 로깅 설정 import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 2. 나머지 imports import os import time from datetime import datetime import gradio as gr # GPU 초기화 설정 import torch if torch.cuda.is_available(): torch.cuda.init() device = torch.device('cuda') logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}") else: device = torch.device('cpu') logger.warning("GPU not available, using CPU") import requests from pathlib import Path import cv2 from PIL import Image import json import spaces import torchaudio import tempfile try: import mmaudio except ImportError: os.system("pip install -e .") import mmaudio from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video, setup_eval_logging) from mmaudio.model.flow_matching import FlowMatching from mmaudio.model.networks import MMAudio, get_my_mmaudio from mmaudio.model.sequence_config import SequenceConfig from mmaudio.model.utils.features_utils import FeaturesUtils # 상단에 번역 모델 import 추가 from transformers import pipeline translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") # 3. API 설정 CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489" REPLICATE_API_TOKEN = os.getenv("API_KEY") # 4. 오디오 모델 설정 dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 # 5. get_model 함수 정의 def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]: seq_cfg = model.seq_cfg net: MMAudio = get_my_mmaudio(model.model_name).to(device) if torch.cuda.is_available(): net = net.to(dtype) net.eval() net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True)) logger.info(f'Loaded weights from {model.model_path}') feature_utils = FeaturesUtils( tod_vae_ckpt=model.vae_path, synchformer_ckpt=model.synchformer_ckpt, enable_conditions=True, mode=model.mode, bigvgan_vocoder_ckpt=model.bigvgan_16k_path, need_vae_encoder=False ).to(device) if torch.cuda.is_available(): feature_utils = feature_utils.to(dtype) feature_utils.eval() return net, feature_utils, seq_cfg # 6. 모델 초기화 model: ModelConfig = all_model_cfg['large_44k_v2'] model.download_if_needed() output_dir = Path('./output/gradio') setup_eval_logging() net, feature_utils, seq_cfg = get_model() @spaces.GPU(duration=30) @torch.inference_mode() def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music", seed: int = -1, num_steps: int = 15, cfg_strength: float = 4.0, target_duration: float = None): try: logger.info("Starting audio generation process") if torch.cuda.is_available(): torch.cuda.empty_cache() # 비디오 길이 확인 cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) video_duration = total_frames / fps cap.release() # 실제 비디오 길이를 target_duration으로 사용 target_duration = video_duration logger.info(f"Video duration: {target_duration} seconds") rng = torch.Generator(device=device) if seed >= 0: rng.manual_seed(seed) else: rng.seed() fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps) # 비디오 길이에 맞춰 load_video 호출 video_info = load_video(video_path, duration_sec=target_duration) if video_info is None: logger.error("Failed to load video") return video_path clip_frames = video_info.clip_frames sync_frames = video_info.sync_frames actual_duration = video_info.duration_sec if clip_frames is None or sync_frames is None: logger.error("Failed to extract frames from video") return video_path # 실제 비디오 프레임 수에 맞춰 조정 clip_frames = clip_frames[:int(actual_duration * video_info.fps)] sync_frames = sync_frames[:int(actual_duration * video_info.fps)] clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16) sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16) # sequence config 업데이트 seq_cfg.duration = actual_duration net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) logger.info(f"Generating audio for {actual_duration} seconds...") logger.info("Generating audio...") with torch.cuda.amp.autocast(): audios = generate(clip_frames, sync_frames, [prompt], negative_text=[negative_prompt], feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=cfg_strength) if audios is None: logger.error("Failed to generate audio") return video_path audio = audios.float().cpu()[0] output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name logger.info(f"Creating final video with audio at {output_path}") make_video(video_info, output_path, audio, sampling_rate=seq_cfg.sampling_rate) torch.cuda.empty_cache() if not os.path.exists(output_path): logger.error("Failed to create output video") return video_path logger.info(f'Successfully saved video with audio to {output_path}') return output_path except Exception as e: logger.error(f"Error in video_to_audio: {str(e)}") torch.cuda.empty_cache() return video_path def upload_to_catbox(file_path): """catbox.moe API를 사용하여 파일 업로드""" try: logger.info(f"Preparing to upload file: {file_path}") url = "https://catbox.moe/user/api.php" mime_types = { '.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.png': 'image/png', '.gif': 'image/gif', '.webp': 'image/webp', '.jfif': 'image/jpeg' } file_extension = Path(file_path).suffix.lower() if file_extension not in mime_types: try: img = Image.open(file_path) if img.mode != 'RGB': img = img.convert('RGB') new_path = file_path.rsplit('.', 1)[0] + '.png' img.save(new_path, 'PNG') file_path = new_path file_extension = '.png' logger.info(f"Converted image to PNG: {file_path}") except Exception as e: logger.error(f"Failed to convert image: {str(e)}") return None files = { 'fileToUpload': ( os.path.basename(file_path), open(file_path, 'rb'), mime_types.get(file_extension, 'application/octet-stream') ) } data = { 'reqtype': 'fileupload', 'userhash': CATBOX_USER_HASH } response = requests.post(url, files=files, data=data) if response.status_code == 200 and response.text.startswith('http'): file_url = response.text logger.info(f"File uploaded successfully: {file_url}") return file_url else: raise Exception(f"Upload failed: {response.text}") except Exception as e: logger.error(f"File upload error: {str(e)}") return None finally: if 'new_path' in locals() and os.path.exists(new_path): try: os.remove(new_path) except: pass def add_watermark(video_path): """OpenCV를 사용하여 비디오에 워터마크 추가""" try: cap = cv2.VideoCapture(video_path) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) text = "GiniGEN.AI" font = cv2.FONT_HERSHEY_SIMPLEX font_scale = height * 0.05 / 30 thickness = 2 color = (255, 255, 255) (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness) margin = int(height * 0.02) x_pos = width - text_width - margin y_pos = height - margin output_path = "watermarked_output.mp4" fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) while cap.isOpened(): ret, frame = cap.read() if not ret: break cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness) out.write(frame) cap.release() out.release() return output_path except Exception as e: logger.error(f"Error adding watermark: {str(e)}") return video_path def generate_video(image, prompt): logger.info("Starting video generation with API") try: API_KEY = os.getenv("API_KEY", "").strip() if not API_KEY: return "API key not properly configured" temp_dir = "temp_videos" os.makedirs(temp_dir, exist_ok=True) image_url = None if image: image_url = upload_to_catbox(image) if not image_url: return "Failed to upload image" logger.info(f"Input image URL: {image_url}") generation_url = "https://api.minimaxi.chat/v1/video_generation" headers = { 'authorization': f'Bearer {API_KEY}', 'Content-Type': 'application/json' } payload = { "model": "video-01", "prompt": prompt if prompt else "", "prompt_optimizer": True } if image_url: payload["first_frame_image"] = image_url logger.info(f"Sending request with payload: {payload}") response = requests.post(generation_url, headers=headers, json=payload) if not response.ok: error_msg = f"Failed to create video generation task: {response.text}" logger.error(error_msg) return error_msg response_data = response.json() task_id = response_data.get('task_id') if not task_id: return "Failed to get task ID from response" query_url = "https://api.minimaxi.chat/v1/query/video_generation" max_attempts = 30 attempt = 0 while attempt < max_attempts: time.sleep(10) query_response = requests.get( f"{query_url}?task_id={task_id}", headers={'authorization': f'Bearer {API_KEY}'} ) if not query_response.ok: attempt += 1 continue status_data = query_response.json() status = status_data.get('status') if status == 'Success': file_id = status_data.get('file_id') if not file_id: return "Failed to get file ID" retrieve_url = "https://api.minimaxi.chat/v1/files/retrieve" params = {'file_id': file_id} file_response = requests.get( retrieve_url, headers={'authorization': f'Bearer {API_KEY}'}, params=params ) if not file_response.ok: return "Failed to retrieve video file" try: file_data = file_response.json() download_url = file_data.get('file', {}).get('download_url') if not download_url: return "Failed to get download URL" result_info = { "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"), "input_image": image_url, "output_video_url": download_url, "prompt": prompt } logger.info(f"Video generation result: {json.dumps(result_info, indent=2)}") video_response = requests.get(download_url) if not video_response.ok: return "Failed to download video" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_path = os.path.join(temp_dir, f"output_{timestamp}.mp4") with open(output_path, 'wb') as f: f.write(video_response.content) final_path = add_watermark(output_path) # 비디오 길이 확인 cap = cv2.VideoCapture(final_path) fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) video_duration = total_frames / fps cap.release() logger.info(f"Original video duration: {video_duration} seconds") # 오디오 처리 추가 try: logger.info("Starting audio generation process") final_path_with_audio = video_to_audio( final_path, prompt=prompt, negative_prompt="music", seed=-1, num_steps=20, cfg_strength=4.5 # target_duration 제거 - 자동으로 비디오 길이 사용 ) if final_path_with_audio != final_path: logger.info("Audio generation successful") try: if output_path != final_path: os.remove(output_path) if final_path != final_path_with_audio: os.remove(final_path) except Exception as e: logger.warning(f"Error cleaning up temporary files: {str(e)}") return final_path_with_audio else: logger.warning("Audio generation skipped, using original video") return final_path except Exception as e: logger.error(f"Error in audio processing: {str(e)}") return final_path # 오디오 처리 실패 시 워터마크만 된 비디오 반환 except Exception as e: logger.error(f"Error processing video file: {str(e)}") return "Error processing video file" elif status == 'Fail': return "Video generation failed" attempt += 1 return "Timeout waiting for video generation" except Exception as e: logger.error(f"Error in video generation: {str(e)}") return f"Error in video generation process: {str(e)}" css = """ footer { visibility: hidden; } .gradio-container {max-width: 1200px !important} """ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: with gr.Row(): with gr.Column(scale=3): video_prompt = gr.Textbox( label="Video Description", placeholder="Enter video description...", lines=3 ) upload_image = gr.Image(type="filepath", label="Upload First Frame Image") video_generate_btn = gr.Button("🎬 Generate Video") with gr.Column(scale=4): video_output = gr.Video(label="Generated Video") # process_and_generate_video 함수 수정 def process_and_generate_video(image, prompt): if image is None: return "Please upload an image" try: # 한글 프롬프트 감지 및 번역 contains_korean = any(ord('가') <= ord(char) <= ord('힣') for char in prompt) if contains_korean: translated = translator(prompt)[0]['translation_text'] logger.info(f"Translated prompt from '{prompt}' to '{translated}'") prompt = translated img = Image.open(image) if img.mode != 'RGB': img = img.convert('RGB') temp_path = f"temp_{int(time.time())}.png" img.save(temp_path, 'PNG') result = generate_video(temp_path, prompt) try: os.remove(temp_path) except: pass return result except Exception as e: logger.error(f"Error processing image: {str(e)}") return "Error processing image" video_generate_btn.click( process_and_generate_video, inputs=[upload_image, video_prompt], outputs=video_output ) if __name__ == "__main__": # GPU 초기화 확인 if torch.cuda.is_available(): logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}") else: logger.warning("GPU not available, using CPU") demo.launch(server_name="0.0.0.0", server_port=7860, share=False)