|
|
|
|
|
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel, FlaxDPMSolverMultistepScheduler |
|
from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed |
|
from flax.training.common_utils import shard |
|
from flax.jax_utils import replicate |
|
from diffusers.utils import load_image |
|
import jax.numpy as jnp |
|
import jax |
|
import cv2 |
|
from PIL import Image |
|
import numpy as np |
|
import gradio as gr |
|
import os |
|
|
|
|
|
if gr.__version__ != "3.28.3": |
|
os.system("pip uninstall -y gradio") |
|
os.system("pip install gradio==3.28.3") |
|
|
|
title_description = """ |
|
# Unlimited Controlled Domain Randomization Network for Bridging the Sim2Real Gap in Robotics |
|
|
|
""" |
|
|
|
description = """ |
|
While existing ControlNet and public diffusion models are predominantly geared towards high-resolution images (512x512 or above) and intricate artistic detail generation, there's an untapped potential of these models in Automatic Data Augmentation (ADA). |
|
By harnessing the inherent variance in prompt-conditioned generated images, we can significantly boost the visual diversity of training samples for computer vision pipelines. |
|
This is particularly relevant in the field of robotics, where deep learning is increasingly playing a pivotal role in training policies for robotic manipulation from images. |
|
|
|
In this HuggingFace sprint, we present UCDR-Net (Unlimited Controlled Domain Randomization Network), a novel CannyEdge mini-ControlNet trained on Stable Diffusion 1.5 with mixed datasets. |
|
Our model generates photorealistic and varied renderings from simplistic robotic simulation images, enabling real-time data augmentation for robotic vision training. |
|
|
|
We specifically designed UCDR-Net to be fast and composition preserving, with an emphasis on lower resolution images (128x128) for online data augmentation in typical preprocessing pipelines. |
|
Our choice of Canny Edge version of ControlNet ensures shape and structure preservation in the image, which is crucial for visuomotor policy learning. |
|
|
|
We trained ControlNet from scratch using only 128x128 images, preprocessing the training datasets and extracting Canny Edge maps. |
|
We then trained four Control-Nets with different mixtures of 2 datasets (Coyo-700M and Bridge Data) and showcased the results. |
|
* [Coyo-700M](https://github.com/kakaobrain/coyo-dataset) |
|
* [Bridge](https://sites.google.com/view/bridgedata) |
|
|
|
Model Description and Training Process: Please refer to the readme file attached to the model repository. |
|
|
|
Model Repository: [ControlNet repo](https://huggingface.co/Baptlem/UCDR-Net_models) |
|
|
|
""" |
|
|
|
traj_description = """ |
|
To demonstrate UCDR-Net's capabilities, we generated a trajectory of our simulated robotic environment and presented the resulting videos for each model. |
|
We batched the frames for each video and performed independent inference for each frame, which explains the "wobbling" effect. |
|
Prompt used for every video: "A robotic arm with a gripper and a small cube on a table, super realistic, industrial background" |
|
|
|
""" |
|
|
|
perfo_description = """ |
|
Our model has been benchmarked on a node of 8 A100 80Go GPUs, achieving an impressive 170 FPS image generation rate! |
|
|
|
To make the benchmark, we loaded one of our model on every GPUs of the node. We then retrieve an episode of our simulation. |
|
For every frame of the episode, we preprocess the image (resize, canny, …) and process the Canny image on the GPUs. |
|
We repeated this procedure for different Batch Size (BS). |
|
|
|
We can see that the greater the BS the greater the FPS. By increazing the BS, we take advantage of the parallelization of the GPUs. |
|
""" |
|
|
|
conclusion_description = """ |
|
UCDR-Net stands as a natural development in bridging the Sim2Real gap in robotics by providing real-time data augmentation for training visual policies. |
|
We are excited to share our work with the HuggingFace community and contribute to the advancement of robotic vision training techniques. |
|
|
|
""" |
|
|
|
def create_key(seed=0): |
|
return jax.random.PRNGKey(seed) |
|
|
|
def load_controlnet(controlnet_version): |
|
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( |
|
"Baptlem/UCDR-Net_models", |
|
subfolder=controlnet_version, |
|
from_flax=True, |
|
dtype=jnp.float32, |
|
) |
|
return controlnet, controlnet_params |
|
|
|
|
|
def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"): |
|
controlnet, controlnet_params = load_controlnet(controlnet_version) |
|
|
|
scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained( |
|
sb_path, |
|
subfolder="scheduler" |
|
) |
|
|
|
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( |
|
sb_path, |
|
controlnet=controlnet, |
|
revision="flax", |
|
dtype=jnp.bfloat16 |
|
) |
|
|
|
pipe.scheduler = scheduler |
|
params["controlnet"] = controlnet_params |
|
params["scheduler"] = scheduler_params |
|
return pipe, params |
|
|
|
|
|
|
|
controlnet_path = "Baptlem/UCDR-Net_models" |
|
controlnet_version = "coyo-500k" |
|
|
|
|
|
low_threshold = 100 |
|
high_threshold = 200 |
|
|
|
print(os.path.abspath('.')) |
|
print(os.listdir(".")) |
|
print("Gradio version:", gr.__version__) |
|
|
|
|
|
|
|
print("Loaded models...") |
|
def pipe_inference( |
|
image, |
|
prompt, |
|
is_canny=False, |
|
num_samples=4, |
|
resolution=128, |
|
num_inference_steps=50, |
|
guidance_scale=7.5, |
|
model="coyo-500k", |
|
seed=0, |
|
negative_prompt="", |
|
): |
|
print("Loading pipe") |
|
pipe, params = load_sb_pipe(model) |
|
|
|
if not isinstance(image, np.ndarray): |
|
image = np.array(image) |
|
|
|
processed_image = resize_image(image, resolution) |
|
|
|
if not is_canny: |
|
resized_image, processed_image = preprocess_canny(processed_image, resolution) |
|
|
|
rng = create_key(seed) |
|
rng = jax.random.split(rng, jax.device_count()) |
|
|
|
prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples) |
|
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples) |
|
processed_image = pipe.prepare_image_inputs([processed_image] * num_samples) |
|
|
|
p_params = replicate(params) |
|
prompt_ids = shard(prompt_ids) |
|
negative_prompt_ids = shard(negative_prompt_ids) |
|
processed_image = shard(processed_image) |
|
print("Inference...") |
|
output = pipe( |
|
prompt_ids=prompt_ids, |
|
image=processed_image, |
|
params=p_params, |
|
prng_seed=rng, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
neg_prompt_ids=negative_prompt_ids, |
|
jit=True, |
|
).images |
|
print("Finished inference...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_outputs = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) |
|
return all_outputs |
|
|
|
def resize_image(image, resolution): |
|
if not isinstance(image, np.ndarray): |
|
image = np.array(image) |
|
h, w = image.shape[:2] |
|
ratio = w/h |
|
if ratio > 1 : |
|
resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST) |
|
elif ratio < 1 : |
|
resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST) |
|
else: |
|
resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST) |
|
|
|
return Image.fromarray(resized_image) |
|
|
|
|
|
def preprocess_canny(image, resolution=128): |
|
if not isinstance(image, np.ndarray): |
|
image = np.array(image) |
|
|
|
processed_image = cv2.Canny(image, low_threshold, high_threshold) |
|
processed_image = processed_image[:, :, None] |
|
processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2) |
|
|
|
resized_image = Image.fromarray(image) |
|
processed_image = Image.fromarray(processed_image) |
|
return resized_image, processed_image |
|
|
|
|
|
def create_demo(process, max_images=12, default_num_images=4): |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
gr.Markdown(title_description) |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(source='upload', type='numpy') |
|
prompt = gr.Textbox(label='Prompt') |
|
run_button = gr.Button(label='Run') |
|
with gr.Accordion('Advanced options', open=False): |
|
is_canny = gr.Checkbox( |
|
label='Is canny', value=False) |
|
num_samples = gr.Slider(label='Images', |
|
minimum=1, |
|
maximum=max_images, |
|
value=default_num_images, |
|
step=1) |
|
""" |
|
canny_low_threshold = gr.Slider( |
|
label='Canny low threshold', |
|
minimum=1, |
|
maximum=255, |
|
value=100, |
|
step=1) |
|
canny_high_threshold = gr.Slider( |
|
label='Canny high threshold', |
|
minimum=1, |
|
maximum=255, |
|
value=200, |
|
step=1) |
|
""" |
|
resolution = gr.Slider(label='Resolution', |
|
minimum=128, |
|
maximum=128, |
|
value=128, |
|
step=1) |
|
num_steps = gr.Slider(label='Steps', |
|
minimum=1, |
|
maximum=100, |
|
value=20, |
|
step=1) |
|
guidance_scale = gr.Slider(label='Guidance Scale', |
|
minimum=0.1, |
|
maximum=30.0, |
|
value=7.5, |
|
step=0.1) |
|
model = gr.Dropdown(choices=["coyo-500k", "bridge-2M", "coyo1M-bridge2M", "coyo2M-bridge325k"], |
|
value="coyo-500k", |
|
label="Model used for inference", |
|
info="Find every models at https://huggingface.co/Baptlem/UCDR-Net_models") |
|
seed = gr.Slider(label='Seed', |
|
minimum=-1, |
|
maximum=2147483647, |
|
step=1, |
|
randomize=True) |
|
n_prompt = gr.Textbox( |
|
label='Negative Prompt', |
|
value= |
|
'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' |
|
) |
|
with gr.Column(): |
|
result = gr.Gallery(label='Output', |
|
show_label=False, |
|
elem_id='gallery').style(grid=2, |
|
height='auto') |
|
|
|
with gr.Row(): |
|
gr.Video("./trajectory_hf/trajectory_coyo2M-bridge325k_64.avi", |
|
format="avi", |
|
interactive=False).style(height=512, |
|
width=512) |
|
|
|
with gr.Row(): |
|
gr.Markdown(description) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown(traj_description) |
|
with gr.Column(): |
|
gr.Video("./trajectory_hf/trajectory.avi", |
|
format="avi", |
|
interactive=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("Trajectory processed with coyo-500k model :") |
|
with gr.Column(): |
|
gr.Video("./trajectory_hf/trajectory_coyo-500k.avi", |
|
format="avi", |
|
interactive=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("Trajectory processed with bridge-2M model :") |
|
with gr.Column(): |
|
gr.Video("./trajectory_hf/trajectory_bridge-2M.avi", |
|
format="avi", |
|
interactive=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("Trajectory processed with coyo1M-bridge2M model :") |
|
with gr.Column(): |
|
gr.Video("./trajectory_hf/trajectory_coyo1M-bridge2M.avi", |
|
format="avi", |
|
interactive=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("Trajectory processed with coyo2M-bridge325k model :") |
|
with gr.Column(): |
|
gr.Video("./trajectory_hf/trajectory_coyo2M-bridge325k.avi", |
|
format="avi", |
|
interactive=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown(perfo_description) |
|
with gr.Column(): |
|
gr.Image("./perfo_rtx.png", |
|
interactive=False) |
|
|
|
with gr.Row(): |
|
gr.Markdown(conclusion_description) |
|
|
|
|
|
|
|
inputs = [ |
|
input_image, |
|
prompt, |
|
is_canny, |
|
num_samples, |
|
resolution, |
|
|
|
|
|
num_steps, |
|
guidance_scale, |
|
model, |
|
seed, |
|
n_prompt, |
|
] |
|
prompt.submit(fn=process, inputs=inputs, outputs=result) |
|
run_button.click(fn=process, |
|
inputs=inputs, |
|
outputs=result, |
|
api_name='canny') |
|
|
|
return demo |
|
|
|
if __name__ == '__main__': |
|
|
|
pipe_inference |
|
demo = create_demo(pipe_inference) |
|
demo.queue().launch() |
|
|
|
|