File size: 4,649 Bytes
9c6b911
8a4a728
a944252
fcd15ea
 
a944252
020c3bd
8a4a728
020c3bd
 
8a4a728
a944252
020c3bd
 
a944252
020c3bd
fcd15ea
 
 
 
 
 
 
 
 
 
 
8a4a728
 
 
 
 
 
 
 
 
 
 
fcd15ea
 
 
 
 
 
 
 
8a4a728
 
03c8589
fcd15ea
8a4a728
 
 
 
03c8589
8a4a728
a944252
8a4a728
 
 
 
03c8589
a944252
fcd15ea
03c8589
 
a944252
03c8589
 
 
 
fcd15ea
03c8589
 
a944252
03c8589
 
 
 
 
a944252
 
03c8589
 
a944252
 
 
 
 
 
 
03c8589
 
 
 
 
a944252
 
 
 
 
 
fcd15ea
 
 
 
8b7a053
d1fbaa3
 
 
 
 
 
 
 
 
 
 
8b7a053
 
 
9418c93
 
 
9c6b911
9418c93
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""This module provides functions for calculating hierarchical variants of precicion, recall and F1."""

from typing import List, Dict, Tuple, Set


def find_ancestors(node: str, hierarchy: Dict[str, Set[str]]) -> Set[str]:
    """
    Find the ancestors of a given node in a hierarchy.

    Args:
        node (str): The node for which to find ancestors.
        hierarchy (Dict[str, Set[str]]): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents.

    Returns:
        Set[str]: A set of ancestors of the given node.
    """
    ancestors = set()
    nodes_to_visit = [node]
    while nodes_to_visit:
        current_node = nodes_to_visit.pop()
        if current_node in hierarchy:
            parents = hierarchy[current_node]
            ancestors.update(parents)
            nodes_to_visit.extend(parents)
    return ancestors


def extend_with_ancestors(classes: set, hierarchy: dict) -> set:
    """
    Extend the given set of classes with their ancestors from the hierarchy.

    Args:
        classes (set): The set of classes to extend.
        hierarchy (dict): The hierarchy of classes.

    Returns:
        set: The extended set of classes including their ancestors.
    """
    extended_classes = set(classes)
    for cls in classes:
        ancestors = find_ancestors(cls, hierarchy)
        extended_classes.update(ancestors)
    return extended_classes


def calculate_hierarchical_precision_recall(
    reference_codes: List[str],
    predicted_codes: List[str],
    hierarchy: Dict[str, Dict[str, float]],
) -> Tuple[float, float]:
    """
    Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition.

    Args:
        reference_codes (List[str]): The list of reference codes.
        predicted_codes (List[str]): The list of predicted codes.
        hierarchy (Dict[str, Dict[str, float]]): The hierarchy definition where keys are nodes and values are dictionaries of parent nodes with distances.

    Returns:
        Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
    """
    extended_real = {}
    extended_predicted = {}

    # Extend the sets of reference codes with their ancestors
    for code in reference_codes:
        extended_real[code] = 1.0  # Full weight for exact match
        for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
            extended_real[ancestor] = max(
                extended_real.get(ancestor, 0), ancestor_weight
            )

    # Extend the sets of predicted codes with their ancestors
    for code in predicted_codes:
        extended_predicted[code] = 1.0
        for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
            extended_predicted[ancestor] = max(
                extended_predicted.get(ancestor, 0), ancestor_weight
            )

    # Calculate weighted correct predictions for precision
    correct_weights_precision = 0
    for code, weight in extended_predicted.items():
        if code in extended_real:
            correct_weights_precision += min(weight, extended_real[code])

    # Calculate weighted correct predictions for recall
    correct_weights_recall = 0
    for code, weight in extended_real.items():
        if code in extended_predicted:
            correct_weights_recall += min(weight, extended_predicted[code])

    total_predicted_weights = sum(extended_predicted.values())
    total_real_weights = sum(extended_real.values())

    # Calculate hierarchical precision and recall using weighted sums
    hP = (
        correct_weights_precision / total_predicted_weights
        if total_predicted_weights
        else 0
    )
    hR = correct_weights_recall / total_real_weights if total_real_weights else 0

    return hP, hR


def hierarchical_f_measure(hP, hR, beta=1.0):
    """
    Calculate the hierarchical F-measure.

    Parameters:
    hP (float): The hierarchical precision.
    hR (float): The hierarchical recall.
    beta (float, optional): The beta value for F-measure calculation. Default is 1.0.

    Returns:
    float: The hierarchical F-measure.
    """
    if hP + hR == 0:
        return 0
    return (beta**2 + 1) * hP * hR / (beta**2 * hP + hR)


# Example list usage:
# reference_codes = ["1111", "1112", "1113", "1114"]
# predicted_codes = ["1111", "1113", "1120", "1211"]
# hierarchy_dict = {'1111': {'111', '1', '11'}, '1112': {'111', '1', '11'}, '1113': {'111', '1', '11'}, '1114': {'111', '1', '11'} ...}
# result = calculate_hierarchical_precision_recall(real_codes, predicted_codes, hierarchy_dict)
# print(result)