Ashoka74 commited on
Commit
791131a
1 Parent(s): c280274

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -99
app.py CHANGED
@@ -823,10 +823,10 @@ def compress_image(image):
823
  return compressed_img
824
 
825
  @spaces.GPU(duration=60)
826
- @torch.inference_mode()
827
  def process_image(input_image, input_text):
828
  """Main processing function for the Gradio interface"""
829
-
830
  # Initialize configs
831
  API_TOKEN = "9c8c865e10ec1821bea79d9fa9dc8720"
832
  SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
@@ -835,6 +835,8 @@ def process_image(input_image, input_text):
835
  OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
836
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
837
 
 
 
838
  # Initialize DDS client
839
  config = Config(API_TOKEN)
840
  client = Client(config)
@@ -850,102 +852,215 @@ def process_image(input_image, input_text):
850
  image_url = client.upload_file(tmpfile.name)
851
  os.remove(tmpfile.name)
852
 
853
- # Run DINO-X detection
854
- task = DinoxTask(
855
- image_url=image_url,
856
- prompts=[TextPrompt(text=input_text)]
857
- )
858
- client.run_task(task)
859
- result = task.result
860
- objects = result.objects
861
-
862
  # Process detection results
863
  input_boxes = []
 
864
  confidences = []
865
  class_names = []
866
  class_ids = []
867
 
868
- for obj in objects:
869
- input_boxes.append(obj.bbox)
870
- confidences.append(obj.score)
871
- cls_name = obj.category.lower().strip()
872
- class_names.append(cls_name)
873
- class_ids.append(class_name_to_id[cls_name])
874
-
875
- input_boxes = np.array(input_boxes)
876
- class_ids = np.array(class_ids)
877
-
878
- # Initialize SAM2
879
- torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
880
- if torch.cuda.get_device_properties(0).major >= 8:
881
- torch.backends.cuda.matmul.allow_tf32 = True
882
- torch.backends.cudnn.allow_tf32 = True
883
-
884
- sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=DEVICE)
885
- sam2_predictor = SAM2ImagePredictor(sam2_model)
886
- sam2_predictor.set_image(input_image)
887
-
888
- # sam2_predictor = run_sam_inference(SAM_IMAGE_MODEL, input_image, detections)
889
-
890
-
891
- # Get masks from SAM2
892
- masks, scores, logits = sam2_predictor.predict(
893
- point_coords=None,
894
- point_labels=None,
895
- box=input_boxes,
896
- multimask_output=False,
897
- )
898
- if masks.ndim == 4:
899
- masks = masks.squeeze(1)
 
 
 
900
 
901
- # Create visualization
902
- labels = [f"{class_name} {confidence:.2f}"
903
- for class_name, confidence in zip(class_names, confidences)]
904
 
905
- detections = sv.Detections(
906
- xyxy=input_boxes,
907
- mask=masks.astype(bool),
908
- class_id=class_ids
909
- )
910
 
911
- box_annotator = sv.BoxAnnotator()
912
- label_annotator = sv.LabelAnnotator()
913
- mask_annotator = sv.MaskAnnotator()
914
-
915
- annotated_frame = input_image.copy()
916
- annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
917
- annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
918
- annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
919
 
920
- # Create transparent mask for first detected object
921
- if len(detections) > 0:
922
- # Get first mask
923
- first_mask = detections.mask[0]
 
 
 
 
924
 
925
- # Get original RGB image
926
- img = input_image.copy()
927
- H, W, C = img.shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
928
 
929
- # Create RGBA image
930
- alpha = np.zeros((H, W, 1), dtype=np.uint8)
931
- alpha[first_mask] = 255
932
- rgba = np.dstack((img, alpha)).astype(np.uint8)
 
933
 
934
- # Crop to mask bounds to minimize image size
935
- y_indices, x_indices = np.where(first_mask)
936
- y_min, y_max = y_indices.min(), y_indices.max()
937
- x_min, x_max = x_indices.min(), x_indices.max()
 
 
 
 
 
 
 
 
 
 
 
 
938
 
