|
|
|
|
|
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel, FlaxDPMSolverMultistepScheduler |
|
from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed |
|
from flax.training.common_utils import shard |
|
from flax.jax_utils import replicate |
|
from diffusers.utils import load_image |
|
import jax.numpy as jnp |
|
import jax |
|
import cv2 |
|
from PIL import Image |
|
import numpy as np |
|
import gradio as gr |
|
import os |
|
|
|
|
|
if gr.__version__ != "3.28.3": |
|
os.system("pip uninstall -y gradio") |
|
os.system("pip install gradio==3.28.3") |
|
|
|
title_description = """ |
|
# SynDRoM |
|
## Synthetic Data augmentation for Robotic Manipulation |
|
|
|
""" |
|
|
|
description = """ |
|
Our project is to use diffusion model to change the texture of our robotic arm simulation. |
|
To do so, we first get our simulated images. After, we process these images to get Canny Edge maps. Finally, we can get brand new images by using ControlNet. |
|
Therefore, we are able to change our simulation texture, and still keep the image composition. |
|
|
|
|
|
Our objectif for the sprint is to perform data augmentation using ControlNet. So we look for having a model that can augment an image quickly. |
|
To do so, we trained many Controlnets from scratch with different datasets : |
|
* [Coyo-700M](https://github.com/kakaobrain/coyo-dataset) |
|
* [Bridge](https://sites.google.com/view/bridgedata) |
|
|
|
A method to accelerate the inference of diffusion model is by simply generating small images. So we decided to work with low resolution images. |
|
After downloading the datasets, we processed them by resizing images to a 128 resolution. |
|
The smallest side of the image (width or height) is resized to 128 and the other side is resized keeping the initial ratio. |
|
After, we retrieve the Canny Edge Map of the images. We performed this preprocess for every datasets we use during the sprint. |
|
|
|
|
|
We train four different Controlnets. For each one of them, we processed the datasets differently. You can find the description of the processing in the readme file attached to the model repo |
|
[Our ControlNet repo](https://huggingface.co/Baptlem/baptlem-controlnet) |
|
|
|
For now, we benchmarked our model on a node of 4 Titan RTX 24Go. We were able to generate a batch of 4 images in a average time of 1.3 seconds! |
|
We also have access to nodes composed of 8 A100 80Go GPUs. The benchmark on one of these nodes will come soon. |
|
|
|
|
|
""" |
|
|
|
traj_description = """ |
|
We generated a trajectory of our simulated environment. We will then use it with our different models. |
|
We made these videos on our Titan RTX node. |
|
The prompt we use for every video is "A robotic arm with a gripper and a small cube on a table, super realistic, industrial background" |
|
""" |
|
|
|
|
|
perfo_description = """ |
|
The Table on the right shows the performances of our models running on different nodes. |
|
To make the benchmark, we loaded one of our model on every GPUs of the node. We then retrieve an episode of our simulation. |
|
For every frame of the episode, we preprocess the image (resize, canny, ...) and process the Canny image on the GPUs. |
|
We repeated this procedure for different Batch Size (BS). |
|
|
|
We can see that the greater the BS the greater the FPS. By increazing the BS, we take advantage of the parallelization of the GPUs. |
|
|
|
""" |
|
|
|
|
|
def create_key(seed=0): |
|
return jax.random.PRNGKey(seed) |
|
|
|
def load_controlnet(controlnet_version): |
|
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( |
|
"Baptlem/baptlem-controlnet", |
|
subfolder=controlnet_version, |
|
from_flax=True, |
|
dtype=jnp.float32, |
|
) |
|
return controlnet, controlnet_params |
|
|
|
|
|
def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"): |
|
controlnet, controlnet_params = load_controlnet(controlnet_version) |
|
|
|
scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained( |
|
sb_path, |
|
subfolder="scheduler" |
|
) |
|
|
|
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( |
|
sb_path, |
|
controlnet=controlnet, |
|
revision="flax", |
|
dtype=jnp.bfloat16 |
|
) |
|
|
|
pipe.scheduler = scheduler |
|
params["controlnet"] = controlnet_params |
|
params["scheduler"] = scheduler_params |
|
return pipe, params |
|
|
|
|
|
|
|
controlnet_path = "Baptlem/baptlem-controlnet" |
|
controlnet_version = "coyo-500k" |
|
|
|
|
|
low_threshold = 100 |
|
high_threshold = 200 |
|
|
|
print(os.path.abspath('.')) |
|
print(os.listdir(".")) |
|
print("Gradio version:", gr.__version__) |
|
|
|
|
|
|
|
print("Loaded models...") |
|
def pipe_inference( |
|
image, |
|
prompt, |
|
is_canny=False, |
|
num_samples=4, |
|
resolution=128, |
|
num_inference_steps=50, |
|
guidance_scale=7.5, |
|
model="coyo-500k", |
|
seed=0, |
|
negative_prompt="", |
|
): |
|
print("Loading pipe") |
|
pipe, params = load_sb_pipe(model) |
|
|
|
if not isinstance(image, np.ndarray): |
|
image = np.array(image) |
|
|
|
processed_image = resize_image(image, resolution) |
|
|
|
if not is_canny: |
|
resized_image, processed_image = preprocess_canny(processed_image, resolution) |
|
|
|
rng = create_key(seed) |
|
rng = jax.random.split(rng, jax.device_count()) |
|
|
|
prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples) |
|
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples) |
|
processed_image = pipe.prepare_image_inputs([processed_image] * num_samples) |
|
|
|
p_params = replicate(params) |
|
prompt_ids = shard(prompt_ids) |
|
negative_prompt_ids = shard(negative_prompt_ids) |
|
processed_image = shard(processed_image) |
|
print("Inference...") |
|
output = pipe( |
|
prompt_ids=prompt_ids, |
|
image=processed_image, |
|
params=p_params, |
|
prng_seed=rng, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
neg_prompt_ids=negative_prompt_ids, |
|
jit=True, |
|
).images |
|
print("Finished inference...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_outputs = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) |
|
return all_outputs |
|
|
|
def resize_image(image, resolution): |
|
if not isinstance(image, np.ndarray): |
|
image = np.array(image) |
|
h, w = image.shape[:2] |
|
ratio = w/h |
|
if ratio > 1 : |
|
resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST) |
|
elif ratio < 1 : |
|
resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST) |
|
else: |
|
resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST) |
|
|
|
return Image.fromarray(resized_image) |
|
|
|
|
|
def preprocess_canny(image, resolution=128): |
|
if not isinstance(image, np.ndarray): |
|
image = np.array(image) |
|
|
|
processed_image = cv2.Canny(image, low_threshold, high_threshold) |
|
processed_image = processed_image[:, :, None] |
|
processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2) |
|
|
|
resized_image = Image.fromarray(image) |
|
processed_image = Image.fromarray(processed_image) |
|
return resized_image, processed_image |
|
|
|
|
|
def create_demo(process, max_images=12, default_num_images=4): |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
gr.Markdown(title_description) |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(source='upload', type='numpy') |
|
prompt = gr.Textbox(label='Prompt') |
|
run_button = gr.Button(label='Run') |
|
with gr.Accordion('Advanced options', open=False): |
|
is_canny = gr.Checkbox( |
|
label='Is canny', value=False) |
|
num_samples = gr.Slider(label='Images', |
|
minimum=1, |
|
maximum=max_images, |
|
value=default_num_images, |
|
step=1) |
|
""" |
|
canny_low_threshold = gr.Slider( |
|
label='Canny low threshold', |
|
minimum=1, |
|
maximum=255, |
|
value=100, |
|
step=1) |
|
canny_high_threshold = gr.Slider( |
|
label='Canny high threshold', |
|
minimum=1, |
|
maximum=255, |
|
value=200, |
|
step=1) |
|
""" |
|
resolution = gr.Slider(label='Resolution', |
|
minimum=128, |
|
maximum=128, |
|
value=128, |
|
step=1) |
|
num_steps = gr.Slider(label='Steps', |
|
minimum=1, |
|
maximum=100, |
|
value=20, |
|
step=1) |
|
guidance_scale = gr.Slider(label='Guidance Scale', |
|
minimum=0.1, |
|
maximum=30.0, |
|
value=7.5, |
|
step=0.1) |
|
model = gr.Dropdown(choices=["coyo-500k", "bridge-2M", "coyo1M-bridge2M", "coyo2M-bridge325k"], |
|
value="coyo-500k", |
|
label="Model used for inference", |
|
info="Find every models at https://huggingface.co/Baptlem/baptlem-controlnet") |
|
seed = gr.Slider(label='Seed', |
|
minimum=-1, |
|
maximum=2147483647, |
|
step=1, |
|
randomize=True) |
|
n_prompt = gr.Textbox( |
|
label='Negative Prompt', |
|
value= |
|
'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' |
|
) |
|
with gr.Column(): |
|
result = gr.Gallery(label='Output', |
|
show_label=False, |
|
elem_id='gallery').style(grid=2, |
|
height='auto') |
|
|
|
with gr.Row(): |
|
gr.Markdown(description) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown(traj_description) |
|
with gr.Column(): |
|
gr.Video("./trajectory_hf/trajectory.avi", |
|
format="avi", |
|
interactive=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("Trajectory processed with coyo-500k model :") |
|
with gr.Column(): |
|
gr.Video("./trajectory_hf/trajectory_coyo-500k.avi", |
|
format="avi", |
|
interactive=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("Trajectory processed with bridge-2M model :") |
|
with gr.Column(): |
|
gr.Video("./trajectory_hf/trajectory_bridge-2M.avi", |
|
format="avi", |
|
interactive=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("Trajectory processed with coyo1M-bridge2M model :") |
|
with gr.Column(): |
|
gr.Video("./trajectory_hf/trajectory_coyo1M-bridge2M.avi", |
|
format="avi", |
|
interactive=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("Trajectory processed with coyo2M-bridge325k model :") |
|
with gr.Column(): |
|
gr.Video("./trajectory_hf/trajectory_coyo2M-bridge325k.avi", |
|
format="avi", |
|
interactive=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown(perfo_description) |
|
with gr.Column(): |
|
gr.Image("./perfo_rtx.png", |
|
interactive=False) |
|
|
|
|
|
|
|
inputs = [ |
|
input_image, |
|
prompt, |
|
is_canny, |
|
num_samples, |
|
resolution, |
|
|
|
|
|
num_steps, |
|
guidance_scale, |
|
model, |
|
seed, |
|
n_prompt, |
|
] |
|
prompt.submit(fn=process, inputs=inputs, outputs=result) |
|
run_button.click(fn=process, |
|
inputs=inputs, |
|
outputs=result, |
|
api_name='canny') |
|
|
|
return demo |
|
|
|
if __name__ == '__main__': |
|
|
|
pipe_inference |
|
demo = create_demo(pipe_inference) |
|
demo.queue().launch() |
|
|
|
|