Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -823,10 +823,10 @@ def compress_image(image):
|
|
823 |
return compressed_img
|
824 |
|
825 |
@spaces.GPU(duration=60)
|
826 |
-
@torch.inference_mode
|
827 |
def process_image(input_image, input_text):
|
828 |
"""Main processing function for the Gradio interface"""
|
829 |
-
|
830 |
# Initialize configs
|
831 |
API_TOKEN = "9c8c865e10ec1821bea79d9fa9dc8720"
|
832 |
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
|
@@ -835,6 +835,8 @@ def process_image(input_image, input_text):
|
|
835 |
OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
|
836 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
837 |
|
|
|
|
|
838 |
# Initialize DDS client
|
839 |
config = Config(API_TOKEN)
|
840 |
client = Client(config)
|
@@ -850,102 +852,215 @@ def process_image(input_image, input_text):
|
|
850 |
image_url = client.upload_file(tmpfile.name)
|
851 |
os.remove(tmpfile.name)
|
852 |
|
853 |
-
# Run DINO-X detection
|
854 |
-
task = DinoxTask(
|
855 |
-
image_url=image_url,
|
856 |
-
prompts=[TextPrompt(text=input_text)]
|
857 |
-
)
|
858 |
-
client.run_task(task)
|
859 |
-
result = task.result
|
860 |
-
objects = result.objects
|
861 |
-
|
862 |
# Process detection results
|
863 |
input_boxes = []
|
|
|
864 |
confidences = []
|
865 |
class_names = []
|
866 |
class_ids = []
|
867 |
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
-
|
872 |
-
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
|
883 |
-
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
-
|
893 |
-
|
894 |
-
|
895 |
-
|
896 |
-
|
897 |
-
|
898 |
-
|
899 |
-
|
|
|
|
|
|
|
900 |
|
901 |
-
|
902 |
-
|
903 |
-
|
904 |
|
905 |
-
|
906 |
-
|
907 |
-
|
908 |
-
|
909 |
-
)
|
910 |
|
911 |
-
|
912 |
-
|
913 |
-
|
914 |
-
|
915 |
-
|
916 |
-
|
917 |
-
|
918 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
919 |
|
920 |
-
|
921 |
-
|
922 |
-
#
|
923 |
-
|
|
|
|
|
|
|
|
|
924 |
|
925 |
-
|
926 |
-
|
927 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
928 |
|
929 |
-
|
930 |
-
|
931 |
-
|
932 |
-
|
|
|
933 |
|
934 |
-
|
935 |
-
|
936 |
-
|
937 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
938 |
|
939 |
-
#
|
940 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
941 |
|
942 |
-
|
943 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
944 |
|
945 |
-
|
|
|
|
|
946 |
|
947 |
-
|
948 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
949 |
|
950 |
block = gr.Blocks().queue()
|
951 |
with block:
|
@@ -958,16 +1073,15 @@ with block:
|
|
958 |
input_fg = gr.Image(type="numpy", label="Image", height=480)
|
959 |
with gr.Row():
|
960 |
with gr.Group():
|
961 |
-
|
962 |
-
|
963 |
-
|
964 |
-
|
965 |
-
|
966 |
-
|
967 |
-
|
968 |
extract_button = gr.Button(value="Remove Background")
|
969 |
with gr.Row():
|
970 |
-
|
971 |
extracted_fg = gr.Image(type="numpy", label="Extracted Foreground", height=480)
|
972 |
|
973 |
# output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
|
@@ -1028,11 +1142,11 @@ with block:
|
|
1028 |
relight_button.click(fn=process_relight, inputs=ips, outputs=[extracted_fg, result_gallery])
|
1029 |
example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
|
1030 |
example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
|
1031 |
-
|
1032 |
-
|
1033 |
-
|
1034 |
-
|
1035 |
-
|
1036 |
extract_button.click(
|
1037 |
fn=extract_foreground,
|
1038 |
inputs=[input_fg],
|
@@ -1169,11 +1283,11 @@ with block:
|
|
1169 |
outputs=[extracted_fg, x_slider, y_slider]
|
1170 |
)
|
1171 |
|
1172 |
-
|
1173 |
-
|
1174 |
-
|
1175 |
-
|
1176 |
-
|
1177 |
|
1178 |
get_depth_button.click(
|
1179 |
fn=get_depth,
|
|
|
823 |
return compressed_img
|
824 |
|
825 |
@spaces.GPU(duration=60)
|
826 |
+
@torch.inference_mode
|
827 |
def process_image(input_image, input_text):
|
828 |
"""Main processing function for the Gradio interface"""
|
829 |
+
|
830 |
# Initialize configs
|
831 |
API_TOKEN = "9c8c865e10ec1821bea79d9fa9dc8720"
|
832 |
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
|
|
|
835 |
OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
|
836 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
837 |
|
838 |
+
|
839 |
+
|
840 |
# Initialize DDS client
|
841 |
config = Config(API_TOKEN)
|
842 |
client = Client(config)
|
|
|
852 |
image_url = client.upload_file(tmpfile.name)
|
853 |
os.remove(tmpfile.name)
|
854 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
855 |
# Process detection results
|
856 |
input_boxes = []
|
857 |
+
masks = []
|
858 |
confidences = []
|
859 |
class_names = []
|
860 |
class_ids = []
|
861 |
|
862 |
+
if len(input_text) == 0:
|
863 |
+
task = DinoxTask(
|
864 |
+
image_url=image_url,
|
865 |
+
prompts=[TextPrompt(text="<prompt_free>")],
|
866 |
+
# targets=[DetectionTarget.BBox, DetectionTarget.Mask]
|
867 |
+
)
|
868 |
+
|
869 |
+
client.run_task(task)
|
870 |
+
predictions = task.result.objects
|
871 |
+
classes = [pred.category for pred in predictions]
|
872 |
+
classes = list(set(classes))
|
873 |
+
class_name_to_id = {name: id for id, name in enumerate(classes)}
|
874 |
+
class_id_to_name = {id: name for name, id in class_name_to_id.items()}
|
875 |
+
|
876 |
+
for idx, obj in enumerate(predictions):
|
877 |
+
input_boxes.append(obj.bbox)
|
878 |
+
masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
|
879 |
+
confidences.append(obj.score)
|
880 |
+
cls_name = obj.category.lower().strip()
|
881 |
+
class_names.append(cls_name)
|
882 |
+
class_ids.append(class_name_to_id[cls_name])
|
883 |
+
|
884 |
+
boxes = np.array(input_boxes)
|
885 |
+
masks = np.array(masks)
|
886 |
+
class_ids = np.array(class_ids)
|
887 |
+
labels = [
|
888 |
+
f"{class_name} {confidence:.2f}"
|
889 |
+
for class_name, confidence
|
890 |
+
in zip(class_names, confidences)
|
891 |
+
]
|
892 |
+
detections = sv.Detections(
|
893 |
+
xyxy=boxes,
|
894 |
+
mask=masks.astype(bool),
|
895 |
+
class_id=class_ids
|
896 |
+
)
|
897 |
|
898 |
+
box_annotator = sv.BoxAnnotator()
|
899 |
+
label_annotator = sv.LabelAnnotator()
|
900 |
+
mask_annotator = sv.MaskAnnotator()
|
901 |
|
902 |
+
annotated_frame = input_image.copy()
|
903 |
+
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
|
904 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
905 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
|
|
906 |
|
907 |
+
# Create transparent mask for first detected object
|
908 |
+
if len(detections) > 0:
|
909 |
+
# Get first mask
|
910 |
+
first_mask = detections.mask[0]
|
911 |
+
|
912 |
+
# Get original RGB image
|
913 |
+
img = input_image.copy()
|
914 |
+
H, W, C = img.shape
|
915 |
+
|
916 |
+
# Create RGBA image
|
917 |
+
alpha = np.zeros((H, W, 1), dtype=np.uint8)
|
918 |
+
alpha[first_mask] = 255
|
919 |
+
rgba = np.dstack((img, alpha)).astype(np.uint8)
|
920 |
+
|
921 |
+
# Crop to mask bounds to minimize image size
|
922 |
+
y_indices, x_indices = np.where(first_mask)
|
923 |
+
y_min, y_max = y_indices.min(), y_indices.max()
|
924 |
+
x_min, x_max = x_indices.min(), x_indices.max()
|
925 |
+
|
926 |
+
# Crop the RGBA image
|
927 |
+
cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
|
928 |
+
|
929 |
+
# Set extracted foreground for mask mover
|
930 |
+
mask_mover.set_extracted_fg(cropped_rgba)
|
931 |
+
|
932 |
+
return annotated_frame, cropped_rgba, gr.update(visible=True), gr.update(visible=True)
|
933 |
|
934 |
+
|
935 |
+
else:
|
936 |
+
# Run DINO-X detection
|
937 |
+
task = DinoxTask(
|
938 |
+
image_url=image_url,
|
939 |
+
prompts=[TextPrompt(text=input_text)],
|
940 |
+
targets=[DetectionTarget.BBox, DetectionTarget.Mask]
|
941 |
+
)
|
942 |
|
943 |
+
client.run_task(task)
|
944 |
+
result = task.result
|
945 |
+
objects = result.objects
|
946 |
+
|
947 |
+
|
948 |
+
|
949 |
+
# for obj in objects:
|
950 |
+
# input_boxes.append(obj.bbox)
|
951 |
+
# confidences.append(obj.score)
|
952 |
+
# cls_name = obj.category.lower().strip()
|
953 |
+
# class_names.append(cls_name)
|
954 |
+
# class_ids.append(class_name_to_id[cls_name])
|
955 |
+
|
956 |
+
# input_boxes = np.array(input_boxes)
|
957 |
+
# class_ids = np.array(class_ids)
|
958 |
+
|
959 |
+
predictions = task.result.objects
|
960 |
+
classes = [x.strip().lower() for x in input_text.split('.') if x]
|
961 |
+
class_name_to_id = {name: id for id, name in enumerate(classes)}
|
962 |
+
class_id_to_name = {id: name for name, id in class_name_to_id.items()}
|
963 |
|
964 |
+
boxes = []
|
965 |
+
masks = []
|
966 |
+
confidences = []
|
967 |
+
class_names = []
|
968 |
+
class_ids = []
|
969 |
|
970 |
+
for idx, obj in enumerate(predictions):
|
971 |
+
boxes.append(obj.bbox)
|
972 |
+
masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
|
973 |
+
confidences.append(obj.score)
|
974 |
+
cls_name = obj.category.lower().strip()
|
975 |
+
class_names.append(cls_name)
|
976 |
+
class_ids.append(class_name_to_id[cls_name])
|
977 |
+
|
978 |
+
boxes = np.array(boxes)
|
979 |
+
masks = np.array(masks)
|
980 |
+
class_ids = np.array(class_ids)
|
981 |
+
labels = [
|
982 |
+
f"{class_name} {confidence:.2f}"
|
983 |
+
for class_name, confidence
|
984 |
+
in zip(class_names, confidences)
|
985 |
+
]
|
986 |
|
987 |
+
# Initialize SAM2
|
988 |
+
# torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
989 |
+
# if torch.cuda.get_device_properties(0).major >= 8:
|
990 |
+
# torch.backends.cuda.matmul.allow_tf32 = True
|
991 |
+
# torch.backends.cudnn.allow_tf32 = True
|
992 |
+
|
993 |
+
# sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=DEVICE)
|
994 |
+
# sam2_predictor = SAM2ImagePredictor(sam2_model)
|
995 |
+
# sam2_predictor.set_image(input_image)
|
996 |
+
|
997 |
+
# sam2_predictor = run_sam_inference(SAM_IMAGE_MODEL, input_image, detections)
|
998 |
+
|
999 |
+
|
1000 |
+
# Get masks from SAM2
|
1001 |
+
# masks, scores, logits = sam2_predictor.predict(
|
1002 |
+
# point_coords=None,
|
1003 |
+
# point_labels=None,
|
1004 |
+
# box=input_boxes,
|
1005 |
+
# multimask_output=False,
|
1006 |
+
# )
|
1007 |
|
1008 |
+
if masks.ndim == 4:
|
1009 |
+
masks = masks.squeeze(1)
|
1010 |
+
|
1011 |
+
# Create visualization
|
1012 |
+
# labels = [f"{class_name} {confidence:.2f}"
|
1013 |
+
# for class_name, confidence in zip(class_names, confidences)]
|
1014 |
+
|
1015 |
+
# detections = sv.Detections(
|
1016 |
+
# xyxy=input_boxes,
|
1017 |
+
# mask=masks.astype(bool),
|
1018 |
+
# class_id=class_ids
|
1019 |
+
# )
|
1020 |
+
|
1021 |
+
detections = sv.Detections(
|
1022 |
+
xyxy = boxes,
|
1023 |
+
mask = masks.astype(bool),
|
1024 |
+
class_id = class_ids,
|
1025 |
+
)
|
1026 |
|
1027 |
+
box_annotator = sv.BoxAnnotator()
|
1028 |
+
label_annotator = sv.LabelAnnotator()
|
1029 |
+
mask_annotator = sv.MaskAnnotator()
|
1030 |
|
1031 |
+
annotated_frame = input_image.copy()
|
1032 |
+
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
|
1033 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
1034 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
1035 |
+
|
1036 |
+
# Create transparent mask for first detected object
|
1037 |
+
if len(detections) > 0:
|
1038 |
+
# Get first mask
|
1039 |
+
first_mask = detections.mask[0]
|
1040 |
+
|
1041 |
+
# Get original RGB image
|
1042 |
+
img = input_image.copy()
|
1043 |
+
H, W, C = img.shape
|
1044 |
+
|
1045 |
+
# Create RGBA image
|
1046 |
+
alpha = np.zeros((H, W, 1), dtype=np.uint8)
|
1047 |
+
alpha[first_mask] = 255
|
1048 |
+
rgba = np.dstack((img, alpha)).astype(np.uint8)
|
1049 |
+
|
1050 |
+
# Crop to mask bounds to minimize image size
|
1051 |
+
y_indices, x_indices = np.where(first_mask)
|
1052 |
+
y_min, y_max = y_indices.min(), y_indices.max()
|
1053 |
+
x_min, x_max = x_indices.min(), x_indices.max()
|
1054 |
+
|
1055 |
+
# Crop the RGBA image
|
1056 |
+
cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
|
1057 |
+
|
1058 |
+
# Set extracted foreground for mask mover
|
1059 |
+
mask_mover.set_extracted_fg(cropped_rgba)
|
1060 |
+
|
1061 |
+
return annotated_frame, cropped_rgba, gr.update(visible=True), gr.update(visible=True)
|
1062 |
+
return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
|
1063 |
+
|
1064 |
|
1065 |
block = gr.Blocks().queue()
|
1066 |
with block:
|
|
|
1073 |
input_fg = gr.Image(type="numpy", label="Image", height=480)
|
1074 |
with gr.Row():
|
1075 |
with gr.Group():
|
1076 |
+
find_objects_button = gr.Button(value="(Option 1) Segment Object from text")
|
1077 |
+
text_prompt = gr.Textbox(
|
1078 |
+
label="Text Prompt",
|
1079 |
+
placeholder="Enter object classes separated by periods (e.g. 'car . person .'), leave empty to get all objects",
|
1080 |
+
value=""
|
1081 |
+
)
|
|
|
1082 |
extract_button = gr.Button(value="Remove Background")
|
1083 |
with gr.Row():
|
1084 |
+
extracted_objects = gr.Image(type="numpy", label="Extracted Foreground", height=480)
|
1085 |
extracted_fg = gr.Image(type="numpy", label="Extracted Foreground", height=480)
|
1086 |
|
1087 |
# output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
|
|
|
1142 |
relight_button.click(fn=process_relight, inputs=ips, outputs=[extracted_fg, result_gallery])
|
1143 |
example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
|
1144 |
example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
|
1145 |
+
find_objects_button.click(
|
1146 |
+
fn=process_image,
|
1147 |
+
inputs=[input_fg, text_prompt],
|
1148 |
+
outputs=[extracted_objects, extracted_fg]
|
1149 |
+
)
|
1150 |
extract_button.click(
|
1151 |
fn=extract_foreground,
|
1152 |
inputs=[input_fg],
|
|
|
1283 |
outputs=[extracted_fg, x_slider, y_slider]
|
1284 |
)
|
1285 |
|
1286 |
+
find_objects_button.click(
|
1287 |
+
fn=process_image,
|
1288 |
+
inputs=[input_image, text_prompt],
|
1289 |
+
outputs=[extracted_objects, extracted_fg, x_slider, y_slider]
|
1290 |
+
)
|
1291 |
|
1292 |
get_depth_button.click(
|
1293 |
fn=get_depth,
|