Culda commited on
Commit
8bb37ad
1 Parent(s): 4c2b067
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -16,18 +16,20 @@ HF_TOKEN = os.getenv("HF_TOKEN")
16
  #
17
  # login()
18
 
19
- dtype = torch.bfloat16
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
  base_model = "black-forest-labs/FLUX.1-dev"
23
  controlnet_model = "YishaoAI/flux-dev-controlnet-canny-kid-clothes"
24
 
25
- controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=dtype)
 
 
26
  pipe = FluxControlNetInpaintPipeline.from_pretrained(
27
- base_model, controlnet=controlnet, torch_dtype=dtype
28
  ).to(device)
29
 
30
- # pipe.enable_model_cpu_offload()
31
 
32
  canny = CannyDetector()
33
 
 
16
  #
17
  # login()
18
 
19
+ dtype = torch.float16
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
  base_model = "black-forest-labs/FLUX.1-dev"
23
  controlnet_model = "YishaoAI/flux-dev-controlnet-canny-kid-clothes"
24
 
25
+ controlnet = FluxControlNetModel.from_pretrained(
26
+ controlnet_model, torch_dtype=dtype, device_map="auto"
27
+ )
28
  pipe = FluxControlNetInpaintPipeline.from_pretrained(
29
+ base_model, controlnet=controlnet, torch_dtype=dtype, device_map="auto"
30
  ).to(device)
31
 
32
+ pipe.enable_model_cpu_offload()
33
 
34
  canny = CannyDetector()
35