image_to_video / app.py
aiqcamp's picture
Update app.py
f6c9d00 verified
raw
history blame
18.2 kB
# 1. ๋จผ์ € ๋กœ๊น… ์„ค์ •
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 2. ๋‚˜๋จธ์ง€ imports
import os
import time
from datetime import datetime
import gradio as gr
# GPU ์ดˆ๊ธฐํ™” ์„ค์ •
import torch
if torch.cuda.is_available():
torch.cuda.init()
device = torch.device('cuda')
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
device = torch.device('cpu')
logger.warning("GPU not available, using CPU")
import requests
from pathlib import Path
import cv2
from PIL import Image
import json
import spaces
import torchaudio
import tempfile
try:
import mmaudio
except ImportError:
os.system("pip install -e .")
import mmaudio
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
setup_eval_logging)
from mmaudio.model.flow_matching import FlowMatching
from mmaudio.model.networks import MMAudio, get_my_mmaudio
from mmaudio.model.sequence_config import SequenceConfig
from mmaudio.model.utils.features_utils import FeaturesUtils
# ์ƒ๋‹จ์— ๋ฒˆ์—ญ ๋ชจ๋ธ import ์ถ”๊ฐ€
from transformers import pipeline
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
# 3. API ์„ค์ •
CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
REPLICATE_API_TOKEN = os.getenv("API_KEY")
# 4. ์˜ค๋””์˜ค ๋ชจ๋ธ ์„ค์ •
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# 5. get_model ํ•จ์ˆ˜ ์ •์˜
def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
seq_cfg = model.seq_cfg
net: MMAudio = get_my_mmaudio(model.model_name).to(device)
if torch.cuda.is_available():
net = net.to(dtype)
net.eval()
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
logger.info(f'Loaded weights from {model.model_path}')
feature_utils = FeaturesUtils(
tod_vae_ckpt=model.vae_path,
synchformer_ckpt=model.synchformer_ckpt,
enable_conditions=True,
mode=model.mode,
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
need_vae_encoder=False
).to(device)
if torch.cuda.is_available():
feature_utils = feature_utils.to(dtype)
feature_utils.eval()
return net, feature_utils, seq_cfg
# 6. ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
model: ModelConfig = all_model_cfg['large_44k_v2']
model.download_if_needed()
output_dir = Path('./output/gradio')
setup_eval_logging()
net, feature_utils, seq_cfg = get_model()
@spaces.GPU(duration=30)
@torch.inference_mode()
def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
seed: int = -1, num_steps: int = 15,
cfg_strength: float = 4.0, target_duration: float = None):
try:
logger.info("Starting audio generation process")
if torch.cuda.is_available():
torch.cuda.empty_cache()
# ๋น„๋””์˜ค ๊ธธ์ด ํ™•์ธ
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
video_duration = total_frames / fps
cap.release()
# ์‹ค์ œ ๋น„๋””์˜ค ๊ธธ์ด๋ฅผ target_duration์œผ๋กœ ์‚ฌ์šฉ
target_duration = video_duration
logger.info(f"Video duration: {target_duration} seconds")
rng = torch.Generator(device=device)
if seed >= 0:
rng.manual_seed(seed)
else:
rng.seed()
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
# ๋น„๋””์˜ค ๊ธธ์ด์— ๋งž์ถฐ load_video ํ˜ธ์ถœ
video_info = load_video(video_path, duration_sec=target_duration)
if video_info is None:
logger.error("Failed to load video")
return video_path
clip_frames = video_info.clip_frames
sync_frames = video_info.sync_frames
actual_duration = video_info.duration_sec
if clip_frames is None or sync_frames is None:
logger.error("Failed to extract frames from video")
return video_path
# ์‹ค์ œ ๋น„๋””์˜ค ํ”„๋ ˆ์ž„ ์ˆ˜์— ๋งž์ถฐ ์กฐ์ •
clip_frames = clip_frames[:int(actual_duration * video_info.fps)]
sync_frames = sync_frames[:int(actual_duration * video_info.fps)]
clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16)
sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16)
# sequence config ์—…๋ฐ์ดํŠธ
seq_cfg.duration = actual_duration
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
logger.info(f"Generating audio for {actual_duration} seconds...")
logger.info("Generating audio...")
with torch.cuda.amp.autocast():
audios = generate(clip_frames,
sync_frames,
[prompt],
negative_text=[negative_prompt],
feature_utils=feature_utils,
net=net,
fm=fm,
rng=rng,
cfg_strength=cfg_strength)
if audios is None:
logger.error("Failed to generate audio")
return video_path
audio = audios.float().cpu()[0]
output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
logger.info(f"Creating final video with audio at {output_path}")
make_video(video_info, output_path, audio, sampling_rate=seq_cfg.sampling_rate)
torch.cuda.empty_cache()
if not os.path.exists(output_path):
logger.error("Failed to create output video")
return video_path
logger.info(f'Successfully saved video with audio to {output_path}')
return output_path
except Exception as e:
logger.error(f"Error in video_to_audio: {str(e)}")
torch.cuda.empty_cache()
return video_path
def upload_to_catbox(file_path):
"""catbox.moe API๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํŒŒ์ผ ์—…๋กœ๋“œ"""
try:
logger.info(f"Preparing to upload file: {file_path}")
url = "https://catbox.moe/user/api.php"
mime_types = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.webp': 'image/webp',
'.jfif': 'image/jpeg'
}
file_extension = Path(file_path).suffix.lower()
if file_extension not in mime_types:
try:
img = Image.open(file_path)
if img.mode != 'RGB':
img = img.convert('RGB')
new_path = file_path.rsplit('.', 1)[0] + '.png'
img.save(new_path, 'PNG')
file_path = new_path
file_extension = '.png'
logger.info(f"Converted image to PNG: {file_path}")
except Exception as e:
logger.error(f"Failed to convert image: {str(e)}")
return None
files = {
'fileToUpload': (
os.path.basename(file_path),
open(file_path, 'rb'),
mime_types.get(file_extension, 'application/octet-stream')
)
}
data = {
'reqtype': 'fileupload',
'userhash': CATBOX_USER_HASH
}
response = requests.post(url, files=files, data=data)
if response.status_code == 200 and response.text.startswith('http'):
file_url = response.text
logger.info(f"File uploaded successfully: {file_url}")
return file_url
else:
raise Exception(f"Upload failed: {response.text}")
except Exception as e:
logger.error(f"File upload error: {str(e)}")
return None
finally:
if 'new_path' in locals() and os.path.exists(new_path):
try:
os.remove(new_path)
except:
pass
def add_watermark(video_path):
"""OpenCV๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋น„๋””์˜ค์— ์›Œํ„ฐ๋งˆํฌ ์ถ”๊ฐ€"""
try:
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
text = "GiniGEN.AI"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = height * 0.05 / 30
thickness = 2
color = (255, 255, 255)
(text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
margin = int(height * 0.02)
x_pos = width - text_width - margin
y_pos = height - margin
output_path = "watermarked_output.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness)
out.write(frame)
cap.release()
out.release()
return output_path
except Exception as e:
logger.error(f"Error adding watermark: {str(e)}")
return video_path
def generate_video(image, prompt):
logger.info("Starting video generation with API")
try:
API_KEY = os.getenv("API_KEY", "").strip()
if not API_KEY:
return "API key not properly configured"
temp_dir = "temp_videos"
os.makedirs(temp_dir, exist_ok=True)
image_url = None
if image:
image_url = upload_to_catbox(image)
if not image_url:
return "Failed to upload image"
logger.info(f"Input image URL: {image_url}")
generation_url = "https://api.minimaxi.chat/v1/video_generation"
headers = {
'authorization': f'Bearer {API_KEY}',
'Content-Type': 'application/json'
}
payload = {
"model": "video-01",
"prompt": prompt if prompt else "",
"prompt_optimizer": True
}
if image_url:
payload["first_frame_image"] = image_url
logger.info(f"Sending request with payload: {payload}")
response = requests.post(generation_url, headers=headers, json=payload)
if not response.ok:
error_msg = f"Failed to create video generation task: {response.text}"
logger.error(error_msg)
return error_msg
response_data = response.json()
task_id = response_data.get('task_id')
if not task_id:
return "Failed to get task ID from response"
query_url = "https://api.minimaxi.chat/v1/query/video_generation"
max_attempts = 30
attempt = 0
while attempt < max_attempts:
time.sleep(10)
query_response = requests.get(
f"{query_url}?task_id={task_id}",
headers={'authorization': f'Bearer {API_KEY}'}
)
if not query_response.ok:
attempt += 1
continue
status_data = query_response.json()
status = status_data.get('status')
if status == 'Success':
file_id = status_data.get('file_id')
if not file_id:
return "Failed to get file ID"
retrieve_url = "https://api.minimaxi.chat/v1/files/retrieve"
params = {'file_id': file_id}
file_response = requests.get(
retrieve_url,
headers={'authorization': f'Bearer {API_KEY}'},
params=params
)
if not file_response.ok:
return "Failed to retrieve video file"
try:
file_data = file_response.json()
download_url = file_data.get('file', {}).get('download_url')
if not download_url:
return "Failed to get download URL"
result_info = {
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
"input_image": image_url,
"output_video_url": download_url,
"prompt": prompt
}
logger.info(f"Video generation result: {json.dumps(result_info, indent=2)}")
video_response = requests.get(download_url)
if not video_response.ok:
return "Failed to download video"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = os.path.join(temp_dir, f"output_{timestamp}.mp4")
with open(output_path, 'wb') as f:
f.write(video_response.content)
final_path = add_watermark(output_path)
# ๋น„๋””์˜ค ๊ธธ์ด ํ™•์ธ
cap = cv2.VideoCapture(final_path)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
video_duration = total_frames / fps
cap.release()
logger.info(f"Original video duration: {video_duration} seconds")
# ์˜ค๋””์˜ค ์ฒ˜๋ฆฌ ์ถ”๊ฐ€
try:
logger.info("Starting audio generation process")
final_path_with_audio = video_to_audio(
final_path,
prompt=prompt,
negative_prompt="music",
seed=-1,
num_steps=20,
cfg_strength=4.5
# target_duration ์ œ๊ฑฐ - ์ž๋™์œผ๋กœ ๋น„๋””์˜ค ๊ธธ์ด ์‚ฌ์šฉ
)
if final_path_with_audio != final_path:
logger.info("Audio generation successful")
try:
if output_path != final_path:
os.remove(output_path)
if final_path != final_path_with_audio:
os.remove(final_path)
except Exception as e:
logger.warning(f"Error cleaning up temporary files: {str(e)}")
return final_path_with_audio
else:
logger.warning("Audio generation skipped, using original video")
return final_path
except Exception as e:
logger.error(f"Error in audio processing: {str(e)}")
return final_path # ์˜ค๋””์˜ค ์ฒ˜๋ฆฌ ์‹คํŒจ ์‹œ ์›Œํ„ฐ๋งˆํฌ๋งŒ ๋œ ๋น„๋””์˜ค ๋ฐ˜ํ™˜
except Exception as e:
logger.error(f"Error processing video file: {str(e)}")
return "Error processing video file"
elif status == 'Fail':
return "Video generation failed"
attempt += 1
return "Timeout waiting for video generation"
except Exception as e:
logger.error(f"Error in video generation: {str(e)}")
return f"Error in video generation process: {str(e)}"
css = """
footer {
visibility: hidden;
}
.gradio-container {max-width: 1200px !important}
"""
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
with gr.Row():
with gr.Column(scale=3):
video_prompt = gr.Textbox(
label="Video Description",
placeholder="Enter video description...",
lines=3
)
upload_image = gr.Image(type="filepath", label="Upload First Frame Image")
video_generate_btn = gr.Button("๐ŸŽฌ Generate Video")
with gr.Column(scale=4):
video_output = gr.Video(label="Generated Video")
# process_and_generate_video ํ•จ์ˆ˜ ์ˆ˜์ •
def process_and_generate_video(image, prompt):
if image is None:
return "Please upload an image"
try:
# ํ•œ๊ธ€ ํ”„๋กฌํ”„ํŠธ ๊ฐ์ง€ ๋ฐ ๋ฒˆ์—ญ
contains_korean = any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in prompt)
if contains_korean:
translated = translator(prompt)[0]['translation_text']
logger.info(f"Translated prompt from '{prompt}' to '{translated}'")
prompt = translated
img = Image.open(image)
if img.mode != 'RGB':
img = img.convert('RGB')
temp_path = f"temp_{int(time.time())}.png"
img.save(temp_path, 'PNG')
result = generate_video(temp_path, prompt)
try:
os.remove(temp_path)
except:
pass
return result
except Exception as e:
logger.error(f"Error processing image: {str(e)}")
return "Error processing image"
video_generate_btn.click(
process_and_generate_video,
inputs=[upload_image, video_prompt],
outputs=video_output
)
if __name__ == "__main__":
# GPU ์ดˆ๊ธฐํ™” ํ™•์ธ
if torch.cuda.is_available():
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
logger.warning("GPU not available, using CPU")
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)