kasper-boy commited on
Commit
7b92dfa
·
verified ·
1 Parent(s): 3b8fc31

Rename app.py to app_cpu.py

Browse files
Files changed (1) hide show
  1. app.py → app_cpu.py +20 -12
app.py → app_cpu.py RENAMED
@@ -1,12 +1,16 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import CLIPTextModel, CLIPTokenizer
4
- from diffusers import StableDiffusionPipeline
5
 
6
- # Load the model and tokenizer
7
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
8
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
9
- pipe = pipe.to("cpu")
 
 
 
 
10
 
11
  def generate_image(prompt, negative_prompt, size):
12
  if not prompt:
@@ -16,13 +20,16 @@ def generate_image(prompt, negative_prompt, size):
16
 
17
  width, height = map(int, size.split('x'))
18
  generator = torch.Generator("cpu").manual_seed(42)
19
-
20
- # Generate the image
21
- result = pipe(prompt, height=height, width=width, negative_prompt=negative_prompt, generator=generator)
22
-
23
- if result is not None and 'images' in result:
24
- return result.images[0]
25
- else:
 
 
 
26
  return None
27
 
28
  with gr.Blocks() as demo:
@@ -41,6 +48,7 @@ with gr.Blocks() as demo:
41
 
42
  demo.launch()
43
 
 
44
  # import gradio as gr
45
  # import torch
46
  # from transformers import CLIPTextModel, CLIPTokenizer
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import DiffusionPipeline
4
+ from PIL import Image
5
 
6
+ # Load the model
7
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
8
+ pipe = DiffusionPipeline.from_pretrained(
9
+ model_id,
10
+ torch_dtype=torch.float32, # Use float32 for CPU
11
+ use_safetensors=True
12
+ )
13
+ pipe.to("cpu")
14
 
15
  def generate_image(prompt, negative_prompt, size):
16
  if not prompt:
 
20
 
21
  width, height = map(int, size.split('x'))
22
  generator = torch.Generator("cpu").manual_seed(42)
23
+
24
+ try:
25
+ result = pipe(prompt=prompt, height=height, width=width, negative_prompt=negative_prompt, generator=generator)
26
+ if result and hasattr(result, 'images') and len(result.images) > 0:
27
+ return result.images[0]
28
+ else:
29
+ print("Error: No images in the result or result is None")
30
+ return None
31
+ except Exception as e:
32
+ print(f"Error occurred: {e}")
33
  return None
34
 
35
  with gr.Blocks() as demo:
 
48
 
49
  demo.launch()
50
 
51
+
52
  # import gradio as gr
53
  # import torch
54
  # from transformers import CLIPTextModel, CLIPTokenizer