Jordan Legg commited on
Commit
d53c0bb
β€’
1 Parent(s): 6af450a

loading chain

Browse files
Files changed (1) hide show
  1. app.py +17 -3
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
@@ -5,15 +8,16 @@ import torch
5
  from PIL import Image
6
  from torchvision import transforms
7
  from diffusers import DiffusionPipeline, AutoencoderKL
8
- import spaces
9
 
10
  # Define constants
11
  flux_dtype = torch.bfloat16
12
  vae_dtype = torch.float32
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
  MAX_SEED = np.iinfo(np.int32).max
15
  MAX_IMAGE_SIZE = 2048
16
 
 
 
 
17
  def load_models():
18
  # Load the initial VAE model for preprocessing in float32
19
  vae_model_name = "runwayml/stable-diffusion-v1-5"
@@ -28,7 +32,13 @@ def load_models():
28
 
29
  return vae, pipe
30
 
31
- vae, pipe = load_models()
 
 
 
 
 
 
32
 
33
  def preprocess_image(image, image_size):
34
  preprocess = transforms.Compose([
@@ -52,6 +62,8 @@ def encode_image(image, vae):
52
 
53
  @spaces.GPU()
54
  def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
 
 
55
  if randomize_seed:
56
  seed = random.randint(0, MAX_SEED)
57
  generator = torch.Generator(device=device).manual_seed(seed)
@@ -113,6 +125,8 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
113
  print(f"Error during inference: {e}")
114
  return fallback_image, seed
115
 
 
 
116
  # Define example prompts
117
  examples = [
118
  "a tiny astronaut hatching from an egg on the moon",
 
1
+ # Import spaces first to avoid CUDA initialization conflicts
2
+ import spaces
3
+
4
  import gradio as gr
5
  import numpy as np
6
  import random
 
8
  from PIL import Image
9
  from torchvision import transforms
10
  from diffusers import DiffusionPipeline, AutoencoderKL
 
11
 
12
  # Define constants
13
  flux_dtype = torch.bfloat16
14
  vae_dtype = torch.float32
 
15
  MAX_SEED = np.iinfo(np.int32).max
16
  MAX_IMAGE_SIZE = 2048
17
 
18
+ # Move device selection after spaces import
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
  def load_models():
22
  # Load the initial VAE model for preprocessing in float32
23
  vae_model_name = "runwayml/stable-diffusion-v1-5"
 
32
 
33
  return vae, pipe
34
 
35
+ # Defer model loading until it's needed
36
+ vae, pipe = None, None
37
+
38
+ def ensure_models_loaded():
39
+ global vae, pipe
40
+ if vae is None or pipe is None:
41
+ vae, pipe = load_models()
42
 
43
  def preprocess_image(image, image_size):
44
  preprocess = transforms.Compose([
 
62
 
63
  @spaces.GPU()
64
  def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
65
+ ensure_models_loaded()
66
+
67
  if randomize_seed:
68
  seed = random.randint(0, MAX_SEED)
69
  generator = torch.Generator(device=device).manual_seed(seed)
 
125
  print(f"Error during inference: {e}")
126
  return fallback_image, seed
127
 
128
+ # ... (rest of the Gradio interface code remains the same)
129
+
130
  # Define example prompts
131
  examples = [
132
  "a tiny astronaut hatching from an egg on the moon",