Spaces:
Runtime error
Runtime error
import os | |
import time | |
from datetime import datetime | |
import gradio as gr | |
import torch | |
import logging | |
import requests | |
from pathlib import Path | |
import cv2 | |
from PIL import Image | |
import json | |
import spaces | |
import torchaudio | |
import tempfile | |
try: | |
import mmaudio | |
except ImportError: | |
os.system("pip install -e .") | |
import mmaudio | |
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video, | |
setup_eval_logging) | |
from mmaudio.model.flow_matching import FlowMatching | |
from mmaudio.model.networks import MMAudio, get_my_mmaudio | |
from mmaudio.model.sequence_config import SequenceConfig | |
from mmaudio.model.utils.features_utils import FeaturesUtils | |
# ์ค๋์ค ๋ชจ๋ธ ์ค์ | |
device = 'cuda' | |
dtype = torch.bfloat16 | |
model: ModelConfig = all_model_cfg['large_44k_v2'] | |
model.download_if_needed() | |
output_dir = Path('./output/gradio') | |
setup_eval_logging() | |
net, feature_utils, seq_cfg = get_model() # get_model ํจ์๋ ์ด์ ์ ์ ๊ณต๋ ์ฝ๋ ์ฌ์ฉ | |
# ๋ก๊น ์ค์ | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# API ์ค์ | |
CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489" | |
REPLICATE_API_TOKEN = os.getenv("API_KEY") | |
def upload_to_catbox(file_path): | |
"""catbox.moe API๋ฅผ ์ฌ์ฉํ์ฌ ํ์ผ ์ ๋ก๋""" | |
try: | |
logger.info(f"Preparing to upload file: {file_path}") | |
url = "https://catbox.moe/user/api.php" | |
mime_types = { | |
'.jpg': 'image/jpeg', | |
'.jpeg': 'image/jpeg', | |
'.png': 'image/png', | |
'.gif': 'image/gif', | |
'.webp': 'image/webp', | |
'.jfif': 'image/jpeg' | |
} | |
file_extension = Path(file_path).suffix.lower() | |
if file_extension not in mime_types: | |
try: | |
img = Image.open(file_path) | |
if img.mode != 'RGB': | |
img = img.convert('RGB') | |
new_path = file_path.rsplit('.', 1)[0] + '.png' | |
img.save(new_path, 'PNG') | |
file_path = new_path | |
file_extension = '.png' | |
logger.info(f"Converted image to PNG: {file_path}") | |
except Exception as e: | |
logger.error(f"Failed to convert image: {str(e)}") | |
return None | |
files = { | |
'fileToUpload': ( | |
os.path.basename(file_path), | |
open(file_path, 'rb'), | |
mime_types.get(file_extension, 'application/octet-stream') | |
) | |
} | |
data = { | |
'reqtype': 'fileupload', | |
'userhash': CATBOX_USER_HASH | |
} | |
response = requests.post(url, files=files, data=data) | |
if response.status_code == 200 and response.text.startswith('http'): | |
file_url = response.text | |
logger.info(f"File uploaded successfully: {file_url}") | |
return file_url | |
else: | |
raise Exception(f"Upload failed: {response.text}") | |
except Exception as e: | |
logger.error(f"File upload error: {str(e)}") | |
return None | |
finally: | |
if 'new_path' in locals() and os.path.exists(new_path): | |
try: | |
os.remove(new_path) | |
except: | |
pass | |
def add_watermark(video_path): | |
"""OpenCV๋ฅผ ์ฌ์ฉํ์ฌ ๋น๋์ค์ ์ํฐ๋งํฌ ์ถ๊ฐ""" | |
try: | |
cap = cv2.VideoCapture(video_path) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
text = "GiniGEN.AI" | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
font_scale = height * 0.05 / 30 | |
thickness = 2 | |
color = (255, 255, 255) | |
(text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness) | |
margin = int(height * 0.02) | |
x_pos = width - text_width - margin | |
y_pos = height - margin | |
output_path = "watermarked_output.mp4" | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness) | |
out.write(frame) | |
cap.release() | |
out.release() | |
return output_path | |
except Exception as e: | |
logger.error(f"Error adding watermark: {str(e)}") | |
return video_path | |
def generate_video(image, prompt): | |
logger.info("Starting video generation with API") | |
try: | |
API_KEY = os.getenv("API_KEY", "").strip() | |
if not API_KEY: | |
return "API key not properly configured" | |
temp_dir = "temp_videos" | |
os.makedirs(temp_dir, exist_ok=True) | |
image_url = None | |
if image: | |
image_url = upload_to_catbox(image) | |
if not image_url: | |
return "Failed to upload image" | |
logger.info(f"Input image URL: {image_url}") | |
generation_url = "https://api.minimaxi.chat/v1/video_generation" | |
headers = { | |
'authorization': f'Bearer {API_KEY}', | |
'Content-Type': 'application/json' | |
} | |
payload = { | |
"model": "video-01", | |
"prompt": prompt if prompt else "", | |
"prompt_optimizer": True | |
} | |
if image_url: | |
payload["first_frame_image"] = image_url | |
logger.info(f"Sending request with payload: {payload}") | |
response = requests.post(generation_url, headers=headers, json=payload) | |
if not response.ok: | |
error_msg = f"Failed to create video generation task: {response.text}" | |
logger.error(error_msg) | |
return error_msg | |
response_data = response.json() | |
task_id = response_data.get('task_id') | |
if not task_id: | |
return "Failed to get task ID from response" | |
query_url = "https://api.minimaxi.chat/v1/query/video_generation" | |
max_attempts = 30 | |
attempt = 0 | |
while attempt < max_attempts: | |
time.sleep(10) | |
query_response = requests.get( | |
f"{query_url}?task_id={task_id}", | |
headers={'authorization': f'Bearer {API_KEY}'} | |
) | |
if not query_response.ok: | |
attempt += 1 | |
continue | |
status_data = query_response.json() | |
status = status_data.get('status') | |
if status == 'Success': | |
file_id = status_data.get('file_id') | |
if not file_id: | |
return "Failed to get file ID" | |
retrieve_url = "https://api.minimaxi.chat/v1/files/retrieve" | |
params = {'file_id': file_id} | |
file_response = requests.get( | |
retrieve_url, | |
headers={'authorization': f'Bearer {API_KEY}'}, | |
params=params | |
) | |
if not file_response.ok: | |
return "Failed to retrieve video file" | |
try: | |
file_data = file_response.json() | |
download_url = file_data.get('file', {}).get('download_url') | |
if not download_url: | |
return "Failed to get download URL" | |
result_info = { | |
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"), | |
"input_image": image_url, | |
"output_video_url": download_url, | |
"prompt": prompt | |
} | |
logger.info(f"Video generation result: {json.dumps(result_info, indent=2)}") | |
video_response = requests.get(download_url) | |
if not video_response.ok: | |
return "Failed to download video" | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
output_path = os.path.join(temp_dir, f"output_{timestamp}.mp4") | |
with open(output_path, 'wb') as f: | |
f.write(video_response.content) | |
final_path = add_watermark(output_path) | |
# ์ค๋์ค ์ฒ๋ฆฌ ์ถ๊ฐ | |
try: | |
final_path_with_audio = video_to_audio( | |
final_path, | |
prompt=prompt, | |
negative_prompt="music", | |
seed=-1, | |
num_steps=25, | |
cfg_strength=4.5, | |
duration=8 | |
) | |
# ์์ ํ์ผ ์ ๋ฆฌ | |
if output_path != final_path: | |
os.remove(output_path) | |
if final_path != final_path_with_audio: | |
os.remove(final_path) | |
return final_path_with_audio | |
except Exception as e: | |
logger.error(f"Error in audio processing: {str(e)}") | |
return final_path # ์ค๋์ค ์ฒ๋ฆฌ ์คํจ ์ ์ํฐ๋งํฌ๋ง ๋ ๋น๋์ค ๋ฐํ | |
except Exception as e: | |
logger.error(f"Error processing video file: {str(e)}") | |
return "Error processing video file" | |
css = """ | |
footer {display: none} | |
.gradio-container {max-width: 1200px !important} | |
""" | |
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: | |
gr.HTML('<div style="text-align: center; font-size: 1.5em; margin: 10px 0;">๐ฅ Image to Video Generator</div>') | |
with gr.Row(): | |
with gr.Column(scale=3): | |
video_prompt = gr.Textbox( | |
label="Video Description", | |
placeholder="Enter video description...", | |
lines=3 | |
) | |
upload_image = gr.Image(type="filepath", label="Upload First Frame Image") | |
video_generate_btn = gr.Button("๐ฌ Generate Video") | |
with gr.Column(scale=4): | |
video_output = gr.Video(label="Generated Video") | |
def process_and_generate_video(image, prompt): | |
if image is None: | |
return "Please upload an image" | |
try: | |
img = Image.open(image) | |
if img.mode != 'RGB': | |
img = img.convert('RGB') | |
temp_path = f"temp_{int(time.time())}.png" | |
img.save(temp_path, 'PNG') | |
result = generate_video(temp_path, prompt) | |
try: | |
os.remove(temp_path) | |
except: | |
pass | |
return result | |
except Exception as e: | |
logger.error(f"Error processing image: {str(e)}") | |
return "Error processing image" | |
video_generate_btn.click( | |
process_and_generate_video, | |
inputs=[upload_image, video_prompt], | |
outputs=video_output | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |