JohanDL commited on
Commit
1fafe10
1 Parent(s): cd51db0

adding demo folder

Browse files
demo/__pycache__/utils.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
demo/utils.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from collections import defaultdict
4
+
5
+ from mmdet.models.task_modules.assigners import BboxOverlaps2D
6
+ from mmengine.structures import InstanceData
7
+ def average_score_filter(instances_list):
8
+ # Extract instance IDs and their scores
9
+ instance_id_to_frames = defaultdict(list)
10
+ instance_id_to_scores = defaultdict(list)
11
+ for frame_idx, instances in enumerate(instances_list):
12
+ for i, instance_id in enumerate(instances[0].pred_track_instances.instances_id):
13
+ instance_id_to_frames[instance_id.item()].append(frame_idx)
14
+ instance_id_to_scores[instance_id.item()].append(instances[0].pred_track_instances.scores[i].cpu().numpy())
15
+
16
+ # Compute average scores for each segment of each instance ID
17
+ for instance_id, frames in instance_id_to_frames.items():
18
+ scores = np.array(instance_id_to_scores[instance_id])
19
+
20
+ # Identify segments
21
+ segments = []
22
+ segment = [frames[0]]
23
+ for idx in range(1, len(frames)):
24
+ if frames[idx] == frames[idx - 1] + 1:
25
+ segment.append(frames[idx])
26
+ else:
27
+ segments.append(segment)
28
+ segment = [frames[idx]]
29
+ segments.append(segment)
30
+
31
+ # Compute average score for each segment
32
+ avg_scores = np.copy(scores)
33
+ for segment in segments:
34
+ segment_scores = scores[frames.index(segment[0]):frames.index(segment[-1]) + 1]
35
+ avg_score = np.mean(segment_scores)
36
+ avg_scores[frames.index(segment[0]):frames.index(segment[-1]) + 1] = avg_score
37
+
38
+ # Update instances_list with average scores
39
+ for frame_idx, avg_score in zip(frames, avg_scores):
40
+ instances_list[frame_idx][0].pred_track_instances.scores[
41
+ instances_list[frame_idx][0].pred_track_instances.instances_id == instance_id] = torch.tensor(avg_score, dtype=instances_list[frame_idx][0].pred_track_instances.scores.dtype)
42
+
43
+ return instances_list
44
+
45
+
46
+ def moving_average_filter(instances_list, window_size=5):
47
+ # Helper function to compute the moving average
48
+ def smooth_bbox(bboxes, window_size):
49
+ smoothed_bboxes = np.copy(bboxes)
50
+ half_window = window_size // 2
51
+ for i in range(4):
52
+ padded_bboxes = np.pad(bboxes[:, i], (half_window, half_window), mode='edge')
53
+ smoothed_bboxes[:, i] = np.convolve(padded_bboxes, np.ones(window_size) / window_size, mode='valid')
54
+ return smoothed_bboxes
55
+
56
+ # Extract bounding boxes and instance IDs
57
+ instance_id_to_frames = defaultdict(list)
58
+ instance_id_to_bboxes = defaultdict(list)
59
+ for frame_idx, instances in enumerate(instances_list):
60
+ for i, instance_id in enumerate(instances[0].pred_track_instances.instances_id):
61
+ instance_id_to_frames[instance_id.item()].append(frame_idx)
62
+ instance_id_to_bboxes[instance_id.item()].append(instances[0].pred_track_instances.bboxes[i].cpu().numpy())
63
+
64
+ # Apply moving average filter to each segment
65
+ for instance_id, frames in instance_id_to_frames.items():
66
+ bboxes = np.array(instance_id_to_bboxes[instance_id])
67
+
68
+ # Identify segments
69
+ segments = []
70
+ segment = [frames[0]]
71
+ for idx in range(1, len(frames)):
72
+ if frames[idx] == frames[idx - 1] + 1:
73
+ segment.append(frames[idx])
74
+ else:
75
+ segments.append(segment)
76
+ segment = [frames[idx]]
77
+ segments.append(segment)
78
+
79
+ # Smooth bounding boxes for each segment
80
+ smoothed_bboxes = np.copy(bboxes)
81
+ for segment in segments:
82
+ if len(segment) >= window_size:
83
+ segment_bboxes = bboxes[frames.index(segment[0]):frames.index(segment[-1]) + 1]
84
+ smoothed_segment_bboxes = smooth_bbox(segment_bboxes, window_size)
85
+ smoothed_bboxes[frames.index(segment[0]):frames.index(segment[-1]) + 1] = smoothed_segment_bboxes
86
+
87
+ # Update instances_list with smoothed bounding boxes
88
+ for frame_idx, smoothed_bbox in zip(frames, smoothed_bboxes):
89
+ instances_list[frame_idx][0].pred_track_instances.bboxes[
90
+ instances_list[frame_idx][0].pred_track_instances.instances_id == instance_id] = torch.tensor(smoothed_bbox, dtype=instances_list[frame_idx][0].pred_track_instances.bboxes.dtype).to(instances_list[frame_idx][0].pred_track_instances.bboxes.device)
91
+
92
+ return instances_list
93
+
94
+
95
+ def identify_and_remove_giant_bounding_boxes(instances_list, image_size, size_threshold, confidence_threshold,
96
+ coverage_threshold, object_num_thr=4, max_objects_in_box=6):
97
+ # Initialize BboxOverlaps2D with 'iof' mode
98
+ bbox_overlaps_calculator = BboxOverlaps2D()
99
+
100
+ # Initialize data structures
101
+ invalid_instance_ids = set()
102
+
103
+ image_width, image_height = image_size
104
+ two_thirds_image_area = (2 / 3) * (image_width * image_height)
105
+
106
+ # Step 1: Identify giant bounding boxes and record their instance_ids
107
+ for frame_idx, instances in enumerate(instances_list):
108
+ bounding_boxes = instances[0].pred_track_instances.bboxes
109
+ confidence_scores = instances[0].pred_track_instances.scores
110
+ instance_ids = instances[0].pred_track_instances.instances_id
111
+
112
+ N = bounding_boxes.size(0)
113
+
114
+ for i in range(N):
115
+ current_box = bounding_boxes[i]
116
+ box_size = (current_box[2] - current_box[0]) * (current_box[3] - current_box[1])
117
+
118
+ if box_size < size_threshold:
119
+ continue
120
+
121
+ other_boxes = torch.cat([bounding_boxes[:i], bounding_boxes[i + 1:]])
122
+ other_confidences = torch.cat([confidence_scores[:i], confidence_scores[i + 1:]])
123
+ iofs = bbox_overlaps_calculator(other_boxes, current_box.unsqueeze(0), mode='iof', is_aligned=False)
124
+
125
+ if iofs.numel() == 0:
126
+ continue
127
+
128
+ high_conf_mask = other_confidences > confidence_threshold
129
+
130
+ if high_conf_mask.numel() == 0 or torch.sum(high_conf_mask) == 0:
131
+ continue
132
+
133
+ high_conf_masked_iofs = iofs[high_conf_mask]
134
+
135
+ covered_high_conf_boxes_count = torch.sum(high_conf_masked_iofs > coverage_threshold)
136
+
137
+ if covered_high_conf_boxes_count >= object_num_thr and torch.all(
138
+ confidence_scores[i] < other_confidences[high_conf_mask]):
139
+ invalid_instance_ids.add(instance_ids[i].item())
140
+ continue
141
+
142
+ if box_size > two_thirds_image_area:
143
+ invalid_instance_ids.add(instance_ids[i].item())
144
+ continue
145
+
146
+ # New condition: if the bounding box contains more than 6 objects
147
+ if covered_high_conf_boxes_count > max_objects_in_box:
148
+ invalid_instance_ids.add(instance_ids[i].item())
149
+ continue
150
+
151
+ # Remove invalid tracks
152
+ for frame_idx, instances in enumerate(instances_list):
153
+ valid_mask = torch.tensor(
154
+ [instance_id.item() not in invalid_instance_ids for instance_id in
155
+ instances[0].pred_track_instances.instances_id])
156
+ if len(valid_mask) == 0:
157
+ continue
158
+ new_instance_data = InstanceData()
159
+ new_instance_data.bboxes = instances[0].pred_track_instances.bboxes[valid_mask]
160
+ new_instance_data.scores = instances[0].pred_track_instances.scores[valid_mask]
161
+ new_instance_data.instances_id = instances[0].pred_track_instances.instances_id[valid_mask]
162
+ new_instance_data.labels = instances[0].pred_track_instances.labels[valid_mask]
163
+ if 'masks' in instances[0].pred_track_instances:
164
+ new_instance_data.masks = instances[0].pred_track_instances.masks[valid_mask]
165
+ instances[0].pred_track_instances = new_instance_data
166
+
167
+ return instances_list
168
+
169
+
170
+ def filter_and_update_tracks(instances_list, image_size, size_threshold=10000, coverage_threshold=0.75,
171
+ confidence_threshold=0.2, smoothing_window_size=5):
172
+
173
+ # Step 1: Identify and remove giant bounding boxes
174
+ instances_list = identify_and_remove_giant_bounding_boxes(instances_list, image_size, size_threshold, confidence_threshold, coverage_threshold)
175
+
176
+ # Step 2: Smooth interpolated bounding boxes
177
+ instances_list = moving_average_filter(instances_list, window_size=smoothing_window_size)
178
+
179
+ # Step 3: compute the track average score
180
+ instances_list = average_score_filter(instances_list)
181
+
182
+
183
+ return instances_list
demo/video_demo_with_text.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
5
+ sys.path.insert(0, project_root)
6
+
7
+ import gc
8
+ import resource
9
+ import argparse
10
+ import cv2
11
+ import tqdm
12
+
13
+ import torch
14
+ from torch.multiprocessing import Pool, set_start_method
15
+
16
+ import mmcv
17
+ from mmcv.transforms import Compose
18
+ from mmengine.utils import track_iter_progress
19
+ from mmdet.apis import init_detector
20
+ from mmdet.registry import VISUALIZERS
21
+ from mmcv.ops.nms import batched_nms
22
+
23
+ import masa
24
+ from masa.apis import inference_masa, init_masa, inference_detector, build_test_pipeline
25
+ from masa.models.sam import SamPredictor, sam_model_registry
26
+ from utils import filter_and_update_tracks
27
+
28
+ import warnings
29
+ warnings.filterwarnings('ignore')
30
+
31
+ # Ensure the right start method for multiprocessing
32
+ try:
33
+ set_start_method('spawn')
34
+ except RuntimeError:
35
+ pass
36
+
37
+ def set_file_descriptor_limit(limit):
38
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
39
+ resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard))
40
+
41
+ # Set the file descriptor limit to 65536
42
+ set_file_descriptor_limit(65536)
43
+
44
+ def visualize_frame(args, visualizer, frame, track_result, frame_idx, fps=None):
45
+ visualizer.add_datasample(
46
+ name='video_' + str(frame_idx),
47
+ image=frame[:, :, ::-1],
48
+ data_sample=track_result[0],
49
+ draw_gt=False,
50
+ show=False,
51
+ out_file=None,
52
+ pred_score_thr=args.score_thr,
53
+ fps=fps,)
54
+ frame = visualizer.get_image()
55
+ gc.collect()
56
+ return frame
57
+
58
+ def parse_args():
59
+
60
+ parser = argparse.ArgumentParser(description='MASA video demo')
61
+ parser.add_argument('video', help='Video file')
62
+ parser.add_argument('--det_config', help='Detector Config file')
63
+ parser.add_argument('--masa_config', help='Masa Config file')
64
+ parser.add_argument('--det_checkpoint', help='Detector Checkpoint file')
65
+ parser.add_argument('--masa_checkpoint', help='Masa Checkpoint file')
66
+ parser.add_argument( '--device', default='cuda:0', help='Device used for inference')
67
+ parser.add_argument('--score-thr', type=float, default=0.2, help='Bbox score threshold')
68
+ parser.add_argument('--out', type=str, help='Output video file')
69
+ parser.add_argument('--save_dir', type=str, help='Output for video frames')
70
+ parser.add_argument('--texts', help='text prompt')
71
+ parser.add_argument('--line_width', type=int, default=5, help='Line width')
72
+ parser.add_argument('--unified', action='store_true', help='Use unified model, which means the masa adapter is built upon the detector model.')
73
+ parser.add_argument('--detector_type', type=str, default='mmdet', help='Choose detector type')
74
+ parser.add_argument('--fp16', action='store_true', help='Activation fp16 mode')
75
+ parser.add_argument('--no-post', action='store_true', help='Do not post-process the results ')
76
+ parser.add_argument('--show_fps', action='store_true', help='Visualize the fps')
77
+ parser.add_argument('--sam_mask', action='store_true', help='Use SAM to generate mask for segmentation tracking')
78
+ parser.add_argument('--sam_path', type=str, default='saved_models/pretrain_weights/sam_vit_h_4b8939.pth', help='Default path for SAM models')
79
+ parser.add_argument('--sam_type', type=str, default='vit_h', help='Default type for SAM models')
80
+ parser.add_argument(
81
+ '--wait-time',
82
+ type=float,
83
+ default=1,
84
+ help='The interval of show (s), 0 is block')
85
+ args = parser.parse_args()
86
+ return args
87
+
88
+ def main():
89
+ args = parse_args()
90
+ assert args.out, \
91
+ ('Please specify at least one operation (save the '
92
+ 'video) with the argument "--out" ')
93
+
94
+ # build the model from a config file and a checkpoint file
95
+ if args.unified:
96
+ masa_model = init_masa(args.masa_config, args.masa_checkpoint, device=args.device)
97
+ else:
98
+ det_model = init_detector(args.det_config, args.det_checkpoint, palette='random', device=args.device)
99
+ masa_model = init_masa(args.masa_config, args.masa_checkpoint, device=args.device)
100
+ # build test pipeline
101
+ det_model.cfg.test_dataloader.dataset.pipeline[
102
+ 0].type = 'mmdet.LoadImageFromNDArray'
103
+ test_pipeline = Compose(det_model.cfg.test_dataloader.dataset.pipeline)
104
+
105
+ if args.sam_mask:
106
+ print('Loading SAM model...')
107
+ device = args.device
108
+ sam_model = sam_model_registry[args.sam_type](args.sam_path)
109
+ sam_predictor = SamPredictor(sam_model.to(device))
110
+
111
+ video_reader = mmcv.VideoReader(args.video)
112
+ video_writer = None
113
+
114
+ #### parsing the text input
115
+ texts = args.texts
116
+ if texts is not None:
117
+ masa_test_pipeline = build_test_pipeline(masa_model.cfg, with_text=True)
118
+ else:
119
+ masa_test_pipeline = build_test_pipeline(masa_model.cfg)
120
+
121
+ if texts is not None:
122
+ masa_model.cfg.visualizer['texts'] = texts
123
+ else:
124
+ masa_model.cfg.visualizer['texts'] = det_model.dataset_meta['classes']
125
+
126
+ # init visualizer
127
+ masa_model.cfg.visualizer['save_dir'] = args.save_dir
128
+ masa_model.cfg.visualizer['line_width'] = args.line_width
129
+ if args.sam_mask:
130
+ masa_model.cfg.visualizer['alpha'] = 0.5
131
+ visualizer = VISUALIZERS.build(masa_model.cfg.visualizer)
132
+
133
+ if args.out:
134
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
135
+ video_writer = cv2.VideoWriter(
136
+ args.out, fourcc, video_reader.fps,
137
+ (video_reader.width, video_reader.height))
138
+
139
+ frame_idx = 0
140
+ instances_list = []
141
+ frames = []
142
+ fps_list = []
143
+ for frame in track_iter_progress((video_reader, len(video_reader))):
144
+
145
+ # unified models mean that masa build upon and reuse the foundation model's backbone features for tracking
146
+ if args.unified:
147
+ track_result = inference_masa(masa_model, frame,
148
+ frame_id=frame_idx,
149
+ video_len=len(video_reader),
150
+ test_pipeline=masa_test_pipeline,
151
+ text_prompt=texts,
152
+ fp16=args.fp16,
153
+ detector_type=args.detector_type,
154
+ show_fps=args.show_fps)
155
+ if args.show_fps:
156
+ track_result, fps = track_result
157
+ else:
158
+
159
+ if args.detector_type == 'mmdet':
160
+ result = inference_detector(det_model, frame,
161
+ text_prompt=texts,
162
+ test_pipeline=test_pipeline,
163
+ fp16=args.fp16)
164
+
165
+ # Perfom inter-class NMS to remove nosiy detections
166
+ det_bboxes, keep_idx = batched_nms(boxes=result.pred_instances.bboxes,
167
+ scores=result.pred_instances.scores,
168
+ idxs=result.pred_instances.labels,
169
+ class_agnostic=True,
170
+ nms_cfg=dict(type='nms',
171
+ iou_threshold=0.5,
172
+ class_agnostic=True,
173
+ split_thr=100000))
174
+
175
+ det_bboxes = torch.cat([det_bboxes,
176
+ result.pred_instances.scores[keep_idx].unsqueeze(1)],
177
+ dim=1)
178
+ det_labels = result.pred_instances.labels[keep_idx]
179
+
180
+ track_result = inference_masa(masa_model, frame, frame_id=frame_idx,
181
+ video_len=len(video_reader),
182
+ test_pipeline=masa_test_pipeline,
183
+ det_bboxes=det_bboxes,
184
+ det_labels=det_labels,
185
+ fp16=args.fp16,
186
+ show_fps=args.show_fps)
187
+ if args.show_fps:
188
+ track_result, fps = track_result
189
+
190
+ frame_idx += 1
191
+ if 'masks' in track_result[0].pred_track_instances:
192
+ if len(track_result[0].pred_track_instances.masks) >0:
193
+ track_result[0].pred_track_instances.masks = torch.stack(track_result[0].pred_track_instances.masks, dim=0)
194
+ track_result[0].pred_track_instances.masks = track_result[0].pred_track_instances.masks.cpu().numpy()
195
+
196
+ track_result[0].pred_track_instances.bboxes = track_result[0].pred_track_instances.bboxes.to(torch.float32)
197
+ instances_list.append(track_result.to('cpu'))
198
+ frames.append(frame)
199
+ if args.show_fps:
200
+ fps_list.append(fps)
201
+
202
+ if not args.no_post:
203
+ instances_list = filter_and_update_tracks(instances_list, (frame.shape[1], frame.shape[0]))
204
+
205
+ if args.sam_mask:
206
+ print('Start to generate mask using SAM!')
207
+ for idx, (frame, track_result) in tqdm.tqdm(enumerate(zip(frames, instances_list))):
208
+ track_result = track_result.to(device)
209
+ track_result[0].pred_track_instances.instances_id = track_result[0].pred_track_instances.instances_id.to(device)
210
+ track_result[0].pred_track_instances = track_result[0].pred_track_instances[(track_result[0].pred_track_instances.scores.float() > args.score_thr).to(device)]
211
+ input_boxes = track_result[0].pred_track_instances.bboxes
212
+ if len(input_boxes) == 0:
213
+ continue
214
+ sam_predictor.set_image(frame)
215
+ transformed_boxes = sam_predictor.transform.apply_boxes_torch(input_boxes, frame.shape[:2])
216
+ masks, _, _ = sam_predictor.predict_torch(
217
+ point_coords=None,
218
+ point_labels=None,
219
+ boxes=transformed_boxes,
220
+ multimask_output=False,
221
+ )
222
+ track_result[0].pred_track_instances.masks = masks.squeeze(1).cpu().numpy()
223
+ instances_list[idx] = track_result
224
+
225
+
226
+
227
+ if args.out:
228
+ print('Start to visualize the results...')
229
+ num_cores = max(1, min(os.cpu_count() - 1, 16))
230
+ print('Using {} cores for visualization'.format(num_cores))
231
+
232
+ if args.show_fps:
233
+ with Pool(processes=num_cores) as pool:
234
+
235
+ frames = pool.starmap(
236
+ visualize_frame, [(args, visualizer, frame, track_result.to('cpu'), idx, fps) for idx, (frame, fps, track_result) in enumerate(zip(frames, fps_list, instances_list))]
237
+ )
238
+ else:
239
+ with Pool(processes=num_cores) as pool:
240
+ frames = pool.starmap(
241
+ visualize_frame, [(args, visualizer, frame, track_result.to('cpu'), idx) for idx, (frame, track_result) in
242
+ enumerate(zip(frames, instances_list))]
243
+ )
244
+ for frame in frames:
245
+ if args.out:
246
+ video_writer.write(frame[:, :, ::-1])
247
+
248
+ if video_writer:
249
+ video_writer.release()
250
+ print('Done')
251
+
252
+
253
+ if __name__ == '__main__':
254
+ main()