animrods's picture
Update app.py
0cd49b7 verified
raw
history blame
8.59 kB
import gradio as gr
import torch
import numpy as np
import diffusers
import os
import random
import spaces
from PIL import Image
hf_token = os.environ.get("HF_TOKEN")
from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
device = "cuda" #if torch.cuda.is_available() else "cpu"
pipe = AutoPipelineForText2Image.from_pretrained("briaai/BRIA-2.3", torch_dtype=torch.float16, force_zeros_for_empty_prompt=False).to(device)
pipe.load_ip_adapter("briaai/Image-Prompt", subfolder='models', weight_name="ip_adapter_bria.bin")
pipe.to(device)
# default_negative_prompt= "" #"Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
MAX_SEED = np.iinfo(np.int32).max
@spaces.GPU(enable_queue=True)
def predict(prompt, upload_images, ip_adapter_scale=0.5, negative_prompt="", seed=100, randomize_seed=False, center_crop=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=50, mode="Basic", progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
ip_adapter_images = []
for img in upload_images:
ip_adapter_images.append(load_image(img))
# ip_adapter_images = [Image.open(image) for image in ip_adapter_images]
# # Optionally resize images if center crop is not selected
# if not center_crop:
# ip_adapter_images = [image.resize((224, 224)) for image in ip_adapter_images]
generator = torch.Generator(device="cuda").manual_seed(seed)
if mode == "Style-Only":
adapter_scale = {"down": {"block_2": [ip_adapter_scale, 0.0]}, "up": {"block_0": [0.0, ip_adapter_scale, 0.0]}, "mid": ip_adapter_scale}
elif mode == "Style2":
adapter_scale = {"down": {"block_2": [ip_adapter_scale, ip_adapter_scale]}, "up": {"block_0": [0.0, ip_adapter_scale, 0.0]}}
elif mode == "Style3":
adapter_scale = {"down": {"block_2": [ip_adapter_scale, 0.0], "block_1": [0.0, ip_adapter_scale]}, "up": {"block_0": [0.0, ip_adapter_scale, 0.0]}}
else:
adapter_scale = ip_adapter_scale
pipe.set_ip_adapter_scale([adapter_scale])
image = pipe(
prompt=prompt,
ip_adapter_image=[ip_adapter_images],
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
generator=generator,
).images[0]
return image, seed
def swap_to_gallery(images):
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
def remove_back_to_files():
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
# examples = [
# ["high quality", ["example1.png"], 1.0, "", 1000, False, False, 1152, 896, 5.0, 30, "Regular"],
# ["capybara", ["example2.png"], 0.7, "", 1000, False, False, 1152, 896, 5.0, 30, "Style-Only"],
# ]
css="""
#col-container {
margin: 0 auto;
max-width: 1024px;
}
#result img{
object-position: top;
}
#result .image-container{
height: 100%
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# Bria's Multi-Image-Prompt-Adapter
""")
with gr.Row():
with gr.Column():
# ip_adapter_images = gr.Gallery(label="Input Images", elem_id="image-gallery").style(grid=[2], preview=True)
# ip_adapter_images = gr.Gallery(label="Input Images", elem_id="image-gallery", show_label=True)#.style(grid=[2])
# ip_adapter_images = gr.Gallery(columns=4, interactive=True, label="Input Images")
files = gr.File(
label="Input Image/s",
file_types=["image"],
file_count="multiple"
)
uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=200)
with gr.Column(visible=False) as clear_button:
remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
with gr.Row():
ip_adapter_scale = gr.Slider(
label="Image Input Scale",
info="Use 1 for creating image variations",
minimum=0.0,
maximum=1.0,
step=0.05,
value=1.0,
)
mode = gr.Dropdown(
["Regular", "Style-Only"], label="Mode",#, info="Mode"
)
with gr.Column():
result = gr.Image(label="Result", elem_id="result", format="png")
prompt = gr.Text(
label="Prompt",
show_label=True,
lines=1,
placeholder="Enter your prompt",
container=True,
info='For image variation, leave empty or try a prompt like: "high quality".'
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=2048,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=2048,
step=32,
value=1024,
)
run_button = gr.Button("Run", scale=0)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=1000,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
center_crop = gr.Checkbox(label="Center Crop image", value=False, info="If not checked, the IP-Adapter image input would be resized to a square.")
# with gr.Row():
# width = gr.Slider(
# label="Width",
# minimum=256,
# maximum=2048,
# step=32,
# value=1024,
# )
# height = gr.Slider(
# label="Height",
# minimum=256,
# maximum=2048,
# step=32,
# value=1024,
# )
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=7.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=25,
)
files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])
# # gr.Examples(
# # examples=examples,
# # fn=predict,
# # inputs=[prompt, files, ip_adapter_scale, negative_prompt, seed, randomize_seed, center_crop, width, height, guidance_scale, num_inference_steps, mode],
# # outputs=[result, seed],
# # cache_examples="lazy"
# # )
gr.on(
triggers=[run_button.click, prompt.submit],
fn=predict,
inputs=[prompt, files, ip_adapter_scale, negative_prompt, seed, randomize_seed, center_crop, width, height, guidance_scale, num_inference_steps, mode],
outputs=[result, seed]
)
demo.queue(max_size=25,api_open=False).launch(show_api=False)
# image_blocks.queue(max_size=25,api_open=False).launch(show_api=False)