Spaces:
Sleeping
Sleeping
import jax | |
import numpy as np | |
from flax.jax_utils import replicate | |
from flax.training.common_utils import shard | |
from diffusers import DiffusionPipeline | |
model_path = "sabman/map-diffuser-v3" | |
pipeline = DiffusionPipeline.from_pretrained( | |
model_path, | |
from_flax=True, safety_checker=None).to("cuda") | |
# prompt = "create a map with traffic signals, busway and residential buildings, in water color style" | |
def generate_images(prompt): | |
prng_seed = jax.random.PRNGKey(-1) | |
num_inference_steps = 20 | |
images = pipeline(prompt, width=512, num_inference_steps=num_inference_steps, num_images_per_prompt=1).images | |
return images[0] | |