File size: 8,479 Bytes
1fafe10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import torch
import numpy as np
from collections import defaultdict

from mmdet.models.task_modules.assigners import BboxOverlaps2D
from mmengine.structures import InstanceData
def average_score_filter(instances_list):
    # Extract instance IDs and their scores
    instance_id_to_frames = defaultdict(list)
    instance_id_to_scores = defaultdict(list)
    for frame_idx, instances in enumerate(instances_list):
        for i, instance_id in enumerate(instances[0].pred_track_instances.instances_id):
            instance_id_to_frames[instance_id.item()].append(frame_idx)
            instance_id_to_scores[instance_id.item()].append(instances[0].pred_track_instances.scores[i].cpu().numpy())

    # Compute average scores for each segment of each instance ID
    for instance_id, frames in instance_id_to_frames.items():
        scores = np.array(instance_id_to_scores[instance_id])

        # Identify segments
        segments = []
        segment = [frames[0]]
        for idx in range(1, len(frames)):
            if frames[idx] == frames[idx - 1] + 1:
                segment.append(frames[idx])
            else:
                segments.append(segment)
                segment = [frames[idx]]
        segments.append(segment)

        # Compute average score for each segment
        avg_scores = np.copy(scores)
        for segment in segments:
            segment_scores = scores[frames.index(segment[0]):frames.index(segment[-1]) + 1]
            avg_score = np.mean(segment_scores)
            avg_scores[frames.index(segment[0]):frames.index(segment[-1]) + 1] = avg_score

        # Update instances_list with average scores
        for frame_idx, avg_score in zip(frames, avg_scores):
            instances_list[frame_idx][0].pred_track_instances.scores[
                instances_list[frame_idx][0].pred_track_instances.instances_id == instance_id] = torch.tensor(avg_score, dtype=instances_list[frame_idx][0].pred_track_instances.scores.dtype)

    return instances_list


def moving_average_filter(instances_list, window_size=5):
    # Helper function to compute the moving average
    def smooth_bbox(bboxes, window_size):
        smoothed_bboxes = np.copy(bboxes)
        half_window = window_size // 2
        for i in range(4):
            padded_bboxes = np.pad(bboxes[:, i], (half_window, half_window), mode='edge')
            smoothed_bboxes[:, i] = np.convolve(padded_bboxes, np.ones(window_size) / window_size, mode='valid')
        return smoothed_bboxes

    # Extract bounding boxes and instance IDs
    instance_id_to_frames = defaultdict(list)
    instance_id_to_bboxes = defaultdict(list)
    for frame_idx, instances in enumerate(instances_list):
        for i, instance_id in enumerate(instances[0].pred_track_instances.instances_id):
            instance_id_to_frames[instance_id.item()].append(frame_idx)
            instance_id_to_bboxes[instance_id.item()].append(instances[0].pred_track_instances.bboxes[i].cpu().numpy())

    # Apply moving average filter to each segment
    for instance_id, frames in instance_id_to_frames.items():
        bboxes = np.array(instance_id_to_bboxes[instance_id])

        # Identify segments
        segments = []
        segment = [frames[0]]
        for idx in range(1, len(frames)):
            if frames[idx] == frames[idx - 1] + 1:
                segment.append(frames[idx])
            else:
                segments.append(segment)
                segment = [frames[idx]]
        segments.append(segment)

        # Smooth bounding boxes for each segment
        smoothed_bboxes = np.copy(bboxes)
        for segment in segments:
            if len(segment) >= window_size:
                segment_bboxes = bboxes[frames.index(segment[0]):frames.index(segment[-1]) + 1]
                smoothed_segment_bboxes = smooth_bbox(segment_bboxes, window_size)
                smoothed_bboxes[frames.index(segment[0]):frames.index(segment[-1]) + 1] = smoothed_segment_bboxes

        # Update instances_list with smoothed bounding boxes
        for frame_idx, smoothed_bbox in zip(frames, smoothed_bboxes):
            instances_list[frame_idx][0].pred_track_instances.bboxes[
                instances_list[frame_idx][0].pred_track_instances.instances_id == instance_id] = torch.tensor(smoothed_bbox, dtype=instances_list[frame_idx][0].pred_track_instances.bboxes.dtype).to(instances_list[frame_idx][0].pred_track_instances.bboxes.device)

    return instances_list


