ginipick commited on
Commit
5e60b44
1 Parent(s): d3e5f6a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app (28).py +427 -0
  2. requirements (9).txt +16 -0
app (28).py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import argparse
3
+ import os
4
+ import time
5
+ from os import path
6
+ import shutil
7
+ from datetime import datetime
8
+ from safetensors.torch import load_file
9
+ from huggingface_hub import hf_hub_download
10
+ import gradio as gr
11
+ import torch
12
+ from diffusers import FluxPipeline
13
+ from diffusers.pipelines.stable_diffusion import safety_checker
14
+ from PIL import Image
15
+ from transformers import pipeline
16
+ import replicate
17
+ import logging
18
+ import requests
19
+ from pathlib import Path
20
+ import cv2
21
+ import numpy as np
22
+ import sys
23
+ import io
24
+ # 로깅 설정
25
+ logging.basicConfig(level=logging.INFO)
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # Setup and initialization code
29
+ cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
30
+ PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
31
+
32
+
33
+ # API 설정
34
+ CATBOX_USER_HASH = "e7a96fc68dd4c7d2954040cd5"
35
+ REPLICATE_API_TOKEN = os.getenv("API_KEY")
36
+
37
+ # 환경 변수 설정
38
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
39
+ os.environ["HF_HUB_CACHE"] = cache_path
40
+ os.environ["HF_HOME"] = cache_path
41
+
42
+ # CUDA 설정
43
+ torch.backends.cuda.matmul.allow_tf32 = True
44
+
45
+ # 번역기 초기화
46
+ translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
47
+
48
+
49
+ if not path.exists(cache_path):
50
+ os.makedirs(cache_path, exist_ok=True)
51
+
52
+ def check_api_key():
53
+ """API 키 확인 및 설정"""
54
+ if not REPLICATE_API_TOKEN:
55
+ logger.error("Replicate API key not found")
56
+ return False
57
+ os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_TOKEN
58
+ logger.info("Replicate API token set successfully")
59
+ return True
60
+
61
+ def translate_if_korean(text):
62
+ """한글이 포함된 경우 영어로 번역"""
63
+ if any(ord(char) >= 0xAC00 and ord(char) <= 0xD7A3 for char in text):
64
+ translation = translator(text)[0]['translation_text']
65
+ return translation
66
+ return text
67
+
68
+ def filter_prompt(prompt):
69
+ inappropriate_keywords = [
70
+ "nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx",
71
+ "erotic", "sensual", "seductive", "provocative", "intimate",
72
+ "violence", "gore", "blood", "death", "kill", "murder", "torture",
73
+ "drug", "suicide", "abuse", "hate", "discrimination"
74
+ ]
75
+
76
+ prompt_lower = prompt.lower()
77
+ for keyword in inappropriate_keywords:
78
+ if keyword in prompt_lower:
79
+ return False, "부적절한 내용이 포함된 프롬프트입니다."
80
+ return True, prompt
81
+
82
+ def process_prompt(prompt):
83
+ """프롬프트 전처리 (번역 및 필터링)"""
84
+ translated_prompt = translate_if_korean(prompt)
85
+ is_safe, filtered_prompt = filter_prompt(translated_prompt)
86
+ return is_safe, filtered_prompt
87
+
88
+ class timer:
89
+ def __init__(self, method_name="timed process"):
90
+ self.method = method_name
91
+ def __enter__(self):
92
+ self.start = time.time()
93
+ print(f"{self.method} starts")
94
+ def __exit__(self, exc_type, exc_val, exc_tb):
95
+ end = time.time()
96
+ print(f"{self.method} took {str(round(end - self.start, 2))}s")
97
+
98
+ # Model initialization
99
+ if not path.exists(cache_path):
100
+ os.makedirs(cache_path, exist_ok=True)
101
+
102
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
103
+ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
104
+ pipe.fuse_lora(lora_scale=0.125)
105
+ pipe.to(device="cuda", dtype=torch.bfloat16)
106
+ pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
107
+
108
+ def upload_to_catbox(image_path):
109
+ """catbox.moe API를 사용하여 이미지 업로드"""
110
+ try:
111
+ logger.info(f"Preparing to upload image: {image_path}")
112
+ url = "https://catbox.moe/user/api.php"
113
+
114
+ file_extension = Path(image_path).suffix.lower()
115
+ if file_extension not in ['.jpg', '.jpeg', '.png', '.gif']:
116
+ logger.error(f"Unsupported file type: {file_extension}")
117
+ return None
118
+
119
+ files = {
120
+ 'fileToUpload': (
121
+ os.path.basename(image_path),
122
+ open(image_path, 'rb'),
123
+ 'image/jpeg' if file_extension in ['.jpg', '.jpeg'] else 'image/png'
124
+ )
125
+ }
126
+
127
+ data = {
128
+ 'reqtype': 'fileupload',
129
+ 'userhash': CATBOX_USER_HASH
130
+ }
131
+
132
+ response = requests.post(url, files=files, data=data)
133
+
134
+ if response.status_code == 200 and response.text.startswith('http'):
135
+ image_url = response.text
136
+ logger.info(f"Image uploaded successfully: {image_url}")
137
+ return image_url
138
+ else:
139
+ raise Exception(f"Upload failed: {response.text}")
140
+
141
+ except Exception as e:
142
+ logger.error(f"Image upload error: {str(e)}")
143
+ return None
144
+
145
+ def add_watermark(video_path):
146
+ """OpenCV를 사용하여 비디오에 워터마크 추가"""
147
+ try:
148
+ cap = cv2.VideoCapture(video_path)
149
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
150
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
151
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
152
+
153
+ text = "GiniGEN.AI"
154
+ font = cv2.FONT_HERSHEY_SIMPLEX
155
+ font_scale = height * 0.05 / 30
156
+ thickness = 2
157
+ color = (255, 255, 255)
158
+
159
+ (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
160
+ margin = int(height * 0.02)
161
+ x_pos = width - text_width - margin
162
+ y_pos = height - margin
163
+
164
+ output_path = "watermarked_output.mp4"
165
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
166
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
167
+
168
+ while cap.isOpened():
169
+ ret, frame = cap.read()
170
+ if not ret:
171
+ break
172
+ cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness)
173
+ out.write(frame)
174
+
175
+ cap.release()
176
+ out.release()
177
+
178
+ return output_path
179
+
180
+ except Exception as e:
181
+ logger.error(f"Error adding watermark: {str(e)}")
182
+ return video_path
183
+
184
+ def generate_video(image, prompt):
185
+ logger.info("Starting video generation")
186
+ try:
187
+ if not check_api_key():
188
+ return "Replicate API key not properly configured"
189
+
190
+ if not image:
191
+ logger.error("No image provided")
192
+ return "Please upload an image"
193
+
194
+ image_url = upload_to_catbox(image)
195
+ if not image_url:
196
+ return "Failed to upload image"
197
+
198
+ input_data = {
199
+ "prompt": prompt,
200
+ "first_frame_image": image_url
201
+ }
202
+
203
+ try:
204
+ replicate.Client(api_token=REPLICATE_API_TOKEN)
205
+ output = replicate.run(
206
+ "minimax/video-01-live",
207
+ input=input_data
208
+ )
209
+
210
+ temp_file = "temp_output.mp4"
211
+
212
+ if hasattr(output, 'read'):
213
+ with open(temp_file, "wb") as file:
214
+ file.write(output.read())
215
+ elif isinstance(output, str):
216
+ response = requests.get(output)
217
+ with open(temp_file, "wb") as file:
218
+ file.write(response.content)
219
+
220
+ final_video = add_watermark(temp_file)
221
+ return final_video
222
+
223
+ except Exception as api_error:
224
+ logger.error(f"API call failed: {str(api_error)}")
225
+ return f"API call failed: {str(api_error)}"
226
+
227
+ except Exception as e:
228
+ logger.error(f"Unexpected error: {str(e)}")
229
+ return f"Unexpected error: {str(e)}"
230
+
231
+ def save_image(image):
232
+ """Save the generated image temporarily"""
233
+ try:
234
+ # 임시 디렉토리에 저장
235
+ temp_dir = "temp"
236
+ if not os.path.exists(temp_dir):
237
+ os.makedirs(temp_dir, exist_ok=True)
238
+
239
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
240
+ filepath = os.path.join(temp_dir, f"temp_{timestamp}.png")
241
+
242
+ if not isinstance(image, Image.Image):
243
+ image = Image.fromarray(image)
244
+
245
+ if image.mode != 'RGB':
246
+ image = image.convert('RGB')
247
+
248
+ image.save(filepath, format='PNG', optimize=True, quality=100)
249
+
250
+ return filepath
251
+ except Exception as e:
252
+ logger.error(f"Error in save_image: {str(e)}")
253
+ return None
254
+
255
+
256
+
257
+ css = """
258
+ footer {
259
+ visibility: hidden;
260
+ }
261
+ """
262
+
263
+
264
+ # Gradio 인터페이스 생성
265
+ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
266
+ gr.HTML('<div class="title">AI Image & Video Generator</div>')
267
+
268
+ with gr.Tabs():
269
+ with gr.Tab("Image Generation"):
270
+ with gr.Row():
271
+ with gr.Column(scale=3):
272
+ img_prompt = gr.Textbox(
273
+ label="Image Description",
274
+ placeholder="이미지 설명을 입력하세요... (한글 입력 가능)",
275
+ lines=3
276
+ )
277
+
278
+ with gr.Accordion("Advanced Settings", open=False):
279
+ with gr.Row():
280
+ height = gr.Slider(
281
+ label="Height",
282
+ minimum=256,
283
+ maximum=1152,
284
+ step=64,
285
+ value=1024
286
+ )
287
+ width = gr.Slider(
288
+ label="Width",
289
+ minimum=256,
290
+ maximum=1152,
291
+ step=64,
292
+ value=1024
293
+ )
294
+
295
+ with gr.Row():
296
+ steps = gr.Slider(
297
+ label="Inference Steps",
298
+ minimum=6,
299
+ maximum=25,
300
+ step=1,
301
+ value=8
302
+ )
303
+ scales = gr.Slider(
304
+ label="Guidance Scale",
305
+ minimum=0.0,
306
+ maximum=5.0,
307
+ step=0.1,
308
+ value=3.5
309
+ )
310
+
311
+ def get_random_seed():
312
+ return torch.randint(0, 1000000, (1,)).item()
313
+
314
+ seed = gr.Number(
315
+ label="Seed",
316
+ value=get_random_seed(),
317
+ precision=0
318
+ )
319
+
320
+ randomize_seed = gr.Button("🎲 Randomize Seed", elem_classes=["generate-btn"])
321
+
322
+ generate_btn = gr.Button(
323
+ "✨ Generate Image",
324
+ elem_classes=["generate-btn"]
325
+ )
326
+
327
+ with gr.Column(scale=4):
328
+ img_output = gr.Image(
329
+ label="Generated Image",
330
+ type="pil",
331
+ format="png"
332
+ )
333
+
334
+
335
+ with gr.Tab("Amazing Video Generation"):
336
+ with gr.Row():
337
+ with gr.Column(scale=3):
338
+ video_prompt = gr.Textbox(
339
+ label="Video Description",
340
+ placeholder="비디오 설명을 입력하세요... (한글 입력 가능)",
341
+ lines=3
342
+ )
343
+ upload_image = gr.Image(
344
+ type="filepath",
345
+ label="Upload First Frame Image"
346
+ )
347
+ video_generate_btn = gr.Button(
348
+ "🎬 Generate Video",
349
+ elem_classes=["generate-btn"]
350
+ )
351
+
352
+ with gr.Column(scale=4):
353
+ video_output = gr.Video(label="Generated Video")
354
+
355
+ @spaces.GPU
356
+ def process_and_save_image(height, width, steps, scales, prompt, seed):
357
+ is_safe, translated_prompt = process_prompt(prompt)
358
+ if not is_safe:
359
+ gr.Warning("부적절한 내용이 포함된 프롬프트입니다.")
360
+ return None
361
+
362
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
363
+ try:
364
+ generated_image = pipe(
365
+ prompt=[translated_prompt],
366
+ generator=torch.Generator().manual_seed(int(seed)),
367
+ num_inference_steps=int(steps),
368
+ guidance_scale=float(scales),
369
+ height=int(height),
370
+ width=int(width),
371
+ max_sequence_length=256
372
+ ).images[0]
373
+
374
+ if not isinstance(generated_image, Image.Image):
375
+ generated_image = Image.fromarray(generated_image)
376
+
377
+ if generated_image.mode != 'RGB':
378
+ generated_image = generated_image.convert('RGB')
379
+
380
+ img_byte_arr = io.BytesIO()
381
+ generated_image.save(img_byte_arr, format='PNG')
382
+
383
+ return Image.open(io.BytesIO(img_byte_arr.getvalue()))
384
+ except Exception as e:
385
+ logger.error(f"Error in image generation: {str(e)}")
386
+ return None
387
+
388
+
389
+
390
+ def process_and_generate_video(image, prompt):
391
+ is_safe, translated_prompt = process_prompt(prompt)
392
+ if not is_safe:
393
+ gr.Warning("부적절한 내용이 포함된 프롬프트입니다.")
394
+ return None
395
+ return generate_video(image, translated_prompt)
396
+
397
+ def update_seed():
398
+ return get_random_seed()
399
+
400
+ generate_btn.click(
401
+ process_and_save_image,
402
+ inputs=[height, width, steps, scales, img_prompt, seed],
403
+ outputs=img_output
404
+ )
405
+
406
+ video_generate_btn.click(
407
+ process_and_generate_video,
408
+ inputs=[upload_image, video_prompt],
409
+ outputs=video_output
410
+ )
411
+
412
+ randomize_seed.click(
413
+ update_seed,
414
+ outputs=[seed]
415
+ )
416
+
417
+ generate_btn.click(
418
+ update_seed,
419
+ outputs=[seed]
420
+ )
421
+
422
+ if __name__ == "__main__":
423
+ demo.launch(
424
+ server_name="0.0.0.0",
425
+ server_port=7860,
426
+ share=True
427
+ )
requirements (9).txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ diffusers==0.30.0
3
+ invisible_watermark
4
+ torch
5
+ transformers==4.43.3
6
+ xformers
7
+ sentencepiece
8
+ peft
9
+ gradio
10
+ replicate
11
+ requests
12
+ python-dotenv
13
+ Pillow
14
+ opencv-python-headless
15
+ numpy
16
+ sacremoses