ru_errant / ru_errant.py
danasone's picture
fix py
9a88164
raw
history blame
8.26 kB
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TODO: Add a description here."""
from __future__ import annotations
import re
from collections import Counter, namedtuple
from typing import Iterable
from tqdm.auto import tqdm
from sage.evaluation.ruerrant_wrapper import classifier
from sage.evaluation.ruerrant_wrapper import merger
from errant.annotator import Annotator
from errant.commands.compare_m2 import process_edits
from errant.commands.compare_m2 import evaluate_edits
from errant.commands.compare_m2 import merge_dict
from errant.edit import Edit
import spacy
from spacy.tokenizer import Tokenizer
from spacy.util import compile_prefix_regex, compile_infix_regex, compile_suffix_regex
import evaluate
import datasets
# TODO: Add BibTeX citation
_CITATION = """\
@InProceedings{huggingface:module,
title = {A great new module},
authors={huggingface, Inc.},
year={2020}
}
"""
# TODO: Add description of the module here
_DESCRIPTION = """\
This new module is designed to solve this great ML task and is crafted with a lot of care.
"""
# TODO: Add description of the arguments of the module here
_KWARGS_DESCRIPTION = """
Calculates how good are predictions given some references, using certain scores
Args:
predictions: list of predictions to score. Each predictions
should be a string with tokens separated by spaces.
references: list of reference for each prediction. Each
reference should be a string with tokens separated by spaces.
Returns:
accuracy: description of the first score,
another_score: description of the second score,
Examples:
Examples should be written in doctest format, and should illustrate how
to use the function.
>>> my_new_module = evaluate.load("my_new_module")
>>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1])
>>> print(results)
{'accuracy': 1.0}
"""
# TODO: Define external resources urls if needed
BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
def update_spacy_tokenizer(nlp):
"""
Changes Spacy tokenizer to parse additional patterns.
"""
infix_re = compile_infix_regex(nlp.Defaults.infixes[:-1] + ["\]\("])
simple_url_re = re.compile(r'''^https?://''')
nlp.tokenizer = Tokenizer(
nlp.vocab,
prefix_search=compile_prefix_regex(nlp.Defaults.prefixes + ['\\\\\"']).search,
suffix_search=compile_suffix_regex(nlp.Defaults.suffixes + ['\\\\']).search,
infix_finditer=infix_re.finditer,
token_match=None,
url_match=simple_url_re.match
)
return nlp
def annotate_errors(self, orig: str, cor: str, merging: str = "rules") -> list[Edit]:
"""
Overrides `Annotator.annotate()` function to allow multiple errors per token.
This is nesessary to parse combined errors, e.g.:
["werd", "Word"] >>> Errors: ["SPELL", "CASE"]
The `classify()` method called inside is implemented in ruerrant_classifier.py
(also overrides the original classifier).
"""
alignment = self.annotator.align(orig, cor, False)
edits = self.annotator.merge(alignment, merging)
classified_edits = []
for edit in edits:
classified_edits.extend(self.annotator.classify(edit))
return sorted(classified_edits, key=lambda x: (x[0], x[2]))
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class RuErrant(evaluate.Metric):
"""TODO: Short description of my evaluation module."""
def _info(self):
# TODO: Specifies the evaluate.EvaluationModuleInfo object
return evaluate.MetricInfo(
# This is the description that will appear on the modules page.
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
# This defines the format of each prediction and reference
features=datasets.Features(
{
"sources": datasets.Value("string", id="sequence"),
"corrections": datasets.Value("string", id="sequence"),
"answers": datasets.Value("string", id="sequence"),
}
),
# Homepage of the module for documentation
homepage="http://module.homepage",
# Additional links to the codebase or references
codebase_urls=["https://github.com/ai-forever/sage"],
reference_urls=["http://path.to.reference.url/new_module"]
)
def _download_and_prepare(self, dl_manager):
self.annotator = Annotator("ru",
nlp=update_spacy_tokenizer(spacy.load("ru_core_news_lg")),
merger=merger,
classifier=classifier)
def _compute(self, sources, corrections, answers):
"""
Evaluates iterables of sources, hyp and ref corrections with ERRANT metric.
Args:
sources (Iterable[str]): an iterable of source texts;
corrections (Iterable[str]): an iterable of gold corrections for the source texts;
answers (Iterable[str]): an iterable of evaluated corrections for the source texts;
Returns:
dict[str, tuple[float, ...]]: a dict mapping error categories to the corresponding
P, R, F1 metric values.
"""
best_dict = Counter({"tp": 0, "fp": 0, "fn": 0})
best_cats = {}
sents = zip(sources, corrections, answers)
pb = tqdm(sents, desc="Calculating errant metric", total=len(sources))
for sent_id, sent in enumerate(pb):
src = self.annotator.parse(sent[0])
ref = self.annotator.parse(sent[1])
hyp = self.annotator.parse(sent[2])
# Align hyp and ref corrections and annotate errors
hyp_edits = self.annotate_errors(src, hyp)
ref_edits = self.annotate_errors(src, ref)
# Process the edits for detection/correction based on args
ProcessingArgs = namedtuple("ProcessingArgs",
["dt", "ds", "single", "multi", "filt", "cse"],
defaults=[False, False, False, False, [], True])
processing_args = ProcessingArgs()
hyp_dict = process_edits(hyp_edits, processing_args)
ref_dict = process_edits(ref_edits, processing_args)
# Evaluate edits and get best TP, FP, FN hyp+ref combo.
EvaluationArgs = namedtuple("EvaluationArgs",
["beta", "verbose"],
defaults=[1.0, False])
evaluation_args = EvaluationArgs()
count_dict, cat_dict = evaluate_edits(
hyp_dict, ref_dict, best_dict, sent_id, evaluation_args)
# Merge these dicts with best_dict and best_cats
best_dict += Counter(count_dict) # corpus-level TP, FP, FN
best_cats = merge_dict(best_cats, cat_dict) # corpus-level errortype-wise TP, FP, FN
cat_prf = {}
for cat, values in best_cats.items():
tp, fp, fn = values # fp - extra corrections, fn - missed corrections
p = float(tp) / (tp + fp) if tp + fp else 1.0
r = float(tp) / (tp + fn) if tp + fn else 1.0
f = (2 * p * r) / (p + r) if p + r else 0.0
cat_prf[cat] = (p, r, f)
for error_category in ["CASE", "PUNCT", "SPELL", "YO"]:
if error_category not in cat_prf:
cat_prf[error_category] = (1.0, 1.0, 1.0)
return cat_prf