UCDR-Net / app.py
Baptlem's picture
Update app.py
ce09356
raw
history blame
8.43 kB
# This file is adapted from https://huggingface.co/spaces/diffusers/controlnet-canny/blob/main/app.py
# The original license file is LICENSE.ControlNet in this repo.
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel, FlaxDPMSolverMultistepScheduler
from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed
from flax.training.common_utils import shard
from flax.jax_utils import replicate
from diffusers.utils import load_image
import jax.numpy as jnp
import jax
import cv2
from PIL import Image
import numpy as np
import gradio as gr
def create_key(seed=0):
return jax.random.PRNGKey(seed)
def load_controlnet(controlnet_version):
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"Baptlem/baptlem-controlnet",
subfolder=controlnet_version,
from_flax=True,
dtype=jnp.float32,
)
return controlnet, controlnet_params
def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"):
controlnet, controlnet_params = load_controlnet(controlnet_version)
scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(
sb_path,
subfolder="scheduler"
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
sb_path,
controlnet=controlnet,
revision="flax",
dtype=jnp.bfloat16
)
pipe.scheduler = scheduler
params["controlnet"] = controlnet_params
params["scheduler"] = scheduler_params
return pipe, params
controlnet_path = "Baptlem/baptlem-controlnet"
controlnet_version = "coyo-500k"
# Constants
low_threshold = 100
high_threshold = 200
pipe, params = load_sb_pipe(controlnet_version)
# pipe.enable_xformers_memory_efficient_attention()
# pipe.enable_model_cpu_offload()
# pipe.enable_attention_slicing()
print("Loaded models...")
def pipe_inference(
image,
prompt,
is_canny=False,
num_samples=4,
resolution=128,
num_inference_steps=50,
guidance_scale=7.5,
seed=0,
negative_prompt="",
):
print("Entered pipe...")
if not isinstance(image, np.ndarray):
image = np.array(image)
processed_image = resize_image(image, resolution)
if not is_canny:
resized_image, processed_image = preprocess_canny(processed_image, resolution)
rng = create_key(seed)
rng = jax.random.split(rng, jax.device_count())
prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples)
processed_image = pipe.prepare_image_inputs([processed_image] * num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)
print("Inference...")
output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
print("Finished inference...")
# all_outputs = []
# all_outputs.append(image)
# if not is_canny:
# all_outputs.append(resized_image)
# for image in output.images:
# all_outputs.append(image)
all_outputs = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
return all_outputs
def resize_image(image, resolution):
if not isinstance(image, np.ndarray):
image = np.array(image)
h, w = image.shape[:2]
ratio = w/h
if ratio > 1 :
resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST)
elif ratio < 1 :
resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST)
else:
resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST)
return Image.fromarray(resized_image)
def preprocess_canny(image, resolution=128):
if not isinstance(image, np.ndarray):
image = np.array(image)
processed_image = cv2.Canny(image, low_threshold, high_threshold)
processed_image = processed_image[:, :, None]
processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2)
resized_image = Image.fromarray(image)
processed_image = Image.fromarray(processed_image)
return resized_image, processed_image
def create_demo(process, max_images=12, default_num_images=4):
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown('## Control Stable Diffusion with Canny Edge Maps')
with gr.Row():
with gr.Column():
input_image = gr.Image(source='upload', type='numpy')
prompt = gr.Textbox(label='Prompt')
run_button = gr.Button(label='Run')
with gr.Accordion('Advanced options', open=False):
is_canny = gr.Checkbox(
label='Is canny', value=False)
num_samples = gr.Slider(label='Images',
minimum=1,
maximum=max_images,
value=default_num_images,
step=1)
"""
canny_low_threshold = gr.Slider(
label='Canny low threshold',
minimum=1,
maximum=255,
value=100,
step=1)
canny_high_threshold = gr.Slider(
label='Canny high threshold',
minimum=1,
maximum=255,
value=200,
step=1)
"""
resolution = gr.Slider(label='Resolution',
minimum=128,
maximum=128,
value=128,
step=1)
num_steps = gr.Slider(label='Steps',
minimum=1,
maximum=100,
value=20,
step=1)
guidance_scale = gr.Slider(label='Guidance Scale',
minimum=0.1,
maximum=30.0,
value=7.5,
step=0.1)
seed = gr.Slider(label='Seed',
minimum=-1,
maximum=2147483647,
step=1,
randomize=True)
n_prompt = gr.Textbox(
label='Negative Prompt',
value=
'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
)
with gr.Column():
result = gr.Gallery(label='Output',
show_label=False,
elem_id='gallery').style(grid=2,
height='auto')
inputs = [
input_image,
prompt,
is_canny,
num_samples,
resolution,
#canny_low_threshold,
#canny_high_threshold,
num_steps,
guidance_scale,
seed,
n_prompt,
]
prompt.submit(fn=process, inputs=inputs, outputs=result)
run_button.click(fn=process,
inputs=inputs,
outputs=result,
api_name='canny')
return demo
if __name__ == '__main__':
pipe_inference
demo = create_demo(pipe_inference)
demo.queue().launch()
# gr.Interface(create_demo).launch()