File size: 2,376 Bytes
a324479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import jax
from PIL import Image
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
import jax.numpy as jnp
import numpy as np


title = "🧨 ControlNet on Segment Anything 🤗"
description = "This is a demo on ControlNet based on Segment Anything"

examples = [["a modern main room of a house", "low quality", "condition_image_1.png", 50, 4]]

controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
    "mfidabel/controlnet-segment-anything", dtype=jnp.float32
)

pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
)

# Add ControlNet params and Replicate
params["controlnet"] = controlnet_params
p_params = replicate(params)


# Inference Function
def infer(prompts, negative_prompts, image, num_inference_steps, seed):
    rng = jax.random.PRNGKey(int(seed))
    num_inference_steps = int(num_inference_steps)
    image = Image.fromarray(image, mode="RGB")
    num_samples = jax.device_count()
    p_rng = jax.random.split(rng, jax.device_count())
    
    prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
    negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
    processed_image = pipe.prepare_image_inputs([image] * num_samples)
    
    prompt_ids = shard(prompt_ids)
    negative_prompt_ids = shard(negative_prompt_ids)
    processed_image = shard(processed_image)
    
    output = pipe(
        prompt_ids=prompt_ids,
        image=processed_image,
        params=p_params,
        prng_seed=p_rng,
        num_inference_steps=num_inference_steps,
        neg_prompt_ids=negative_prompt_ids,
        jit=True,
    ).images

    print(output[0].shape)
    
    final_image = [np.array(x[0]*255, dtype=np.uint8) for x in output]

    del output
    
    return final_image
    
gr.Interface(fn = infer, 
             inputs = ["text", "text", "image", "number", "number"],
             outputs = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(columns=[2], rows=[2], object_fit="contain", height="auto", preview=True),
             title = title,
             description = description,
             examples = examples).launch()