datnguyentien204's picture
Upload 338 files
8e0b903 verified
raw
history blame
8.59 kB
import numpy as np
import torch
from ..metrics import ap_per_class
def fitness(x):
# Model fitness as a weighted combination of metrics
w = [0.0, 0.0, 0.1, 0.9, 0.0, 0.0, 0.1, 0.9, 0.1, 0.9]
return (x[:, :len(w)] * w).sum(1)
def ap_per_class_box_and_mask(
tp_m,
tp_b,
conf,
pred_cls,
target_cls,
plot=False,
save_dir=".",
names=(),
):
"""
Args:
tp_b: tp of boxes.
tp_m: tp of masks.
other arguments see `func: ap_per_class`.
"""
results_boxes = ap_per_class(tp_b,
conf,
pred_cls,
target_cls,
plot=plot,
save_dir=save_dir,
names=names,
prefix="Box")[2:]
results_masks = ap_per_class(tp_m,
conf,
pred_cls,
target_cls,
plot=plot,
save_dir=save_dir,
names=names,
prefix="Mask")[2:]
results = {
"boxes": {
"p": results_boxes[0],
"r": results_boxes[1],
"ap": results_boxes[3],
"f1": results_boxes[2],
"ap_class": results_boxes[4]},
"masks": {
"p": results_masks[0],
"r": results_masks[1],
"ap": results_masks[3],
"f1": results_masks[2],
"ap_class": results_masks[4]}}
return results
class Metric:
def __init__(self) -> None:
self.p = [] # (nc, )
self.r = [] # (nc, )
self.f1 = [] # (nc, )
self.all_ap = [] # (nc, 10)
self.ap_class_index = [] # (nc, )
@property
def ap50(self):
"""AP@0.5 of all classes.
Return:
(nc, ) or [].
"""
return self.all_ap[:, 0] if len(self.all_ap) else []
@property
def ap(self):
"""AP@0.5:0.95
Return:
(nc, ) or [].
"""
return self.all_ap.mean(1) if len(self.all_ap) else []
@property
def mp(self):
"""mean precision of all classes.
Return:
float.
"""
return self.p.mean() if len(self.p) else 0.0
@property
def mr(self):
"""mean recall of all classes.
Return:
float.
"""
return self.r.mean() if len(self.r) else 0.0
@property
def map50(self):
"""Mean AP@0.5 of all classes.
Return:
float.
"""
return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
@property
def map(self):
"""Mean AP@0.5:0.95 of all classes.
Return:
float.
"""
return self.all_ap.mean() if len(self.all_ap) else 0.0
def mean_results(self):
"""Mean of results, return mp, mr, map50, map"""
return (self.mp, self.mr, self.map50, self.map)
def class_result(self, i):
"""class-aware result, return p[i], r[i], ap50[i], ap[i]"""
return (self.p[i], self.r[i], self.ap50[i], self.ap[i])
def get_maps(self, nc):
maps = np.zeros(nc) + self.map
for i, c in enumerate(self.ap_class_index):
maps[c] = self.ap[i]
return maps
def update(self, results):
"""
Args:
results: tuple(p, r, ap, f1, ap_class)
"""
p, r, all_ap, f1, ap_class_index = results
self.p = p
self.r = r
self.all_ap = all_ap
self.f1 = f1
self.ap_class_index = ap_class_index
class Metrics:
"""Metric for boxes and masks."""
def __init__(self) -> None:
self.metric_box = Metric()
self.metric_mask = Metric()
def update(self, results):
"""
Args:
results: Dict{'boxes': Dict{}, 'masks': Dict{}}
"""
self.metric_box.update(list(results["boxes"].values()))
self.metric_mask.update(list(results["masks"].values()))
def mean_results(self):
return self.metric_box.mean_results() + self.metric_mask.mean_results()
def class_result(self, i):
return self.metric_box.class_result(i) + self.metric_mask.class_result(i)
def get_maps(self, nc):
return self.metric_box.get_maps(nc) + self.metric_mask.get_maps(nc)
@property
def ap_class_index(self):
# boxes and masks have the same ap_class_index
return self.metric_box.ap_class_index
class Semantic_Metrics:
def __init__(self, nc, device):
self.nc = nc # number of classes
self.device = device
self.iou = []
self.c_bit_counts = torch.zeros(nc, dtype = torch.long).to(device)
self.c_intersection_counts = torch.zeros(nc, dtype = torch.long).to(device)
self.c_union_counts = torch.zeros(nc, dtype = torch.long).to(device)
def update(self, pred_masks, target_masks):
nb, nc, h, w = pred_masks.shape
device = pred_masks.device
for b in range(nb):
onehot_mask = pred_masks[b].to(device)
# convert predict mask to one hot
semantic_mask = torch.flatten(onehot_mask, start_dim = 1).permute(1, 0) # class x h x w -> (h x w) x class
max_idx = semantic_mask.argmax(1)
output_masks = (torch.zeros(semantic_mask.shape).to(self.device)).scatter(1, max_idx.unsqueeze(1), 1.0) # one hot: (h x w) x class
output_masks = torch.reshape(output_masks.permute(1, 0), (nc, h, w)) # (h x w) x class -> class x h x w
onehot_mask = output_masks.int()
for c in range(self.nc):
pred_mask = onehot_mask[c].to(device)
target_mask = target_masks[b, c].to(device)
# calculate IoU
intersection = (torch.logical_and(pred_mask, target_mask).sum()).item()
union = (torch.logical_or(pred_mask, target_mask).sum()).item()
iou = 0. if (0 == union) else (intersection / union)
# record class pixel counts, intersection counts, union counts
self.c_bit_counts[c] += target_mask.int().sum()
self.c_intersection_counts[c] += intersection
self.c_union_counts[c] += union
self.iou.append(iou)
def results(self):
# Mean IoU
miou = 0. if (0 == len(self.iou)) else np.sum(self.iou) / (len(self.iou) * self.nc)
# Frequency Weighted IoU
c_iou = self.c_intersection_counts / (self.c_union_counts + 1) # add smooth
# c_bit_counts = self.c_bit_counts.astype(int)
total_c_bit_counts = self.c_bit_counts.sum()
freq_ious = torch.zeros(1, dtype = torch.long).to(self.device) if (0 == total_c_bit_counts) else (self.c_bit_counts / total_c_bit_counts) * c_iou
fwiou = (freq_ious.sum()).item()
return (miou, fwiou)
def reset(self):
self.iou = []
self.c_bit_counts = torch.zeros(self.nc, dtype = torch.long).to(self.device)
self.c_intersection_counts = torch.zeros(self.nc, dtype = torch.long).to(self.device)
self.c_union_counts = torch.zeros(self.nc, dtype = torch.long).to(self.device)
KEYS = [
"train/box_loss",
"train/seg_loss", # train loss
"train/cls_loss",
"train/dfl_loss",
"train/fcl_loss",
"train/dic_loss",
"metrics/precision(B)",
"metrics/recall(B)",
"metrics/mAP_0.5(B)",
"metrics/mAP_0.5:0.95(B)", # metrics
"metrics/precision(M)",
"metrics/recall(M)",
"metrics/mAP_0.5(M)",
"metrics/mAP_0.5:0.95(M)", # metrics
"metrics/MIOUS(S)",
"metrics/FWIOUS(S)", # metrics
"val/box_loss",
"val/seg_loss", # val loss
"val/cls_loss",
"val/dfl_loss",
"val/fcl_loss",
"val/dic_loss",
"x/lr0",
"x/lr1",
"x/lr2",]
BEST_KEYS = [
"best/epoch",
"best/precision(B)",
"best/recall(B)",
"best/mAP_0.5(B)",
"best/mAP_0.5:0.95(B)",
"best/precision(M)",
"best/recall(M)",
"best/mAP_0.5(M)",
"best/mAP_0.5:0.95(M)",
"best/MIOUS(S)",
"best/FWIOUS(S)",]