Kabatubare commited on
Commit
eaaecdb
·
verified ·
1 Parent(s): 3339520

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -25
app.py CHANGED
@@ -13,41 +13,39 @@ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
13
  base = "stabilityai/stable-diffusion-xl-base-1.0"
14
  repo = "ByteDance/SDXL-Lightning"
15
  checkpoints = {
16
- "Warp 1" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
17
- "Warp 2" : ["sdxl_lightning_2step_unet.safetensors", 2],
18
- "Warp 4" : ["sdxl_lightning_4step_unet.safetensors", 4],
19
- "Warp 8" : ["sdxl_lightning_8step_unet.safetensors", 8],
20
  }
21
  loaded = None
22
 
23
-
24
  # Ensure model and scheduler are initialized in GPU-enabled function
25
  if torch.cuda.is_available():
26
  pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
27
 
28
  if SAFETY_CHECKER:
29
  from safety_checker import StableDiffusionSafetyChecker
30
- from transformers import CLIPFeatureExtractor
31
 
32
  safety_checker = StableDiffusionSafetyChecker.from_pretrained(
33
  "CompVis/stable-diffusion-safety-checker"
34
  ).to("cuda")
35
- feature_extractor = CLIPFeatureExtractor.from_pretrained(
 
36
  "openai/clip-vit-base-patch32"
37
  )
38
 
39
- def check_nsfw_images(
40
- images: list[Image.Image],
41
- ) -> tuple[list[Image.Image], list[bool]]:
42
  safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
43
  has_nsfw_concepts = safety_checker(
44
- images=[images],
45
- clip_input=safety_checker_input.pixel_values.to("cuda")
46
  )
47
 
48
  return images, has_nsfw_concepts
49
 
50
- # Function
51
  @spaces.GPU(enable_queue=True)
52
  def generate_image(prompt, ckpt):
53
  global loaded
@@ -57,10 +55,10 @@ def generate_image(prompt, ckpt):
57
  num_inference_steps = checkpoints[ckpt][1]
58
 
59
  if loaded != num_inference_steps:
60
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon")
61
  pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
62
  loaded = num_inference_steps
63
-
64
  results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
65
 
66
  if SAFETY_CHECKER:
@@ -71,8 +69,11 @@ def generate_image(prompt, ckpt):
71
  return images[0]
72
  return results.images[0]
73
 
 
 
 
 
74
 
75
- # Update the dropdown choices to match the corrected keys in the 'checkpoints' dictionary
76
  with gr.Blocks(css="style.css") as demo:
77
  gr.HTML("<h1><center>🌌 Starfleet Command: Text-to-Image Warp Drive - SDXL-Lightning ⚡</center></h1>")
78
  gr.Markdown(description)
@@ -83,13 +84,7 @@ with gr.Blocks(css="style.css") as demo:
83
  submit = gr.Button(scale=1, variant='primary')
84
  img = gr.Image(label='The Universe, As You Envision It')
85
 
86
- prompt.submit(fn=generate_image,
87
- inputs=[prompt, ckpt],
88
- outputs=img,
89
- )
90
- submit.click(fn=generate_image,
91
- inputs=[prompt, ckpt],
92
- outputs=img,
93
- )
94
-
95
  demo.queue().launch()
 
13
  base = "stabilityai/stable-diffusion-xl-base-1.0"
14
  repo = "ByteDance/SDXL-Lightning"
15
  checkpoints = {
16
+ "Warp 1": ["sdxl_lightning_1step_unet_x0.safetensors", 1],
17
+ "Warp 2": ["sdxl_lightning_2step_unet.safetensors", 2],
18
+ "Warp 4": ["sdxl_lightning_4step_unet.safetensors", 4],
19
+ "Warp 8": ["sdxl_lightning_8step_unet.safetensors", 8],
20
  }
21
  loaded = None
22
 
 
23
  # Ensure model and scheduler are initialized in GPU-enabled function
24
  if torch.cuda.is_available():
25
  pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
26
 
27
  if SAFETY_CHECKER:
28
  from safety_checker import StableDiffusionSafetyChecker
29
+ from transformers import CLIPFeatureExtractor, CLIPProcessor
30
 
31
  safety_checker = StableDiffusionSafetyChecker.from_pretrained(
32
  "CompVis/stable-diffusion-safety-checker"
33
  ).to("cuda")
34
+ # Updated to use CLIPProcessor as CLIPFeatureExtractor is deprecated
35
+ feature_extractor = CLIPProcessor.from_pretrained(
36
  "openai/clip-vit-base-patch32"
37
  )
38
 
39
+ def check_nsfw_images(images: list[Image.Image]) -> tuple[list[Image.Image], list[bool]]:
 
 
40
  safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
41
  has_nsfw_concepts = safety_checker(
42
+ images=images,
43
+ clip_input=safety_checker_input.pixel_values.to("cuda"),
44
  )
45
 
46
  return images, has_nsfw_concepts
47
 
48
+ # Function
49
  @spaces.GPU(enable_queue=True)
50
  def generate_image(prompt, ckpt):
51
  global loaded
 
55
  num_inference_steps = checkpoints[ckpt][1]
56
 
57
  if loaded != num_inference_steps:
58
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps == 1 else "epsilon")
59
  pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
60
  loaded = num_inference_steps
61
+
62
  results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
63
 
64
  if SAFETY_CHECKER:
 
69
  return images[0]
70
  return results.images[0]
71
 
72
+ # Gradio Interface
73
+ description = """
74
+ Welcome aboard the Starship SDXL Enterprise! Our mission: to explore strange new AI generations, to seek out new visual frontiers and computational boundaries, to boldly generate images like no one has seen before. Utilizing the cutting-edge SDXL-Lightning model, we're at the forefront of text-to-image technology, ready to transform your imaginative prompts into visual spectacles. Whether you're navigating the uncharted territories of outer space or delving into the realms of fantasy, your adventure begins now. Model powered by the pioneering intellects at ByteDance. Journey safely through the stars!
75
+ """
76
 
 
77
  with gr.Blocks(css="style.css") as demo:
78
  gr.HTML("<h1><center>🌌 Starfleet Command: Text-to-Image Warp Drive - SDXL-Lightning ⚡</center></h1>")
79
  gr.Markdown(description)
 
84
  submit = gr.Button(scale=1, variant='primary')
85
  img = gr.Image(label='The Universe, As You Envision It')
86
 
87
+ prompt.submit(fn=generate_image, inputs=[prompt, ckpt], outputs=img)
88
+ submit.click(fn=generate_image, inputs=[prompt, ckpt], outputs=img)
89
+
 
 
 
 
 
 
90
  demo.queue().launch()