Spaces:
Running
on
Zero
Running
on
Zero
Update app_3.py
Browse files
app_3.py
CHANGED
@@ -780,6 +780,9 @@ def process_bg(input_fg, input_bg, prompt, image_width, image_height, num_sample
|
|
780 |
|
781 |
@torch.inference_mode()
|
782 |
def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
|
|
|
|
|
|
|
783 |
logging.info(f"Input foreground shape: {input_fg.shape}, dtype: {input_fg.dtype}")
|
784 |
results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
|
785 |
logging.info(f"Results shape: {results.shape}, dtype: {results.dtype}")
|
@@ -993,6 +996,8 @@ def use_orientation(selected_image:gr.SelectData):
|
|
993 |
@torch.inference_mode
|
994 |
def process_image(input_image, input_text):
|
995 |
"""Main processing function for the Gradio interface"""
|
|
|
|
|
996 |
|
997 |
# Initialize configs
|
998 |
API_TOKEN = "9c8c865e10ec1821bea79d9fa9dc8720"
|
@@ -1003,7 +1008,6 @@ def process_image(input_image, input_text):
|
|
1003 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
1004 |
|
1005 |
|
1006 |
-
|
1007 |
# Initialize DDS client
|
1008 |
config = Config(API_TOKEN)
|
1009 |
client = Client(config)
|
@@ -1013,6 +1017,8 @@ def process_image(input_image, input_text):
|
|
1013 |
class_name_to_id = {name: id for id, name in enumerate(classes)}
|
1014 |
class_id_to_name = {id: name for name, id in class_name_to_id.items()}
|
1015 |
|
|
|
|
|
1016 |
# Save input image to temp file and get URL
|
1017 |
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
|
1018 |
cv2.imwrite(tmpfile.name, input_image)
|
@@ -1078,12 +1084,41 @@ def process_image(input_image, input_text):
|
|
1078 |
|
1079 |
# Get original RGB image
|
1080 |
img = input_image.copy()
|
1081 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1082 |
|
1083 |
-
#
|
1084 |
-
|
1085 |
-
|
1086 |
-
|
|
|
|
|
1087 |
|
1088 |
# Crop to mask bounds to minimize image size
|
1089 |
# y_indices, x_indices = np.where(first_mask)
|
|
|
780 |
|
781 |
@torch.inference_mode()
|
782 |
def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
|
783 |
+
# Convert input foreground from PIL to NumPy array if it's in PIL format
|
784 |
+
if isinstance(input_fg, Image.Image):
|
785 |
+
input_fg = np.array(input_fg)
|
786 |
logging.info(f"Input foreground shape: {input_fg.shape}, dtype: {input_fg.dtype}")
|
787 |
results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
|
788 |
logging.info(f"Results shape: {results.shape}, dtype: {results.dtype}")
|
|
|
996 |
@torch.inference_mode
|
997 |
def process_image(input_image, input_text):
|
998 |
"""Main processing function for the Gradio interface"""
|
999 |
+
if isinstance(input_image, Image.Image):
|
1000 |
+
input_image = np.array(input_image)
|
1001 |
|
1002 |
# Initialize configs
|
1003 |
API_TOKEN = "9c8c865e10ec1821bea79d9fa9dc8720"
|
|
|
1008 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
1009 |
|
1010 |
|
|
|
1011 |
# Initialize DDS client
|
1012 |
config = Config(API_TOKEN)
|
1013 |
client = Client(config)
|
|
|
1017 |
class_name_to_id = {name: id for id, name in enumerate(classes)}
|
1018 |
class_id_to_name = {id: name for name, id in class_name_to_id.items()}
|
1019 |
|
1020 |
+
|
1021 |
+
|
1022 |
# Save input image to temp file and get URL
|
1023 |
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
|
1024 |
cv2.imwrite(tmpfile.name, input_image)
|
|
|
1084 |
|
1085 |
# Get original RGB image
|
1086 |
img = input_image.copy()
|
1087 |
+
|
1088 |
+
# rgba = preprocess_image(img)
|
1089 |
+
|
1090 |
+
alpha = img[..., 3] > 0
|
1091 |
+
H, W = alpha.shape
|
1092 |
+
# get the bounding box of alpha
|
1093 |
+
y, x = np.where(alpha)
|
1094 |
+
y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
|
1095 |
+
x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
|
1096 |
+
image_center = img[y0:y1, x0:x1]
|
1097 |
+
# resize the longer side to H * 0.9
|
1098 |
+
H, W, _ = image_center.shape
|
1099 |
+
if H > W:
|
1100 |
+
W = int(W * (height * 0.9) / H)
|
1101 |
+
H = int(height * 0.9)
|
1102 |
+
else:
|
1103 |
+
H = int(H * (width * 0.9) / W)
|
1104 |
+
W = int(width * 0.9)
|
1105 |
+
image_center = np.array(Image.fromarray(image_center).resize((W, H)))
|
1106 |
+
# pad to H, W
|
1107 |
+
start_h = (height - H) // 2
|
1108 |
+
start_w = (width - W) // 2
|
1109 |
+
image = np.zeros((height, width, 4), dtype=np.uint8)
|
1110 |
+
image[start_h : start_h + H, start_w : start_w + W] = image_center
|
1111 |
+
image = image.astype(np.float32) / 255.0
|
1112 |
+
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
|
1113 |
+
image = (image * 255).clip(0, 255).astype(np.uint8)
|
1114 |
+
image = Image.fromarray(image)
|
1115 |
|
1116 |
+
# H, W, C = img.shape
|
1117 |
+
|
1118 |
+
# # Create RGBA image
|
1119 |
+
# alpha = np.zeros((H, W, 1), dtype=np.uint8)
|
1120 |
+
# alpha[first_mask] = 255
|
1121 |
+
# rgba = np.dstack((img, alpha)).astype(np.uint8)
|
1122 |
|
1123 |
# Crop to mask bounds to minimize image size
|
1124 |
# y_indices, x_indices = np.where(first_mask)
|