939
- # Crop the RGBA image
940
- cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
941
 
942
- # Set extracted foreground for mask mover
943
- mask_mover.set_extracted_fg(cropped_rgba)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
944
 
945
- return annotated_frame, cropped_rgba, gr.update(visible=True), gr.update(visible=True)
 
 
946
 
947
- return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
948
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
949
 
950
  block = gr.Blocks().queue()
951
  with block:
@@ -958,16 +1073,15 @@ with block:
958
  input_fg = gr.Image(type="numpy", label="Image", height=480)
959
  with gr.Row():
960
  with gr.Group():
961
- # find_objects_button = gr.Button(value="(Option 1) Segment Object from text")
962
- # text_prompt = gr.Textbox(
963
- # label="Text Prompt",
964
- # placeholder="Enter object classes separated by periods (e.g. 'car . person .')",
965
- # value="couch . table ."
966
- # )
967
-
968
  extract_button = gr.Button(value="Remove Background")
969
  with gr.Row():
970
- #extracted_objects = gr.Image(type="numpy", label="Extracted Foreground", height=480)
971
  extracted_fg = gr.Image(type="numpy", label="Extracted Foreground", height=480)
972
 
973
  # output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
@@ -1028,11 +1142,11 @@ with block:
1028
  relight_button.click(fn=process_relight, inputs=ips, outputs=[extracted_fg, result_gallery])
1029
  example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
1030
  example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
1031
- # find_objects_button.click(
1032
- # fn=process_image,
1033
- # inputs=[input_fg, text_prompt],
1034
- # outputs=[extracted_objects, extracted_fg]
1035
- # )
1036
  extract_button.click(
1037
  fn=extract_foreground,
1038
  inputs=[input_fg],
@@ -1169,11 +1283,11 @@ with block:
1169
  outputs=[extracted_fg, x_slider, y_slider]
1170
  )
1171
 
1172
- # find_objects_button.click(
1173
- # fn=process_image,
1174
- # inputs=[input_image, text_prompt],
1175
- # outputs=[extracted_objects, extracted_fg, x_slider, y_slider]
1176
- # )
1177
 
