File size: 10,710 Bytes
c7a9171
 
a7661dd
c7a9171
 
 
 
 
 
 
 
 
 
 
 
a7661dd
 
c7a9171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152b5cf
c7a9171
 
 
152b5cf
 
c7a9171
152b5cf
 
c7a9171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
082047a
 
c7a9171
 
 
 
 
 
 
 
 
 
f56951b
 
c7a9171
 
 
 
 
 
 
 
 
 
 
 
f56951b
 
c7a9171
ac41e13
c7a9171
e571989
 
c7a9171
9f891c4
 
c7a9171
f56951b
b9e7665
 
 
9f891c4
c7a9171
b9e7665
 
 
c7a9171
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
""" DataStats metric. """

import logging
import functools
from collections import Counter
from multiprocessing import Pool
from contextlib import contextmanager
from typing import List, Any, Dict, Optional
from collections import namedtuple as _namedtuple

import spacy
import datasets
import evaluate
from packaging import version

logger = logging.getLogger(__name__)

try:
    _en = spacy.load('en_core_web_sm')
except OSError as stderr:
    spacy.cli.download('en_core_web_sm')
    _en = spacy.load('en_core_web_sm')

@contextmanager
def filter_logging_context():
    
    def filter_log(record):
        return False if "This is expected if you are initialising" in record.msg else True

    logger = datasets.utils.logging.get_logger("transformers.modeling_utils")
    logger.addFilter(filter_log)
    
    try:
        yield
    finally:
        logger.removeFilter(filter_log)


_CITATION = """\
@article{grusky2018newsroom,
  title={Newsroom: A dataset of 1.3 million summaries with diverse extractive strategies},
  author={Grusky, Max and Naaman, Mor and Artzi, Yoav},
  journal={arXiv preprint arXiv:1804.11283},
  year={2018}
}
"""

_DESCRIPTION = """\
DataStats examines summarisation strategies using three measures that capture the degree of text overlap between the summary and article, and the rate of compression of the information conveyed.
"""

_KWARGS_DESCRIPTION = """
DataStats metric for text summarisation.

Args:
    summaries (list of str): model-generated summries.
    articles (list of str or list of list of str): Original articles.
    
Returns:
    coverage: Percentage of words in the summary that are from the source article, measuring the extent to which a summary is a derivative of a text.
    density: It is defined as the average length of the extractive fragment to which each summary word belongs.
    compression: It is defined as the word ratio between the articles and its summaries.
    
Examples:
    >>> predictions = ["hello there", "general kenobi"]
    >>> references = ["hello there", "general kenobi"]
    >>> bertscore = evaluate.load("datastats")
    >>> results = bertscore.compute(predictions=predictions, references=references)
"""


def find_ngrams(input_list: List[Any], n: int):
    return zip(*[input_list[i:] for i in range(n)])


def normalize(tokens: List[str], lowercase: bool = False):
    """
    Lowercases and turns tokens into distinct words.
    """
    return [str(t).lower() if not lowercase else str(t) for t in tokens]


class Fragments:

    Match = _namedtuple("Match", ("summary", "text", "length"))

    def __init__(self, summary, text, lowercase: bool = False):
        if isinstance(summary, str):
            self.summary = summary.split()
        else:
            self.summary = summary
        if isinstance(text, str):
            self.text = text.split()
        else:
            self.text = text
        self._norm_summary = normalize(self.summary, lowercase)
        self._norm_text = normalize(self.text, lowercase)
        self._match(self._norm_summary, self._norm_text)

    def overlaps(self):
        """
        Return a list of Fragments.Match objects between summary and text.
        This is a list of named tuples of the form (summary, text, length):
        """
        return self._matches

    def strings(self, min_length=0, summary_base=True):
        # Compute the strings against the summary or the text?
        base = self.summary if summary_base else self.text
        # Generate strings, filtering out strings below the minimum length.
        strings = [base[i : i + length] for i, j, length in self.overlaps() if length > min_length]
        return strings

    def coverage(self, summary_base=True):
        """
        Return the COVERAGE score of the summary and text.
        """
        numerator = sum(o.length for o in self.overlaps())
        if summary_base:
            denominator = len(self.summary)
        else:
            denominator = len(self.text)
        if denominator == 0:
            return 0
        else:
            return numerator / denominator

    def density(self, summary_base=True):
        """
        Return the DENSITY score of summary and text.
        """
        numerator = sum(o.length ** 2 for o in self.overlaps())
        if summary_base:
            denominator = len(self.summary)
        else:
            denominator = len(self.text)
        if denominator == 0:
            return 0
        else:
            return numerator / denominator

    def compression(self, text_to_summary=True):
        """
        Return compression ratio between summary and text.
        """
        ratio = [len(self.text), len(self.summary)]
        try:
            if text_to_summary:
                return ratio[0] / ratio[1]
            else:
                return ratio[1] / ratio[0]
        except ZeroDivisionError:
            return 0

    def _match(self, a, b):
        """
        Raw procedure for matching summary in text, described in paper.
        """
        self._matches = []
        a_start = b_start = 0
        while a_start < len(a):
            best_match = None
            best_match_length = 0
            while b_start < len(b):
                if a[a_start] == b[b_start]:
                    a_end = a_start
                    b_end = b_start
                    while a_end < len(a) and b_end < len(b) \
                            and b[b_end] == a[a_end]:
                        b_end += 1
                        a_end += 1
                    length = a_end - a_start
                    if length > best_match_length:
                        best_match = Fragments.Match(a_start, b_start, length)
                        best_match_length = length
                    b_start = b_end
                else:
                    b_start += 1
            b_start = 0
            if best_match:
                if best_match_length > 0:
                    self._matches.append(best_match)
                a_start += best_match_length
            else:
                a_start += 1


