# This file is adapted from https://huggingface.co/spaces/diffusers/controlnet-canny/blob/main/app.py # The original license file is LICENSE.ControlNet in this repo. from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel, FlaxDPMSolverMultistepScheduler from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed from flax.training.common_utils import shard from flax.jax_utils import replicate from diffusers.utils import load_image import jax.numpy as jnp import jax import cv2 from PIL import Image import numpy as np import gradio as gr def create_key(seed=0): return jax.random.PRNGKey(seed) def load_controlnet(controlnet_version): controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( "Baptlem/baptlem-controlnet", subfolder=controlnet_version, from_flax=True, dtype=jnp.float32, ) return controlnet, controlnet_params def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"): controlnet, controlnet_params = load_controlnet(controlnet_version) scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained( sb_path, subfolder="scheduler" ) pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( sb_path, controlnet=controlnet, revision="flax", dtype=jnp.bfloat16 ) pipe.scheduler = scheduler params["controlnet"] = controlnet_params params["scheduler"] = scheduler_params return pipe, params controlnet_path = "Baptlem/baptlem-controlnet" controlnet_version = "coyo-500k" # Constants low_threshold = 100 high_threshold = 200 pipe, params = load_sb_pipe(controlnet_version) # pipe.enable_xformers_memory_efficient_attention() # pipe.enable_model_cpu_offload() # pipe.enable_attention_slicing() print("Loaded models...") def pipe_inference( image, prompt, is_canny=False, num_samples=4, resolution=128, num_inference_steps=50, guidance_scale=7.5, seed=0, negative_prompt="", ): print("Entered pipe...") if not isinstance(image, np.ndarray): image = np.array(image) processed_image = resize_image(image, resolution) if not is_canny: resized_image, processed_image = preprocess_canny(processed_image, resolution) rng = create_key(seed) rng = jax.random.split(rng, jax.device_count()) prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples) negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples) processed_image = pipe.prepare_image_inputs([processed_image] * num_samples) p_params = replicate(params) prompt_ids = shard(prompt_ids) negative_prompt_ids = shard(negative_prompt_ids) processed_image = shard(processed_image) print("Inference...") output = pipe( prompt_ids=prompt_ids, image=processed_image, params=p_params, prng_seed=rng, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, neg_prompt_ids=negative_prompt_ids, jit=True, ).images # all_outputs = [] # all_outputs.append(image) # if not is_canny: # all_outputs.append(resized_image) # for image in output.images: # all_outputs.append(image) all_outputs = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) return all_outputs def resize_image(image, resolution): if not isinstance(image, np.ndarray): image = np.array(image) h, w = image.shape[:2] ratio = w/h if ratio > 1 : resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST) elif ratio < 1 : resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST) else: resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST) return Image.fromarray(resized_image) def preprocess_canny(image, resolution=128): if not isinstance(image, np.ndarray): image = np.array(image) processed_image = cv2.Canny(image, low_threshold, high_threshold) processed_image = processed_image[:, :, None] processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2) resized_image = Image.fromarray(image) processed_image = Image.fromarray(processed_image) return resized_image, processed_image def create_demo(process, max_images=12, default_num_images=4): with gr.Blocks() as demo: with gr.Row(): gr.Markdown('## Control Stable Diffusion with Canny Edge Maps') with gr.Row(): with gr.Column(): input_image = gr.Image(source='upload', type='numpy') prompt = gr.Textbox(label='Prompt') run_button = gr.Button(label='Run') with gr.Accordion('Advanced options', open=False): is_canny = gr.Checkbox( label='Is canny', value=False) num_samples = gr.Slider(label='Images', minimum=1, maximum=max_images, value=default_num_images, step=1) """ canny_low_threshold = gr.Slider( label='Canny low threshold', minimum=1, maximum=255, value=100, step=1) canny_high_threshold = gr.Slider( label='Canny high threshold', minimum=1, maximum=255, value=200, step=1) """ resolution = gr.Slider(label='Resolution', minimum=128, maximum=128, value=128, step=1) num_steps = gr.Slider(label='Steps', minimum=1, maximum=100, value=20, step=1) guidance_scale = gr.Slider(label='Guidance Scale', minimum=0.1, maximum=30.0, value=7.5, step=0.1) seed = gr.Slider(label='Seed', minimum=-1, maximum=2147483647, step=1, randomize=True) n_prompt = gr.Textbox( label='Negative Prompt', value= 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' ) with gr.Column(): result = gr.Gallery(label='Output', show_label=False, elem_id='gallery').style(grid=2, height='auto') inputs = [ input_image, prompt, is_canny, num_samples, resolution, #canny_low_threshold, #canny_high_threshold, num_steps, guidance_scale, seed, n_prompt, ] prompt.submit(fn=process, inputs=inputs, outputs=result) run_button.click(fn=process, inputs=inputs, outputs=result, api_name='canny') return demo if __name__ == '__main__': pipe_inference demo = create_demo(pipe_inference) demo.queue().launch() # gr.Interface(create_demo).launch()