Spaces:
Running
Running
import os | |
import re | |
import requests | |
import tempfile | |
import gradio as gr | |
from PIL import Image, ImageDraw | |
from config import theme | |
from public.data.images.loras.flux1 import loras as flux1_loras | |
# os.makedirs(os.getenv("HF_HOME"), exist_ok=True) | |
# UI | |
with gr.Blocks( | |
theme=theme, | |
fill_width=True, | |
css_paths=[os.path.join("static/css", f) for f in os.listdir("static/css")], | |
) as demo: | |
# States | |
data_state = gr.State() | |
local_state = gr.BrowserState( | |
{ | |
"selected_loras": [], | |
} | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Label("AllFlux", show_label=False) | |
with gr.Accordion("Settings", open=True): | |
with gr.Group(): | |
height_slider = gr.Slider( | |
minimum=64, | |
maximum=2048, | |
value=1024, | |
step=64, | |
label="Height", | |
interactive=True, | |
) | |
width_slider = gr.Slider( | |
minimum=64, | |
maximum=2048, | |
value=1024, | |
step=64, | |
label="Width", | |
interactive=True, | |
) | |
with gr.Group(): | |
num_images_slider = gr.Slider( | |
minimum=1, | |
maximum=4, | |
value=1, | |
step=1, | |
label="Number of Images", | |
interactive=True, | |
) | |
toggles = gr.CheckboxGroup( | |
choices=["Realtime", "Randomize Seed"], | |
value=["Randomize Seed"], | |
show_label=False, | |
interactive=True, | |
) | |
with gr.Accordion("Advanced", open=False): | |
num_steps_slider = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=20, | |
step=1, | |
label="Steps", | |
interactive=True, | |
) | |
guidance_scale_slider = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=3.5, | |
step=0.1, | |
label="Guidance Scale", | |
interactive=True, | |
) | |
seed_slider = gr.Slider( | |
minimum=0, | |
maximum=4294967295, | |
value=42, | |
step=1, | |
label="Seed", | |
interactive=True, | |
) | |
upscale_slider = gr.Slider( | |
minimum=2, | |
maximum=4, | |
value=2, | |
step=2, | |
label="Upscale", | |
interactive=True, | |
) | |
scheduler_dropdown = gr.Dropdown( | |
label="Scheduler", | |
choices=[ | |
"Euler a", | |
"Euler", | |
"LMS", | |
"Heun", | |
"DPM++ 2", | |
"DPM++ 2 a", | |
"DPM++ SDE", | |
"DPM++ SDE Karras", | |
"DDIM", | |
"PLMS", | |
], | |
value="Euler a", | |
interactive=True, | |
) | |
gr.LoginButton() | |
gr.Markdown( | |
""" | |
Yurrrrrrrrrrrr, WIP | |
""" | |
) | |
with gr.Column(scale=3): | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Textbox( | |
show_label=False, | |
placeholder="Enter your prompt here...", | |
lines=3, | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
submit_btn = gr.Button("Submit") | |
with gr.Column(scale=1): | |
ai_improve_btn = gr.Button("💡", link="#improve-prompt") | |
with gr.Group(): | |
output_gallery = gr.Gallery( | |
label="Outputs", interactive=False, height=500 | |
) | |
with gr.Row(): | |
upscale_selected_btn = gr.Button("Upscale Selected", size="sm") | |
upscale_all_btn = gr.Button("Upscale All", size="sm") | |
create_similar_btn = gr.Button("Create Similar", size="sm") | |
with gr.Accordion("Output History", open=False): | |
with gr.Group(): | |
output_history_gallery = gr.Gallery( | |
show_label=False, interactive=False, height=500 | |
) | |
with gr.Row(): | |
clear_history_btn = gr.Button("Clear All", size="sm") | |
download_history_btn = gr.Button("Download All", size="sm") | |
with gr.Accordion("Image Playground", open=True): | |
def show_info(content: str | None = None): | |
info_checkbox = gr.Checkbox( | |
value=False, label="Show Info", interactive=True | |
) | |
def show_info(info_checkbox): | |
return ( | |
gr.Markdown( | |
f"""Sup, need some help here, please check the community tab. {content}""" | |
) | |
if info_checkbox | |
else None | |
) | |
with gr.Tabs(): | |
with gr.Tab("Img 2 Img"): | |
with gr.Group(): | |
img2img_img = gr.Image(show_label=False, interactive=True) | |
img2img_strength_slider = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=1.0, | |
step=0.1, | |
label="Strength", | |
interactive=True, | |
) | |
show_info() | |
with gr.Tab("Inpaint"): | |
with gr.Group(): | |
inpaint_img = gr.ImageMask( | |
show_label=False, interactive=True, type="pil" | |
) | |
generate_mask_btn = gr.Button( | |
"Remove Background", size="sm" | |
) | |
use_fill_pipe_inpaint = gr.Checkbox( | |
value=True, | |
label="Use Fill Pipeline 🧪", | |
interactive=True, | |
) | |
show_info() | |
inpaint_img.upload( | |
fn=lambda x: ( | |
gr.update(height=x["layers"][0].height + 96) | |
if x is not None | |
else None | |
), | |
inputs=inpaint_img, | |
outputs=inpaint_img, | |
) | |
with gr.Tab("Outpaint"): | |
outpaint_img = gr.Image( | |
show_label=False, interactive=True, type="pil" | |
) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=3): | |
ratio_9_16 = gr.Radio( | |
label="Image Ratio", | |
choices=["9:16", "16:9", "1:1", "Height & Width"], | |
value="9:16", | |
container=True, | |
interactive=True, | |
) | |
with gr.Column(scale=1): | |
mask_position = gr.Dropdown( | |
choices=[ | |
"Middle", | |
"Left", | |
"Right", | |
"Top", | |
"Bottom", | |
], | |
value="Middle", | |
label="Alignment", | |
interactive=True, | |
) | |
with gr.Group(): | |
resize_options = gr.Radio( | |
choices=["Full", "75%", "50%", "33%", "25%", "Custom"], | |
value="Full", | |
label="Resize", | |
interactive=True, | |
) | |
resize_option_custom = gr.State() | |
def resize_options_render(resize_option): | |
if resize_option == "Custom": | |
resize_option_custom = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=50, | |
step=1, | |
label="Custom Size %", | |
interactive=True, | |
) | |
with gr.Accordion("Advanced settings", open=False): | |
with gr.Group(): | |
mask_overlap_slider = gr.Slider( | |
label="Mask Overlap %", | |
minimum=1, | |
maximum=50, | |
value=10, | |
step=1, | |
interactive=True, | |
) | |
with gr.Row(): | |
overlap_top = gr.Checkbox( | |
value=True, | |
label="Overlap Top", | |
interactive=True, | |
) | |
overlap_right = gr.Checkbox( | |
value=True, | |
label="Overlap Right", | |
interactive=True, | |
) | |
with gr.Row(): | |
overlap_left = gr.Checkbox( | |
value=True, | |
label="Overlap Left", | |
interactive=True, | |
) | |
overlap_bottom = gr.Checkbox( | |
value=True, | |
label="Overlap Bottom", | |
interactive=True, | |
) | |
mask_preview_btn = gr.Button( | |
"Preview", interactive=True | |
) | |
mask_preview_img = gr.Image( | |
show_label=False, visible=False, interactive=True | |
) | |
def prepare_image_and_mask( | |
image, | |
width, | |
height, | |
overlap_percentage, | |
resize_option, | |
custom_resize_percentage, | |
alignment, | |
overlap_left, | |
overlap_right, | |
overlap_top, | |
overlap_bottom, | |
): | |
target_size = (width, height) | |
scale_factor = min( | |
target_size[0] / image.width, | |
target_size[1] / image.height, | |
) | |
new_width = int(image.width * scale_factor) | |
new_height = int(image.height * scale_factor) | |
source = image.resize( | |
(new_width, new_height), Image.LANCZOS | |
) | |
if resize_option == "Full": | |
resize_percentage = 100 | |
elif resize_option == "75%": | |
resize_percentage = 75 | |
elif resize_option == "50%": | |
resize_percentage = 50 | |
elif resize_option == "33%": | |
resize_percentage = 33 | |
elif resize_option == "25%": | |
resize_percentage = 25 | |
else: # Custom | |
resize_percentage = custom_resize_percentage | |
# Calculate new dimensions based on percentage | |
resize_factor = resize_percentage / 100 | |
new_width = int(source.width * resize_factor) | |
new_height = int(source.height * resize_factor) | |
# Ensure minimum size of 64 pixels | |
new_width = max(new_width, 64) | |
new_height = max(new_height, 64) | |
# Resize the image | |
source = source.resize( | |
(new_width, new_height), Image.LANCZOS | |
) | |
# Calculate the overlap in pixels based on the percentage | |
overlap_x = int(new_width * (overlap_percentage / 100)) | |
overlap_y = int(new_height * (overlap_percentage / 100)) | |
# Ensure minimum overlap of 1 pixel | |
overlap_x = max(overlap_x, 1) | |
overlap_y = max(overlap_y, 1) | |
# Calculate margins based on alignment | |
if alignment == "Middle": | |
margin_x = (target_size[0] - new_width) // 2 | |
margin_y = (target_size[1] - new_height) // 2 | |
elif alignment == "Left": | |
margin_x = 0 | |
margin_y = (target_size[1] - new_height) // 2 | |
elif alignment == "Right": | |
margin_x = target_size[0] - new_width | |
margin_y = (target_size[1] - new_height) // 2 | |
elif alignment == "Top": | |
margin_x = (target_size[0] - new_width) // 2 | |
margin_y = 0 | |
elif alignment == "Bottom": | |
margin_x = (target_size[0] - new_width) // 2 | |
margin_y = target_size[1] - new_height | |
# Adjust margins to eliminate gaps | |
margin_x = max( | |
0, min(margin_x, target_size[0] - new_width) | |
) | |
margin_y = max( | |
0, min(margin_y, target_size[1] - new_height) | |
) | |
# Create a new background image and paste the resized source image | |
background = Image.new( | |
"RGB", target_size, (255, 255, 255) | |
) | |
background.paste(source, (margin_x, margin_y)) | |
# Create the mask | |
mask = Image.new("L", target_size, 255) | |
mask_draw = ImageDraw.Draw(mask) | |
# Calculate overlap areas | |
white_gaps_patch = 2 | |
left_overlap = ( | |
margin_x + overlap_x | |
if overlap_left | |
else margin_x + white_gaps_patch | |
) | |
right_overlap = ( | |
margin_x + new_width - overlap_x | |
if overlap_right | |
else margin_x + new_width - white_gaps_patch | |
) | |
top_overlap = ( | |
margin_y + overlap_y | |
if overlap_top | |
else margin_y + white_gaps_patch | |
) | |
bottom_overlap = ( | |
margin_y + new_height - overlap_y | |
if overlap_bottom | |
else margin_y + new_height - white_gaps_patch | |
) | |
if alignment == "Left": | |
left_overlap = ( | |
margin_x + overlap_x | |
if overlap_left | |
else margin_x | |
) | |
elif alignment == "Right": | |
right_overlap = ( | |
margin_x + new_width - overlap_x | |
if overlap_right | |
else margin_x + new_width | |
) | |
elif alignment == "Top": | |
top_overlap = ( | |
margin_y + overlap_y | |
if overlap_top | |
else margin_y | |
) | |
elif alignment == "Bottom": | |
bottom_overlap = ( | |
margin_y + new_height - overlap_y | |
if overlap_bottom | |
else margin_y + new_height | |
) | |
# Draw the mask | |
mask_draw.rectangle( | |
[ | |
(left_overlap, top_overlap), | |
(right_overlap, bottom_overlap), | |
], | |
fill=0, | |
) | |
return background, mask | |
mask_preview_btn.click( | |
fn=prepare_image_and_mask, | |
inputs=[ | |
outpaint_img, | |
width_slider, | |
height_slider, | |
mask_overlap_slider, | |
resize_options, | |
resize_option_custom, | |
mask_position, | |
overlap_left, | |
overlap_right, | |
overlap_top, | |
overlap_bottom, | |
], | |
outputs=[mask_preview_img, outpaint_img], | |
) | |
mask_preview_img.clear( | |
fn=lambda: gr.update(visible=False), | |
outputs=mask_preview_img, | |
) | |
use_fill_pipe_outpaint = gr.Checkbox( | |
value=True, | |
label="Use Fill Pipeline 🧪", | |
interactive=True, | |
) | |
show_info() | |
with gr.Tab("In-Context"): | |
with gr.Group(): | |
incontext_img = gr.Image(show_label=False, interactive=True) | |
# https://huggingface.co/spaces/Yuanshi/OminiControl | |
show_info(content="1024 res is in beta") | |
with gr.Tab("IP-Adapter"): | |
with gr.Group(): | |
ip_adapter_img = gr.Image( | |
show_label=False, interactive=True | |
) | |
ip_adapter_img_scale = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.7, | |
step=0.1, | |
label="Scale", | |
interactive=True, | |
) | |
# https://huggingface.co/InstantX/FLUX.1-dev-IP-Adapter | |
show_info(content="1024 res is in beta") | |
with gr.Tab("Canny"): | |
with gr.Group(): | |
canny_img = gr.Image(show_label=False, interactive=True) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=3): | |
canny_controlnet_conditioning_scale = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.65, | |
step=0.05, | |
label="ControlNet Conditioning Scale", | |
interactive=True, | |
) | |
with gr.Column(scale=1): | |
canny_img_is_preprocessed = gr.Checkbox( | |
value=True, | |
label="Preprocessed", | |
interactive=True, | |
) | |
with gr.Tab("Tile"): | |
with gr.Group(): | |
tile_img = gr.Image(show_label=False, interactive=True) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=3): | |
tile_controlnet_conditioning_scale = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.45, | |
step=0.05, | |
label="ControlNet Conditioning Scale", | |
interactive=True, | |
) | |
with gr.Column(scale=1): | |
tile_img_is_preprocessed = gr.Checkbox( | |
value=True, | |
label="Preprocessed", | |
interactive=True, | |
) | |
with gr.Tab("Depth"): | |
with gr.Group(): | |
depth_img = gr.Image(show_label=False, interactive=True) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=3): | |
depth_controlnet_conditioning_scale = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.55, | |
step=0.05, | |
label="ControlNet Conditioning Scale", | |
interactive=True, | |
) | |
with gr.Column(scale=1): | |
depth_img_is_preprocessed = gr.Checkbox( | |
value=True, | |
label="Preprocessed", | |
interactive=True, | |
) | |
with gr.Tab("Blur"): | |
with gr.Group(): | |
blur_img = gr.Image(show_label=False, interactive=True) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=3): | |
blur_controlnet_conditioning_scale = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.45, | |
step=0.05, | |
label="ControlNet Conditioning Scale", | |
interactive=True, | |
) | |
with gr.Column(scale=1): | |
blur_img_is_preprocessed = gr.Checkbox( | |
value=True, | |
label="Preprocessed", | |
interactive=True, | |
) | |
with gr.Tab("Pose"): | |
with gr.Group(): | |
pose_img = gr.Image(show_label=False, interactive=True) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=3): | |
pose_controlnet_conditioning_scale = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.55, | |
step=0.05, | |
label="ControlNet Conditioning Scale", | |
interactive=True, | |
) | |
with gr.Column(scale=1): | |
pose_img_is_preprocessed = gr.Checkbox( | |
value=True, | |
label="Preprocessed", | |
interactive=True, | |
) | |
with gr.Tab("Gray"): | |
with gr.Group(): | |
gray_img = gr.Image(show_label=False, interactive=True) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=3): | |
gray_controlnet_conditioning_scale = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.45, | |
step=0.05, | |
label="ControlNet Conditioning Scale", | |
interactive=True, | |
) | |
with gr.Column(scale=1): | |
gray_img_is_preprocessed = gr.Checkbox( | |
value=True, | |
label="Preprocessed", | |
interactive=True, | |
) | |
with gr.Tab("Low Quality"): | |
with gr.Group(): | |
low_quality_img = gr.Image( | |
show_label=False, interactive=True | |
) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=3): | |
low_quality_controlnet_conditioning_scale = ( | |
gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.4, | |
step=0.05, | |
label="ControlNet Conditioning Scale", | |
interactive=True, | |
) | |
) | |
with gr.Column(scale=1): | |
low_quality_img_is_preprocessed = gr.Checkbox( | |
value=True, | |
label="Preprocessed", | |
interactive=True, | |
) | |
# with gr.Tab("Official Canny"): | |
# with gr.Group(): | |
# gr.HTML( | |
# """ | |
# <script | |
# type="module" | |
# src="https://gradio.s3-us-west-2.amazonaws.com/5.6.0/gradio.js" | |
# ></script> | |
# <gradio-app src="https://black-forest-labs-flux-1-canny-dev.hf.space"></gradio-app> | |
# """ | |
# ) | |
# with gr.Tab("Official Depth"): | |
# with gr.Group(): | |
# gr.HTML( | |
# """ | |
# <script | |
# type="module" | |
# src="https://gradio.s3-us-west-2.amazonaws.com/5.6.0/gradio.js" | |
# ></script> | |
# <gradio-app src="https://black-forest-labs-flux-1-depth-dev.hf.space"></gradio-app> | |
# """ | |
# ) | |
with gr.Tab("Auto Trainer"): | |
gr.HTML( | |
""" | |
<script | |
type="module" | |
src="https://gradio.s3-us-west-2.amazonaws.com/4.42.0/gradio.js" | |
></script> | |
<gradio-app src="https://autotrain-projects-train-flux-lora-ease.hf.space"></gradio-app> | |
""" | |
) | |
resize_mode_radio = gr.Radio( | |
label="Resize Mode", | |
choices=["Crop & Resize", "Resize Only", "Resize & Fill"], | |
value="Resize & Fill", | |
interactive=True, | |
) | |
with gr.Accordion("Prompt Generator", open=False): | |
gr.HTML( | |
""" | |
<gradio-app src="https://gokaygokay-flux-prompt-generator.hf.space"></gradio-app> | |
""" | |
) | |
with gr.Column(scale=1): | |
# Loras | |
with gr.Accordion("Loras", open=True): | |
selected_loras = gr.State([]) | |
lora_selector = gr.Gallery( | |
show_label=False, | |
value=[(l["image"], l["title"]) for l in flux1_loras], | |
container=False, | |
columns=3, | |
show_download_button=False, | |
show_fullscreen_button=False, | |
allow_preview=False, | |
) | |
with gr.Group(): | |
lora_selected = gr.Textbox( | |
show_label=False, | |
placeholder="Select a Lora to apply...", | |
container=False, | |
) | |
add_lora_btn = gr.Button("Add Lora", size="sm") | |
gr.Markdown( | |
"*You can add a Lora by entering a URL or a Hugging Face repo path." | |
) | |
# update the selected_loras state with the new lora | |
def add_lora(lora_selected): | |
title = None | |
weights = None | |
info = None | |
if isinstance(lora_selected, int): | |
# Add from lora selector | |
title = lora_selector[lora_selected]["title"] | |
weights = lora_selector[lora_selected]["weights"] | |
info = lora_selector[lora_selected]["trigger_word"] | |
elif isinstance(lora_selected, str): | |
# check if url | |
if lora_selected.startswith("http"): | |
# Check if it's a CivitAI URL | |
if "civitai.com/models/" in lora_selected: | |
try: | |
# Extract model ID and version ID from URL | |
model_id = re.search( | |
r"/models/(\d+)", lora_selected | |
).group(1) | |
version_id = re.search( | |
r"modelVersionId=(\d+)", lora_selected | |
) | |
version_id = ( | |
version_id.group(1) if version_id else None | |
) | |
# Get API token from env | |
api_token = os.getenv("CIVITAI_TOKEN") | |
headers = ( | |
{"Authorization": f"Bearer {api_token}"} | |
if api_token | |
else {} | |
) | |
# Get model version info | |
if version_id: | |
url = f"https://civitai.com/api/v1/model-versions/{version_id}" | |
else: | |
# Get latest version if no specific version | |
url = f"https://civitai.com/api/v1/models/{model_id}" | |
response = requests.get(url, headers=headers) | |
data = response.json() | |
# For models endpoint, get first version | |
if "modelVersions" in data: | |
version_data = data["modelVersions"][0] | |
else: | |
version_data = data | |
# Verify it's a LoRA for Flux | |
if ( | |
"flux" not in version_data["baseModel"].lower() | |
and "1" not in version_data["baseModel"].lower() | |
): | |
raise ValueError( | |
"This LoRA is not compatible with Flux base model" | |
) | |
# Find .safetensor file | |
safetensor_file = next( | |
( | |
f | |
for f in version_data["files"] | |
if f["name"].endswith(".safetensors") | |
), | |
None, | |
) | |
if not safetensor_file: | |
raise ValueError("No .safetensor file found") | |
# Download file to temp location | |
temp_dir = tempfile.gettempdir() | |
file_path = os.path.join( | |
temp_dir, safetensor_file["name"] | |
) | |
download_url = safetensor_file["downloadUrl"] | |
if api_token: | |
download_url += f"?token={api_token}" | |
response = requests.get( | |
download_url, headers=headers | |
) | |
with open(file_path, "wb") as f: | |
f.write(response.content) | |
# Set info from model data | |
title = data["name"] | |
weights = file_path | |
# Check usage tips for default weight | |
if "description" in version_data: | |
strength_match = re.search( | |
r"strength[:\s]+(\d*\.?\d+)", | |
version_data["description"], | |
re.IGNORECASE, | |
) | |
if strength_match: | |
weight = float(strength_match.group(1)) | |
info = ", ".join( | |
version_data.get("trainedWords", []) | |
) | |
except Exception as e: | |
gr.Error(f"Error processing CivitAI URL: {str(e)}") | |
else: | |
# check if a hugging face repo (user/repo) | |
if re.match( | |
r"^[a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+$", lora_selected | |
): | |
try: | |
# Get API token from env | |
api_token = os.getenv("HF_TOKEN") | |
headers = ( | |
{"Authorization": f"Bearer {api_token}"} | |
if api_token | |
else {} | |
) | |
# Get model info | |
url = f"https://huggingface.co/api/models/{lora_selected}" | |
response = requests.get(url, headers=headers) | |
data = response.json() | |
# Verify it's a LoRA for Flux | |
if ( | |
"tags" in data | |
and "flux-lora" not in data["tags"] | |
): | |
raise ValueError( | |
"This model is not tagged as a Flux LoRA" | |
) | |
# Find .safetensor file | |
files_url = f"https://huggingface.co/api/models/{lora_selected}/tree" | |
response = requests.get(files_url, headers=headers) | |
files = response.json() | |
safetensor_file = next( | |
( | |
f | |
for f in files | |
if f.get("path", "").endswith( | |
".safetensors" | |
) | |
), | |
None, | |
) | |
if not safetensor_file: | |
raise ValueError("No .safetensor file found") | |
# Download file to temp location | |
temp_dir = tempfile.gettempdir() | |
file_name = os.path.basename( | |
safetensor_file["path"] | |
) | |
file_path = os.path.join(temp_dir, file_name) | |
download_url = ( | |
f"https://huggingface.co/{lora_selected}" | |
f"/resolve/main/{safetensor_file['path']}" | |
) | |
response = requests.get( | |
download_url, headers=headers | |
) | |
with open(file_path, "wb") as f: | |
f.write(response.content) | |
# Set info from model data | |
title = data.get( | |
"name", lora_selected.split("/")[-1] | |
) | |
weights = file_path | |
# Check model card for weight recommendations | |
if ( | |
"cardData" in data | |
and "weight" in data["cardData"] | |
): | |
try: | |
weight = float(data["cardData"]["weight"]) | |
except (ValueError, TypeError): | |
weight = 1.0 | |
# Get trigger words from tags or model card | |
trigger_words = [] | |
if ( | |
"cardData" in data | |
and "trigger_words" in data["cardData"] | |
): | |
trigger_words.extend( | |
data["cardData"]["trigger_words"] | |
) | |
if "tags" in data: | |
trigger_words.extend( | |
t | |
for t in data["tags"] | |
if not t.startswith("flux-") | |
) | |
info = ( | |
", ".join(trigger_words) | |
if trigger_words | |
else None | |
) | |
except Exception as e: | |
gr.Error( | |
f"Error processing Hugging Face repo: {str(e)}" | |
) | |
# add lora to selected_loras | |
selected_loras.append( | |
{ | |
"title": title, | |
"weights": weights, # i.e safetensors file path | |
"info": info, | |
} | |
) | |
# render the selected_loras state as sliders | |
def render_selected_loras(selected_loras): | |
def update_lora_weight(lora_slider, selected_loras): | |
for i, lora in enumerate(selected_loras): | |
if lora["title"] == lora_slider.label: | |
lora["weight"] = lora_slider.value | |
for i, lora in enumerate(selected_loras): | |
lora_slider = gr.Slider( | |
label=lora["title"], | |
value=0.8, | |
interactive=True, | |
info=lora["info"], | |
) | |
lora_slider.change( | |
fn=update_lora_weight, | |
inputs=[lora_slider, selected_loras], | |
outputs=selected_loras, | |
) | |
demo.launch() | |