image_to_video / app.py
aiqcamp's picture
Update app.py
9ae8acd verified
raw
history blame
10.7 kB
import os
import time
from datetime import datetime
import gradio as gr
import torch
import logging
import requests
from pathlib import Path
import cv2
from PIL import Image
import json
import spaces
import torchaudio
import tempfile
try:
import mmaudio
except ImportError:
os.system("pip install -e .")
import mmaudio
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
setup_eval_logging)
from mmaudio.model.flow_matching import FlowMatching
from mmaudio.model.networks import MMAudio, get_my_mmaudio
from mmaudio.model.sequence_config import SequenceConfig
from mmaudio.model.utils.features_utils import FeaturesUtils
# ์˜ค๋””์˜ค ๋ชจ๋ธ ์„ค์ •
device = 'cuda'
dtype = torch.bfloat16
model: ModelConfig = all_model_cfg['large_44k_v2']
model.download_if_needed()
output_dir = Path('./output/gradio')
setup_eval_logging()
net, feature_utils, seq_cfg = get_model() # get_model ํ•จ์ˆ˜๋Š” ์ด์ „์— ์ œ๊ณต๋œ ์ฝ”๋“œ ์‚ฌ์šฉ
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# API ์„ค์ •
CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
REPLICATE_API_TOKEN = os.getenv("API_KEY")
def upload_to_catbox(file_path):
"""catbox.moe API๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํŒŒ์ผ ์—…๋กœ๋“œ"""
try:
logger.info(f"Preparing to upload file: {file_path}")
url = "https://catbox.moe/user/api.php"
mime_types = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.webp': 'image/webp',
'.jfif': 'image/jpeg'
}
file_extension = Path(file_path).suffix.lower()
if file_extension not in mime_types:
try:
img = Image.open(file_path)
if img.mode != 'RGB':
img = img.convert('RGB')
new_path = file_path.rsplit('.', 1)[0] + '.png'
img.save(new_path, 'PNG')
file_path = new_path
file_extension = '.png'
logger.info(f"Converted image to PNG: {file_path}")
except Exception as e:
logger.error(f"Failed to convert image: {str(e)}")
return None
files = {
'fileToUpload': (
os.path.basename(file_path),
open(file_path, 'rb'),
mime_types.get(file_extension, 'application/octet-stream')
)
}
data = {
'reqtype': 'fileupload',
'userhash': CATBOX_USER_HASH
}
response = requests.post(url, files=files, data=data)
if response.status_code == 200 and response.text.startswith('http'):
file_url = response.text
logger.info(f"File uploaded successfully: {file_url}")
return file_url
else:
raise Exception(f"Upload failed: {response.text}")
except Exception as e:
logger.error(f"File upload error: {str(e)}")
return None
finally:
if 'new_path' in locals() and os.path.exists(new_path):
try:
os.remove(new_path)
except:
pass
def add_watermark(video_path):
"""OpenCV๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋น„๋””์˜ค์— ์›Œํ„ฐ๋งˆํฌ ์ถ”๊ฐ€"""
try:
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
text = "GiniGEN.AI"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = height * 0.05 / 30
thickness = 2
color = (255, 255, 255)
(text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
margin = int(height * 0.02)
x_pos = width - text_width - margin
y_pos = height - margin
output_path = "watermarked_output.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness)
out.write(frame)
cap.release()
out.release()
return output_path
except Exception as e:
logger.error(f"Error adding watermark: {str(e)}")
return video_path
def generate_video(image, prompt):
logger.info("Starting video generation with API")
try:
API_KEY = os.getenv("API_KEY", "").strip()
if not API_KEY:
return "API key not properly configured"
temp_dir = "temp_videos"
os.makedirs(temp_dir, exist_ok=True)
image_url = None
if image:
image_url = upload_to_catbox(image)
if not image_url:
return "Failed to upload image"
logger.info(f"Input image URL: {image_url}")
generation_url = "https://api.minimaxi.chat/v1/video_generation"
headers = {
'authorization': f'Bearer {API_KEY}',
'Content-Type': 'application/json'
}
payload = {
"model": "video-01",
"prompt": prompt if prompt else "",
"prompt_optimizer": True
}
if image_url:
payload["first_frame_image"] = image_url
logger.info(f"Sending request with payload: {payload}")
response = requests.post(generation_url, headers=headers, json=payload)
if not response.ok:
error_msg = f"Failed to create video generation task: {response.text}"
logger.error(error_msg)
return error_msg
response_data = response.json()
task_id = response_data.get('task_id')
if not task_id:
return "Failed to get task ID from response"
query_url = "https://api.minimaxi.chat/v1/query/video_generation"
max_attempts = 30
attempt = 0
while attempt < max_attempts:
time.sleep(10)
query_response = requests.get(
f"{query_url}?task_id={task_id}",
headers={'authorization': f'Bearer {API_KEY}'}
)
if not query_response.ok:
attempt += 1
continue
status_data = query_response.json()
status = status_data.get('status')
if status == 'Success':
file_id = status_data.get('file_id')
if not file_id:
return "Failed to get file ID"
retrieve_url = "https://api.minimaxi.chat/v1/files/retrieve"
params = {'file_id': file_id}
file_response = requests.get(
retrieve_url,
headers={'authorization': f'Bearer {API_KEY}'},
params=params
)
if not file_response.ok:
return "Failed to retrieve video file"
try:
file_data = file_response.json()
download_url = file_data.get('file', {}).get('download_url')
if not download_url:
return "Failed to get download URL"
result_info = {
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
"input_image": image_url,
"output_video_url": download_url,
"prompt": prompt
}
logger.info(f"Video generation result: {json.dumps(result_info, indent=2)}")
video_response = requests.get(download_url)
if not video_response.ok:
return "Failed to download video"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = os.path.join(temp_dir, f"output_{timestamp}.mp4")
with open(output_path, 'wb') as f:
f.write(video_response.content)
final_path = add_watermark(output_path)
# ์˜ค๋””์˜ค ์ฒ˜๋ฆฌ ์ถ”๊ฐ€
try:
final_path_with_audio = video_to_audio(
final_path,
prompt=prompt,
negative_prompt="music",
seed=-1,
num_steps=25,
cfg_strength=4.5,
duration=8
)
# ์ž„์‹œ ํŒŒ์ผ ์ •๋ฆฌ
if output_path != final_path:
os.remove(output_path)
if final_path != final_path_with_audio:
os.remove(final_path)
return final_path_with_audio
except Exception as e:
logger.error(f"Error in audio processing: {str(e)}")
return final_path # ์˜ค๋””์˜ค ์ฒ˜๋ฆฌ ์‹คํŒจ ์‹œ ์›Œํ„ฐ๋งˆํฌ๋งŒ ๋œ ๋น„๋””์˜ค ๋ฐ˜ํ™˜
except Exception as e:
logger.error(f"Error processing video file: {str(e)}")
return "Error processing video file"
css = """
footer {display: none}
.gradio-container {max-width: 1200px !important}
"""
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
gr.HTML('<div style="text-align: center; font-size: 1.5em; margin: 10px 0;">๐ŸŽฅ Image to Video Generator</div>')
with gr.Row():
with gr.Column(scale=3):
video_prompt = gr.Textbox(
label="Video Description",
placeholder="Enter video description...",
lines=3
)
upload_image = gr.Image(type="filepath", label="Upload First Frame Image")
video_generate_btn = gr.Button("๐ŸŽฌ Generate Video")
with gr.Column(scale=4):
video_output = gr.Video(label="Generated Video")
def process_and_generate_video(image, prompt):
if image is None:
return "Please upload an image"
try:
img = Image.open(image)
if img.mode != 'RGB':
img = img.convert('RGB')
temp_path = f"temp_{int(time.time())}.png"
img.save(temp_path, 'PNG')
result = generate_video(temp_path, prompt)
try:
os.remove(temp_path)
except:
pass
return result
except Exception as e:
logger.error(f"Error processing image: {str(e)}")
return "Error processing image"
video_generate_btn.click(
process_and_generate_video,
inputs=[upload_image, video_prompt],
outputs=video_output
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)