Jordan Legg commited on
Commit
3d05f5b
β€’
1 Parent(s): b9bd528

simplify code

Browse files
Files changed (1) hide show
  1. app.py +30 -42
app.py CHANGED
@@ -3,14 +3,14 @@ import numpy as np
3
  import random
4
  import spaces
5
  import torch
6
- from diffusers import DiffusionPipeline
7
 
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
9
 
10
- # Load the model in FP16
11
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16)
12
-
13
- # Move the pipeline to GPU if available
14
  pipe = pipe.to(device)
15
 
16
  # Convert text encoders to full precision
@@ -18,48 +18,36 @@ pipe.text_encoder = pipe.text_encoder.to(torch.float32)
18
  if hasattr(pipe, 'text_encoder_2'):
19
  pipe.text_encoder_2 = pipe.text_encoder_2.to(torch.float32)
20
 
21
- # Enable memory efficient attention if available and on CUDA
22
- if device == "cuda" and hasattr(pipe, 'enable_xformers_memory_efficient_attention'):
23
- try:
24
- pipe.enable_xformers_memory_efficient_attention()
25
- print("xformers memory efficient attention enabled")
26
- except Exception as e:
27
- print(f"Could not enable memory efficient attention: {e}")
28
-
29
- # Compile the UNet for potential speedups if on CUDA
30
- if device == "cuda":
31
- try:
32
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
33
- print("UNet compiled for potential speedups")
34
- except Exception as e:
35
- print(f"Could not compile UNet: {e}")
36
-
37
  MAX_SEED = np.iinfo(np.int32).max
38
  MAX_IMAGE_SIZE = 2048
39
 
40
  @spaces.GPU()
41
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
42
- if randomize_seed:
43
- seed = random.randint(0, MAX_SEED)
44
- generator = torch.Generator(device=device).manual_seed(seed)
45
-
46
- # Use full precision for text encoding
47
- with torch.no_grad():
48
- text_inputs = pipe.tokenizer(prompt, return_tensors="pt").to(device)
49
- text_embeddings = pipe.text_encoder(text_inputs.input_ids)[0]
50
-
51
- # Use mixed precision for the rest of the pipeline
52
- with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
53
- image = pipe(
54
- prompt_embeds=text_embeddings,
55
- width=width,
56
- height=height,
57
- num_inference_steps=num_inference_steps,
58
- generator=generator,
59
- guidance_scale=0.0
60
- ).images[0]
61
-
62
- return image, seed
 
 
 
 
63
 
64
  examples = [
65
  "a tiny astronaut hatching from an egg on the moon",
 
3
  import random
4
  import spaces
5
  import torch
6
+ from diffusers import FluxPipeline
7
 
8
+ # Check for CUDA and set device
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ print(f"Using device: {device}")
11
 
12
+ # Load the model
13
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16)
 
 
14
  pipe = pipe.to(device)
15
 
16
  # Convert text encoders to full precision
 
18
  if hasattr(pipe, 'text_encoder_2'):
19
  pipe.text_encoder_2 = pipe.text_encoder_2.to(torch.float32)
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  MAX_SEED = np.iinfo(np.int32).max
22
  MAX_IMAGE_SIZE = 2048
23
 
24
  @spaces.GPU()
25
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
26
+ try:
27
+ if randomize_seed:
28
+ seed = random.randint(0, MAX_SEED)
29
+ generator = torch.Generator(device=device).manual_seed(seed)
30
+
31
+ # Use full precision for text encoding
32
+ with torch.no_grad():
33
+ text_inputs = pipe.tokenizer(prompt, return_tensors="pt").to(device)
34
+ text_embeddings = pipe.text_encoder(text_inputs.input_ids)[0]
35
+
36
+ # Use mixed precision for the rest of the pipeline
37
+ with torch.autocast(device_type=device, dtype=torch.float16):
38
+ image = pipe(
39
+ prompt_embeds=text_embeddings,
40
+ width=width,
41
+ height=height,
42
+ num_inference_steps=num_inference_steps,
43
+ generator=generator,
44
+ guidance_scale=0.0
45
+ ).images[0]
46
+
47
+ return image, seed
48
+ except Exception as e:
49
+ print(f"Error during inference: {e}")
50
+ return None, seed
51
 
52
  examples = [
53
  "a tiny astronaut hatching from an egg on the moon",