saburq commited on
Commit
41b7f97
1 Parent(s): 0b04592

return just one image

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. inference_code.py +4 -4
app.py CHANGED
@@ -10,7 +10,7 @@ def generate_image_predictions(prompt):
10
  iface = gr.Interface(
11
  fn=generate_image_predictions,
12
  inputs=gr.components.Textbox(label="Enter a text prompt here"),
13
- outputs=[gr.components.Image(label="Output Image") for i in range(4)],
14
  title="Map Diffuser",
15
  description="Generates four images from a given text prompt.",
16
  examples=[["Satellite image of amsterdam with industrial area and highways"], [
 
10
  iface = gr.Interface(
11
  fn=generate_image_predictions,
12
  inputs=gr.components.Textbox(label="Enter a text prompt here"),
13
+ outputs=gr.components.Image(label="Output Image"),
14
  title="Map Diffuser",
15
  description="Generates four images from a given text prompt.",
16
  examples=[["Satellite image of amsterdam with industrial area and highways"], [
inference_code.py CHANGED
@@ -4,11 +4,11 @@ from flax.jax_utils import replicate
4
  from flax.training.common_utils import shard
5
  from diffusers import FlaxStableDiffusionPipeline
6
 
 
 
7
 
8
  # prompt = "create a map with traffic signals, busway and residential buildings, in water color style"
9
  def generate_images(prompt):
10
- model_path = "sabman/map-diffuser-v3"
11
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
12
 
13
  prng_seed = jax.random.PRNGKey(-1)
14
  num_inference_steps = 50
@@ -18,11 +18,11 @@ def generate_images(prompt):
18
  prompt_ids = pipeline.prepare_inputs(prompt)
19
 
20
  # shard inputs and rng
21
- params = replicate(params)
22
  prng_seed = jax.random.split(prng_seed, jax.device_count())
23
  prompt_ids = shard(prompt_ids)
24
 
25
  images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
26
  images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
27
 
28
- return [images[0], images[1], images[2], images[3]]
 
4
  from flax.training.common_utils import shard
5
  from diffusers import FlaxStableDiffusionPipeline
6
 
7
+ model_path = "sabman/map-diffuser-v3"
8
+ pipeline, _params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
9
 
10
  # prompt = "create a map with traffic signals, busway and residential buildings, in water color style"
11
  def generate_images(prompt):
 
 
12
 
13
  prng_seed = jax.random.PRNGKey(-1)
14
  num_inference_steps = 50
 
18
  prompt_ids = pipeline.prepare_inputs(prompt)
19
 
20
  # shard inputs and rng
21
+ params = replicate(_params)
22
  prng_seed = jax.random.split(prng_seed, jax.device_count())
23
  prompt_ids = shard(prompt_ids)
24
 
25
  images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
26
  images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
27
 
28
+ return images[0]