Ashoka74 commited on
Commit
57dcba1
1 Parent(s): 48416d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -120,6 +120,7 @@ unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
120
  # Load model directly
121
  from transformers import AutoModelForImageSegmentation
122
  rmbg = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
 
123
 
124
  model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
125
  model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
@@ -128,6 +129,7 @@ model.eval()
128
 
129
  # Change UNet
130
 
 
131
  with torch.no_grad():
132
  new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
133
  new_conv_in.weight.zero_()
@@ -314,7 +316,7 @@ def encode_prompt_pair(positive_prompt, negative_prompt):
314
 
315
  return c, uc
316
 
317
-
318
  @torch.inference_mode()
319
  def pytorch2numpy(imgs, quant=True):
320
  results = []
@@ -331,7 +333,7 @@ def pytorch2numpy(imgs, quant=True):
331
  results.append(y)
332
  return results
333
 
334
-
335
  @torch.inference_mode()
336
  def numpy2pytorch(imgs):
337
  h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
@@ -359,7 +361,7 @@ def resize_without_crop(image, target_width, target_height):
359
  resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
360
  return np.array(resized_image)
361
 
362
-
363
  @torch.inference_mode()
364
  def run_rmbg(img, sigma=0.0):
365
  # Convert RGBA to RGB if needed
@@ -384,6 +386,8 @@ def run_rmbg(img, sigma=0.0):
384
  rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
385
  result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
386
  return result.clip(0, 255).astype(np.uint8), rgba
 
 
387
  @torch.inference_mode()
388
  def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
389
  clear_memory()
 
120
  # Load model directly
121
  from transformers import AutoModelForImageSegmentation
122
  rmbg = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
123
+ rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
124
 
125
  model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
126
  model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
 
129
 
130
  # Change UNet
131
 
132
+
133
  with torch.no_grad():
134
  new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
135
  new_conv_in.weight.zero_()
 
316
 
317
  return c, uc
318
 
319
+ @spaces.GPU(duration=60)
320
  @torch.inference_mode()
321
  def pytorch2numpy(imgs, quant=True):
322
  results = []
 
333
  results.append(y)
334
  return results
335
 
336
+ @spaces.GPU(duration=60)
337
  @torch.inference_mode()
338
  def numpy2pytorch(imgs):
339
  h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
 
361
  resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
362
  return np.array(resized_image)
363
 
364
+ @spaces.GPU(duration=60)
365
  @torch.inference_mode()
366
  def run_rmbg(img, sigma=0.0):
367
  # Convert RGBA to RGB if needed
 
386
  rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
387
  result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
388
  return result.clip(0, 255).astype(np.uint8), rgba
389
+
390
+ @spaces.GPU(duration=60)
391
  @torch.inference_mode()
392
  def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
393
  clear_memory()