class DataStatsMetric(object):

    def __init__(
        self, 
        n_gram: int = 3, 
        n_workers: int = 24, 
        lowercase: bool = False, 
        tokenize: bool = True
    ):
        """
        Data Statistics metric

        Args:
            n_gram (int): Compute statistics for n-grams up to and including this length.
            n_workers (int): Number of processes to use if using multiprocessing.
            case (bool): Whether to lowercase input before calculating statistics.
            tokenize (bool): Whether to tokenize the input.
        """
        self.n_gram = n_gram
        self.n_workers = n_workers
        self.lowercase = lowercase
        self.tokenize = tokenize

    def evaluate_example(self, summary, input_text):
        if self.tokenize:
            input_text = _en(input_text, disable=["tagger", "parser", "ner", "textcat"])
            input_text = [tok.text for tok in input_text]
            summary = _en(summary, disable=["tagger", "parser", "ner", "textcat"])
            summary = [tok.text for tok in summary]
        fragments = Fragments(summary, input_text, lowercase=self.lowercase)
        coverage = fragments.coverage()
        density = fragments.density()
        compression = fragments.compression()
        score_dict = {"coverage": coverage, "density": density, "compression": compression}
        tokenized_summary = fragments._norm_summary
        tokenized_text = fragments._norm_text
        score_dict["summary_length"] = len(tokenized_summary)
        for i in range(1, self.n_gram + 1):
            input_ngrams = list(find_ngrams(tokenized_text, i))
            summ_ngrams = list(find_ngrams(tokenized_summary, i))
            input_ngrams_set = set(input_ngrams)
            summ_ngrams_set = set(summ_ngrams)
            intersect = summ_ngrams_set.intersection(input_ngrams_set)
            try:
                score_dict[f"percentage_novel_{i}-gram"] = (len(summ_ngrams_set) \
                    - len(intersect))/float(len(summ_ngrams_set))
                ngramCounter = Counter()
                ngramCounter.update(summ_ngrams)
                repeated = [key for key, val in ngramCounter.items() if val > 1]
                score_dict[f"percentage_repeated_{i}-gram_in_summ"] = len(repeated)/float(len(summ_ngrams_set))
            except ZeroDivisionError:
                continue
        return score_dict

    def evaluate_batch(self, summaries, input_texts, aggregate=True):
        corpus_score_dict = Counter()
        p = Pool(processes=self.n_workers)
        results = p.starmap(self.evaluate_example, zip(summaries, input_texts))
        p.close()
        if aggregate:
            [corpus_score_dict.update(x) for x in results]
            for key in corpus_score_dict.keys():
                corpus_score_dict[key] /= float(len(input_texts))
            return corpus_score_dict
        else:
            return results

    @property
    def supports_multi_ref(self):
        return False


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class DataStats(evaluate.Metric):

    name = 'DataStats'
    
    def _info(self):
        return evaluate.MetricInfo(
            description=_DESCRIPTION,
            citation=_CITATION,
            homepage="",
            inputs_description=_KWARGS_DESCRIPTION,
            features=[
                datasets.Features(
                    {
                        "predictions": datasets.Value("string", id="sequence"),
                        "references": datasets.Value("string", id="sequence"),
                    }
                ),
            ],
            codebase_urls=["https://github.com/Tiiiger/bert_score"],
            reference_urls=[
                "https://github.com/lil-lab/newsroom",
                "https://arxiv.org/pdf/2007.12626",
            ],
        )

    def _compute(
        self,
        predictions, 
        references, 
        n_gram: int = 3, 
        n_workers: int = 4, 
        lowercase: bool = False, 
        tokenize: bool = True, 
        **kwargs, 
    ):
        # logger.info(predictions)
        # logger.info(references)
        datastats = DataStatsMetric(n_gram, n_workers, lowercase, tokenize)
        results = datastats.evaluate_batch(predictions, references)
        coverage = float(results['coverage'])
        density = float(results['density'])
        compression = float(results['compression'])
        # logger.info(coverage, density, compression)
        return {
            "coverage": coverage, 
            "density": density, 
            "compression": compression
        }