multimodalart HF staff commited on
Commit
b50750a
·
verified ·
1 Parent(s): 067f7c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -26,11 +26,11 @@ USE_TORCH_COMPILE = False
26
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
27
  PREVIEW_IMAGES = True
28
 
29
- dtype = torch.float16
30
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
  if torch.cuda.is_available():
32
- prior_pipeline = StableCascadePriorPipeline.from_pretrained("diffusers/StableCascade-prior", torch_dtype=torch.bfloat16).to("cuda")
33
- decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("diffusers/StableCascade-decoder", torch_dtype=torch.bfloat16).to("cuda")
34
 
35
  if ENABLE_CPU_OFFLOAD:
36
  prior_pipeline.enable_model_cpu_offload()
@@ -46,6 +46,7 @@ if torch.cuda.is_available():
46
  if PREVIEW_IMAGES:
47
  previewer = Previewer()
48
  previewer.load_state_dict(torch.load("previewer/previewer_v1_100k.pt")["state_dict"])
 
49
  def callback_prior(i, t, latents):
50
  output = previewer(latents)
51
  output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy())
@@ -81,9 +82,9 @@ def generate(
81
  num_images_per_prompt: int = 2,
82
  #profile: gr.OAuthProfile | None = None,
83
  ) -> PIL.Image.Image:
84
- prior_pipeline.to("cuda")
85
- decoder_pipeline.to("cuda")
86
- previewer.eval().requires_grad_(False).to(device).to(dtype)
87
  generator = torch.Generator().manual_seed(seed)
88
  prior_output = prior_pipeline(
89
  prompt=prompt,
 
26
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
27
  PREVIEW_IMAGES = True
28
 
29
+ dtype = torch.bfloat16
30
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
  if torch.cuda.is_available():
32
+ prior_pipeline = StableCascadePriorPipeline.from_pretrained("diffusers/StableCascade-prior", torch_dtype=dtype).to(device)
33
+ decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("diffusers/StableCascade-decoder", torch_dtype=dtype).to(device)
34
 
35
  if ENABLE_CPU_OFFLOAD:
36
  prior_pipeline.enable_model_cpu_offload()
 
46
  if PREVIEW_IMAGES:
47
  previewer = Previewer()
48
  previewer.load_state_dict(torch.load("previewer/previewer_v1_100k.pt")["state_dict"])
49
+ previewer.eval().requires_grad_(False).to(device).to(dtype)
50
  def callback_prior(i, t, latents):
51
  output = previewer(latents)
52
  output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy())
 
82
  num_images_per_prompt: int = 2,
83
  #profile: gr.OAuthProfile | None = None,
84
  ) -> PIL.Image.Image:
85
+ #prior_pipeline.to(device)
86
+ #decoder_pipeline.to(device)
87
+ #previewer.eval().requires_grad_(False).to(device).to(dtype)
88
  generator = torch.Generator().manual_seed(seed)
89
  prior_output = prior_pipeline(
90
  prompt=prompt,