UCDR-Net / app.py
Baptlem's picture
Update app.py
cd1e8dc
raw
history blame
8.41 kB
# This file is adapted from https://huggingface.co/spaces/diffusers/controlnet-canny/blob/main/app.py
# The original license file is LICENSE.ControlNet in this repo.
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel, FlaxDPMSolverMultistepScheduler
from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed
from flax.training.common_utils import shard
from flax.jax_utils import replicate
from diffusers.utils import load_image
import jax.numpy as jnp
import jax
import cv2
from PIL import Image
import numpy as np
import gradio as gr
def create_key(seed=0):
return jax.random.PRNGKey(seed)
def load_controlnet(controlnet_version):
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"Baptlem/baptlem-controlnet",
subfolder=controlnet_version,
from_flax=True,
dtype=jnp.float32,
)
return controlnet, controlnet_params
def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"):
controlnet, controlnet_params = load_controlnet(controlnet_version)
scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(
base_model_path,
subfolder="scheduler"
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
sb_path,
controlnet=controlnet,
dtype=jnp.float32,
from_pt=True
)
pipe.scheduler = scheduler
params["controlnet"] = controlnet_params
params["scheduler"] = scheduler_params
return pipe, params
controlnet_path = "Baptlem/baptlem-controlnet"
controlnet_version = "coyo-500k"
# Constants
low_threshold = 100
high_threshold = 200
pipe, params = load_sb_pipe(controlnet_version)
pipe.enable_xformers_memory_efficient_attention()
pipe.enable_model_cpu_offload()
pipe.enable_attention_slicing()
def pipe_inference(
image,
prompt,
is_canny=False,
num_samples=4,
resolution=128,
num_inference_steps=50,
guidance_scale=7.5,
seed=0,
negative_prompt="",
):
if not isinstance(image, np.ndarray):
image = np.array(image)
resized_image = resize_image(image, resolution)
if not is_canny:
resized_image = preprocess_canny(resized_image)
rng = create_key(seed)
# rng = jax.random.split(rng,)
prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples)
processed_image = pipe.prepare_image_inputs([resized_image] * num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)
output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
neg_prompt_ids=negative_prompt_ids,
jit=True,
)
all_outputs = []
all_outputs.append(image)
if not is_canny:
all_outputs.append(resized_image)
for image in output.images:
all_outputs.append(image)
return all_outputs
def resize_image(image, resolution):
h, w = image.shape
ratio = w/h
if ratio > 1 :
resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST)
elif ratio < 1 :
resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST)
else:
resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST)
return resized_image
def preprocess_canny(image, resolution=128):
h, w = image.shape
ratio = w/h
if ratio > 1 :
resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST)
elif ratio < 1 :
resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST)
else:
resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST)
processed_image = cv2.Canny(resized_image, low_threshold, high_threshold)
processed_image = processed_image[:, :, None]
processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2)
resized_image = Image.fromarray(resized_image)
processed_image = Image.fromarray(processed_image)
return resized_image, processed_image
def create_demo(process, max_images=12, default_num_images=4):
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown('## Control Stable Diffusion with Canny Edge Maps')
with gr.Row():
with gr.Column():
input_image = gr.Image(source='upload', type='numpy')
prompt = gr.Textbox(label='Prompt')
run_button = gr.Button(label='Run')
with gr.Accordion('Advanced options', open=False):
is_canny = gr.Checkbox(
label='Is canny', value=False)
num_samples = gr.Slider(label='Images',
minimum=1,
maximum=max_images,
value=default_num_images,
step=1)
"""
canny_low_threshold = gr.Slider(
label='Canny low threshold',
minimum=1,
maximum=255,
value=100,
step=1)
canny_high_threshold = gr.Slider(
label='Canny high threshold',
minimum=1,
maximum=255,
value=200,
step=1)
"""
resolution = gr.Slider(label='Resolution',
minimum=128,
maximum=128,
value=128,
step=1)
num_steps = gr.Slider(label='Steps',
minimum=1,
maximum=100,
value=20,
step=1)
guidance_scale = gr.Slider(label='Guidance Scale',
minimum=0.1,
maximum=30.0,
value=7.5,
step=0.1)
seed = gr.Slider(label='Seed',
minimum=-1,
maximum=2147483647,
step=1,
randomize=True)
n_prompt = gr.Textbox(
label='Negative Prompt',
value=
'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
)
with gr.Column():
result = gr.Gallery(label='Output',
show_label=False,
elem_id='gallery').style(grid=2,
height='auto')
inputs = [
input_image,
prompt,
is_canny,
num_samples,
resolution,
#canny_low_threshold,
#canny_high_threshold,
num_steps,
guidance_scale,
seed,
n_prompt,
]
prompt.submit(fn=process, inputs=inputs, outputs=result)
run_button.click(fn=process,
inputs=inputs,
outputs=result,
api_name='canny')
return demo
if __name__ == '__main__':
"""
from model import Model
model = Model()
demo = create_demo(model.process_canny)
demo.queue().launch()
"""
pass