ControlLight / app.py
Nahrawy's picture
Update app.py
5832738
raw
history blame
4.45 kB
import gradio as gr
import jax
import numpy as np
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import cv2
# load control net and stable diffusion v1-5
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"Nahrawy/controlnet-VIDIT-FAID", dtype=jnp.bfloat16, revision="615ba4a457b95a0eba813bcc8caf842c03a4f7bd"
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
)
def create_key(seed=0):
return jax.random.PRNGKey(seed)
def process_mask(image):
mask = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
mask = cv2.resize(mask,(512,512))
return mask
def infer(prompts, negative_prompts, image):
params["controlnet"] = controlnet_params
num_samples = 1 #jax.device_count()
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
im = process_mask(image)
mask = Image.fromarray(im)
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([mask] * num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)
print(processed_image[0].shape)
output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=50,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
return output_images
e_images = ['0.png',
'0.png',
'0.png',
'0.png',
'0.png',
'2.png',
'2.png',
'2.png',
'2.png',]
e_prompts = ['a dog in the middle of the road, shadow on the ground,light direction north-east',
'a dog in the middle of the road, shadow on the ground,light direction north-west',
'a dog in the middle of the road, shadow on the ground,light direction south-west',
'a dog in the middle of the road, shadow on the ground,light direction south-east',
'a red rural house, light temperature 5500, shadow on the ground, light direction north',
'a red rural house, light temperature 4500, shadow on the ground, light direction east',
'a red rural house, light temperature 3500, shadow on the ground, light direction south',
'a red rural house, light temperature 2500, shadow on the ground, light direction west']
e_negative_prompts = ['monochromatic, unrealistic, bad looking, full of glitches',
'monochromatic, unrealistic, bad looking, full of glitches',
'monochromatic, unrealistic, bad looking, full of glitches',
'monochromatic, unrealistic, bad looking, full of glitches',
'monochromatic, unrealistic, bad looking, full of glitches',
'monochromatic, unrealistic, bad looking, full of glitches',
'monochromatic, unrealistic, bad looking, full of glitches',
'monochromatic, unrealistic, bad looking, full of glitches']
examples = []
for image, prompt, negative_prompt in zip(e_images, e_prompts, e_negative_prompts):
examples.append([prompt, negative_prompt, image])
with gr.Blocks() as demo:
gr.Markdown(title)
prompts = gr.Textbox(label='prompts')
negative_prompts = gr.Textbox(label='negative_prompts')
with gr.Row():
with gr.Column():
in_image = gr.Image(label="Depth Map Conditioning")
with gr.Column():
out_image = gr.Gallery(label="Generated Image")
with gr.Row():
btn = gr.Button("Run")
with gr.Row()
gr.Examples(examples=examples,
inputs=[prompts,negative_prompts, in_image],
outputs=out_image,
fn=infer,
cache_examples=True)
btn.click(fn=infer, inputs=[prompts,negative_prompts, in_image] , outputs=out_image)
demo.launch()