Spaces:
Running
Running
import os, sys, json, re, time, base64, random, shutil | |
import gradio as gr | |
import numpy as np | |
import requests | |
from requests import Session | |
from requests.adapters import HTTPAdapter | |
from requests.packages.urllib3.util.retry import Retry | |
from datetime import datetime | |
import urllib.parse | |
from groq import Groq | |
from exif import Image | |
from PIL import Image as PILImage, ExifTags as PILExifTags | |
from io import BytesIO | |
import colorsys | |
# MARK: INIT | |
# MAX_SEED = np.iinfo(np.int32).max | |
MAX_SEED = 9999 | |
MAX_IMAGE_SIZE = 2048 | |
GROQ_APIKEY_PROMPTENHANCE = os.getenv("GROQ_APIKEY_PROMPTENHANCE") | |
API_V1 = os.getenv("API_V1") | |
API_V1_MODELS = urllib.parse.unquote(API_V1 + "/models") | |
API_V1_IMAGE = urllib.parse.unquote(API_V1 + "/imagine") | |
API_OLD = urllib.parse.unquote(os.getenv("API_OLD")) | |
# print(API_V1_MODELS) | |
# print(API_V1_IMAGE) | |
CACHE_DIR = os.path.join(os.path.dirname(__file__), "cache") | |
IMAGE_DIR = os.path.join(CACHE_DIR, "images") | |
if not os.path.exists(CACHE_DIR): | |
os.makedirs(CACHE_DIR) | |
print(f"Created cache dir on path {CACHE_DIR}") | |
os.makedirs(IMAGE_DIR) | |
print(f"Created images dir on path {IMAGE_DIR}") | |
RES = os.path.join(os.path.dirname(__file__), "_res") | |
gr.set_static_paths(paths=["_res/assets/", "_res/assets/emojis/", "_res/assets/favicons/"]) | |
custom_css = RES + "/_custom.css" | |
custom_js = RES + "/_custom.js" | |
custom_head = f""" | |
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.9.0/css/all.min.css"/> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.9.0/js/all.min.js"></script> | |
<link rel="apple-touch-icon" sizes="180x180" href="file=_res/assets/favicons/apple-touch-icon.png"> | |
<link rel="icon" type="image/png" sizes="32x32" href="file=_res/assets/favicons/favicon-32x32.png"> | |
<link rel="icon" type="image/png" sizes="16x16" href="file=_res/assets/favicons/favicon-16x16.png"> | |
<link rel="icon" type="image/x-icon" href="file=_res/assets/favicons/favicon.ico"> | |
<link rel="manifest" href="file=_res/assets/favicons/site.webmanifest"> | |
""" | |
theme = gr.themes.Soft( | |
# primary_hue="orange", | |
radius_size="sm", | |
neutral_hue=gr.themes.Color(c100="#a6adc8", c200="#9399b2", c300="#7f849c", c400="#6c7086", c50="#cdd6f4", c500="#585b70", c600="#45475a", c700="#313244", c800="#1e1e2e", c900="#181825", c950="#11111b"), | |
) | |
title = "Bilder Builder" | |
def check_api(url): | |
try: | |
response = requests.get(url, timeout=2) | |
return response.status_code | |
except requests.exceptions.RequestException as e: | |
print(f"An error occurred: {e}") | |
return None | |
# MARK: GET PREV. IMAGES | |
def get_gallery_images(dirpath): | |
gallery_images = [dirpath + "/" + s for s in os.listdir(dirpath) if os.path.isfile(os.path.join(dirpath, s))] | |
gallery_images.sort(key=lambda s: os.path.getmtime(s), reverse=True) | |
return gallery_images | |
# MARK: READ EXIF | |
def read_exif(image_path): | |
with open(image_path, "rb") as src: | |
img = Image(src) | |
img_comment = json.loads(img.user_comment) | |
# checking if the key exists before removing | |
if "concept" in img_comment: | |
img_comment.pop("concept") | |
return img_comment | |
def read_image_exfi_data(image_path): | |
print("Imagepath:", image_path) | |
img_exif_make, img_exif_comment = read_exif(image_path) | |
return None, image_path, img_exif_comment | |
# MARK: GROQ PROMPT ENHANCE | |
def groq_enhance_process(Prompt=""): | |
client = Groq(api_key=GROQ_APIKEY_PROMPTENHANCE) | |
Prompt = "random prompt" if Prompt == "" else Prompt | |
SYSTEMPROMPT = os.path.join(RES, "groq_systemmessage_prompt_enhance.json") | |
with open(SYSTEMPROMPT, "r") as f: | |
SYSTEMPROMPT = json.load(f) | |
completion = client.chat.completions.create( | |
model="llama-3.1-70b-versatile", | |
messages=[SYSTEMPROMPT, {"role": "user", "content": Prompt}], | |
temperature=0.8, | |
max_tokens=512, | |
top_p=0.9, | |
stream=False, | |
seed=random.randint(0, MAX_SEED), | |
stop=None, | |
) | |
if completion.choices[0].message.content != "": | |
enhanced_prompt = completion.choices[0].message.content | |
enhanced_prompt = re.sub(r"[\.\"]+", "", enhanced_prompt) | |
return enhanced_prompt | |
def image_get_size(image_path): | |
img = PILImage.open(image_path) | |
print("Image size:", img.size) | |
width, height = img.size | |
return width, height | |
# MARK: DOMINANT COLOR | |
def image_get_dominant_color(image_path): | |
img = PILImage.open(image_path) | |
img = img.convert("RGB") | |
img = img.resize((100, 100), resample=0) | |
pixels = list(img.getdata()) | |
# Erzeuge eine Liste mit den Häufigkeiten der Farben | |
colors = [] | |
for pixel in pixels: | |
r, g, b = pixel | |
h, s, v = colorsys.rgb_to_hsv(r / 255, g / 255, b / 255) | |
if v > 0.5: # Filteriere hellere Farben aus | |
continue | |
if v > 0.99: # Filteriere Weiß aus | |
continue | |
colors.append((h, s, v)) | |
# Ermittle die dominante Farbe | |
dominant_color = max(colors, key=lambda x: x[2]) | |
dominant_color_rgb = colorsys.hsv_to_rgb(dominant_color[0], dominant_color[1], dominant_color[2]) | |
dominant_color_rgb = [int(c * 255) for c in dominant_color_rgb] | |
dominant_color_rgb = f"rgb({dominant_color_rgb[0]}, {dominant_color_rgb[1]}, {dominant_color_rgb[2]})" | |
print(dominant_color_rgb) | |
return dominant_color_rgb | |
# MARK: CLEAR COMPONENTS | |
def clear_components(): | |
return None | |
def process(Prompt, used_model, image_width, image_height, image_ratio, image_seed, randomize_seed): | |
if Prompt == "": | |
gr.Info("Kein Prompt angegeben, es wird ein zufälliger Prompt generiert.", duration=12) | |
Prompt = groq_enhance_process("random prompt") | |
image_ratio = "9:16" if image_ratio == "" else image_ratio | |
used_seed = random.randint(0, MAX_SEED) if image_seed == 0 or randomize_seed else image_seed | |
# used_model = "turbo" if int(image_width) > 1024 or int(image_height) > 1024 else "flux" # turbo, flux | |
timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") | |
filename_prompt = re.sub(r"[^\w\s-]", "", Prompt).strip().replace(" ", "_") | |
filename = timestamp + "_" + filename_prompt[:100] + ".png" | |
file_path = os.path.join(IMAGE_DIR, filename) | |
# Retry-Logik mit requests und Retry | |
session = Session() | |
retries = Retry(total=3, status_forcelist=[429, 500, 502, 503, 504], backoff_factor=0.3, respect_retry_after_header=True) # Max 3 Versuche # Codes, die wiederholt werden # Exponential Backoff # Retry-Header beachten | |
adapter = HTTPAdapter(max_retries=retries) | |
session.mount("https://", adapter) | |
session.mount("http://", adapter) | |
API_OLD_REQUEST_URL = f"{API_OLD}{urllib.parse.quote(Prompt)}?model=flux&width={image_width}&height={image_height}&nologo=true&enhance=false&nofeed=true&seed={used_seed}" | |
API_V1_REQUEST_URL = f"{API_V1_IMAGE}?prompt={urllib.parse.quote(Prompt)}&size={urllib.parse.quote(image_ratio)}&seed={used_seed}&model={used_model.lower()}" | |
REQUEST_URL = API_V1_REQUEST_URL if used_model.lower() != "flux-api-old" else API_OLD_REQUEST_URL | |
print(REQUEST_URL) | |
try: | |
response = session.get(REQUEST_URL, timeout=60) | |
if response.status_code == 200: | |
print("Imagine API Request solved") | |
print("Save image to: " + file_path) | |
img = PILImage.open(BytesIO(response.content)) | |
img.save(file_path, "PNG") | |
# img_exif_comment = "" # read_exif(file_path) | |
img_dominant_color = image_get_dominant_color(file_path) | |
img_width, img_height = image_get_size(file_path) | |
return ({"value": Prompt, "__type__": "update"}, {"value": file_path, "__type__": "update"}, {"value": None, "visible": False, "__type__": "update"}, {"visible": True, "__type__": "update"}, {"value": Prompt, "visible": True, "__type__": "update"}, img_width, img_height, used_seed, {"value": file_path, "visible": True, "__type__": "update"}, img_dominant_color, used_seed) | |
else: | |
print("Imagine API Request ERROR") | |
raise gr.Error("Imagine API-Aufruf fehlgeschlagen 💥!", duration=15) | |
except requests.exceptions.Timeout: | |
raise gr.Error("⏰ Zeitüberschreitung beim API-Aufruf", duration=15) | |
except requests.exceptions.RequestException as e: | |
print(f"Unbekannter Fehler beim API-Aufruf: {e}") | |
raise gr.Error("Unbekannter Fehler beim API-Aufruf! 🤷♂️", duration=15) | |
def get_inference_models(): | |
models = [] | |
if check_api(API_OLD) == 200: | |
models.append("FLUX-API-OLD") | |
selected_model = "FLUX-API-OLD" | |
if check_api(API_V1_IMAGE) == 200: | |
models_new_api = ["FLUX", "FLUX-Realism", "FLUX-Pixel", "FLUX-Anime", "FLUX-3D", "FLUX-Disney", "ANY-DARK", "Stable-Diffusion-XL-Base", "Stable-Diffusion-XL-Lightning"] | |
for model in models_new_api: | |
models.append(model) | |
selected_model = "FLUX" | |
return models, selected_model | |
# MARK: Gradio BLOCKS UI | |
with gr.Blocks(theme=theme, head=custom_head, css=custom_css, js=custom_js, title=title) as demo: | |
with gr.Row(elem_classes="row-header"): | |
gr.Markdown(f""" | |
<h1>{title}</h1> | |
<p><span style="font-weight: 600">LG Sebastian</span> <i class="winking-hand-emoji"></i> gib dem Space gerne ein <i class="heart-beat-emoji"></i></p> | |
""", elem_classes="md-header",) | |
with gr.Tab("Bilder Builder"): | |
with gr.Row(): | |
with gr.Column(scale=2): # min_width=420, | |
with gr.Row(): | |
placeholder_text = "[???] Generiert dir einen zufälligen Prompt.\n[STERN] optimiert deinen eignen Prompt.\n[RUN] generiert dein Bild." | |
text_prompt = gr.Textbox(label="Prompt", show_label=False, lines=12, max_lines=18, placeholder=placeholder_text, elem_id="prompt_input", elem_classes="prompt-input hide-progress", autofocus=True) | |
with gr.Row(): | |
random_prompt_button = gr.Button("", variant="secondary", elem_id="random_prompt_btn", elem_classes="random-prompt-btn", icon="_res/assets/star_light_48.png") | |
enhance_prompt_button = gr.Button("", variant="secondary", elem_id="enhance_prompt_btn", elem_classes="enhance-prompt-btn", icon="_res/assets/star_light_48.png") | |
run_button = gr.Button("Erstellen", variant="primary", elem_id="run_btn", elem_classes="run-btn") | |
with gr.Row(elem_classes="image_size_selctor_wrapper"): | |
with gr.Column(scale=1): | |
with gr.Row(): | |
inference_models, selected_model = get_inference_models() | |
select_model = gr.Dropdown(choices=inference_models, value=selected_model, label="Model", elem_id="select_model", elem_classes="select-model") | |
# with gr.Row(): | |
image_width = gr.Number(label="Breite", minimum=256, maximum=MAX_IMAGE_SIZE, value=576, step=32, elem_id="image_width_selector", elem_classes="image-width-selector", scale=1, visible=False) | |
image_height = gr.Number(label="Höhe", minimum=256, maximum=MAX_IMAGE_SIZE, value=1024, step=32, elem_id="image_height_selector", elem_classes="image-height-selector", scale=1, visible=False) | |
with gr.Row(): | |
image_ratio_buttons = gr.Radio(["9:16", "3:4", "2:3", "1:1"], value="9:16", label="Hochformat", show_label=True, info="Seitenverhältniss drehen", interactive=True, elem_id="image_ratio_buttons", elem_classes="image-ratio-buttons", container=True, scale=2) | |
switch_width_height = gr.Button("", size="sm", elem_id="switch_width_height", elem_classes="switch-ratio-btn", variant="primary", scale=1) | |
with gr.Column(): | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True, elem_classes="random-seed-cb toggle-btn") | |
image_seed = gr.Slider(label="Seed", info="Jeder Seed generiert ein anderes Bild mit dem selben Prompt", minimum=0, step=1, value=42, maximum=MAX_SEED, elem_id="image_seed", elem_classes="image-seed hide-progress", interactive=False) | |
with gr.Column(scale=4): # min_width=600, | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Row(): | |
output_image = gr.Image(show_label=False, min_width=320, scale=3, elem_id="output_image", elem_classes="output-image", type="filepath", format="webp") | |
# gallery = gr.Gallery(label="Bisher erstellte Bilder", show_label=True, value=get_gallery_images(IMAGE_DIR), elem_id="gallery", columns=[4], object_fit="cover", height="auto", interactive=False, format="webp") | |
with gr.Column(scale=1, visible=False, elem_classes="image-info-wrapper") as image_info_wrapper: | |
with gr.Group(): | |
image_informations = gr.Markdown("""## Bildinformationen""", visible=True) | |
with gr.Row(elem_classes="img-size-wrapper"): | |
image_info_tb_width = gr.Textbox(label="Breite", lines=1, max_lines=1, interactive=False, show_copy_button=True, elem_classes="image-info-tb-width") | |
image_info_tb_height = gr.Textbox(label="Höhe", lines=1, max_lines=1, interactive=False, show_copy_button=True, elem_classes="image-info-tb-height") | |
with gr.Row(elem_classes="img-seed-wrapper"): | |
image_info_tb_seed = gr.Textbox(label="Seed", lines=1, max_lines=1, interactive=False, show_copy_button=True, elem_classes="image-info-tb-seed") | |
image_info_tb_prompt = gr.Textbox("Bild Prompt", lines=4, max_lines=8, interactive=False, elem_classes="hide-progress", show_copy_button=True, visible=False) | |
image_download_button = gr.DownloadButton("Bild herunterladen", value=None, elem_classes="download-button", variant="primary", visible=False) | |
output_url = gr.Textbox(label="Output URL", show_label=True, interactive=False, visible=False) | |
outpu_image_comment = gr.Json(visible=False) | |
output_dominant_image_color = gr.Textbox(show_label=False, elem_id="dominant_image_color", visible=True, elem_classes="output-dominant-image-color") | |
def switch_image_size_values(image_width, image_height): | |
return image_height, image_width | |
def switch_image_ratio_buttons(ratio_value): | |
ratio_value = ratio_value.split(":") | |
ratio_value_new = f"{int(ratio_value[1])}:{int(ratio_value[0])}" | |
if int(ratio_value[1]) > int(ratio_value[0]): | |
# Querformat | |
new_choises = ["16:9", "4:3", "3:2", "1:1"] | |
new_label = "Querformat" | |
elif int(ratio_value[1]) < int(ratio_value[0]): | |
# Hochformat | |
new_choises = ["9:16", "3:4", "2:3", "1:1"] | |
new_label = "Hochformat" | |
elif int(ratio_value[1]) == int(ratio_value[0]): | |
new_choises = image_ratio_buttons.choices | |
new_label = "Quadratisch" | |
return {"choices": new_choises, "value": ratio_value_new, "label": new_label, "__type__": "update"} | |
def calculate_ratio_values(image_ratio_buttons): | |
ratio_value = image_ratio_buttons.split(":") | |
if int(ratio_value[0]) > int(ratio_value[1]): | |
a = 1024 | |
b = int(a * int(ratio_value[1]) / int(ratio_value[0])) | |
b = round(b / 8) * 8 | |
new_width = a | |
new_height = b | |
new_label = "Querformat" | |
elif int(ratio_value[0]) < int(ratio_value[1]): | |
b = 1024 | |
a = int(b * int(ratio_value[0]) / int(ratio_value[1])) | |
a = round(a / 8) * 8 | |
new_width = a | |
new_height = b | |
new_label = "Hochformat" | |
elif int(ratio_value[0]) == int(ratio_value[1]): | |
new_width = 1024 | |
new_height = 1024 | |
new_label = "Quadratisch" | |
return {"label": new_label, "__type__": "update"}, new_width, new_height | |
switch_width_height.click(fn=switch_image_size_values, inputs=[image_width, image_height], outputs=[image_width, image_height], show_progress="hidden", show_api=False) | |
switch_width_height.click(fn=switch_image_ratio_buttons, inputs=[image_ratio_buttons], outputs=[image_ratio_buttons], show_progress="hidden", show_api=False) | |
image_ratio_buttons.input(fn=calculate_ratio_values, inputs=[image_ratio_buttons], outputs=[image_ratio_buttons, image_width, image_height], show_progress="hidden", show_api=False) | |
run_button.click(fn=lambda: ({"interactive": False, "__type__": "update"}, {"interactive": False, "__type__": "update"}, {"interactive": False, "__type__": "update"}), outputs=[run_button, enhance_prompt_button, random_prompt_button], show_api=False).then(fn=clear_components, outputs=[output_image], show_api=False).then( | |
fn=process, inputs=[text_prompt, select_model, image_width, image_height, image_ratio_buttons, image_seed, randomize_seed], outputs=[text_prompt, output_image, output_url, image_informations, image_info_tb_prompt, image_info_tb_width, image_info_tb_height, image_info_tb_seed, image_download_button, output_dominant_image_color, image_seed] | |
).then(fn=lambda: ({"interactive": True, "__type__": "update"}, {"interactive": True, "__type__": "update"}, {"interactive": True, "__type__": "update"}), outputs=[run_button, enhance_prompt_button, random_prompt_button], show_api=False) | |
randomize_seed.input(lambda x: {"interactive": False if x == True else True, "__type__": "update"}, inputs=[randomize_seed], outputs=[image_seed], show_api=False) | |
enhance_prompt_button.click(fn=groq_enhance_process, inputs=[text_prompt], outputs=[text_prompt], show_api=False) | |
random_prompt_button.click(fn=groq_enhance_process, inputs=None, outputs=[text_prompt], show_api=False) | |
# MARK: Gradio LAUNCH | |
demo.launch(show_api=True) | |