diopside commited on
Commit
48fd1f5
·
verified ·
1 Parent(s): 4f1e83c

Utilize HF's "balanced" device_map + dynamically pair diffusion components to relevant execution cores

Browse files

When using ZeroGPU, pytorch throws exception that actually stems from OOM.

By utilizing balanced mode + explicitly pairing diffusion components, we avoid that OOM.

Distribution approach (i.e): Text encoder on GPU 1 ~16.6GB, Everything else on GPU 2 (cuda:2) - ~44.5GB including: Controlnet (~4.23GB), VAE (~254MB), Transformer (~40GB).

This keeps the overall memory usage efficiently split across the GPUs while ensuring all components that need to interact directly are on the same device.

Files changed (1) hide show
  1. app.py +34 -7
app.py CHANGED
@@ -71,12 +71,39 @@ def use_output_as_input(output_image):
71
  base_model = "Qwen/Qwen-Image"
72
  controlnet_model = "InstantX/Qwen-Image-ControlNet-Inpainting"
73
 
74
- controlnet = QwenImageControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
75
-
76
  pipe = QwenImageControlNetInpaintPipeline.from_pretrained(
77
- base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
 
 
 
78
  )
79
- pipe.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  @spaces.GPU(duration=150)
@@ -93,7 +120,7 @@ def infer(edit_images,
93
 
94
  image = edit_images["background"]
95
  mask = edit_images["layers"][0]
96
-
97
  if randomize_seed:
98
  seed = random.randint(0, MAX_SEED)
99
 
@@ -113,7 +140,7 @@ def infer(edit_images,
113
  width=image.size[0],
114
  height=image.size[1],
115
  true_cfg_scale=true_cfg_scale,
116
- generator=torch.Generator(device="cuda").manual_seed(seed)
117
  ).images[0]
118
 
119
  return [image, result_image], seed
@@ -140,7 +167,7 @@ css = """
140
 
141
 
142
  with gr.Blocks(css=css, theme=gr.themes.Citrus()) as demo:
143
- gr.HTML("<h1 style='text-align: center'>Qwen-Image with InstantX Inpainting ControlNet</style>")
144
  gr.Markdown(
145
  "Inpaint images with [InstantX/Qwen-Image-ControlNet-Inpainting](https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting)"
146
  )
 
71
  base_model = "Qwen/Qwen-Image"
72
  controlnet_model = "InstantX/Qwen-Image-ControlNet-Inpainting"
73
 
74
+ # First create the pipeline with device_map="balanced"
 
75
  pipe = QwenImageControlNetInpaintPipeline.from_pretrained(
76
+ base_model,
77
+ controlnet=None, # We'll add the controlnet later
78
+ torch_dtype=torch.bfloat16,
79
+ device_map="balanced"
80
  )
81
+
82
+ pipe_device_map = pipe.hf_device_map
83
+ print("Initial device map:", pipe_device_map)
84
+ # Expected output: {'transformer': 0, 'text_encoder': 1, 'vae': 2}
85
+
86
+ # Move the controlnet to the same device as the VAE (cuda:2)
87
+ vae_device = pipe_device_map['vae']
88
+ vae_device = f"cuda:{vae_device}" # This is where the VAE is in the balanced config
89
+ controlnet = QwenImageControlNetModel.from_pretrained(
90
+ controlnet_model,
91
+ torch_dtype=torch.bfloat16
92
+ ).to(vae_device)
93
+
94
+ # Attach the controlnet to the pipeline
95
+ pipe.controlnet = controlnet
96
+
97
+ pipe.enable_vae_slicing()
98
+ pipe.enable_vae_tiling()
99
+
100
+ print("Controlnet device:", next(pipe.controlnet.parameters()).device)
101
+ print("VAE device:", next(pipe.vae.parameters()).device)
102
+
103
+
104
+ # Create a helper function to get a generator on the correct device
105
+ def get_generator(seed):
106
+ return torch.Generator(device=vae_device).manual_seed(seed)
107
 
108
 
109
  @spaces.GPU(duration=150)
 
120
 
121
  image = edit_images["background"]
122
  mask = edit_images["layers"][0]
123
+
124
  if randomize_seed:
125
  seed = random.randint(0, MAX_SEED)
126
 
 
140
  width=image.size[0],
141
  height=image.size[1],
142
  true_cfg_scale=true_cfg_scale,
143
+ generator=get_generator(seed)
144
  ).images[0]
145
 
146
  return [image, result_image], seed
 
167
 
168
 
169
  with gr.Blocks(css=css, theme=gr.themes.Citrus()) as demo:
170
+ gr.HTML("<h1 style='text-align: center'>Qwen-Image + InstantX Inpainting ControlNet</style>")
171
  gr.Markdown(
172
  "Inpaint images with [InstantX/Qwen-Image-ControlNet-Inpainting](https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting)"
173
  )