Spaces:
Build error
Build error
import gradio as gr | |
import jax | |
import numpy as np | |
import jax.numpy as jnp | |
from flax.jax_utils import replicate | |
from flax.training.common_utils import shard | |
from PIL import Image | |
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel | |
import cv2 | |
# load control net and stable diffusion v1-5 | |
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( | |
"Nahrawy/controlnet-VIDIT-FAID", dtype=jnp.bfloat16, revision="615ba4a457b95a0eba813bcc8caf842c03a4f7bd" | |
) | |
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16 | |
) | |
def create_key(seed=0): | |
return jax.random.PRNGKey(seed) | |
def process_mask(image): | |
mask = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
mask = cv2.resize(mask,(512,512)) | |
return mask | |
def infer(prompts, negative_prompts, image): | |
params["controlnet"] = controlnet_params | |
num_samples = 1 #jax.device_count() | |
rng = create_key(0) | |
rng = jax.random.split(rng, jax.device_count()) | |
im = process_mask(image) | |
mask = Image.fromarray(im) | |
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) | |
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples) | |
processed_image = pipe.prepare_image_inputs([mask] * num_samples) | |
p_params = replicate(params) | |
prompt_ids = shard(prompt_ids) | |
negative_prompt_ids = shard(negative_prompt_ids) | |
processed_image = shard(processed_image) | |
print(processed_image[0].shape) | |
output = pipe( | |
prompt_ids=prompt_ids, | |
image=processed_image, | |
params=p_params, | |
prng_seed=rng, | |
num_inference_steps=50, | |
neg_prompt_ids=negative_prompt_ids, | |
jit=True, | |
).images | |
output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) | |
return output_images | |
e_images = ['0.png', | |
'1.png', | |
'2.png'] | |
e_prompts = ['a dog in the middle of the road, shadow on the ground,light direction north-east', | |
'a skyscraper in the middle of an intersection, shadow on the ground, light direction east', | |
'a red rural house, light temperature 5500, shadow on the ground, light direction south-west'] | |
e_negative_prompts = ['monochromatic, unrealistic, bad looking, full of glitches', | |
'monochromatic, unrealistic, bad looking, full of glitches', | |
'monochromatic, unrealistic, bad looking, full of glitches'] | |
examples = [] | |
for image, prompt, negative_prompt in zip(e_images, e_prompts, e_negative_prompts): | |
examples.append([prompt, negative_prompt, image]) | |
with gr.Blocks() as demo: | |
gr.Markdown(title) | |
prompts = gr.Textbox(label='prompts') | |
negative_prompts = gr.Textbox(label='negative_prompts') | |
with gr.Row(): | |
with gr.Column(): | |
in_image = gr.Image(label="Depth Map Conditioning") | |
with gr.Column(): | |
out_image = gr.Gallery(label="Generated Image") | |
with gr.Row(): | |
btn = gr.Button("Run") | |
with gr.Row() | |
gr.Examples(examples=examples, | |
inputs=[prompts,negative_prompts, in_image], | |
outputs=out_image, | |
fn=infer, | |
cache_examples=True) | |
btn.click(fn=infer, inputs=[prompts,negative_prompts, in_image] , outputs=out_image) | |
demo.launch() |