animrods's picture
Update app.py
527b2cf verified
raw
history blame
7.21 kB
import gradio as gr
import torch
import numpy as np
import diffusers
import os
import random
import spaces
from PIL import Image
hf_token = os.environ.get("HF_TOKEN")
from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
device = "cuda" #if torch.cuda.is_available() else "cpu"
pipe = AutoPipelineForText2Image.from_pretrained("briaai/BRIA-2.3", torch_dtype=torch.float16, force_zeros_for_empty_prompt=False).to(device)
pipe.load_ip_adapter("briaai/Image-Prompt", subfolder='models', weight_name="ip_adapter_bria.bin")
pipe.to(device)
# default_negative_prompt= "" #"Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
MAX_SEED = np.iinfo(np.int32).max
@spaces.GPU(enable_queue=True)
def predict(prompt, files, ip_adapter_scale=0.5, negative_prompt="", seed=100, randomize_seed=False, center_crop=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=50, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
ip_adapter_images = []
for img in upload_images:
ip_adapter_images.append(load_image(img))
# ip_adapter_images = [Image.open(image) for image in ip_adapter_images]
# # Optionally resize images if center crop is not selected
# if not center_crop:
# ip_adapter_images = [image.resize((224, 224)) for image in ip_adapter_images]
generator = torch.Generator(device="cuda").manual_seed(seed)
pipe.set_ip_adapter_scale([ip_adapter_scale])
image = pipe(
prompt=prompt,
ip_adapter_image=[ip_adapter_image],
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
generator=generator,
).images[0]
return image, seed
examples = [
["high quality", "example1.png", 1.0, "", 1000, False, False, 1152, 896],
["capybara", "example2.png", 0.7, "", 1000, False, False, 1152, 896],
]
css="""
#col-container {
margin: 0 auto;
max-width: 1024px;
}
#result img{
object-position: top;
}
#result .image-container{
height: 100%
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# Bria's Image-Prompt-Adapter
""")
with gr.Row():
with gr.Column():
# ip_adapter_images = gr.Gallery(label="Input Images", elem_id="image-gallery").style(grid=[2], preview=True)
# ip_adapter_images = gr.Gallery(label="Input Images", elem_id="image-gallery", show_label=True)#.style(grid=[2])
ip_adapter_images = gr.Gallery(columns=4, interactive=True, label="Input Images")
files = gr.File(
label="Input Image/s",
file_types=["image"],
file_count="multiple"
)
uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=200)
ip_adapter_scale = gr.Slider(
label="Image Input Scale",
info="Use 1 for creating image variations",
minimum=0.0,
maximum=1.0,
step=0.05,
value=1.0,
)
with gr.Column():
result = gr.Image(label="Result", elem_id="result", format="png")
prompt = gr.Text(
label="Prompt",
show_label=True,
lines=1,
placeholder="Enter your prompt",
container=True,
info='For image variation, leave empty or try a prompt like: "high quality".'
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=2048,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=2048,
step=32,
value=1024,
)
run_button = gr.Button("Run", scale=0)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=1000,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
center_crop = gr.Checkbox(label="Center Crop image", value=False, info="If not checked, the IP-Adapter image input would be resized to a square.")
# with gr.Row():
# width = gr.Slider(
# label="Width",
# minimum=256,
# maximum=2048,
# step=32,
# value=1024,
# )
# height = gr.Slider(
# label="Height",
# minimum=256,
# maximum=2048,
# step=32,
# value=1024,
# )
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=7.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=25,
)
files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])
# gr.Examples(
# examples=examples,
# fn=predict,
# inputs=[prompt, ip_adapter_images, ip_adapter_scale, negative_prompt, seed, randomize_seed, center_crop, width, height],
# outputs=[result, seed],
# cache_examples="lazy"
# )
gr.on(
triggers=[run_button.click, prompt.submit],
fn=predict,
inputs=[prompt, files, ip_adapter_scale, negative_prompt, seed, randomize_seed, center_crop, width, height, guidance_scale, num_inference_steps],
outputs=[result, seed]
)
demo.queue(max_size=25,api_open=False).launch(show_api=False)
# image_blocks.queue(max_size=25,api_open=False).launch(show_api=False)