tonyassi commited on
Commit
9cb2064
Β·
verified Β·
1 Parent(s): fdc2668

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -67
app.py CHANGED
@@ -1,22 +1,24 @@
1
  # app.py
2
  import os
 
3
  import spaces
4
  import gradio as gr
5
- from PIL import Image
6
  import torch
7
-
8
  from diffusers import AutoPipelineForInpainting, AutoencoderKL
9
 
10
- # -----------------------------
11
- # Pure-CPU helpers (no CUDA here)
12
- # -----------------------------
13
- from PIL import Image, ImageChops
14
- import math
15
 
16
- def _round_up(x, m=8):
17
  return int(math.ceil(x / m) * m)
18
 
19
- def autocrop_content(img: Image.Image, bg_color=(255, 255, 255), tol=12) -> Image.Image:
 
 
 
 
20
  if img.mode in ("RGBA", "LA"):
21
  alpha = img.split()[-1]
22
  bbox = alpha.getbbox()
@@ -28,42 +30,46 @@ def autocrop_content(img: Image.Image, bg_color=(255, 255, 255), tol=12) -> Imag
28
  bbox = mask.getbbox()
29
  return img.crop(bbox) if bbox else img
30
 
31
- def squarify_image(img: Image.Image, color="white") -> Image.Image:
32
- # 1) trim margins
33
- img = autocrop_content(img, bg_color=(255, 255, 255), tol=12)
34
-
35
- w, h = img.size
36
- # 2) target square side rounded **up** to /8
37
- side = _round_up(max(w, h), 8)
 
 
 
38
 
39
  bg = Image.new("RGB", (side, side), color=color)
