fallenshock commited on
Commit
02b6647
1 Parent(s): 53c1364

Update app.py

Browse files

changed model loading

Files changed (1) hide show
  1. app.py +18 -16
app.py CHANGED
@@ -18,7 +18,10 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
18
  # device = "cpu"
19
  # model_type = 'SD3'
20
 
21
- # pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
 
 
 
22
  # scheduler = pipe.scheduler
23
  # pipe = pipe.to(device)
24
  loaded_model = 'None'
@@ -63,7 +66,7 @@ def get_examples():
63
  return case
64
 
65
 
66
- @spaces.GPU(duration=95)
67
  def FlowEditRun(
68
  image_src: str,
69
  model_type: str,
@@ -84,22 +87,21 @@ def FlowEditRun(
84
  if not len(tar_prompt):
85
  raise gr.Error("target prompt cannot be empty")
86
 
87
- global pipe
88
- global scheduler
89
- global loaded_model
90
 
91
  # reload model only if different from the loaded model
92
- if loaded_model != model_type:
93
-
94
- if model_type == 'FLUX':
95
- # pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16)
96
- pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.float16, token=os.getenv('HF_ACCESS_TOK'))
97
- loaded_model = 'FLUX'
98
- elif model_type == 'SD3':
99
- pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16, token=os.getenv('HF_ACCESS_TOK'))
100
- loaded_model = 'SD3'
101
- else:
102
- raise NotImplementedError(f"Model type {model_type} not implemented")
103
 
104
  scheduler = pipe.scheduler
105
  pipe = pipe.to(device)
 
18
  # device = "cpu"
19
  # model_type = 'SD3'
20
 
21
+ pipe_sd3 = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16, token=os.getenv('HF_ACCESS_TOK'))
22
+ pipe_flux = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.float16, token=os.getenv('HF_ACCESS_TOK'))
23
+
24
+
25
  # scheduler = pipe.scheduler
26
  # pipe = pipe.to(device)
27
  loaded_model = 'None'
 
66
  return case
67
 
68
 
69
+ @spaces.GPU(duration=60)
70
  def FlowEditRun(
71
  image_src: str,
72
  model_type: str,
 
87
  if not len(tar_prompt):
88
  raise gr.Error("target prompt cannot be empty")
89
 
90
+ # global pipe_sd3
91
+ # global scheduler
92
+ # global loaded_model
93
 
94
  # reload model only if different from the loaded model
95
+ # if loaded_model != model_type:
96
+
97
+ if model_type == 'FLUX':
98
+ # pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.float16, token=os.getenv('HF_ACCESS_TOK'))
99
+ pipe = pipe_flux.clone() # still on CPU
100
+ elif model_type == 'SD3':
101
+ # pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16, token=os.getenv('HF_ACCESS_TOK'))
102
+ pipe = pipe_sd3.clone() # still on CPU
103
+ else:
104
+ raise NotImplementedError(f"Model type {model_type} not implemented")
 
105
 
106
  scheduler = pipe.scheduler
107
  pipe = pipe.to(device)