Ashoka74 commited on
Commit
2d16f46
1 Parent(s): d2f41dd

Update app_3.py

Browse files
Files changed (1) hide show
  1. app_3.py +41 -6
app_3.py CHANGED
@@ -780,6 +780,9 @@ def process_bg(input_fg, input_bg, prompt, image_width, image_height, num_sample
780
 
781
  @torch.inference_mode()
782
  def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
 
 
 
783
  logging.info(f"Input foreground shape: {input_fg.shape}, dtype: {input_fg.dtype}")
784
  results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
785
  logging.info(f"Results shape: {results.shape}, dtype: {results.dtype}")
@@ -993,6 +996,8 @@ def use_orientation(selected_image:gr.SelectData):
993
  @torch.inference_mode
994
  def process_image(input_image, input_text):
995
  """Main processing function for the Gradio interface"""
 
 
996
 
997
  # Initialize configs
998
  API_TOKEN = "9c8c865e10ec1821bea79d9fa9dc8720"
@@ -1003,7 +1008,6 @@ def process_image(input_image, input_text):
1003
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
1004
 
1005
 
1006
-
1007
  # Initialize DDS client
1008
  config = Config(API_TOKEN)
1009
  client = Client(config)
@@ -1013,6 +1017,8 @@ def process_image(input_image, input_text):
1013
  class_name_to_id = {name: id for id, name in enumerate(classes)}
1014
  class_id_to_name = {id: name for name, id in class_name_to_id.items()}
1015
 
 
 
1016
  # Save input image to temp file and get URL
1017
  with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
1018
  cv2.imwrite(tmpfile.name, input_image)
@@ -1078,12 +1084,41 @@ def process_image(input_image, input_text):
1078
 
1079
  # Get original RGB image
1080
  img = input_image.copy()
1081
- H, W, C = img.shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1082
 
1083
- # Create RGBA image
1084
- alpha = np.zeros((H, W, 1), dtype=np.uint8)
1085
- alpha[first_mask] = 255
1086
- rgba = np.dstack((img, alpha)).astype(np.uint8)
 
 
1087
 
1088
  # Crop to mask bounds to minimize image size
1089
  # y_indices, x_indices = np.where(first_mask)
 
780
 
781
  @torch.inference_mode()
782
  def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
783
+ # Convert input foreground from PIL to NumPy array if it's in PIL format
784
+ if isinstance(input_fg, Image.Image):
785
+ input_fg = np.array(input_fg)
786
  logging.info(f"Input foreground shape: {input_fg.shape}, dtype: {input_fg.dtype}")
787
  results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
788
  logging.info(f"Results shape: {results.shape}, dtype: {results.dtype}")
 
996
  @torch.inference_mode
997
  def process_image(input_image, input_text):
998
  """Main processing function for the Gradio interface"""
999
+ if isinstance(input_image, Image.Image):
1000
+ input_image = np.array(input_image)
1001
 
1002
  # Initialize configs
1003
  API_TOKEN = "9c8c865e10ec1821bea79d9fa9dc8720"
 
1008
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
1009
 
1010
 
 
1011
  # Initialize DDS client
1012
  config = Config(API_TOKEN)
1013
  client = Client(config)
 
1017
  class_name_to_id = {name: id for id, name in enumerate(classes)}
1018
  class_id_to_name = {id: name for name, id in class_name_to_id.items()}
1019
 
1020
+
1021
+
1022
  # Save input image to temp file and get URL
1023
  with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
1024
  cv2.imwrite(tmpfile.name, input_image)
 
1084
 
1085
  # Get original RGB image
1086
  img = input_image.copy()
1087
+
1088
+ # rgba = preprocess_image(img)
1089
+
1090
+ alpha = img[..., 3] > 0
1091
+ H, W = alpha.shape
1092
+ # get the bounding box of alpha
1093
+ y, x = np.where(alpha)
1094
+ y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
1095
+ x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
1096
+ image_center = img[y0:y1, x0:x1]
1097
+ # resize the longer side to H * 0.9
1098
+ H, W, _ = image_center.shape
1099
+ if H > W:
1100
+ W = int(W * (height * 0.9) / H)
1101
+ H = int(height * 0.9)
1102
+ else:
1103
+ H = int(H * (width * 0.9) / W)
1104
+ W = int(width * 0.9)
1105
+ image_center = np.array(Image.fromarray(image_center).resize((W, H)))
1106
+ # pad to H, W
1107
+ start_h = (height - H) // 2
1108
+ start_w = (width - W) // 2
1109
+ image = np.zeros((height, width, 4), dtype=np.uint8)
1110
+ image[start_h : start_h + H, start_w : start_w + W] = image_center
1111
+ image = image.astype(np.float32) / 255.0
1112
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
1113
+ image = (image * 255).clip(0, 255).astype(np.uint8)
1114
+ image = Image.fromarray(image)
1115
 
1116
+ # H, W, C = img.shape
1117
+
1118
+ # # Create RGBA image
1119
+ # alpha = np.zeros((H, W, 1), dtype=np.uint8)
1120
+ # alpha[first_mask] = 255
1121
+ # rgba = np.dstack((img, alpha)).astype(np.uint8)
1122
 
1123
  # Crop to mask bounds to minimize image size
1124
  # y_indices, x_indices = np.where(first_mask)