Spaces:
Build error
Build error
import os | |
from .burst_helpers.burst_base import BURSTBase | |
from .burst_helpers.format_converter import GroundTruthBURSTFormatToTAOFormatConverter, PredictionBURSTFormatToTAOFormatConverter | |
from .. import utils | |
class BURST(BURSTBase): | |
"""Dataset class for TAO tracking""" | |
def get_default_dataset_config(): | |
tao_config = BURSTBase.get_default_dataset_config() | |
code_path = utils.get_code_path() | |
# e.g. 'data/gt/tsunami/exemplar_guided/' | |
tao_config['GT_FOLDER'] = os.path.join( | |
code_path, 'data/gt/burst/val/') # Location of GT data | |
# e.g. 'data/trackers/tsunami/exemplar_guided/mask_guided/validation/' | |
tao_config['TRACKERS_FOLDER'] = os.path.join( | |
code_path, 'data/trackers/burst/class-guided/') # Trackers location | |
# set to True or False | |
tao_config['EXEMPLAR_GUIDED'] = False | |
return tao_config | |
def _iou_type(self): | |
return 'mask' | |
def _box_or_mask_from_det(self, det): | |
return det['segmentation'] | |
def _calculate_area_for_ann(self, ann): | |
import pycocotools.mask as cocomask | |
return cocomask.area(ann["segmentation"]) | |
def _calculate_similarities(self, gt_dets_t, tracker_dets_t): | |
similarity_scores = self._calculate_mask_ious(gt_dets_t, tracker_dets_t, is_encoded=True, do_ioa=False) | |
return similarity_scores | |
def _is_exemplar_guided(self): | |
exemplar_guided = self.config['EXEMPLAR_GUIDED'] | |
return exemplar_guided | |
def _postproc_ground_truth_data(self, data): | |
return GroundTruthBURSTFormatToTAOFormatConverter(data).convert() | |
def _postproc_prediction_data(self, data): | |
return PredictionBURSTFormatToTAOFormatConverter( | |
self.gt_data, data, | |
exemplar_guided=self._is_exemplar_guided()).convert() | |