UCDR-Net / app.py
Baptlem's picture
Update app.py
1d71dff
raw
history blame
13.6 kB
# This file is adapted from https://huggingface.co/spaces/diffusers/controlnet-canny/blob/main/app.py
# The original license file is LICENSE.ControlNet in this repo.
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel, FlaxDPMSolverMultistepScheduler
from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed
from flax.training.common_utils import shard
from flax.jax_utils import replicate
from diffusers.utils import load_image
import jax.numpy as jnp
import jax
import cv2
from PIL import Image
import numpy as np
import gradio as gr
import os
if gr.__version__ != "3.28.3": #doesn't work...
os.system("pip uninstall -y gradio")
os.system("pip install gradio==3.28.3")
title_description = """
# SynDRoM
## Synthetic Data augmentation for Robotic Manipulation
"""
description = """
Our project is to use diffusion model to change the texture of our robotic arm simulation.
To do so, we first get our simulated images. After, we process these images to get Canny Edge maps. Finally, we can get brand new images by using ControlNet.
Therefore, we are able to change our simulation texture, and still keep the image composition.
Our objectif for the sprint is to perform data augmentation using ControlNet. So we look for having a model that can augment an image quickly.
To do so, we trained many Controlnets from scratch with different datasets :
* [Coyo-700M](https://github.com/kakaobrain/coyo-dataset)
* [Bridge](https://sites.google.com/view/bridgedata)
A method to accelerate the inference of diffusion model is by simply generating small images. So we decided to work with low resolution images.
After downloading the datasets, we processed them by resizing images to a 128 resolution.
The smallest side of the image (width or height) is resized to 128 and the other side is resized keeping the initial ratio.
After, we retrieve the Canny Edge Map of the images. We performed this preprocess for every datasets we use during the sprint.
We train four different Controlnets. For each one of them, we processed the datasets differently. You can find the description of the processing in the readme file attached to the model repo
[Our ControlNet repo](https://huggingface.co/Baptlem/baptlem-controlnet)
For now, we benchmarked our model on a node of 4 Titan RTX 24Go. We were able to generate a batch of 4 images in a average time of 1.3 seconds!
We also have access to nodes composed of 8 A100 80Go GPUs. The benchmark on one of these nodes will come soon.
"""
traj_description = """
We generated a trajectory of our simulated environment. We will then use it with our different models.
We made these videos on our Titan RTX node.
The prompt we use for every video is "A robotic arm with a gripper and a small cube on a table, super realistic, industrial background"
"""
perfo_description = """
The Table on the right shows the performances of our models running on different nodes.
To make the benchmark, we loaded one of our model on every GPUs of the node. We then retrieve an episode of our simulation.
For every frame of the episode, we preprocess the image (resize, canny, ...) and process the Canny image on the GPUs.
We repeated this procedure for different Batch Size (BS).
We can see that the greater the BS the greater the FPS. By increazing the BS, we take advantage of the parallelization of the GPUs.
"""
def create_key(seed=0):
return jax.random.PRNGKey(seed)
def load_controlnet(controlnet_version):
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"Baptlem/baptlem-controlnet",
subfolder=controlnet_version,
from_flax=True,
dtype=jnp.float32,
)
return controlnet, controlnet_params
def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"):
controlnet, controlnet_params = load_controlnet(controlnet_version)
scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(
sb_path,
subfolder="scheduler"
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
sb_path,
controlnet=controlnet,
revision="flax",
dtype=jnp.bfloat16
)
pipe.scheduler = scheduler
params["controlnet"] = controlnet_params
params["scheduler"] = scheduler_params
return pipe, params
controlnet_path = "Baptlem/baptlem-controlnet"
controlnet_version = "coyo-500k"
# Constants
low_threshold = 100
high_threshold = 200
print(os.path.abspath('.'))
print(os.listdir("."))
print("Gradio version:", gr.__version__)
# pipe.enable_xformers_memory_efficient_attention()
# pipe.enable_model_cpu_offload()
# pipe.enable_attention_slicing()
print("Loaded models...")
def pipe_inference(
image,
prompt,
is_canny=False,
num_samples=4,
resolution=128,
num_inference_steps=50,
guidance_scale=7.5,
model="coyo-500k",
seed=0,
negative_prompt="",
):
print("Loading pipe")
pipe, params = load_sb_pipe(model)
if not isinstance(image, np.ndarray):
image = np.array(image)
processed_image = resize_image(image, resolution) #-> PIL
if not is_canny:
resized_image, processed_image = preprocess_canny(processed_image, resolution)
rng = create_key(seed)
rng = jax.random.split(rng, jax.device_count())
prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples)
processed_image = pipe.prepare_image_inputs([processed_image] * num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)
print("Inference...")
output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
print("Finished inference...")
# all_outputs = []
# all_outputs.append(image)
# if not is_canny:
# all_outputs.append(resized_image)
# for image in output.images:
# all_outputs.append(image)
all_outputs = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
return all_outputs
def resize_image(image, resolution):
if not isinstance(image, np.ndarray):
image = np.array(image)
h, w = image.shape[:2]
ratio = w/h
if ratio > 1 :
resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST)
elif ratio < 1 :
resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST)
else:
resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST)
return Image.fromarray(resized_image)
def preprocess_canny(image, resolution=128):
if not isinstance(image, np.ndarray):
image = np.array(image)
processed_image = cv2.Canny(image, low_threshold, high_threshold)
processed_image = processed_image[:, :, None]
processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2)
resized_image = Image.fromarray(image)
processed_image = Image.fromarray(processed_image)
return resized_image, processed_image
def create_demo(process, max_images=12, default_num_images=4):
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown(title_description)
with gr.Row():
with gr.Column():
input_image = gr.Image(source='upload', type='numpy')
prompt = gr.Textbox(label='Prompt')
run_button = gr.Button(label='Run')
with gr.Accordion('Advanced options', open=False):
is_canny = gr.Checkbox(
label='Is canny', value=False)
num_samples = gr.Slider(label='Images',
minimum=1,
maximum=max_images,
value=default_num_images,
step=1)
"""
canny_low_threshold = gr.Slider(
label='Canny low threshold',
minimum=1,
maximum=255,
value=100,
step=1)
canny_high_threshold = gr.Slider(
label='Canny high threshold',
minimum=1,
maximum=255,
value=200,
step=1)
"""
resolution = gr.Slider(label='Resolution',
minimum=128,
maximum=128,
value=128,
step=1)
num_steps = gr.Slider(label='Steps',
minimum=1,
maximum=100,
value=20,
step=1)
guidance_scale = gr.Slider(label='Guidance Scale',
minimum=0.1,
maximum=30.0,
value=7.5,
step=0.1)
model = gr.Dropdown(choices=["coyo-500k", "bridge-2M", "coyo1M-bridge2M", "coyo2M-bridge325k"],
value="coyo-500k",
label="Model used for inference",
info="Find every models at https://huggingface.co/Baptlem/baptlem-controlnet")
seed = gr.Slider(label='Seed',
minimum=-1,
maximum=2147483647,
step=1,
randomize=True)
n_prompt = gr.Textbox(
label='Negative Prompt',
value=
'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
)
with gr.Column():
result = gr.Gallery(label='Output',
show_label=False,
elem_id='gallery').style(grid=2,
height='auto')
with gr.Row():
gr.Markdown(description)
with gr.Row():
with gr.Column():
gr.Markdown(traj_description)
with gr.Column():
gr.Video("./trajectory_hf/trajectory.avi",
format="avi",
interactive=False)
with gr.Row():
with gr.Column():
gr.Markdown("Trajectory processed with coyo-500k model :")
with gr.Column():
gr.Video("./trajectory_hf/trajectory_coyo-500k.avi",
format="avi",
interactive=False)
with gr.Row():
with gr.Column():
gr.Markdown("Trajectory processed with bridge-2M model :")
with gr.Column():
gr.Video("./trajectory_hf/trajectory_bridge-2M.avi",
format="avi",
interactive=False)
with gr.Row():
with gr.Column():
gr.Markdown("Trajectory processed with coyo1M-bridge2M model :")
with gr.Column():
gr.Video("./trajectory_hf/trajectory_coyo1M-bridge2M.avi",
format="avi",
interactive=False)
with gr.Row():
with gr.Column():
gr.Markdown("Trajectory processed with coyo2M-bridge325k model :")
with gr.Column():
gr.Video("./trajectory_hf/trajectory_coyo2M-bridge325k.avi",
format="avi",
interactive=False)
with gr.Row():
with gr.Column():
gr.Markdown(perfo_description)
with gr.Column():
gr.Image("./perfo_rtx.png",
interactive=False)
inputs = [
input_image,
prompt,
is_canny,
num_samples,
resolution,
#canny_low_threshold,
#canny_high_threshold,
num_steps,
guidance_scale,
model,
seed,
n_prompt,
]
prompt.submit(fn=process, inputs=inputs, outputs=result)
run_button.click(fn=process,
inputs=inputs,
outputs=result,
api_name='canny')
return demo
if __name__ == '__main__':
pipe_inference
demo = create_demo(pipe_inference)
demo.queue().launch()
# gr.Interface(create_demo).launch()