danieldux commited on
Commit
8b7a053
·
1 Parent(s): 5b469a4

Refactor hierarchical precision, recall, and F-measure calculations***

Browse files
Files changed (1) hide show
  1. ham.py +44 -32
ham.py CHANGED
@@ -1,32 +1,44 @@
1
- # After review with respect to equations
2
-
3
-
4
- def hierarchical_precision_recall_fmeasure(
5
- true_labels, predicted_labels, ancestors, beta=1.0
6
- ):
7
- # Initialize counters for true positives, predicted, and true conditions
8
- true_positive_sum = predicted_sum = true_sum = 0
9
-
10
- # Process each instance
11
- for true, predicted in zip(true_labels, predicted_labels):
12
- # Extend the sets with ancestors
13
- extended_true = true.union(
14
- *[ancestors[label] for label in true if label in ancestors]
15
- )
16
- extended_predicted = predicted.union(
17
- *[ancestors[label] for label in predicted if label in ancestors]
18
- )
19
-
20
- # Update counters
21
- true_positive_sum += len(extended_true.intersection(extended_predicted))
22
- predicted_sum += len(extended_predicted)
23
- true_sum += len(extended_true)
24
-
25
- # Calculate hierarchical precision and recall
26
- hP = true_positive_sum / predicted_sum if predicted_sum else 0
27
- hR = true_positive_sum / true_sum if true_sum else 0
28
-
29
- # Calculate hierarchical F-measure
30
- hF = ((beta**2 + 1) * hP * hR) / (beta**2 * hP + hR) if (hP + hR) else 0
31
-
32
- return hP, hR, hF
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def ancestors(class_label, hierarchy):
2
+ """Return all ancestors of a given class label, excluding the root."""
3
+ if class_label not in hierarchy or not hierarchy[class_label]:
4
+ return set()
5
+ else:
6
+ # Recursively get all ancestors for each parent
7
+ anc = set(hierarchy[class_label])
8
+ for parent in hierarchy[class_label]:
9
+ anc.update(ancestors(parent, hierarchy))
10
+ return anc
11
+
12
+
13
+ def extend_with_ancestors(class_labels, hierarchy):
14
+ """Extend a set of class labels with their ancestors."""
15
+ extended_set = set(class_labels)
16
+ for label in class_labels:
17
+ extended_set.update(ancestors(label, hierarchy))
18
+ return extended_set
19
+
20
+
21
+ def hierarchical_precision_recall(true_labels, predicted_labels, hierarchy):
22
+ """Calculate hierarchical precision and recall."""
23
+ true_extended = [extend_with_ancestors(ci, hierarchy) for ci in true_labels]
24
+ predicted_extended = [
25
+ extend_with_ancestors(c_prime_i, hierarchy) for c_prime_i in predicted_labels
26
+ ]
27
+
28
+ intersect_sum = sum(
29
+ len(ci & c_prime_i) for ci, c_prime_i in zip(true_extended, predicted_extended)
30
+ )
31
+ predicted_sum = sum(len(c_prime_i) for c_prime_i in predicted_extended)
32
+ true_sum = sum(len(ci) for ci in true_extended)
33
+
34
+ hP = intersect_sum / predicted_sum if predicted_sum > 0 else 0
35
+ hR = intersect_sum / true_sum if true_sum > 0 else 0
36
+
37
+ return hP, hR
38
+
39
+
40
+ def hierarchical_f_measure(hP, hR, beta=1.0):
41
+ """Calculate the hierarchical F-measure."""
42
+ if hP + hR == 0:
43
+ return 0
44
+ return (beta**2 + 1) * hP * hR / (beta**2 * hP + hR)