ControlLight / app.py
Nahrawy's picture
Update app.py
8d1c7e2
raw
history blame
3.42 kB
import gradio as gr
import jax
import numpy as np
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import cv2
# load control net and stable diffusion v1-5
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"Nahrawy/controlnet-VIDIT-FAID", dtype=jnp.bfloat16, revision="615ba4a457b95a0eba813bcc8caf842c03a4f7bd"
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
)
def create_key(seed=0):
return jax.random.PRNGKey(seed)
def process_mask(image):
mask = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
mask = cv2.resize(mask,(512,512))
return mask
def infer(prompts, negative_prompts, image):
params["controlnet"] = controlnet_params
num_samples = 1 #jax.device_count()
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
im = process_mask(image)
mask = Image.fromarray(im)
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([mask] * num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)
print(processed_image[0].shape)
output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=50,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
return output_images
e_images = ['0.png',
'1.png',
'2.png']
e_prompts = ['a dog in the middle of the road, shadow on the ground,light direction north-east',
'a skyscraper in the middle of an intersection, shadow on the ground, light direction east',
'a red rural house, light temperature 5500, shadow on the ground, light direction south-west']
e_negative_prompts = ['monochromatic, unrealistic, bad looking, full of glitches',
'monochromatic, unrealistic, bad looking, full of glitches',
'monochromatic, unrealistic, bad looking, full of glitches']
examples = []
for image, prompt, negative_prompt in zip(e_images, e_prompts, e_negative_prompts):
examples.append([prompt, negative_prompt, image])
with gr.Blocks() as demo:
gr.Markdown(title)
prompts = gr.Textbox(label='prompts')
negative_prompts = gr.Textbox(label='negative_prompts')
with gr.Row():
with gr.Column():
in_image = gr.Image(label="Depth Map Conditioning")
with gr.Column():
out_image = gr.Gallery(label="Generated Image")
with gr.Row():
btn = gr.Button("Run")
with gr.Row()
gr.Examples(examples=examples,
inputs=[prompts,negative_prompts, in_image],
outputs=out_image,
fn=infer,
cache_examples=True)
btn.click(fn=infer, inputs=[prompts,negative_prompts, in_image] , outputs=out_image)
demo.launch()