# This file is adapted from https://huggingface.co/spaces/diffusers/controlnet-canny/blob/main/app.py # The original license file is LICENSE.ControlNet in this repo. from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel, FlaxDPMSolverMultistepScheduler from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed from flax.training.common_utils import shard from flax.jax_utils import replicate from diffusers.utils import load_image import jax.numpy as jnp import jax import cv2 from PIL import Image import numpy as np import gradio as gr import os description = """ Our project is to use diffusion model to change the texture of our robotic arm simulation. To do so, we first get our simulated images. After, we process these images to get Canny Edge maps. Finally, we can get brand new images by using ControlNet. Therefore, we are able to change our simulation texture, and still keeping the image composition. Our objectif for the sprint is to perform data augmentation using ControlNet. We then look for having a model that can augment an image quickly. For now, we benchmarked our model on a node of 4 Titan RTX 24Go. We were able to generate a batch of 4 images in a average time of 1.3 seconds! We also have access to nodes composed of 8 A100 80Go GPUs. The benchmark on one of these nodes will come soon. """ def create_key(seed=0): return jax.random.PRNGKey(seed) def load_controlnet(controlnet_version): controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( "Baptlem/baptlem-controlnet", subfolder=controlnet_version, from_flax=True, dtype=jnp.float32, ) return controlnet, controlnet_params def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"): controlnet, controlnet_params = load_controlnet(controlnet_version) scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained( sb_path, subfolder="scheduler" ) pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( sb_path, controlnet=controlnet, revision="flax", dtype=jnp.bfloat16 ) pipe.scheduler = scheduler params["controlnet"] = controlnet_params params["scheduler"] = scheduler_params return pipe, params controlnet_path = "Baptlem/baptlem-controlnet" controlnet_version = "coyo-500k" # Constants low_threshold = 100 high_threshold = 200 print(os.path.abspath('.')) print(os.listdir(".")) print("Gradio version:", gr.__version__) # pipe.enable_xformers_memory_efficient_attention() # pipe.enable_model_cpu_offload() # pipe.enable_attention_slicing() print("Loaded models...") def pipe_inference( image, prompt, is_canny=False, num_samples=4, resolution=128, num_inference_steps=50, guidance_scale=7.5, model="coyo-500k", seed=0, negative_prompt="", ): print("Loading pipe") pipe, params = load_sb_pipe(model) if not isinstance(image, np.ndarray): image = np.array(image) processed_image = resize_image(image, resolution) #-> PIL if not is_canny: resized_image, processed_image = preprocess_canny(processed_image, resolution) rng = create_key(seed) rng = jax.random.split(rng, jax.device_count()) prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples) negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples) processed_image = pipe.prepare_image_inputs([processed_image] * num_samples) p_params = replicate(params) prompt_ids = shard(prompt_ids) negative_prompt_ids = shard(negative_prompt_ids) processed_image = shard(processed_image) print("Inference...") output = pipe( prompt_ids=prompt_ids, image=processed_image, params=p_params, prng_seed=rng, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, neg_prompt_ids=negative_prompt_ids, jit=True, ).images print("Finished inference...") # all_outputs = [] # all_outputs.append(image) # if not is_canny: # all_outputs.append(resized_image) # for image in output.images: # all_outputs.append(image) all_outputs = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) return all_outputs def resize_image(image, resolution): if not isinstance(image, np.ndarray): image = np.array(image) h, w = image.shape[:2] ratio = w/h if ratio > 1 : resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST) elif ratio < 1 : resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST) else: resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST) return Image.fromarray(resized_image) def preprocess_canny(image, resolution=128): if not isinstance(image, np.ndarray): image = np.array(image) processed_image = cv2.Canny(image, low_threshold, high_threshold) processed_image = processed_image[:, :, None] processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2) resized_image = Image.fromarray(image) processed_image = Image.fromarray(processed_image) return resized_image, processed_image def create_demo(process, max_images=12, default_num_images=4): with gr.Blocks() as demo: with gr.Row(): gr.Markdown('## Control Stable Diffusion with Canny Edge Maps') with gr.Row(): with gr.Column(): input_image = gr.Image(source='upload', type='numpy') prompt = gr.Textbox(label='Prompt') run_button = gr.Button(label='Run') with gr.Accordion('Advanced options', open=False): is_canny = gr.Checkbox( label='Is canny', value=False) num_samples = gr.Slider(label='Images', minimum=1, maximum=max_images, value=default_num_images, step=1) """ canny_low_threshold = gr.Slider( label='Canny low threshold', minimum=1, maximum=255, value=100, step=1) canny_high_threshold = gr.Slider( label='Canny high threshold', minimum=1, maximum=255, value=200, step=1) """ resolution = gr.Slider(label='Resolution', minimum=128, maximum=128, value=128, step=1) num_steps = gr.Slider(label='Steps', minimum=1, maximum=100, value=20, step=1) guidance_scale = gr.Slider(label='Guidance Scale', minimum=0.1, maximum=30.0, value=7.5, step=0.1) model = gr.Dropdown(choices=["coyo-500k", "bridge-2M", "coyo2M-bridge3M"], value="coyo-500k", label="Model used for inference", info="Find every models at https://huggingface.co/Baptlem/baptlem-controlnet") seed = gr.Slider(label='Seed', minimum=-1, maximum=2147483647, step=1, randomize=True) n_prompt = gr.Textbox( label='Negative Prompt', value= 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' ) with gr.Column(): result = gr.Gallery(label='Output', show_label=False, elem_id='gallery').style(grid=2, height='auto') with gr.Row(): gr.Markdown(description) gr.Video(value=".trajectory_hf/trajectory.avi", format="avi", interactive=False) inputs = [ input_image, prompt, is_canny, num_samples, resolution, #canny_low_threshold, #canny_high_threshold, num_steps, guidance_scale, model, seed, n_prompt, ] prompt.submit(fn=process, inputs=inputs, outputs=result) run_button.click(fn=process, inputs=inputs, outputs=result, api_name='canny') return demo if __name__ == '__main__': pipe_inference demo = create_demo(pipe_inference) demo.queue().launch() # gr.Interface(create_demo).launch()