Learner commited on
Commit
669cb62
1 Parent(s): a260e25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -11
app.py CHANGED
@@ -1,9 +1,5 @@
1
  import gradio as gr
2
-
3
- import requests
4
  from PIL import Image
5
- from pathlib import Path
6
- from io import BytesIO
7
 
8
  # Diffusers
9
  from diffusers import (
@@ -11,18 +7,14 @@ from diffusers import (
11
  FlaxStableDiffusionControlNetPipeline
12
  )
13
  from diffusers.utils import load_image
14
-
15
- # Pytorch
16
  import torch
17
-
18
  # Numpy
19
  import numpy as np
20
-
21
  # Jax
22
  import jax
23
  import jax.numpy as jnp
24
  from jax import pmap
25
-
26
  # Flax
27
  import flax
28
  from flax.jax_utils import replicate
@@ -53,11 +45,11 @@ def infer(prompts, negative_prompts, image):
53
  num_samples = 1 # jax.device_count()
54
  rng = create_key(0)
55
  rng = jax.random.split(rng, jax.device_count())
56
- #battlemap_image = Image.open(image)
57
 
58
  prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
59
  negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
60
- processed_image = pipe.prepare_image_inputs([image] * num_samples) #battlemap_image
61
 
62
  p_params = replicate(params)
63
  prompt_ids = shard(prompt_ids)
 
1
  import gradio as gr
 
 
2
  from PIL import Image
 
 
3
 
4
  # Diffusers
5
  from diffusers import (
 
7
  FlaxStableDiffusionControlNetPipeline
8
  )
9
  from diffusers.utils import load_image
10
+ # PyTorch
 
11
  import torch
 
12
  # Numpy
13
  import numpy as np
 
14
  # Jax
15
  import jax
16
  import jax.numpy as jnp
17
  from jax import pmap
 
18
  # Flax
19
  import flax
20
  from flax.jax_utils import replicate
 
45
  num_samples = 1 # jax.device_count()
46
  rng = create_key(0)
47
  rng = jax.random.split(rng, jax.device_count())
48
+ battlemap_image = Image.open(image).copy()
49
 
50
  prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
51
  negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
52
+ processed_image = pipe.prepare_image_inputs([battlemap_image] * num_samples) #battlemap_image
53
 
54
  p_params = replicate(params)
55
  prompt_ids = shard(prompt_ids)