def identify_and_remove_giant_bounding_boxes(instances_list, image_size, size_threshold, confidence_threshold,
                                             coverage_threshold, object_num_thr=4, max_objects_in_box=6):
    # Initialize BboxOverlaps2D with 'iof' mode
    bbox_overlaps_calculator = BboxOverlaps2D()

    # Initialize data structures
    invalid_instance_ids = set()

    image_width, image_height = image_size
    two_thirds_image_area = (2 / 3) * (image_width * image_height)

    # Step 1: Identify giant bounding boxes and record their instance_ids
    for frame_idx, instances in enumerate(instances_list):
        bounding_boxes = instances[0].pred_track_instances.bboxes
        confidence_scores = instances[0].pred_track_instances.scores
        instance_ids = instances[0].pred_track_instances.instances_id

        N = bounding_boxes.size(0)

        for i in range(N):
            current_box = bounding_boxes[i]
            box_size = (current_box[2] - current_box[0]) * (current_box[3] - current_box[1])

            if box_size < size_threshold:
                continue

            other_boxes = torch.cat([bounding_boxes[:i], bounding_boxes[i + 1:]])
            other_confidences = torch.cat([confidence_scores[:i], confidence_scores[i + 1:]])
            iofs = bbox_overlaps_calculator(other_boxes, current_box.unsqueeze(0), mode='iof', is_aligned=False)

            if iofs.numel() == 0:
                continue

            high_conf_mask = other_confidences > confidence_threshold

            if high_conf_mask.numel() == 0 or torch.sum(high_conf_mask) == 0:
                continue

            high_conf_masked_iofs = iofs[high_conf_mask]

            covered_high_conf_boxes_count = torch.sum(high_conf_masked_iofs > coverage_threshold)

            if covered_high_conf_boxes_count >= object_num_thr and torch.all(
                    confidence_scores[i] < other_confidences[high_conf_mask]):
                invalid_instance_ids.add(instance_ids[i].item())
                continue

            if box_size > two_thirds_image_area:
                invalid_instance_ids.add(instance_ids[i].item())
                continue

            # New condition: if the bounding box contains more than 6 objects
            if covered_high_conf_boxes_count > max_objects_in_box:
                invalid_instance_ids.add(instance_ids[i].item())
                continue

    # Remove invalid tracks
    for frame_idx, instances in enumerate(instances_list):
        valid_mask = torch.tensor(
            [instance_id.item() not in invalid_instance_ids for instance_id in
             instances[0].pred_track_instances.instances_id])
        if len(valid_mask) == 0:
            continue
        new_instance_data = InstanceData()
        new_instance_data.bboxes = instances[0].pred_track_instances.bboxes[valid_mask]
        new_instance_data.scores = instances[0].pred_track_instances.scores[valid_mask]
        new_instance_data.instances_id = instances[0].pred_track_instances.instances_id[valid_mask]
        new_instance_data.labels = instances[0].pred_track_instances.labels[valid_mask]
        if 'masks' in instances[0].pred_track_instances:
            new_instance_data.masks = instances[0].pred_track_instances.masks[valid_mask]
        instances[0].pred_track_instances = new_instance_data

    return instances_list


def filter_and_update_tracks(instances_list, image_size, size_threshold=10000, coverage_threshold=0.75,
                             confidence_threshold=0.2, smoothing_window_size=5):

    # Step 1: Identify and remove giant bounding boxes
    instances_list = identify_and_remove_giant_bounding_boxes(instances_list, image_size, size_threshold, confidence_threshold, coverage_threshold)

     # Step 2: Smooth interpolated bounding boxes
    instances_list = moving_average_filter(instances_list, window_size=smoothing_window_size)

    # Step 3: compute the track average score
    instances_list = average_score_filter(instances_list)


    return instances_list