ginipick's picture
Update app.py
36a4fe5 verified
raw
history blame
23.4 kB
import gradio as gr
import hashlib
import re
from datetime import datetime
import os
import json
import schedule
import threading
import time
import dns.resolver
from huggingface_hub import Repository
import spaces
import argparse
from os import path
import shutil
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import torch
from diffusers import FluxPipeline
from diffusers.pipelines.stable_diffusion import safety_checker
from PIL import Image
from transformers import pipeline
import replicate
import logging
import requests
from pathlib import Path
import cv2
import numpy as np
import sys
import io
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Hugging Face Dataset Repo ์„ค์ •
HF_TOKEN = os.getenv("HF_TOKEN")
REPO_ID = "ginigen/MEMBERSHIP"
LOCAL_DIR = "./my_dataset"
repo = Repository(
local_dir=LOCAL_DIR,
clone_from=REPO_ID,
use_auth_token=HF_TOKEN,
repo_type="dataset"
)
DATA_FILE = os.path.join(LOCAL_DIR, "data.json")
current_user = {"email": None, "points": 0}
ADMIN_EMAIL = "arxivgpt@gmail.com"
ADMIN_PASS = "Arxiv4837!@"
# Setup and initialization code
cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
# API ์„ค์ •
CATBOX_USER_HASH = "e7a96fc68dd4c7d2954040cd5"
REPLICATE_API_TOKEN = os.getenv("API_KEY")
# ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["HF_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path
# ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™”
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cpu")
if not path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
def load_data():
try:
if os.path.exists(DATA_FILE):
with open(DATA_FILE, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"Load error: {e}")
return {"users": {}}
def save_data(data):
try:
with open(DATA_FILE, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
repo.git_add()
repo.git_commit("Update data.json")
repo.git_push()
return True
except Exception as e:
print(f"Save error: {e}")
return False
def init_db():
print("Initializing database...")
try:
data = load_data()
if ADMIN_EMAIL not in data["users"]:
data["users"][ADMIN_EMAIL] = {
"password": hash_password(ADMIN_PASS),
"registration_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"points": 999,
"is_admin": 1
}
save_data(data)
print("Admin account created")
print("Database initialized successfully")
return True
except Exception as e:
print(f"Init error: {e}")
return False
def hash_password(password: str) -> str:
if not password:
return ""
return hashlib.sha256(password.encode("utf-8")).hexdigest()
def is_valid_email(email: str) -> bool:
pattern = r'^[\w\.-]+@[\w\.-]+\.\w+$'
return re.match(pattern, email) is not None
def is_email_domain_valid(email: str) -> bool:
try:
domain = email.split("@")[1]
except IndexError:
return False
try:
records = dns.resolver.resolve(domain, 'MX')
if len(records) > 0:
return True
except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer, dns.exception.Timeout):
pass
return False
def register(email, password):
if not email or not password:
return "Please fill all fields."
if not is_valid_email(email):
return "Invalid email format. Please try again."
if not is_email_domain_valid(email):
return "This email domain seems invalid. Please use a different address."
if len(password) < 6:
return "Password must be at least 6 characters long."
data = load_data()
if email.strip() in data["users"]:
return "This email is already registered."
data["users"][email.strip()] = {
"password": hash_password(password.strip()),
"registration_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"points": 15, # ์ดˆ๊ธฐ ํฌ์ธํŠธ 15๋กœ ์„ค์ •
"is_admin": 0
}
if save_data(data):
return "Registration successful!"
return "Error occurred during registration."
def login(email, password):
global current_user
if not email or not password:
return {"value": "Please enter both email and password."}
data = load_data()
user_data = data["users"].get(email.strip())
if user_data and user_data["password"] == hash_password(password.strip()):
current_user = {
"email": email.strip(),
"points": user_data["points"],
"is_admin": user_data.get("is_admin", 0)
}
if current_user["is_admin"] == 1:
return {"value": "Logged in as ADMIN."}
else:
return {"value": f"Welcome! You have {user_data['points']} points."}
current_user = {"email": None, "points": 0, "is_admin": 0}
return {"value": "Wrong email or password."}
def use_point(points=5):
if not current_user["email"]:
return "You need to log in first."
data = load_data()
user_data = data["users"][current_user["email"]]
if user_data["points"] < points:
return f"Not enough points. Required: {points} points"
user_data["points"] -= points
current_user["points"] = user_data["points"]
if save_data(data):
return f"Points used! Remaining points: {user_data['points']}"
return "Error occurred while using points."
def get_profile():
if not current_user["email"]:
return "", "", "You need to log in first."
data = load_data()
user_data = data["users"].get(current_user["email"])
if user_data:
return (
current_user["email"],
user_data["registration_date"],
f"Current points: {user_data['points']}"
)
return "", "", "Profile not found."
def delete_user(email):
if not is_admin():
return "Only admin can delete users."
if email == ADMIN_EMAIL:
return "Cannot delete the admin account."
data = load_data()
if email in data["users"]:
del data["users"][email]
if save_data(data):
return f"User {email} has been deleted."
return "User not found."
def is_admin():
return current_user.get("email") == ADMIN_EMAIL and current_user.get("is_admin") == 1
def get_all_users():
if not is_admin():
return "Only admin can access this page."
data = load_data()
users = data["users"]
if not users or len(users) <= 1:
return "No registered users."
result = "<table style='width:100%; border-collapse: collapse;'>"
result += "<tr style='background-color: #f2f2f2;'>"
result += "<th style='padding: 12px; text-align: left; border: 1px solid #ddd;'>Email</th>"
result += "<th style='padding: 12px; text-align: left; border: 1px solid #ddd;'>Registration Date</th>"
result += "<th style='padding: 12px; text-align: left; border: 1px solid #ddd;'>Points</th>"
result += "<th style='padding: 12px; text-align: center; border: 1px solid #ddd;'>Action</th></tr>"
for email, user_data in users.items():
if email != ADMIN_EMAIL:
result += f"<tr style='border: 1px solid #ddd;'>"
result += f"<td style='padding: 12px; border: 1px solid #ddd;'>{email}</td>"
result += f"<td style='padding: 12px; border: 1px solid #ddd;'>{user_data['registration_date']}</td>"
result += f"<td style='padding: 12px; border: 1px solid #ddd;'>{user_data['points']}</td>"
result += f"<td style='padding: 12px; text-align: center; border: 1px solid #ddd;'>"
result += f"<button onclick='deleteUser(\"{email}\")' style='padding: 5px 10px; background-color: #ff4444; color: white; border: none; border-radius: 3px; cursor: pointer;'>Delete</button></td>"
result += "</tr>"
result += "</table>"
return result
@spaces.GPU
def setup_torch():
torch.backends.cuda.matmul.allow_tf32 = True
def translate_if_korean(text):
if any(ord(char) >= 0xAC00 and ord(char) <= 0xD7A3 for char in text):
translation = translator(text)[0]['translation_text']
return translation
return text
def filter_prompt(prompt):
inappropriate_keywords = [
"nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx",
"erotic", "sensual", "seductive", "provocative", "intimate",
"violence", "gore", "blood", "death", "kill", "murder", "torture",
"drug", "suicide", "abuse", "hate", "discrimination"
]
prompt_lower = prompt.lower()
for keyword in inappropriate_keywords:
if keyword in prompt_lower:
return False, "๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค."
return True, prompt
def process_prompt(prompt):
translated_prompt = translate_if_korean(prompt)
is_safe, filtered_prompt = filter_prompt(translated_prompt)
return is_safe, filtered_prompt
def add_watermark(video_path):
try:
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
text = "GiniGEN.AI"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = height * 0.05 / 30
thickness = 2
color = (255, 255, 255)
(text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
margin = int(height * 0.02)
x_pos = width - text_width - margin
y_pos = height - margin
output_path = "watermarked_output.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness)
out.write(frame)
cap.release()
out.release()
return output_path
except Exception as e:
logger.error(f"Error adding watermark: {str(e)}")
return video_path
def check_api_key():
api_key = os.getenv("API_KEY")
if not api_key:
logger.error("API_KEY environment variable not found")
return False
os.environ["REPLICATE_API_TOKEN"] = api_key
try:
response = requests.get(
"https://api.replicate.com/v1/account",
headers={"Authorization": f"Bearer {api_key}"}
)
if response.status_code == 200:
logger.info("Replicate API token validated successfully")
return True
else:
logger.error(f"API key validation failed with status code: {response.status_code}")
return False
except Exception as e:
logger.error(f"API key validation error: {str(e)}")
return False
def generate_video(image, prompt):
logger.info("Starting video generation")
try:
if not check_api_key():
return "Replicate API key not properly configured"
input_data = {
"prompt": prompt
}
if image:
try:
import base64
with open(image, 'rb') as img_file:
data = base64.b64encode(img_file.read()).decode('utf-8')
input_data["first_frame_image"] = f"data:image/png;base64,{data}"
except Exception as img_error:
logger.error(f"Error processing image: {str(img_error)}")
return f"Error processing image: {str(img_error)}"
try:
prediction = replicate.predictions.create(
model="minimax/video-01-live",
input=input_data
)
while prediction.status not in ["succeeded", "failed", "canceled"]:
prediction = replicate.predictions.get(prediction.id)
time.sleep(1)
if prediction.status == "succeeded" and prediction.output:
temp_file = "temp_output.mp4"
try:
response = requests.get(prediction.output, stream=True)
response.raise_for_status()
with open(temp_file, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
final_video = add_watermark(temp_file)
return final_video
except Exception as download_error:
logger.error(f"Error downloading video: {str(download_error)}")
return f"Error downloading video: {str(download_error)}"
else:
error_msg = f"Prediction failed with status: {prediction.status}"
if hasattr(prediction, 'error'):
error_msg += f" Error: {prediction.error}"
logger.error(error_msg)
return error_msg
except Exception as api_error:
logger.error(f"API call failed: {str(api_error)}")
return f"API call failed: {str(api_error)}"
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
return f"Unexpected error: {str(e)}"
finally:
try:
if 'temp_file' in locals() and os.path.exists(temp_file):
os.remove(temp_file)
except Exception as cleanup_error:
logger.warning(f"Error cleaning up temporary file: {str(cleanup_error)}")
def process_and_generate_video(image, prompt):
result = use_point(5) # 5ํฌ์ธํŠธ ์ฐจ๊ฐ
if "Not enough points" in result or "need to log in" in result:
return result
is_safe, translated_prompt = process_prompt(prompt)
if not is_safe:
return "๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค."
try:
video_result = generate_video(image, translated_prompt)
if isinstance(video_result, str) and ("error" in video_result.lower() or "failed" in video_result.lower()):
# ์—๋Ÿฌ ๋ฐœ์ƒ ์‹œ ํฌ์ธํŠธ ํ™˜๋ถˆ
data = load_data()
user_data = data["users"][current_user["email"]]
user_data["points"] += 5
current_user["points"] = user_data["points"]
save_data(data)
return f"Error: {video_result}"
return video_result
except Exception as e:
# ์—๋Ÿฌ ๋ฐœ์ƒ ์‹œ ํฌ์ธํŠธ ํ™˜๋ถˆ
data = load_data()
user_data = data["users"][current_user["email"]]
user_data["points"] += 5
current_user["points"] = user_data["points"]
save_data(data)
return f"Error: {str(e)}"
# CSS ์Šคํƒ€์ผ ์ •์˜
CUSTOM_CSS = """
/* ํ•œ ์ค„์— Email + Password + (Login, Sign Up) ๋ฒ„ํŠผ 2๊ฐœ + ๊ฒฐ๊ณผ์ฐฝ */
/* ์ „์ฒด Row ์ปจํ…Œ์ด๋„ˆ */
#row-container {
display: flex;
align-items: center;
gap: 8px;
background: linear-gradient(135deg, #b2fefa 0%, #8fd3f4 100%);
padding: 10px;
border-radius: 8px;
box-shadow: 0 3px 8px rgba(0,0,0,0.15);
margin-bottom: 15px;
}
/* ํ…์ŠคํŠธ๋ฐ•์Šค, ๋ฒ„ํŠผ ํฌ๊ธฐ ํ†ต์ผ */
.same-size {
width: 130px !important;
height: 34px !important;
font-size: 0.9rem !important;
text-align: center !important;
padding: 5px 6px !important;
margin: 0 !important;
border-radius: 5px;
border: 1px solid #ddd;
}
/* ๋ฒ„ํŠผ์€ ์•ฝ๊ฐ„ ๋‹ค๋ฅธ ๋ฐฐ๊ฒฝ */
button.same-size {
background-color: #6666ff;
color: white;
border: none;
cursor: pointer;
}
button.same-size:hover {
background-color: #4a4acc;
}
/* ๊ฒฐ๊ณผ์ฐฝ๋„ ๋™์ผ ํฌ๊ธฐ */
.result-box {
width: 160px !important;
height: 34px !important;
font-size: 0.9rem !important;
padding: 5px 6px !important;
border: 1px solid #ccc;
border-radius: 5px;
background-color: #fff;
text-align: left;
overflow: hidden;
white-space: nowrap;
text-overflow: ellipsis;
}
/* ์ƒ๋‹จ ํ—ค๋” ์Šคํƒ€์ผ */
.header-container {
display: flex;
justify-content: space-between;
align-items: center;
padding: 6px;
background-color: #f5f5f5;
margin-bottom: 10px;
}
.user-badge {
padding: 3px 6px;
border-radius: 15px;
background-color: #4CAF50;
color: white;
font-size: 0.8em;
}
.admin-badge {
padding: 3px 6px;
border-radius: 15px;
background-color: #ff4444;
color: white;
font-size: 0.8em;
cursor: pointer;
}
.points-display {
background: linear-gradient(135deg, #6e8efb 0%, #5d7df9 100%);
padding: 20px;
border-radius: 15px;
margin: 20px 0;
box-shadow: 0 4px 15px rgba(110, 142, 251, 0.2);
}
.points-text {
color: white;
font-size: 1.5em;
text-align: center;
margin: 0;
text-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}
footer {
visibility: hidden;
}
.gradio-container {
background: linear-gradient(135deg, #f6f8ff 0%, #e9f0ff 100%);
}
.gr-button {
border: 2px solid rgba(100, 100, 255, 0.2);
background: linear-gradient(135deg, #6e8efb 0%, #5d7df9 100%);
box-shadow: 0 4px 15px rgba(110, 142, 251, 0.2);
}
.gr-button:hover {
background: linear-gradient(135deg, #5d7df9 0%, #4a6af8 100%);
box-shadow: 0 4px 20px rgba(110, 142, 251, 0.3);
}
.gr-input, .gr-box {
border-radius: 12px;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.05);
border: 2px solid rgba(100, 100, 255, 0.1);
}
"""
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
with gr.Blocks(theme="soft", css=CUSTOM_CSS) as demo:
gr.HTML("""<a href="https://visitorbadge.io/status?path=https%3A%2F%2Fginigen-Dokdo-membership.hf.space">
<img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fginigen-Dokdo-membership.hf.space&countColor=%23263759" />
</a>""")
gr.Markdown("## 'Dokdo membership' Image to Video generation")
# ์ƒ๋‹จ ํ—ค๋”
with gr.Row(elem_classes="header-container"):
gr.Markdown("##### Upon Dokdo membership registration, get 15 points. Video creation costs 5 points. Need more points through public contributions or paid services? Contact ginipicks@gmail.com.")
with gr.Column(scale=1):
user_info = gr.HTML(value="", elem_classes="user-badge")
admin_button = gr.Button("Admin Page", visible=False, elem_classes="admin-badge")
# ๋กœ๊ทธ์ธ/ํšŒ์›๊ฐ€์ž… ์„น์…˜
with gr.Row(elem_id="row-container"):
email_box = gr.Textbox(
placeholder="Email",
show_label=False,
elem_classes=["same-size"]
)
pass_box = gr.Textbox(
placeholder="Password",
show_label=False,
type="password",
elem_classes=["same-size"]
)
login_btn = gr.Button(
"Login",
elem_classes=["same-size"]
)
signup_btn = gr.Button(
"Sign Up",
elem_classes=["same-size"]
)
auth_output = gr.Textbox(
show_label=False,
elem_classes=["result-box"],
interactive=False,
placeholder="Message"
)
# ๋กœ๊ทธ์ธ ํ›„ ํ™”๋ฉด
with gr.Column(elem_id="studio-section", visible=False) as studio_container:
gr.Markdown("### My Studio")
with gr.Group(elem_classes="points-display"):
points_display = gr.Markdown("", elem_classes="points-text")
# ๋น„๋””์˜ค ์ƒ์„ฑ ์„น์…˜
with gr.Row():
with gr.Column(scale=3):
video_prompt = gr.Textbox(
label="Video Description",
placeholder="๋น„๋””์˜ค ์„ค๋ช…์„ ์ž…๋ ฅํ•˜์„ธ์š”... (ํ•œ๊ธ€ ์ž…๋ ฅ ๊ฐ€๋Šฅ)",
lines=3
)
upload_image = gr.Image(type="filepath", label="Upload First Frame Image")
video_generate_btn = gr.Button("๐ŸŽฌ Generate Video (5 points)")
with gr.Column(scale=4):
video_output = gr.Video(label="Generated Video")
refresh_profile_button = gr.Button("Refresh Points")
# ๊ด€๋ฆฌ์ž ํŽ˜์ด์ง€
with gr.Column(visible=False) as admin_container:
gr.Markdown("### Admin Page")
admin_refresh = gr.Button("Get All Users")
admin_output = gr.HTML()
delete_status = gr.Textbox(label="Result")
# ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
def login_and_update(email, password):
res = login(email, password)
msg = res["value"]
if "Welcome!" in msg or "ADMIN" in msg:
user_badge = f"<div>Logged in: {email}</div>"
is_admin_flag = "ADMIN" in msg
_, _, points_info = get_profile()
return {
user_info: user_badge,
studio_container: gr.Column(visible=True),
admin_button: gr.Button(visible=is_admin_flag),
points_display: f"### {points_info}",
auth_output: msg
}
else:
return {auth_output: msg}
# ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
login_btn.click(
fn=login_and_update,
inputs=[email_box, pass_box],
outputs=[
user_info,
studio_container,
admin_button,
points_display,
auth_output
]
)
signup_btn.click(
fn=register,
inputs=[email_box, pass_box],
outputs=auth_output
)
video_generate_btn.click(
fn=process_and_generate_video,
inputs=[upload_image, video_prompt],
outputs=video_output
)
refresh_profile_button.click(
fn=lambda: get_profile()[2], # points_info๋งŒ ๋ฐ˜ํ™˜
outputs=[points_display]
)
def toggle_admin_page():
return {admin_container: gr.Column(visible=True)}
admin_button.click(fn=toggle_admin_page, outputs=[admin_container])
admin_refresh.click(fn=get_all_users, outputs=admin_output)
# ํšŒ์› ์‚ญ์ œ JS
gr.HTML("""
<script>
function deleteUser(email) {
if (confirm('Are you sure you want to delete this user?')) {
const deleteEmail = document.getElementById('delete_user_email');
if (deleteEmail) {
deleteEmail.value = email;
}
const deleteBtn = document.getElementById('delete_and_refresh');
if (deleteBtn) {
deleteBtn.click();
}
}
}
</script>
""")
delete_user_email = gr.Textbox(visible=False)
delete_and_refresh = gr.Button(visible=False)
def delete_user_and_refresh(email):
msg = delete_user(email)
return msg, get_all_users()
delete_and_refresh.click(
fn=delete_user_and_refresh,
inputs=delete_user_email,
outputs=[delete_status, admin_output]
)
# ๋ฉ”์ธ ์‹คํ–‰
if __name__ == "__main__":
print("\n=== Application Startup ===")
init_db()
demo.launch(debug=True)