import spaces import argparse import os import time from os import path import shutil from datetime import datetime from safetensors.torch import load_file from huggingface_hub import hf_hub_download import gradio as gr import torch from diffusers import FluxPipeline from diffusers.pipelines.stable_diffusion import safety_checker from PIL import Image from transformers import pipeline import replicate import logging import requests from pathlib import Path import cv2 import numpy as np import sys import io # 로깅 설정 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Setup and initialization code cache_path = path.join(path.dirname(path.abspath(__file__)), "models") PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".") # API 설정 CATBOX_USER_HASH = "e7a96fc68dd4c7d2954040cd5" REPLICATE_API_TOKEN = os.getenv("API_KEY") # 환경 변수 설정 os.environ["TRANSFORMERS_CACHE"] = cache_path os.environ["HF_HUB_CACHE"] = cache_path os.environ["HF_HOME"] = cache_path # CUDA 설정 torch.backends.cuda.matmul.allow_tf32 = True # 번역기 초기화 부분 수정 translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cuda" if torch.cuda.is_available() else "cpu") if not path.exists(cache_path): os.makedirs(cache_path, exist_ok=True) def check_api_key(): """API 키 확인 및 설정""" if not REPLICATE_API_TOKEN: logger.error("Replicate API key not found") return False os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_TOKEN logger.info("Replicate API token set successfully") return True def translate_if_korean(text): """한글이 포함된 경우 영어로 번역""" if any(ord(char) >= 0xAC00 and ord(char) <= 0xD7A3 for char in text): translation = translator(text)[0]['translation_text'] return translation return text def filter_prompt(prompt): inappropriate_keywords = [ "nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx", "erotic", "sensual", "seductive", "provocative", "intimate", "violence", "gore", "blood", "death", "kill", "murder", "torture", "drug", "suicide", "abuse", "hate", "discrimination" ] prompt_lower = prompt.lower() for keyword in inappropriate_keywords: if keyword in prompt_lower: return False, "부적절한 내용이 포함된 프롬프트입니다." return True, prompt def process_prompt(prompt): """프롬프트 전처리 (번역 및 필터링)""" translated_prompt = translate_if_korean(prompt) is_safe, filtered_prompt = filter_prompt(translated_prompt) return is_safe, filtered_prompt class timer: def __init__(self, method_name="timed process"): self.method = method_name def __enter__(self): self.start = time.time() print(f"{self.method} starts") def __exit__(self, exc_type, exc_val, exc_tb): end = time.time() print(f"{self.method} took {str(round(end - self.start, 2))}s") # Model initialization if not path.exists(cache_path): os.makedirs(cache_path, exist_ok=True) pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")) pipe.fuse_lora(lora_scale=0.125) pipe.to(device="cuda", dtype=torch.bfloat16) pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") def upload_to_catbox(image_path): """catbox.moe API를 사용하여 이미지 업로드""" try: logger.info(f"Preparing to upload image: {image_path}") url = "https://catbox.moe/user/api.php" file_extension = Path(image_path).suffix.lower() if file_extension not in ['.jpg', '.jpeg', '.png', '.gif']: logger.error(f"Unsupported file type: {file_extension}") return None files = { 'fileToUpload': ( os.path.basename(image_path), open(image_path, 'rb'), 'image/jpeg' if file_extension in ['.jpg', '.jpeg'] else 'image/png' ) } data = { 'reqtype': 'fileupload', 'userhash': CATBOX_USER_HASH } response = requests.post(url, files=files, data=data) if response.status_code == 200 and response.text.startswith('http'): image_url = response.text logger.info(f"Image uploaded successfully: {image_url}") return image_url else: raise Exception(f"Upload failed: {response.text}") except Exception as e: logger.error(f"Image upload error: {str(e)}") return None def add_watermark(video_path): """OpenCV를 사용하여 비디오에 워터마크 추가""" try: cap = cv2.VideoCapture(video_path) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) text = "GiniGEN.AI" font = cv2.FONT_HERSHEY_SIMPLEX font_scale = height * 0.05 / 30 thickness = 2 color = (255, 255, 255) (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness) margin = int(height * 0.02) x_pos = width - text_width - margin y_pos = height - margin output_path = "watermarked_output.mp4" fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) while cap.isOpened(): ret, frame = cap.read() if not ret: break cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness) out.write(frame) cap.release() out.release() return output_path except Exception as e: logger.error(f"Error adding watermark: {str(e)}") return video_path def generate_video(image, prompt): logger.info("Starting video generation") try: if not check_api_key(): return "Replicate API key not properly configured" if not image: logger.error("No image provided") return "Please upload an image" image_url = upload_to_catbox(image) if not image_url: return "Failed to upload image" input_data = { "prompt": prompt, "first_frame_image": image_url } try: replicate.Client(api_token=REPLICATE_API_TOKEN) output = replicate.run( "minimax/video-01-live", input=input_data ) temp_file = "temp_output.mp4" if hasattr(output, 'read'): with open(temp_file, "wb") as file: file.write(output.read()) elif isinstance(output, str): response = requests.get(output) with open(temp_file, "wb") as file: file.write(response.content) final_video = add_watermark(temp_file) return final_video except Exception as api_error: logger.error(f"API call failed: {str(api_error)}") return f"API call failed: {str(api_error)}" except Exception as e: logger.error(f"Unexpected error: {str(e)}") return f"Unexpected error: {str(e)}" def save_image(image): """Save the generated image temporarily""" try: # 임시 디렉토리에 저장 temp_dir = "temp" if not os.path.exists(temp_dir): os.makedirs(temp_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filepath = os.path.join(temp_dir, f"temp_{timestamp}.png") if not isinstance(image, Image.Image): image = Image.fromarray(image) if image.mode != 'RGB': image = image.convert('RGB') image.save(filepath, format='PNG', optimize=True, quality=100) return filepath except Exception as e: logger.error(f"Error in save_image: {str(e)}") return None css = """ footer {display: none} .gradio-container {max-width: 1200px !important} #gallery { margin: 20px auto; padding: 20px; } #gallery img { width: 300px !important; height: 300px !important; object-fit: cover; border-radius: 8px; } .gallery-item { margin: 0 !important; padding: 5px !important; } #video_player { margin: 20px auto; max-width: 800px; } .title { text-align: center; font-size: 1.5em; margin: 10px 0; } """ def get_random_seed(): return torch.randint(0, 1000000, (1,)).item() def create_thumbnail_gallery(): # 0부터 9까지의 이미지 리스트 생성 return [ "image/0.jpg", "image/1.jpg", "image/2.jpg", "image/3.jpg", "image/4.jpg", "image/5.jpg", "image/6.jpg", "image/7.jpg", "image/8.jpg", "image/9.jpg" ] def check_image_files(): current_dir = os.path.dirname(os.path.abspath(__file__)) missing_files = [] for i in range(10): # 0부터 9까지 확인 image_path = os.path.join(current_dir, f"image/{i}.jpg") video_path = os.path.join(current_dir, f"image/{i}.mp4") if not os.path.exists(image_path): missing_files.append(f"{i}.jpg") if not os.path.exists(video_path): missing_files.append(f"{i}.mp4") if missing_files: logger.error(f"Missing files: {', '.join(missing_files)}") return False return True def load_gallery_images(): gallery_images = [] current_dir = os.path.dirname(os.path.abspath(__file__)) try: for i in range(10): # 0부터 9까지 로드 image_path = os.path.join(current_dir, f"image/{i}.jpg") if os.path.exists(image_path): img = Image.open(image_path) gallery_images.append(img) else: logger.warning(f"Image not found: {image_path}") except Exception as e: logger.error(f"Error loading gallery images: {str(e)}") return gallery_images # UI 부분 수정 with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: gr.HTML('
🎥 Dokdo✨ Digital Odyssey from Korea, Designing Original
') gr.HTML('
😄 Enjoy the amazing free video creation and enhancement services!
') with gr.Tabs(): # 첫 번째 탭: Example Gallery with gr.Tab("Example Gallery"): with gr.Row(): gallery = gr.Gallery( value=create_thumbnail_gallery(), columns=[5], # 한 줄에 5개씩 표시 rows=[2], # 2줄로 표시 height="auto", show_label=False, elem_id="gallery" ) with gr.Row(): video_player = gr.Video( label="Selected Video", elem_id="video_player", interactive=False, autoplay=True ) # 두 번째 탭: Image Generation with gr.Tab("Image Generation & Enhanced"): with gr.Row(): with gr.Column(scale=3): img_prompt = gr.Textbox( label="Image Description", placeholder="이미지 설명을 입력하세요... (한글 입력 가능)", lines=3 ) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=1024) width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=1024) with gr.Row(): steps = gr.Slider(label="Inference Steps", minimum=6, maximum=25, step=1, value=8) scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=3.5) seed = gr.Number(label="Seed", value=get_random_seed(), precision=0) randomize_seed = gr.Button("🎲 Randomize Seed", elem_classes=["generate-btn"]) generate_btn = gr.Button("✨ Generate Image", elem_classes=["generate-btn"]) with gr.Column(scale=4): img_output = gr.Image(label="Generated Image", type="pil", format="png") # 세 번째 탭: Video Generation with gr.Tab("Amazing Video Generation"): with gr.Row(): with gr.Column(scale=3): video_prompt = gr.Textbox( label="Video Description", placeholder="비디오 설명을 입력하세요... (한글 입력 가능)", lines=3 ) upload_image = gr.Image(type="filepath", label="Upload First Frame Image") video_generate_btn = gr.Button("🎬 Generate Video", elem_classes=["generate-btn"]) with gr.Column(scale=4): video_output = gr.Video(label="Generated Video") @spaces.GPU def process_and_save_image(height, width, steps, scales, prompt, seed): is_safe, translated_prompt = process_prompt(prompt) if not is_safe: gr.Warning("부적절한 내용이 포함된 프롬프트입니다.") return None with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"): try: generated_image = pipe( prompt=[translated_prompt], generator=torch.Generator().manual_seed(int(seed)), num_inference_steps=int(steps), guidance_scale=float(scales), height=int(height), width=int(width), max_sequence_length=256 ).images[0] if not isinstance(generated_image, Image.Image): generated_image = Image.fromarray(generated_image) if generated_image.mode != 'RGB': generated_image = generated_image.convert('RGB') img_byte_arr = io.BytesIO() generated_image.save(img_byte_arr, format='PNG') return Image.open(io.BytesIO(img_byte_arr.getvalue())) except Exception as e: logger.error(f"Error in image generation: {str(e)}") return None def process_and_generate_video(image, prompt): is_safe, translated_prompt = process_prompt(prompt) if not is_safe: gr.Warning("부적절한 내용이 포함된 프롬프트입니다.") return None return generate_video(image, translated_prompt) def update_seed(): return get_random_seed() # 이벤트 핸들러 수정 def show_video(evt: gr.SelectData): video_num = evt.index # 0부터 시작하는 인덱스 video_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f"image/{video_num}.mp4") if os.path.exists(video_path): return video_path return None # 이벤트 연결 gallery.select(fn=show_video, outputs=video_player) generate_btn.click( process_and_save_image, inputs=[height, width, steps, scales, img_prompt, seed], outputs=img_output ) video_generate_btn.click( process_and_generate_video, inputs=[upload_image, video_prompt], outputs=video_output ) randomize_seed.click(update_seed, outputs=[seed]) generate_btn.click(update_seed, outputs=[seed]) if __name__ == "__main__": # 이미지와 비디오 파일 존재 확인 if not check_image_files(): print("Error: Required image and video files (0.jpg through 9.jpg and 0.mp4 through 9.mp4) are missing!") sys.exit(1) demo.launch( server_name="0.0.0.0", server_port=7860, share=False )