import spaces import argparse import os import time from os import path import shutil from datetime import datetime from safetensors.torch import load_file from huggingface_hub import hf_hub_download import gradio as gr import torch from diffusers import FluxPipeline from diffusers.pipelines.stable_diffusion import safety_checker from PIL import Image from transformers import pipeline import replicate import logging import requests from pathlib import Path import cv2 import numpy as np import sys import io # 로깅 설정 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Setup and initialization code cache_path = path.join(path.dirname(path.abspath(__file__)), "models") PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".") # API 설정 CATBOX_USER_HASH = "e7a96fc68dd4c7d2954040cd5" REPLICATE_API_TOKEN = os.getenv("API_KEY") # 환경 변수 설정 os.environ["TRANSFORMERS_CACHE"] = cache_path os.environ["HF_HUB_CACHE"] = cache_path os.environ["HF_HOME"] = cache_path # CUDA 설정 torch.backends.cuda.matmul.allow_tf32 = True # 번역기 초기화 부분 수정 translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cuda" if torch.cuda.is_available() else "cpu") if not path.exists(cache_path): os.makedirs(cache_path, exist_ok=True) def check_api_key(): """API 키 확인 및 설정""" if not REPLICATE_API_TOKEN: logger.error("Replicate API key not found") return False os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_TOKEN logger.info("Replicate API token set successfully") return True def translate_if_korean(text): """한글이 포함된 경우 영어로 번역""" if any(ord(char) >= 0xAC00 and ord(char) <= 0xD7A3 for char in text): translation = translator(text)[0]['translation_text'] return translation return text def filter_prompt(prompt): inappropriate_keywords = [ "nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx", "erotic", "sensual", "seductive", "provocative", "intimate", "violence", "gore", "blood", "death", "kill", "murder", "torture", "drug", "suicide", "abuse", "hate", "discrimination" ] prompt_lower = prompt.lower() for keyword in inappropriate_keywords: if keyword in prompt_lower: return False, "부적절한 내용이 포함된 프롬프트입니다." return True, prompt def process_prompt(prompt): """프롬프트 전처리 (번역 및 필터링)""" translated_prompt = translate_if_korean(prompt) is_safe, filtered_prompt = filter_prompt(translated_prompt) return is_safe, filtered_prompt class timer: def __init__(self, method_name="timed process"): self.method = method_name def __enter__(self): self.start = time.time() print(f"{self.method} starts") def __exit__(self, exc_type, exc_val, exc_tb): end = time.time() print(f"{self.method} took {str(round(end - self.start, 2))}s") # Model initialization if not path.exists(cache_path): os.makedirs(cache_path, exist_ok=True) pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")) pipe.fuse_lora(lora_scale=0.125) pipe.to(device="cuda", dtype=torch.bfloat16) pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") def upload_to_catbox(image_path): """catbox.moe API를 사용하여 이미지 업로드""" try: logger.info(f"Preparing to upload image: {image_path}") url = "https://catbox.moe/user/api.php" file_extension = Path(image_path).suffix.lower() if file_extension not in ['.jpg', '.jpeg', '.png', '.gif']: logger.error(f"Unsupported file type: {file_extension}") return None files = { 'fileToUpload': ( os.path.basename(image_path), open(image_path, 'rb'), 'image/jpeg' if file_extension in ['.jpg', '.jpeg'] else 'image/png' ) } data = { 'reqtype': 'fileupload', 'userhash': CATBOX_USER_HASH } response = requests.post(url, files=files, data=data) if response.status_code == 200 and response.text.startswith('http'): image_url = response.text logger.info(f"Image uploaded successfully: {image_url}") return image_url else: raise Exception(f"Upload failed: {response.text}") except Exception as e: logger.error(f"Image upload error: {str(e)}") return None def add_watermark(video_path): """OpenCV를 사용하여 비디오에 워터마크 추가""" try: cap = cv2.VideoCapture(video_path) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) text = "GiniGEN.AI" font = cv2.FONT_HERSHEY_SIMPLEX font_scale = height * 0.05 / 30 thickness = 2 color = (255, 255, 255) (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness) margin = int(height * 0.02) x_pos = width - text_width - margin y_pos = height - margin output_path = "watermarked_output.mp4" fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) while cap.isOpened(): ret, frame = cap.read() if not ret: break cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness) out.write(frame) cap.release() out.release() return output_path except Exception as e: logger.error(f"Error adding watermark: {str(e)}") return video_path def generate_video(image, prompt): logger.info("Starting video generation") try: if not check_api_key(): return "Replicate API key not properly configured" if not image: logger.error("No image provided") return "Please upload an image" image_url = upload_to_catbox(image) if not image_url: return "Failed to upload image" input_data = { "prompt": prompt, "first_frame_image": image_url } try: replicate.Client(api_token=REPLICATE_API_TOKEN) output = replicate.run( "minimax/video-01-live", input=input_data ) temp_file = "temp_output.mp4" if hasattr(output, 'read'): with open(temp_file, "wb") as file: file.write(output.read()) elif isinstance(output, str): response = requests.get(output) with open(temp_file, "wb") as file: file.write(response.content) final_video = add_watermark(temp_file) return final_video except Exception as api_error: logger.error(f"API call failed: {str(api_error)}") return f"API call failed: {str(api_error)}" except Exception as e: logger.error(f"Unexpected error: {str(e)}") return f"Unexpected error: {str(e)}" def save_image(image): """Save the generated image temporarily""" try: # 임시 디렉토리에 저장 temp_dir = "temp" if not os.path.exists(temp_dir): os.makedirs(temp_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filepath = os.path.join(temp_dir, f"temp_{timestamp}.png") if not isinstance(image, Image.Image): image = Image.fromarray(image) if image.mode != 'RGB': image = image.convert('RGB') image.save(filepath, format='PNG', optimize=True, quality=100) return filepath except Exception as e: logger.error(f"Error in save_image: {str(e)}") return None css = """ footer {display: none} .gradio-container {max-width: 1200px !important} #gallery { margin: 20px auto; padding: 20px; } #gallery img { width: 300px !important; height: 300px !important; object-fit: cover; border-radius: 8px; } .gallery-item { margin: 0 !important; padding: 5px !important; } #video_player { margin: 20px auto; max-width: 800px; } .title { text-align: center; font-size: 1.5em; margin: 10px 0; } """ def get_random_seed(): return torch.randint(0, 1000000, (1,)).item() def create_thumbnail_gallery(): # 0부터 9까지의 이미지 리스트 생성 return [ "image/0.jpg", "image/1.jpg", "image/2.jpg", "image/3.jpg", "image/4.jpg", "image/5.jpg", "image/6.jpg", "image/7.jpg", "image/8.jpg", "image/9.jpg" ] def check_image_files(): current_dir = os.path.dirname(os.path.abspath(__file__)) missing_files = [] for i in range(10): # 0부터 9까지 확인 image_path = os.path.join(current_dir, f"image/{i}.jpg") video_path = os.path.join(current_dir, f"image/{i}.mp4") if not os.path.exists(image_path): missing_files.append(f"{i}.jpg") if not os.path.exists(video_path): missing_files.append(f"{i}.mp4") if missing_files: logger.error(f"Missing files: {', '.join(missing_files)}") return False return True def load_gallery_images(): gallery_images = [] current_dir = os.path.dirname(os.path.abspath(__file__)) try: for i in range(10): # 0부터 9까지 로드 image_path = os.path.join(current_dir, f"image/{i}.jpg") if os.path.exists(image_path): img = Image.open(image_path) gallery_images.append(img) else: logger.warning(f"Image not found: {image_path}") except Exception as e: logger.error(f"Error loading gallery images: {str(e)}") return gallery_images # UI 부분 수정 with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: gr.HTML('