sabman commited on
Commit
fe0eaf3
1 Parent(s): ae143ff

Update inference_code.py

Browse files
Files changed (1) hide show
  1. inference_code.py +13 -9
inference_code.py CHANGED
@@ -16,16 +16,20 @@ def generate_images(prompt):
16
  prng_seed = jax.random.PRNGKey(-1)
17
  num_inference_steps = 20
18
 
19
- num_samples = jax.device_count()
20
- prompt = num_samples * [prompt]
21
- prompt_ids = pipeline.prepare_inputs(prompt)
22
 
23
- # shard inputs and rng
24
- params = replicate(_params)
25
- prng_seed = jax.random.split(prng_seed, jax.device_count())
26
- prompt_ids = shard(prompt_ids)
27
 
28
- images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
29
- images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
 
 
 
 
 
30
 
31
  return images[0]
 
16
  prng_seed = jax.random.PRNGKey(-1)
17
  num_inference_steps = 20
18
 
19
+ images = pipeline(prompt, width=512, num_inference_steps=20, num_images_per_prompt=1).images
20
+ images = pipeline.numpy_to_pil(np.asarray(images.reshape((1,) + images.shape[-3:])))
 
21
 
22
+
23
+ # num_samples = jax.device_count()
24
+ # prompt = num_samples * [prompt]
25
+ # prompt_ids = pipeline.prepare_inputs(prompt)
26
 
27
+ # # shard inputs and rng
28
+ # params = replicate(_params)
29
+ # prng_seed = jax.random.split(prng_seed, jax.device_count())
30
+ # prompt_ids = shard(prompt_ids)
31
+
32
+ # images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
33
+ # images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
34
 
35
  return images[0]