File size: 4,443 Bytes
a46c388
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12c6f34
5832738
 
 
 
 
 
 
 
a46c388
5832738
 
 
13f1240
 
 
 
8d1c7e2
 
5832738
 
 
 
 
8d1c7e2
a46c388
 
 
ed800ac
 
a46c388
 
 
 
 
 
 
 
 
12c6f34
a46c388
 
de0c08b
a46c388
 
 
02b3506
a46c388
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import jax
import numpy as np
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import cv2

# load control net and stable diffusion v1-5
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
    "Nahrawy/controlnet-VIDIT-FAID", dtype=jnp.bfloat16, revision="615ba4a457b95a0eba813bcc8caf842c03a4f7bd"
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
)

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

def process_mask(image):
    mask = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    mask = cv2.resize(mask,(512,512))
    return mask



def infer(prompts, negative_prompts, image):
    params["controlnet"] = controlnet_params
    
    num_samples = 1 #jax.device_count()
    rng = create_key(0)
    rng = jax.random.split(rng, jax.device_count())
    im = process_mask(image)
    mask = Image.fromarray(im)
    
    prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
    negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
    processed_image = pipe.prepare_image_inputs([mask] * num_samples)
    
    p_params = replicate(params)
    prompt_ids = shard(prompt_ids)
    negative_prompt_ids = shard(negative_prompt_ids)
    processed_image = shard(processed_image)
    print(processed_image[0].shape)
    output = pipe(
        prompt_ids=prompt_ids,
        image=processed_image,
        params=p_params,
        prng_seed=rng,
        num_inference_steps=50,
        neg_prompt_ids=negative_prompt_ids,
        jit=True,
    ).images
    
    output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
    return output_images

e_images = ['0.png',
            '0.png',
            '0.png',
            '0.png',
            '0.png',
            '2.png',
            '2.png',
            '2.png',
            '2.png',]
e_prompts = ['a dog in the middle of the road, shadow on the ground,light direction north-east',
             'a dog in the middle of the road, shadow on the ground,light direction north-west',
             'a dog in the middle of the road, shadow on the ground,light direction south-west',
             'a dog in the middle of the road, shadow on the ground,light direction south-east',
            'a red rural house, shadow on the ground, light direction north',
            'a red rural house, shadow on the ground, light direction east',
            'a red rural house, shadow on the ground, light direction south',
            'a red rural house, shadow on the ground, light direction west']
e_negative_prompts = ['monochromatic, unrealistic, bad looking, full of glitches',
                     'monochromatic, unrealistic, bad looking, full of glitches',
                     'monochromatic, unrealistic, bad looking, full of glitches',
                     'monochromatic, unrealistic, bad looking, full of glitches',
                     'monochromatic, unrealistic, bad looking, full of glitches',
                     'monochromatic, unrealistic, bad looking, full of glitches',
                     'monochromatic, unrealistic, bad looking, full of glitches',
                     'monochromatic, unrealistic, bad looking, full of glitches']
examples = []
for image, prompt, negative_prompt in zip(e_images, e_prompts, e_negative_prompts):
  examples.append([prompt, negative_prompt, image])
    
title = " # ControlLight: Light control through ControlNet and Depth Maps conditioning"

with gr.Blocks() as demo:
    gr.Markdown(title)
    prompts = gr.Textbox(label='prompts')
    negative_prompts = gr.Textbox(label='negative_prompts')
    with gr.Row():
      with gr.Column():
        in_image = gr.Image(label="Depth Map Conditioning")
      with gr.Column():
        out_image = gr.Gallery(label="Generated Image")
    with gr.Row():
        btn = gr.Button("Run")
    with gr.Row():
        gr.Examples(examples=examples,
                   inputs=[prompts,negative_prompts, in_image],
                   outputs=out_image,
                   fn=infer,
                   cache_examples=True)
    btn.click(fn=infer, inputs=[prompts,negative_prompts, in_image] , outputs=out_image)
    
demo.launch()