File size: 8,534 Bytes
a0b78f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from __future__ import annotations

import itertools
import re
from string import punctuation

import Levenshtein
from errant.alignment import Alignment
from errant.edit import Edit


def get_rule_edits(alignment: Alignment) -> list[Edit]:
    """Groups word-level alignment according to merging rules."""
    edits = []
    # Split alignment into groups
    alignment_groups = group_alignment(alignment, "new")
    for op, group in alignment_groups:
        group = list(group)
        # Ignore M
        if op == "M":
            continue
        # T is always split
        if op == "T":
            for seq in group:
                edits.append(Edit(alignment.orig, alignment.cor, seq[1:]))
        # Process D, I and S subsequence
        else:
            processed = process_seq(group, alignment)
            # Turn the processed sequence into edits
            for seq in processed:
                edits.append(Edit(alignment.orig, alignment.cor, seq[1:]))
    return edits


def group_alignment(alignment: Alignment, mode: str = "default") -> list[tuple[str, list[tuple]]]:
    """
    Does initial alignment grouping:
    1. Make groups of MDM, MIM od MSM.
    2. In remaining operations, make groups of Ms, groups of Ts, and D/I/Ss.
    Do not group what was on the sides of M[DIS]M: SSMDMS -> [SS, MDM, S], not [MDM, SSS].
    3. Sort groups by the order in which they appear in the alignment.
    """
    if mode == "new":
        op_groups = []
        # Format operation types sequence as string to use regex sequence search
        all_ops_seq = "".join([op[0][0] for op in alignment.align_seq])
        # Find M[DIS]M groups and merge (need them to detect hyphen vs. space spelling)
        ungrouped_ids = list(range(len(alignment.align_seq)))
        for match in re.finditer("M[DIS]M", all_ops_seq):
            start, end = match.start(), match.end()
            op_groups.append(("MSM", alignment.align_seq[start:end]))
            for idx in range(start, end):
                ungrouped_ids.remove(idx)
        # Group remaining operations by default rules (groups of M, T and rest)
        if ungrouped_ids:
            def get_group_type(operation):
                return operation if operation in {"M", "T"} else "DIS"
            curr_group = [alignment.align_seq[ungrouped_ids[0]]]
            last_oper_type = get_group_type(curr_group[0][0][0])
            for i, idx in enumerate(ungrouped_ids[1:], start=1):
                operation = alignment.align_seq[idx]
                oper_type = get_group_type(operation[0][0])
                if (oper_type == last_oper_type and
                        (idx - ungrouped_ids[i-1] == 1 or oper_type in {"M", "T"})):
                    curr_group.append(operation)
                else:
                    op_groups.append((last_oper_type, curr_group))
                    curr_group = [operation]
                last_oper_type = oper_type
            if curr_group:
                op_groups.append((last_oper_type, curr_group))
        # Sort groups by the start id of the first group entry
        op_groups = sorted(op_groups, key=lambda x: x[1][0][1])
    else:
        grouped = itertools.groupby(alignment.align_seq,
                                    lambda x: x[0][0] if x[0][0] in {"M", "T"} else False)
        op_groups = [(op, list(group)) for op, group in grouped]
    return op_groups