40
- bg.paste(img, ((side - w) // 2, (side - h) // 2))
41
- return bg
 
 
42
 
43
- def divisible_by_8(image: Image.Image) -> Image.Image:
 
 
 
44
  w, h = image.size
45
- # round **up** so we never shrink (keeps content + avoids 1012-style errors)
46
- nw = _round_up(w, 8)
47
- nh = _round_up(h, 8)
48
  if (nw, nh) == (w, h):
49
  return image
50
  return image.resize((nw, nh), Image.LANCZOS)
51
 
52
- # -----------------------------
53
  # Lazy singletons (created inside GPU context)
54
- # -----------------------------
55
  PIPELINE = None
56
- IP_LOADED = False
57
 
58
  def _get_pipeline(device: str):
59
  """
60
  Create & cache the diffusers pipeline once we actually have a GPU (ZeroGPU).
61
  No CUDA calls should happen before this is executed.
62
  """
63
- global PIPELINE, IP_LOADED
64
-
65
  if PIPELINE is not None:
66
- # ensure it's on the current device (ZeroGPU gives you a device per call)
67
  PIPELINE.to(device)
68
  return PIPELINE
69
 
@@ -75,11 +81,10 @@ def _get_pipeline(device: str):
75
  if not ip_adapter_repo:
76
  raise RuntimeError("Missing env var IP_ADAPTER (e.g. 'h94/IP-Adapter').")
77
 
78
- # Build VAE & pipeline WITHOUT sending to CUDA yet
79
- # (dtype is fine; just don't .to('cuda') at import time)
80
  vae = AutoencoderKL.from_pretrained(
81
  "madebyollin/sdxl-vae-fp16-fix",
82
- torch_dtype=torch.float16
83
  )
84
 
85
  pipe = AutoPipelineForInpainting.from_pretrained(
@@ -90,65 +95,56 @@ def _get_pipeline(device: str):
90
  use_safetensors=True,
91
  )
92
 
93
- # Load IP-Adapter weights
94
- # (this only attaches modules; not a CUDA op)
95
  pipe.load_ip_adapter(
96
  ip_adapter_repo,
97
  subfolder="sdxl_models",
98
  weight_name="ip-adapter_sdxl.bin",
99
  )
100
 
101
- # NOW move the whole pipeline to the GPU that ZeroGPU just handed us
102
  pipe.to(device)
103
-
104
  PIPELINE = pipe
105
- IP_LOADED = True
106
  return PIPELINE
107
 
108
- # -----------------------------
109
  # Main generate (GPU section)
110
- # -----------------------------
111
- # Increase duration if you need >60s (100 steps on SDXL often does).
112
  @spaces.GPU(duration=180)
113
  def generate(person: Image.Image, clothing: Image.Image) -> Image.Image:
114
  """
115
  This function is called *after* ZeroGPU allocates a CUDA device.
116
  All CUDA/ONNXRuntime initializations must happen here (or deeper).
117
  """
118
- # Import segmentation modules *inside* the GPU function so any CUDA/ORT provider
119
- # decisions happen after the GPU exists. If these libs choose ORT providers,
120
- # do it based on torch.cuda.is_available().
121
  from SegBody import segment_body
122
  from SegCloth import segment_clothing
 
 
123
  try:
124
- import onnxruntime as ort # some seg libs use ORT under the hood
125
- # If ZeroGPU gave us a CUDA device, ORT can try CUDA; else fallback to CPU.
126
- # (If the seg modules create sessions themselves, they should use similar logic.)
127
- if torch.cuda.is_available():
128
- _ = ort.get_device() # just to ensure ORT is importable
129
- else:
130
- # As a defensive fallback, you can force CPU by env (only if needed)
131
  os.environ.setdefault("ORT_DISABLE_CUDA", "1")
132
  except Exception:
133
- # If onnxruntime isn't used, that's fine.
134
  pass
135
 
136
  device = "cuda" if torch.cuda.is_available() else "cpu"
137
  pipe = _get_pipeline(device)
138
 
139
- # --- Preprocess on CPU (cheap ops)
140
  person = person.copy()
141
  clothing = clothing.copy()
142
 
 
143
  person.thumbnail((1024, 1024))
144
- person = divisible_by_8(person)
 
145
 
 
146
  clothing.thumbnail((1024, 1024))
147
- clothing = divisible_by_8(clothing)
148
-
149
- image = squarify_image(person)
150
 
151
- # --- Segmentation (runs after GPU allocation; modules can use GPU if they want)
152
  seg_image, mask_image = segment_body(image, face=False)
153
  seg_cloth = segment_clothing(
154
  clothing,
@@ -158,9 +154,7 @@ def generate(person: Image.Image, clothing: Image.Image) -> Image.Image:
158
  # --- Diffusion
159
  pipe.set_ip_adapter_scale(1.0)
160
  result = pipe(
161
- prompt=(
162
- "photorealistic, perfect body, beautiful skin, realistic skin, natural skin"
163
- ),
164
  negative_prompt=(
165
  "ugly, bad quality, bad anatomy, deformed body, deformed hands, "
166
  "deformed feet, deformed face, deformed clothing, deformed skin, "
@@ -176,19 +170,16 @@ def generate(person: Image.Image, clothing: Image.Image) -> Image.Image:
176
  num_inference_steps=100,
177
  ).images[0]
178
 
179
- # Crop back to original (pre-squared) person dims
180
- final = result.crop((0, 0, person.width, person.height))
181
  return final
182
 
183
- # -----------------------------
184
  # Gradio UI
185
- # -----------------------------
186
  iface = gr.Interface(
187
  fn=generate,
188
- inputs=[
189
- gr.Image(label="Person", type="pil"),
190
- gr.Image(label="Clothing", type="pil"),
191
- ],
192
  outputs=[gr.Image(label="Result")],
193
  title="Fashion Try-On",
194
  description="""
 
1
  # app.py
2
  import os
3
+ import math
4
  import spaces
5
  import gradio as gr
 
6
  import torch
7
+ from PIL import Image, ImageChops
8
  from diffusers import AutoPipelineForInpainting, AutoencoderKL
9
 
10
+ # =============================
11
+ # Helpers (CPU-only; no CUDA)
12
+ # =============================
 
 
13
 
14
+ def _round_up(x: int, m: int = 8) -> int:
15
  return int(math.ceil(x / m) * m)
16
 
17
+ def autocrop_content(img: Image.Image, bg_color=(255, 255, 255), tol: int = 12) -> Image.Image:
18
+ """
19
+ Trim uniform white (or near-white) margins before centering/padding.
20
+ Handles RGBA via alpha bbox; for RGB compares to a solid background.
21
+ """
22
  if img.mode in ("RGBA", "LA"):
23
  alpha = img.split()[-1]
24
  bbox = alpha.getbbox()
 
30
  bbox = mask.getbbox()
31
  return img.crop(bbox) if bbox else img
32
 
33
+ def square_pad_meta(
34
+ img: Image.Image, color: str = "white", multiple: int = 8
35
+ ) -> tuple[Image.Image, int, int, int, int, int]:
36
+ """
37
+ Autocrop -> center-pad to a square whose side is rounded UP to `multiple`.
38
+ Returns (square_img, left, top, orig_w, orig_h, side).
39
+ """
40
+ img = autocrop_content(img, (255, 255, 255), tol=12)
41
+ orig_w, orig_h = img.size
42
+ side = _round_up(max(orig_w, orig_h), multiple)
43
 
44
  bg = Image.new("RGB", (side, side), color=color)
45
+ left = (side - orig_w) // 2
46
+ top = (side - orig_h) // 2
47
+ bg.paste(img, (left, top))
48
+ return bg, left, top, orig_w, orig_h, side
49
 
50
+ def resize_to_multiple(image: Image.Image, m: int = 8) -> Image.Image:
51
+ """
52
+ Resize **up** so width/height are multiples of m (avoids 1012x1012 errors).
53
+ """
54
  w, h = image.size
55
+ nw = _round_up(w, m)
56
+ nh = _round_up(h, m)
 
57
  if (nw, nh) == (w, h):
58
  return image
59
  return image.resize((nw, nh), Image.LANCZOS)
60
 
61
+ # =============================
62
  # Lazy singletons (created inside GPU context)
63
+ # =============================
64
  PIPELINE = None
 
65
 
66
  def _get_pipeline(device: str):
67
  """
68
  Create & cache the diffusers pipeline once we actually have a GPU (ZeroGPU).
69
  No CUDA calls should happen before this is executed.
70
  """
71
+ global PIPELINE
 
72
  if PIPELINE is not None:
 
73
  PIPELINE.to(device)
74
  return PIPELINE
75
 
 
81
  if not ip_adapter_repo:
82
  raise RuntimeError("Missing env var IP_ADAPTER (e.g. 'h94/IP-Adapter').")
83
 
84
+ # Build VAE & pipeline WITHOUT touching CUDA yet.
 
85
  vae = AutoencoderKL.from_pretrained(
86
  "madebyollin/sdxl-vae-fp16-fix",
87
+ torch_dtype=torch.float16,
88
  )
89
 
90
  pipe = AutoPipelineForInpainting.from_pretrained(
 
95
  use_safetensors=True,
96
  )
97
 
98
+ # Attach IP-Adapter weights (no CUDA op yet)
 
99
  pipe.load_ip_adapter(
100
  ip_adapter_repo,
101
  subfolder="sdxl_models",
102
  weight_name="ip-adapter_sdxl.bin",
103
  )
104
 
105
+ # NOW move the whole pipeline to the device ZeroGPU assigned
106
  pipe.to(device)
 
107
  PIPELINE = pipe
 
108
  return PIPELINE
109
 
110
+ # =============================
111
  # Main generate (GPU section)
112
+ # =============================
 
113
  @spaces.GPU(duration=180)
114
  def generate(person: Image.Image, clothing: Image.Image) -> Image.Image:
115
  """
116
  This function is called *after* ZeroGPU allocates a CUDA device.
117
  All CUDA/ONNXRuntime initializations must happen here (or deeper).
118
  """
119
+ # Import segmentation modules here so they initialize after GPU exists.
 
 
120
  from SegBody import segment_body
121
  from SegCloth import segment_clothing
122
+
123
+ # If onnxruntime is used under the hood, ensure it doesn't try CUDA without a GPU.
124
  try:
125
+ import onnxruntime as ort # noqa: F401
126
+ if not torch.cuda.is_available():
 
 
 
 
 
127
  os.environ.setdefault("ORT_DISABLE_CUDA", "1")
128
  except Exception:
 
129
  pass
130
 
131
  device = "cuda" if torch.cuda.is_available() else "cpu"
132
  pipe = _get_pipeline(device)
133
 
134
+ # --- Preprocess (CPU)
135
  person = person.copy()
136
  clothing = clothing.copy()
137
 
138
+ # Keep person within 1024, then square-pad to /8 and remember offsets.
139
  person.thumbnail((1024, 1024))
140
+ square_img, left, top, ow, oh, side = square_pad_meta(person, color="white", multiple=8)
141
+ image = square_img # feed this square to seg & pipeline (already /8-compliant)
142
 
143
+ # Clothing can be smaller; make dimensions /8 to be safe.
144
  clothing.thumbnail((1024, 1024))
145
+ clothing = resize_to_multiple(clothing, 8)
 
 
146
 
147
+ # --- Segmentation (after GPU allocation; modules can use GPU if they choose)
148
  seg_image, mask_image = segment_body(image, face=False)
149
  seg_cloth = segment_clothing(
150
  clothing,
 
154
  # --- Diffusion
155
  pipe.set_ip_adapter_scale(1.0)
156
  result = pipe(
157
+ prompt="photorealistic, perfect body, beautiful skin, realistic skin, natural skin",
 
 
158
  negative_prompt=(
159
  "ugly, bad quality, bad anatomy, deformed body, deformed hands, "
160
  "deformed feet, deformed face, deformed clothing, deformed skin, "
 
170
  num_inference_steps=100,
171
  ).images[0]
172
 
173
+ # Crop back to the original (post-thumbnail) person frame using the paste offsets.
174
+ final = result.crop((left, top, left + ow, top + oh))
175
  return final
176
 
177
+ # =============================
178
  # Gradio UI
179
+ # =============================
180
  iface = gr.Interface(
181
  fn=generate,
182
+ inputs=[gr.Image(label="Person", type="pil"), gr.Image(label="Clothing", type="pil")],
 
 
 
183
  outputs=[gr.Image(label="Result")],
184
  title="Fashion Try-On",
185
  description="""