File size: 8,432 Bytes
cd1e8dc
628e6c3
cd1e8dc
 
 
 
 
 
 
 
 
 
628e6c3
 
cd1e8dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a95bff
cd1e8dc
 
 
 
 
a5e9129
 
 
cd1e8dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5e9129
 
 
ccf1a03
cd1e8dc
 
 
 
 
 
 
 
 
 
 
ccf1a03
cd1e8dc
 
 
9702a1f
cd1e8dc
 
9702a1f
cd1e8dc
 
a5e9129
 
cd1e8dc
 
9702a1f
a5e9129
cd1e8dc
 
 
 
7b66f42
cd1e8dc
 
 
 
 
 
 
 
 
788a013
dc72f49
788a013
 
 
 
cd1e8dc
788a013
 
 
 
cd1e8dc
 
68696f0
71f4cfe
 
5461399
cd1e8dc
 
 
 
 
 
 
71f4cfe
 
cd1e8dc
 
 
71f4cfe
 
 
783c45d
cd1e8dc
 
 
783c45d
cd1e8dc
 
 
a5e9129
cd1e8dc
628e6c3
 
 
 
 
 
 
 
 
cd1e8dc
 
628e6c3
 
 
 
 
cd1e8dc
628e6c3
 
 
 
 
 
 
 
 
 
 
 
cd1e8dc
 
 
 
 
 
628e6c3
 
 
 
 
 
 
 
3f2581f
628e6c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd1e8dc
628e6c3
cd1e8dc
 
 
628e6c3
 
 
cd1e8dc
628e6c3
 
 
 
 
 
88bd8ea
 
628e6c3
 
a5e9129
 
 
628e6c3
a5e9129
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# This file is adapted from https://huggingface.co/spaces/diffusers/controlnet-canny/blob/main/app.py
# The original license file is LICENSE.ControlNet in this repo.
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel, FlaxDPMSolverMultistepScheduler
from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed
from flax.training.common_utils import shard
from flax.jax_utils import replicate    
from diffusers.utils import load_image
import jax.numpy as jnp
import jax
import cv2
from PIL import Image
import numpy as np
import gradio as gr

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

def load_controlnet(controlnet_version):
    controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
        "Baptlem/baptlem-controlnet",
        subfolder=controlnet_version,
        from_flax=True,
        dtype=jnp.float32,
    )
    return controlnet, controlnet_params


def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"):
    controlnet, controlnet_params = load_controlnet(controlnet_version)

    scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(
        sb_path,
        subfolder="scheduler"
    )
    
    pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
        sb_path,
        controlnet=controlnet, 
        revision="flax", 
        dtype=jnp.bfloat16
    )
        
    pipe.scheduler = scheduler
    params["controlnet"] = controlnet_params
    params["scheduler"] = scheduler_params
    return pipe, params  

    

controlnet_path = "Baptlem/baptlem-controlnet"
controlnet_version = "coyo-500k"

# Constants
low_threshold = 100
high_threshold = 200

pipe, params = load_sb_pipe(controlnet_version)

# pipe.enable_xformers_memory_efficient_attention()
# pipe.enable_model_cpu_offload()
# pipe.enable_attention_slicing()
print("Loaded models...")
def pipe_inference(
    image,
    prompt,
    is_canny=False,
    num_samples=4,
    resolution=128,
    num_inference_steps=50,
    guidance_scale=7.5,
    seed=0,
    negative_prompt="",
    ):
    print("Entered pipe...")
    if not isinstance(image, np.ndarray):
        image = np.array(image) 

    processed_image = resize_image(image, resolution)
        
    if not is_canny:
        resized_image, processed_image = preprocess_canny(processed_image, resolution)

    rng = create_key(seed)
    rng = jax.random.split(rng, jax.device_count())

    prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples)
    negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples)
    processed_image = pipe.prepare_image_inputs([processed_image] * num_samples)
        
    p_params = replicate(params)
    prompt_ids = shard(prompt_ids)
    negative_prompt_ids = shard(negative_prompt_ids)
    processed_image = shard(processed_image)
    print("Inference...")
    output = pipe(
        prompt_ids=prompt_ids,
        image=processed_image,
        params=p_params,
        prng_seed=rng,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        neg_prompt_ids=negative_prompt_ids,
        jit=True,
    ).images
    print("Finished inference...")
    # all_outputs = []
    # all_outputs.append(image)
    # if not is_canny:
    #     all_outputs.append(resized_image)
        
    # for image in output.images:
    #     all_outputs.append(image)

    all_outputs = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
    return all_outputs