def process_seq(seq: list[tuple], alignment: Alignment) -> list[tuple]:
    """Applies merging rules to previously formed alignment groups (`seq`)."""
    # Return single alignments
    if len(seq) <= 1:
        return seq
    # Get the ops for the whole sequence
    ops = [op[0] for op in seq]

    # Get indices of all start-end combinations in the seq: 012 = 01, 02, 12
    combos = list(itertools.combinations(range(0, len(seq)), 2))
    # Sort them starting with largest spans first
    combos.sort(key=lambda x: x[1] - x[0], reverse=True)
    # Loop through combos
    for start, end in combos:
        # Ignore ranges that do NOT contain a substitution, deletion or insertion.
        if not any(type_ in ops[start:end + 1] for type_ in ["D", "I", "S"]):
            continue
        # Merge all D xor I ops. (95% of human multi-token edits contain S).
        if set(ops[start:end + 1]) == {"D"} or set(ops[start:end + 1]) == {"I"}:
            return (process_seq(seq[:start], alignment)
                    + merge_edits(seq[start:end + 1])
                    + process_seq(seq[end + 1:], alignment))
        # Get the tokens in orig and cor.
        o = alignment.orig[seq[start][1]:seq[end][2]]
        c = alignment.cor[seq[start][3]:seq[end][4]]
        if ops[start:end + 1] in [["M", "D", "M"], ["M", "I", "M"], ["M", "S", "M"]]:
            # merge hyphens
            if (o[start + 1].text == "-" or c[start + 1].text == "-") and len(o) != len(c):
                return (process_seq(seq[:start], alignment)
                        + merge_edits(seq[start:end + 1])
                        + process_seq(seq[end + 1:], alignment))
            # if it is not a hyphen-space edit, return only punct edit
            return seq[start + 1: end]
        # Merge possessive suffixes: [friends -> friend 's]
        if o[-1].tag_ == "POS" or c[-1].tag_ == "POS":
            return (process_seq(seq[:end - 1], alignment)
                    + merge_edits(seq[end - 1:end + 1])
                    + process_seq(seq[end + 1:], alignment))
        # Case changes
        if o[-1].lower == c[-1].lower:
            # Merge first token I or D: [Cat -> The big cat]
            if (start == 0 and
                    (len(o) == 1 and c[0].text[0].isupper()) or
                    (len(c) == 1 and o[0].text[0].isupper())):
                return (merge_edits(seq[start:end + 1])
                        + process_seq(seq[end + 1:], alignment))
            # Merge with previous punctuation: [, we -> . We], [we -> . We]
            if (len(o) > 1 and is_punct(o[-2])) or \
                    (len(c) > 1 and is_punct(c[-2])):
                return (process_seq(seq[:end - 1], alignment)
                        + merge_edits(seq[end - 1:end + 1])
                        + process_seq(seq[end + 1:], alignment))
        # Merge whitespace/hyphens: [acat -> a cat], [sub - way -> subway]
        s_str = re.sub("['-]", "", "".join([tok.lower_ for tok in o]))
        t_str = re.sub("['-]", "", "".join([tok.lower_ for tok in c]))
        if s_str == t_str or s_str.replace(" ", "") == t_str.replace(" ", ""):
            return (process_seq(seq[:start], alignment)
                    + merge_edits(seq[start:end + 1])
                    + process_seq(seq[end + 1:], alignment))
        # Merge same POS or auxiliary/infinitive/phrasal verbs:
        # [to eat -> eating], [watch -> look at]
        pos_set = set([tok.pos for tok in o] + [tok.pos for tok in c])
        if len(o) != len(c) and (len(pos_set) == 1 or pos_set.issubset({"AUX", "PART", "VERB"})):
            return (process_seq(seq[:start], alignment)
                    + merge_edits(seq[start:end + 1])
                    + process_seq(seq[end + 1:], alignment))
        # Split rules take effect when we get to smallest chunks
        if end - start < 2:
            # Split adjacent substitutions
            if len(o) == len(c) == 2:
                return (process_seq(seq[:start + 1], alignment)
                        + process_seq(seq[start + 1:], alignment))
            # Split similar substitutions at sequence boundaries
            if ((ops[start] == "S" and char_cost(o[0].text, c[0].text) > 0.75) or
                    (ops[end] == "S" and char_cost(o[-1].text, c[-1].text) > 0.75)):
                return (process_seq(seq[:start + 1], alignment)
                        + process_seq(seq[start + 1:], alignment))
            # Split final determiners
            if (end == len(seq) - 1 and
                    ((ops[-1] in {"D", "S"} and o[-1].pos == "DET") or
                     (ops[-1] in {"I", "S"} and c[-1].pos == "DET"))):
                return process_seq(seq[:-1], alignment) + [seq[-1]]
    return seq


def is_punct(token) -> bool:
    return token.text in punctuation


def char_cost(a: str, b: str) -> float:
    """Calculate the cost of character alignment; i.e. char similarity."""

    return Levenshtein.ratio(a, b)


def merge_edits(seq: list[tuple]) -> list[tuple]:
    """Merge the input alignment sequence to a single edit span."""

    if seq:
        return [("X", seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])]
    return seq