Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import argparse | |
import os | |
import time | |
from os import path | |
import shutil | |
from datetime import datetime | |
from safetensors.torch import load_file | |
from huggingface_hub import hf_hub_download | |
import gradio as gr | |
import torch | |
from diffusers import FluxPipeline | |
from diffusers.pipelines.stable_diffusion import safety_checker | |
from PIL import Image | |
from transformers import pipeline | |
import replicate | |
import logging | |
import requests | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
import sys | |
import io | |
# 로깅 설정 | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Setup and initialization code | |
cache_path = path.join(path.dirname(path.abspath(__file__)), "models") | |
PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".") | |
# API 설정 | |
CATBOX_USER_HASH = "e7a96fc68dd4c7d2954040cd5" | |
REPLICATE_API_TOKEN = os.getenv("API_KEY") | |
# 환경 변수 설정 | |
os.environ["TRANSFORMERS_CACHE"] = cache_path | |
os.environ["HF_HUB_CACHE"] = cache_path | |
os.environ["HF_HOME"] = cache_path | |
# CUDA 설정 | |
torch.backends.cuda.matmul.allow_tf32 = True | |
# 번역기 초기화 | |
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") | |
if not path.exists(cache_path): | |
os.makedirs(cache_path, exist_ok=True) | |
def check_api_key(): | |
"""API 키 확인 및 설정""" | |
if not REPLICATE_API_TOKEN: | |
logger.error("Replicate API key not found") | |
return False | |
os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_TOKEN | |
logger.info("Replicate API token set successfully") | |
return True | |
def translate_if_korean(text): | |
"""한글이 포함된 경우 영어로 번역""" | |
if any(ord(char) >= 0xAC00 and ord(char) <= 0xD7A3 for char in text): | |
translation = translator(text)[0]['translation_text'] | |
return translation | |
return text | |
def filter_prompt(prompt): | |
inappropriate_keywords = [ | |
"nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx", | |
"erotic", "sensual", "seductive", "provocative", "intimate", | |
"violence", "gore", "blood", "death", "kill", "murder", "torture", | |
"drug", "suicide", "abuse", "hate", "discrimination" | |
] | |
prompt_lower = prompt.lower() | |
for keyword in inappropriate_keywords: | |
if keyword in prompt_lower: | |
return False, "부적절한 내용이 포함된 프롬프트입니다." | |
return True, prompt | |
def process_prompt(prompt): | |
"""프롬프트 전처리 (번역 및 필터링)""" | |
translated_prompt = translate_if_korean(prompt) | |
is_safe, filtered_prompt = filter_prompt(translated_prompt) | |
return is_safe, filtered_prompt | |
class timer: | |
def __init__(self, method_name="timed process"): | |
self.method = method_name | |
def __enter__(self): | |
self.start = time.time() | |
print(f"{self.method} starts") | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
end = time.time() | |
print(f"{self.method} took {str(round(end - self.start, 2))}s") | |
# Model initialization | |
if not path.exists(cache_path): | |
os.makedirs(cache_path, exist_ok=True) | |
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) | |
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")) | |
pipe.fuse_lora(lora_scale=0.125) | |
pipe.to(device="cuda", dtype=torch.bfloat16) | |
pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") | |
def upload_to_catbox(image_path): | |
"""catbox.moe API를 사용하여 이미지 업로드""" | |
try: | |
logger.info(f"Preparing to upload image: {image_path}") | |
url = "https://catbox.moe/user/api.php" | |
file_extension = Path(image_path).suffix.lower() | |
if file_extension not in ['.jpg', '.jpeg', '.png', '.gif']: | |
logger.error(f"Unsupported file type: {file_extension}") | |
return None | |
files = { | |
'fileToUpload': ( | |
os.path.basename(image_path), | |
open(image_path, 'rb'), | |
'image/jpeg' if file_extension in ['.jpg', '.jpeg'] else 'image/png' | |
) | |
} | |
data = { | |
'reqtype': 'fileupload', | |
'userhash': CATBOX_USER_HASH | |
} | |
response = requests.post(url, files=files, data=data) | |
if response.status_code == 200 and response.text.startswith('http'): | |
image_url = response.text | |
logger.info(f"Image uploaded successfully: {image_url}") | |
return image_url | |
else: | |
raise Exception(f"Upload failed: {response.text}") | |
except Exception as e: | |
logger.error(f"Image upload error: {str(e)}") | |
return None | |
def add_watermark(video_path): | |
"""OpenCV를 사용하여 비디오에 워터마크 추가""" | |
try: | |
cap = cv2.VideoCapture(video_path) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
text = "GiniGEN.AI" | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
font_scale = height * 0.05 / 30 | |
thickness = 2 | |
color = (255, 255, 255) | |
(text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness) | |
margin = int(height * 0.02) | |
x_pos = width - text_width - margin | |
y_pos = height - margin | |
output_path = "watermarked_output.mp4" | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness) | |
out.write(frame) | |
cap.release() | |
out.release() | |
return output_path | |
except Exception as e: | |
logger.error(f"Error adding watermark: {str(e)}") | |
return video_path | |
def generate_video(image, prompt): | |
logger.info("Starting video generation") | |
try: | |
if not check_api_key(): | |
return "Replicate API key not properly configured" | |
if not image: | |
logger.error("No image provided") | |
return "Please upload an image" | |
image_url = upload_to_catbox(image) | |
if not image_url: | |
return "Failed to upload image" | |
input_data = { | |
"prompt": prompt, | |
"first_frame_image": image_url | |
} | |
try: | |
replicate.Client(api_token=REPLICATE_API_TOKEN) | |
output = replicate.run( | |
"minimax/video-01-live", | |
input=input_data | |
) | |
temp_file = "temp_output.mp4" | |
if hasattr(output, 'read'): | |
with open(temp_file, "wb") as file: | |
file.write(output.read()) | |
elif isinstance(output, str): | |
response = requests.get(output) | |
with open(temp_file, "wb") as file: | |
file.write(response.content) | |
final_video = add_watermark(temp_file) | |
return final_video | |
except Exception as api_error: | |
logger.error(f"API call failed: {str(api_error)}") | |
return f"API call failed: {str(api_error)}" | |
except Exception as e: | |
logger.error(f"Unexpected error: {str(e)}") | |
return f"Unexpected error: {str(e)}" | |
def save_image(image): | |
"""Save the generated image temporarily""" | |
try: | |
# 임시 디렉토리에 저장 | |
temp_dir = "temp" | |
if not os.path.exists(temp_dir): | |
os.makedirs(temp_dir, exist_ok=True) | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
filepath = os.path.join(temp_dir, f"temp_{timestamp}.png") | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
image.save(filepath, format='PNG', optimize=True, quality=100) | |
return filepath | |
except Exception as e: | |
logger.error(f"Error in save_image: {str(e)}") | |
return None | |
css = """ | |
footer { | |
visibility: hidden; | |
} | |
""" | |
# Gradio 인터페이스 생성 | |
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: | |
gr.HTML('<div class="title">🎥 Dokdo✨ Digital Odyssey from Korea, Designing Original</div>') | |
gr.HTML('<div class="title">😄 Enjoy the amazing free video creation and enhancement services!</div>') | |
with gr.Tabs(): | |
with gr.Tab("Image Generation"): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
img_prompt = gr.Textbox( | |
label="Image Description", | |
placeholder="이미지 설명을 입력하세요... (한글 입력 가능)", | |
lines=3 | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=1152, | |
step=64, | |
value=1024 | |
) | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=1152, | |
step=64, | |
value=1024 | |
) | |
with gr.Row(): | |
steps = gr.Slider( | |
label="Inference Steps", | |
minimum=6, | |
maximum=25, | |
step=1, | |
value=8 | |
) | |
scales = gr.Slider( | |
label="Guidance Scale", | |
minimum=0.0, | |
maximum=5.0, | |
step=0.1, | |
value=3.5 | |
) | |
def get_random_seed(): | |
return torch.randint(0, 1000000, (1,)).item() | |
seed = gr.Number( | |
label="Seed", | |
value=get_random_seed(), | |
precision=0 | |
) | |
randomize_seed = gr.Button("🎲 Randomize Seed", elem_classes=["generate-btn"]) | |
generate_btn = gr.Button( | |
"✨ Generate Image", | |
elem_classes=["generate-btn"] | |
) | |
with gr.Column(scale=4): | |
img_output = gr.Image( | |
label="Generated Image", | |
type="pil", | |
format="png" | |
) | |
with gr.Tab("Amazing Video Generation"): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
video_prompt = gr.Textbox( | |
label="Video Description", | |
placeholder="비디오 설명을 입력하세요... (한글 입력 가능)", | |
lines=3 | |
) | |
upload_image = gr.Image( | |
type="filepath", | |
label="Upload First Frame Image" | |
) | |
video_generate_btn = gr.Button( | |
"🎬 Generate Video", | |
elem_classes=["generate-btn"] | |
) | |
with gr.Column(scale=4): | |
video_output = gr.Video(label="Generated Video") | |
def process_and_save_image(height, width, steps, scales, prompt, seed): | |
is_safe, translated_prompt = process_prompt(prompt) | |
if not is_safe: | |
gr.Warning("부적절한 내용이 포함된 프롬프트입니다.") | |
return None | |
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"): | |
try: | |
generated_image = pipe( | |
prompt=[translated_prompt], | |
generator=torch.Generator().manual_seed(int(seed)), | |
num_inference_steps=int(steps), | |
guidance_scale=float(scales), | |
height=int(height), | |
width=int(width), | |
max_sequence_length=256 | |
).images[0] | |
if not isinstance(generated_image, Image.Image): | |
generated_image = Image.fromarray(generated_image) | |
if generated_image.mode != 'RGB': | |
generated_image = generated_image.convert('RGB') | |
img_byte_arr = io.BytesIO() | |
generated_image.save(img_byte_arr, format='PNG') | |
return Image.open(io.BytesIO(img_byte_arr.getvalue())) | |
except Exception as e: | |
logger.error(f"Error in image generation: {str(e)}") | |
return None | |
def process_and_generate_video(image, prompt): | |
is_safe, translated_prompt = process_prompt(prompt) | |
if not is_safe: | |
gr.Warning("부적절한 내용이 포함된 프롬프트입니다.") | |
return None | |
return generate_video(image, translated_prompt) | |
def update_seed(): | |
return get_random_seed() | |
generate_btn.click( | |
process_and_save_image, | |
inputs=[height, width, steps, scales, img_prompt, seed], | |
outputs=img_output | |
) | |
video_generate_btn.click( | |
process_and_generate_video, | |
inputs=[upload_image, video_prompt], | |
outputs=video_output | |
) | |
randomize_seed.click( | |
update_seed, | |
outputs=[seed] | |
) | |
generate_btn.click( | |
update_seed, | |
outputs=[seed] | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) |