Spaces:
Sleeping
Sleeping
import jax | |
import numpy as np | |
from flax.jax_utils import replicate | |
from flax.training.common_utils import shard | |
from diffusers import FlaxStableDiffusionPipeline | |
model_path = "sabman/map-diffuser-v3" | |
pipeline, _params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16) | |
# prompt = "create a map with traffic signals, busway and residential buildings, in water color style" | |
def generate_images(prompt): | |
prng_seed = jax.random.PRNGKey(-1) | |
num_inference_steps = 50 | |
num_samples = jax.device_count() | |
prompt = num_samples * [prompt] | |
prompt_ids = pipeline.prepare_inputs(prompt) | |
# shard inputs and rng | |
params = replicate(_params) | |
prng_seed = jax.random.split(prng_seed, jax.device_count()) | |
prompt_ids = shard(prompt_ids) | |
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images | |
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) | |
return images[0] | |