1178
  get_depth_button.click(
1179
  fn=get_depth,
 
823
  return compressed_img
824
 
825
  @spaces.GPU(duration=60)
826
+ @torch.inference_mode
827
  def process_image(input_image, input_text):
828
  """Main processing function for the Gradio interface"""
829
+
830
  # Initialize configs
831
  API_TOKEN = "9c8c865e10ec1821bea79d9fa9dc8720"
832
  SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
 
835
  OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
836
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
837
 
838
+
839
+
840
  # Initialize DDS client
841
  config = Config(API_TOKEN)
842
  client = Client(config)
 
852
  image_url = client.upload_file(tmpfile.name)
853
  os.remove(tmpfile.name)
854
 
 
 
 
 
 
 
 
 
 
855
  # Process detection results
856
  input_boxes = []
857
+ masks = []
858
  confidences = []
859
  class_names = []
860
  class_ids = []
861
 
862
+ if len(input_text) == 0:
863
+ task = DinoxTask(
864
+ image_url=image_url,
865
+ prompts=[TextPrompt(text="<prompt_free>")],
866
+ # targets=[DetectionTarget.BBox, DetectionTarget.Mask]
867
+ )
868
+
869
+ client.run_task(task)
870
+ predictions = task.result.objects
871
+ classes = [pred.category for pred in predictions]
872
+ classes = list(set(classes))
873
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
874
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
875
+
876
+ for idx, obj in enumerate(predictions):
877
+ input_boxes.append(obj.bbox)
878
+ masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
879
+ confidences.append(obj.score)
880
+ cls_name = obj.category.lower().strip()
881
+ class_names.append(cls_name)
882
+ class_ids.append(class_name_to_id[cls_name])
883
+
884
+ boxes = np.array(input_boxes)
885
+ masks = np.array(masks)
886
+ class_ids = np.array(class_ids)
887
+ labels = [
888
+ f"{class_name} {confidence:.2f}"
889
+ for class_name, confidence
890
+ in zip(class_names, confidences)
891
+ ]
892
+ detections = sv.Detections(
893
+ xyxy=boxes,
894
+ mask=masks.astype(bool),
895
+ class_id=class_ids
896
+ )
897
 
898
+ box_annotator = sv.BoxAnnotator()
899
+ label_annotator = sv.LabelAnnotator()
900
+ mask_annotator = sv.MaskAnnotator()
901
 
902
+ annotated_frame = input_image.copy()
903
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
904
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
905
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
 
906
 
907
+ # Create transparent mask for first detected object
908
+ if len(detections) > 0:
909
+ # Get first mask
910
+ first_mask = detections.mask[0]
911
+
912
+ # Get original RGB image
913
+ img = input_image.copy()
914
+ H, W, C = img.shape
915
+
916
+ # Create RGBA image
917
+ alpha = np.zeros((H, W, 1), dtype=np.uint8)
918
+ alpha[first_mask] = 255
919
+ rgba = np.dstack((img, alpha)).astype(np.uint8)
920
+
921
+ # Crop to mask bounds to minimize image size
922
+ y_indices, x_indices = np.where(first_mask)
923
+ y_min, y_max = y_indices.min(), y_indices.max()
924
+ x_min, x_max = x_indices.min(), x_indices.max()
925
+
926
+ # Crop the RGBA image
927
+ cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
928
+
929
+ # Set extracted foreground for mask mover
930
+ mask_mover.set_extracted_fg(cropped_rgba)
931
+
932
+ return annotated_frame, cropped_rgba, gr.update(visible=True), gr.update(visible=True)
933
 
934
+
935
+ else:
936
+ # Run DINO-X detection
937
+ task = DinoxTask(
938
+ image_url=image_url,
939
+ prompts=[TextPrompt(text=input_text)],
940
+ targets=[DetectionTarget.BBox, DetectionTarget.Mask]
941
+ )
942
 
943
+ client.run_task(task)
944
+ result = task.result
945
+ objects = result.objects
946
+
947
+
948
+
949
+ # for obj in objects:
950
+ # input_boxes.append(obj.bbox)
951
+ # confidences.append(obj.score)
952
+ # cls_name = obj.category.lower().strip()
953
+ # class_names.append(cls_name)
954
+ # class_ids.append(class_name_to_id[cls_name])
955
+
956
+ # input_boxes = np.array(input_boxes)
957
+ # class_ids = np.array(class_ids)
958
+
959
+ predictions = task.result.objects
960
+ classes = [x.strip().lower() for x in input_text.split('.') if x]
961
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
962
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
963
 
964
+ boxes = []
965
+ masks = []
966
+ confidences = []
967
+ class_names = []
968
+ class_ids = []
969
 
970
+ for idx, obj in enumerate(predictions):
971
+ boxes.append(obj.bbox)
972
+ masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
973
+ confidences.append(obj.score)
974
+ cls_name = obj.category.lower().strip()
975
+ class_names.append(cls_name)
976
+ class_ids.append(class_name_to_id[cls_name])
977
+
978
+ boxes = np.array(boxes)
979
+ masks = np.array(masks)
980
+ class_ids = np.array(class_ids)
981
+ labels = [
982
+ f"{class_name} {confidence:.2f}"
983
+ for class_name, confidence
984
+ in zip(class_names, confidences)
985
+ ]
986
 
987
+ # Initialize SAM2
988
+ # torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
989
+ # if torch.cuda.get_device_properties(0).major >= 8:
990
+ # torch.backends.cuda.matmul.allow_tf32 = True
991
+ # torch.backends.cudnn.allow_tf32 = True
992
+
993
+ # sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=DEVICE)
994
+ # sam2_predictor = SAM2ImagePredictor(sam2_model)
995
+ # sam2_predictor.set_image(input_image)
996
+
997
+ # sam2_predictor = run_sam_inference(SAM_IMAGE_MODEL, input_image, detections)
998
+
999
+
1000
+ # Get masks from SAM2
1001
+ # masks, scores, logits = sam2_predictor.predict(
1002
+ # point_coords=None,
1003
+ # point_labels=None,
1004
+ # box=input_boxes,
1005
+ # multimask_output=False,
1006
+ # )
1007
 
1008
+ if masks.ndim == 4:
1009
+ masks = masks.squeeze(1)
1010
+
1011
+ # Create visualization
1012
+ # labels = [f"{class_name} {confidence:.2f}"
1013
+ # for class_name, confidence in zip(class_names, confidences)]
1014
+
1015
+ # detections = sv.Detections(
1016
+ # xyxy=input_boxes,
1017
+ # mask=masks.astype(bool),
1018
+ # class_id=class_ids
1019
+ # )
1020
+
1021
+ detections = sv.Detections(
1022
+ xyxy = boxes,
1023
+ mask = masks.astype(bool),
1024
+ class_id = class_ids,
1025
+ )
1026
 
1027
+ box_annotator = sv.BoxAnnotator()
1028
+ label_annotator = sv.LabelAnnotator()
1029
+ mask_annotator = sv.MaskAnnotator()
1030
 
1031
+ annotated_frame = input_image.copy()
1032
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
1033
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
1034
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
1035
+
1036
+ # Create transparent mask for first detected object
1037
+ if len(detections) > 0:
1038
+ # Get first mask
1039
+ first_mask = detections.mask[0]
1040
+
1041
+ # Get original RGB image
1042
+ img = input_image.copy()
1043
+ H, W, C = img.shape
1044
+
1045
+ # Create RGBA image
1046
+ alpha = np.zeros((H, W, 1), dtype=np.uint8)
1047
+ alpha[first_mask] = 255
1048
+ rgba = np.dstack((img, alpha)).astype(np.uint8)
1049
+
1050
+ # Crop to mask bounds to minimize image size
1051
+ y_indices, x_indices = np.where(first_mask)
1052
+ y_min, y_max = y_indices.min(), y_indices.max()
1053
+ x_min, x_max = x_indices.min(), x_indices.max()
1054
+
1055
+ # Crop the RGBA image
1056
+ cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
1057
+
1058
+ # Set extracted foreground for mask mover
1059
+ mask_mover.set_extracted_fg(cropped_rgba)
1060
+
1061
+ return annotated_frame, cropped_rgba, gr.update(visible=True), gr.update(visible=True)
1062
+ return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
1063
+
1064
 
1065
  block = gr.Blocks().queue()
1066
  with block:
 
1073
  input_fg = gr.Image(type="numpy", label="Image", height=480)
1074
  with gr.Row():
1075
  with gr.Group():
1076
+ find_objects_button = gr.Button(value="(Option 1) Segment Object from text")
1077
+ text_prompt = gr.Textbox(
1078
+ label="Text Prompt",
1079
+ placeholder="Enter object classes separated by periods (e.g. 'car . person .'), leave empty to get all objects",
1080
+ value=""
1081
+ )
 
1082
  extract_button = gr.Button(value="Remove Background")
1083
  with gr.Row():
1084
+ extracted_objects = gr.Image(type="numpy", label="Extracted Foreground", height=480)
1085
  extracted_fg = gr.Image(type="numpy", label="Extracted Foreground", height=480)
1086
 
1087
  # output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
 
1142
  relight_button.click(fn=process_relight, inputs=ips, outputs=[extracted_fg, result_gallery])
1143
  example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
1144
  example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
1145
+ find_objects_button.click(
1146
+ fn=process_image,
1147
+ inputs=[input_fg, text_prompt],
1148
+ outputs=[extracted_objects, extracted_fg]
1149
+ )
1150
  extract_button.click(
1151
  fn=extract_foreground,
1152
  inputs=[input_fg],
 
1283
  outputs=[extracted_fg, x_slider, y_slider]
1284
  )
1285
 
1286
+ find_objects_button.click(
1287
+ fn=process_image,
1288
+ inputs=[input_image, text_prompt],
1289
+ outputs=[extracted_objects, extracted_fg, x_slider, y_slider]
1290
+ )
1291
 
1292
  get_depth_button.click(
1293
  fn=get_depth,