Ashoka74 commited on
Commit
0ec586b
ยท
verified ยท
1 Parent(s): 39b54e9

Update app_merged.py

Browse files
Files changed (1) hide show
  1. app_merged.py +12 -27
app_merged.py CHANGED
@@ -599,8 +599,7 @@ def encode_prompt_pair(positive_prompt, negative_prompt):
599
 
600
  return c, uc
601
 
602
- # @spaces.GPU(duration=60)
603
- # @torch.inference_mode()
604
  @spaces.GPU(duration=60)
605
  @torch.inference_mode()
606
  def infer(
@@ -616,6 +615,8 @@ def infer(
616
  progress=gr.Progress(track_tqdm=True),
617
  ):
618
  #logging.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
 
 
619
 
620
  # Convert input to PIL if needed
621
  if isinstance(image, np.ndarray):
@@ -698,35 +699,11 @@ def resize_without_crop(image, target_width, target_height):
698
  resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
699
  return np.array(resized_image)
700
 
701
- # @spaces.GPU(duration=60)
702
- # @torch.inference_mode()
703
- # def run_rmbg(img, sigma=0.0):
704
- # # Convert RGBA to RGB if needed
705
- # if img.shape[-1] == 4:
706
- # # Use white background for alpha composition
707
- # alpha = img[..., 3:] / 255.0
708
- # rgb = img[..., :3]
709
- # white_bg = np.ones_like(rgb) * 255
710
- # img = (rgb * alpha + white_bg * (1 - alpha)).astype(np.uint8)
711
-
712
- # H, W, C = img.shape
713
- # assert C == 3
714
- # k = (256.0 / float(H * W)) ** 0.5
715
- # feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
716
- # feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
717
- # alpha = rmbg(feed)[0][0]
718
- # alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
719
- # alpha = alpha.movedim(1, -1)[0]
720
- # alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
721
-
722
- # # Create RGBA image
723
- # rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
724
- # result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
725
- # return result.clip(0, 255).astype(np.uint8), rgba
726
 
727
  @spaces.GPU
728
  @torch.inference_mode()
729
  def run_rmbg(image):
 
730
  image_size = image.size
731
  input_images = transform_image(image).unsqueeze(0).to("cuda")
732
  # Prediction
@@ -893,6 +870,7 @@ def process(input_fg, prompt, image_width, image_height, num_samples, seed, step
893
  def extract_foreground(image):
894
  if image is None:
895
  return None, gr.update(visible=True), gr.update(visible=True)
 
896
  #logging.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
897
  #result, rgba = run_rmbg(image)
898
  result = run_rmbg(image)
@@ -1077,6 +1055,9 @@ def use_orientation(selected_image:gr.SelectData):
1077
  def process_image(input_image, input_text):
1078
  """Main processing function for the Gradio interface"""
1079
 
 
 
 
1080
  if isinstance(input_image, Image.Image):
1081
  input_image = np.array(input_image)
1082
 
@@ -1443,6 +1424,8 @@ imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
1443
  @spaces.GPU
1444
  def generate_image(prompt, structure_image, style_image, depth_strength=15, style_strength=0.5, progress=gr.Progress(track_tqdm=True)) -> str:
1445
  """Main generation function that processes inputs and returns the path to the generated image."""
 
 
1446
  with torch.inference_mode():
1447
  # Set up CLIP
1448
  clip_switch = cr_clip_input_switch.switch(
@@ -1571,6 +1554,8 @@ def generate_image(prompt, structure_image, style_image, depth_strength=15, styl
1571
  images=get_value_at_index(decoded, 0),
1572
  )
1573
  saved_path = f"output/{saved['ui']['images'][0]['filename']}"
 
 
1574
  return saved_path
1575
 
1576
  # Create Gradio interface
 
599
 
600
  return c, uc
601
 
602
+
 
603
  @spaces.GPU(duration=60)
604
  @torch.inference_mode()
605
  def infer(
 
615
  progress=gr.Progress(track_tqdm=True),
616
  ):
617
  #logging.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
618
+
619
+ clear_memory()
620
 
621
  # Convert input to PIL if needed
622
  if isinstance(image, np.ndarray):
 
699
  resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
700
  return np.array(resized_image)
701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
702
 
703
  @spaces.GPU
704
  @torch.inference_mode()
705
  def run_rmbg(image):
706
+ clear_memory()
707
  image_size = image.size
708
  input_images = transform_image(image).unsqueeze(0).to("cuda")
709
  # Prediction
 
870
  def extract_foreground(image):
871
  if image is None:
872
  return None, gr.update(visible=True), gr.update(visible=True)
873
+ clear_memory()
874
  #logging.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
875
  #result, rgba = run_rmbg(image)
876
  result = run_rmbg(image)
 
1055
  def process_image(input_image, input_text):
1056
  """Main processing function for the Gradio interface"""
1057
 
1058
+
1059
+ clear_memory()
1060
+
1061
  if isinstance(input_image, Image.Image):
1062
  input_image = np.array(input_image)
1063
 
 
1424
  @spaces.GPU
1425
  def generate_image(prompt, structure_image, style_image, depth_strength=15, style_strength=0.5, progress=gr.Progress(track_tqdm=True)) -> str:
1426
  """Main generation function that processes inputs and returns the path to the generated image."""
1427
+
1428
+ clear_memory()
1429
  with torch.inference_mode():
1430
  # Set up CLIP
1431
  clip_switch = cr_clip_input_switch.switch(
 
1554
  images=get_value_at_index(decoded, 0),
1555
  )
1556
  saved_path = f"output/{saved['ui']['images'][0]['filename']}"
1557
+
1558
+ clear_memory()
1559
  return saved_path
1560
 
1561
  # Create Gradio interface