def resize_image(image, resolution):  
    if not isinstance(image, np.ndarray):
        image = np.array(image) 
    h, w = image.shape[:2]
    ratio = w/h
    if ratio > 1 :
        resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST)
    elif ratio < 1 :
        resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST)
    else:
        resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST)
    
    return Image.fromarray(resized_image)
    
    
def preprocess_canny(image, resolution=128):
    if not isinstance(image, np.ndarray):
        image = np.array(image) 
        
    processed_image = cv2.Canny(image, low_threshold, high_threshold)
    processed_image = processed_image[:, :, None]
    processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2)

    resized_image = Image.fromarray(image)
    processed_image = Image.fromarray(processed_image)
    return resized_image, processed_image


def create_demo(process, max_images=12, default_num_images=4):
    with gr.Blocks() as demo:
        with gr.Row():
            gr.Markdown('## Control Stable Diffusion with Canny Edge Maps')
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(source='upload', type='numpy')
                prompt = gr.Textbox(label='Prompt')
                run_button = gr.Button(label='Run')
                with gr.Accordion('Advanced options', open=False):
                    is_canny = gr.Checkbox(
                        label='Is canny', value=False)
                    num_samples = gr.Slider(label='Images',
                                            minimum=1,
                                            maximum=max_images,
                                            value=default_num_images,
                                            step=1)
                    """
                    canny_low_threshold = gr.Slider(
                        label='Canny low threshold',
                        minimum=1,
                        maximum=255,
                        value=100,
                        step=1)
                    canny_high_threshold = gr.Slider(
                        label='Canny high threshold',
                        minimum=1,
                        maximum=255,
                        value=200,
                        step=1)
                    """
                    resolution = gr.Slider(label='Resolution',
                                          minimum=128,
                                          maximum=128,
                                          value=128,
                                          step=1)
                    num_steps = gr.Slider(label='Steps',
                                          minimum=1,
                                          maximum=100,
                                          value=20,
                                          step=1)
                    guidance_scale = gr.Slider(label='Guidance Scale',
                                               minimum=0.1,
                                               maximum=30.0,
                                               value=7.5,
                                               step=0.1)
                    seed = gr.Slider(label='Seed',
                                     minimum=-1,
                                     maximum=2147483647,
                                     step=1,
                                     randomize=True)
                    n_prompt = gr.Textbox(
                        label='Negative Prompt',
                        value=
                        'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
                    )
            with gr.Column():
                result = gr.Gallery(label='Output',
                                    show_label=False,
                                    elem_id='gallery').style(grid=2,
                                                             height='auto')
        inputs = [
            input_image,
            prompt,
            is_canny,
            num_samples,
            resolution,
            #canny_low_threshold,
            #canny_high_threshold,
            num_steps,
            guidance_scale,
            seed,
            n_prompt,
        ]
        prompt.submit(fn=process, inputs=inputs, outputs=result)
        run_button.click(fn=process,
                         inputs=inputs,
                         outputs=result,
                         api_name='canny')
    
    return demo

if __name__ == '__main__':

    pipe_inference
    demo = create_demo(pipe_inference)
    demo.queue().launch()
    # gr.Interface(create_demo).launch()