Ashoka74 commited on
Commit
ca60a4a
1 Parent(s): ca892e4

Update app_3.py

Browse files
Files changed (1) hide show
  1. app_3.py +41 -25
app_3.py CHANGED
@@ -366,8 +366,8 @@ def encode_prompt_pair(positive_prompt, negative_prompt):
366
 
367
  return c, uc
368
 
369
- @spaces.GPU(duration=60)
370
- @torch.inference_mode()
371
  @spaces.GPU(duration=60)
372
  @torch.inference_mode()
373
  def infer(
@@ -467,29 +467,44 @@ def resize_without_crop(image, target_width, target_height):
467
 
468
  @spaces.GPU(duration=60)
469
  @torch.inference_mode()
470
- def run_rmbg(img, sigma=0.0):
471
- # Convert RGBA to RGB if needed
472
- if img.shape[-1] == 4:
473
- # Use white background for alpha composition
474
- alpha = img[..., 3:] / 255.0
475
- rgb = img[..., :3]
476
- white_bg = np.ones_like(rgb) * 255
477
- img = (rgb * alpha + white_bg * (1 - alpha)).astype(np.uint8)
478
 
479
- H, W, C = img.shape
480
- assert C == 3
481
- k = (256.0 / float(H * W)) ** 0.5
482
- feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
483
- feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
484
- alpha = rmbg(feed)[0][0]
485
- alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
486
- alpha = alpha.movedim(1, -1)[0]
487
- alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
488
 
489
- # Create RGBA image
490
- rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
491
- result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
492
- return result.clip(0, 255).astype(np.uint8), rgba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
 
494
  @spaces.GPU(duration=60)
495
  @torch.inference_mode()
@@ -615,9 +630,10 @@ def extract_foreground(image):
615
  if image is None:
616
  return None, gr.update(visible=True), gr.update(visible=True)
617
  logging.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
618
- result, rgba = run_rmbg(image)
 
619
  logging.info(f"Result shape: {result.shape}, dtype: {result.dtype}")
620
- logging.info(f"RGBA shape: {rgba.shape}, dtype: {rgba.dtype}")
621
  return result, gr.update(visible=True), gr.update(visible=True)
622
 
623
  def update_extracted_fg_height(selected_image: gr.SelectData):
 
366
 
367
  return c, uc
368
 
369
+ # @spaces.GPU(duration=60)
370
+ # @torch.inference_mode()
371
  @spaces.GPU(duration=60)
372
  @torch.inference_mode()
373
  def infer(
 
467
 
468
  @spaces.GPU(duration=60)
469
  @torch.inference_mode()
470
+ # def run_rmbg(img, sigma=0.0):
471
+ # # Convert RGBA to RGB if needed
472
+ # if img.shape[-1] == 4:
473
+ # # Use white background for alpha composition
474
+ # alpha = img[..., 3:] / 255.0
475
+ # rgb = img[..., :3]
476
+ # white_bg = np.ones_like(rgb) * 255
477
+ # img = (rgb * alpha + white_bg * (1 - alpha)).astype(np.uint8)
478
 
479
+ # H, W, C = img.shape
480
+ # assert C == 3
481
+ # k = (256.0 / float(H * W)) ** 0.5
482
+ # feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
483
+ # feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
484
+ # alpha = rmbg(feed)[0][0]
485
+ # alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
486
+ # alpha = alpha.movedim(1, -1)[0]
487
+ # alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
488
 
489
+ # # Create RGBA image
490
+ # rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
491
+ # result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
492
+ # return result.clip(0, 255).astype(np.uint8), rgba
493
+
494
+ @spaces.GPU
495
+ @torch.inference_mode()
496
+ def run_rmbg(image):
497
+ image_size = image.size
498
+ input_images = transform_image(image).unsqueeze(0).to("cuda")
499
+ # Prediction
500
+ with torch.no_grad():
501
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
502
+ pred = preds[0].squeeze()
503
+ pred_pil = transforms.ToPILImage()(pred)
504
+ mask = pred_pil.resize(image_size)
505
+ image.putalpha(mask)
506
+ return image
507
+
508
 
509
  @spaces.GPU(duration=60)
510
  @torch.inference_mode()
 
630
  if image is None:
631
  return None, gr.update(visible=True), gr.update(visible=True)
632
  logging.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
633
+ #result, rgba = run_rmbg(image)
634
+ result = run_rmbg(image)
635
  logging.info(f"Result shape: {result.shape}, dtype: {result.dtype}")
636
+ #logging.info(f"RGBA shape: {rgba.shape}, dtype: {rgba.dtype}")
637
  return result, gr.update(visible=True), gr.update(visible=True)
638
 
639
  def update_extracted_fg_height(selected_image: gr.SelectData):