|
from functools import lru_cache
|
|
|
|
from typing import cast, Any, Callable, Dict, Iterable, List, Optional
|
|
from typing import Sequence, Tuple, Union
|
|
from collections import Counter
|
|
from copy import deepcopy
|
|
from itertools import islice
|
|
import numpy as np
|
|
|
|
import srsly
|
|
from thinc.api import Config, Model, SequenceCategoricalCrossentropy, NumpyOps
|
|
from thinc.types import Floats2d, Ints2d
|
|
|
|
from spacy.pipeline._edit_tree_internals.edit_trees import EditTrees
|
|
from spacy.pipeline._edit_tree_internals.schemas import validate_edit_tree
|
|
from spacy.pipeline.lemmatizer import lemmatizer_score
|
|
from spacy.pipeline.trainable_pipe import TrainablePipe
|
|
from spacy.errors import Errors
|
|
from spacy.language import Language
|
|
from spacy.tokens import Doc, Token
|
|
from spacy.training import Example, validate_examples, validate_get_examples
|
|
from spacy.vocab import Vocab
|
|
from spacy import util
|
|
|
|
|
|
TOP_K_GUARDRAIL = 20
|
|
|
|
|
|
default_model_config = """
|
|
[model]
|
|
@architectures = "spacy.Tagger.v2"
|
|
|
|
[model.tok2vec]
|
|
@architectures = "spacy.HashEmbedCNN.v2"
|
|
pretrained_vectors = null
|
|
width = 96
|
|
depth = 4
|
|
embed_size = 2000
|
|
window_size = 1
|
|
maxout_pieces = 3
|
|
subword_features = true
|
|
"""
|
|
DEFAULT_EDIT_TREE_LEMMATIZER_MODEL = Config().from_str(default_model_config)["model"]
|
|
|
|
|
|
@Language.factory(
|
|
"trainable_lemmatizer_v2",
|
|
assigns=["token.lemma"],
|
|
requires=[],
|
|
default_config={
|
|
"model": DEFAULT_EDIT_TREE_LEMMATIZER_MODEL,
|
|
"backoff": "orth",
|
|
"min_tree_freq": 3,
|
|
"overwrite": False,
|
|
"top_k": 1,
|
|
"overwrite_labels": True,
|
|
"scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"},
|
|
},
|
|
default_score_weights={"lemma_acc": 1.0},
|
|
)
|
|
def make_edit_tree_lemmatizer(
|
|
nlp: Language,
|
|
name: str,
|
|
model: Model,
|
|
backoff: Optional[str],
|
|
min_tree_freq: int,
|
|
overwrite: bool,
|
|
top_k: int,
|
|
overwrite_labels: bool,
|
|
scorer: Optional[Callable],
|
|
):
|
|
"""Construct an EditTreeLemmatizer component."""
|
|
return EditTreeLemmatizer(
|
|
nlp.vocab,
|
|
model,
|
|
name,
|
|
backoff=backoff,
|
|
min_tree_freq=min_tree_freq,
|
|
overwrite=overwrite,
|
|
top_k=top_k,
|
|
overwrite_labels=overwrite_labels,
|
|
scorer=scorer,
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def debug(*args):
|
|
pass
|
|
|
|
|
|
class EditTreeLemmatizer(TrainablePipe):
|
|
"""
|
|
Lemmatizer that lemmatizes each word using a predicted edit tree.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vocab: Vocab,
|
|
model: Model,
|
|
name: str = "trainable_lemmatizer",
|
|
*,
|
|
backoff: Optional[str] = "orth",
|
|
min_tree_freq: int = 3,
|
|
overwrite: bool = False,
|
|
top_k: int = 1,
|
|
overwrite_labels,
|
|
scorer: Optional[Callable] = lemmatizer_score,
|
|
):
|
|
"""
|
|
Construct an edit tree lemmatizer.
|
|
|
|
backoff (Optional[str]): backoff to use when the predicted edit trees
|
|
are not applicable. Must be an attribute of Token or None (leave the
|
|
lemma unset).
|
|
min_tree_freq (int): prune trees that are applied less than this
|
|
frequency in the training data.
|
|
overwrite (bool): overwrite existing lemma annotations.
|
|
top_k (int): try to apply at most the k most probable edit trees.
|
|
"""
|
|
self.vocab = vocab
|
|
self.model = model
|
|
self.name = name
|
|
self.backoff = backoff
|
|
self.min_tree_freq = min_tree_freq
|
|
self.overwrite = overwrite
|
|
self.top_k = top_k
|
|
self.overwrite_labels = overwrite_labels
|
|
|
|
self.trees = EditTrees(self.vocab.strings)
|
|
self.tree2label: Dict[int, int] = {}
|
|
|
|
self.cfg: Dict[str, Any] = {"labels": []}
|
|
self.scorer = scorer
|
|
self.numpy_ops = NumpyOps()
|
|
|
|
def get_loss(
|
|
self, examples: Iterable[Example], scores: List[Floats2d]
|
|
) -> Tuple[float, List[Floats2d]]:
|
|
validate_examples(examples, "EditTreeLemmatizer.get_loss")
|
|
loss_func = SequenceCategoricalCrossentropy(normalize=False, missing_value=-1)
|
|
|
|
truths = []
|
|
for eg in examples:
|
|
eg_truths = []
|
|
for (predicted, gold_lemma, gold_pos, gold_sent_start) in zip(
|
|
eg.predicted,
|
|
eg.get_aligned("LEMMA", as_string=True),
|
|
eg.get_aligned("POS", as_string=True),
|
|
eg.get_aligned_sent_starts(),
|
|
):
|
|
if gold_lemma is None:
|
|
label = -1
|
|
else:
|
|
form = self._get_true_cased_form(
|
|
predicted.text, gold_sent_start, gold_pos
|
|
)
|
|
tree_id = self.trees.add(form, gold_lemma)
|
|
|
|
label = self.tree2label.get(tree_id, 0)
|
|
eg_truths.append(label)
|
|
|
|
truths.append(eg_truths)
|
|
|
|
d_scores, loss = loss_func(scores, truths)
|
|
if self.model.ops.xp.isnan(loss):
|
|
raise ValueError(Errors.E910.format(name=self.name))
|
|
|
|
return float(loss), d_scores
|
|
|
|
def predict(self, docs: Iterable[Doc]) -> List[Ints2d]:
|
|
if self.top_k == 1:
|
|
scores2guesses = self._scores2guesses_top_k_equals_1
|
|
elif self.top_k <= TOP_K_GUARDRAIL:
|
|
scores2guesses = self._scores2guesses_top_k_greater_1
|
|
else:
|
|
scores2guesses = self._scores2guesses_top_k_guardrail
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n_docs = len(list(docs))
|
|
if not any(len(doc) for doc in docs):
|
|
|
|
n_labels = len(self.cfg["labels"])
|
|
guesses: List[Ints2d] = [self.model.ops.alloc2i(0, n_labels) for _ in docs]
|
|
assert len(guesses) == n_docs
|
|
return guesses
|
|
scores = self.model.predict(docs)
|
|
assert len(scores) == n_docs
|
|
guesses = scores2guesses(docs, scores)
|
|
assert len(guesses) == n_docs
|
|
return guesses
|
|
|
|
def _scores2guesses_top_k_equals_1(self, docs, scores):
|
|
guesses = []
|
|
for doc, doc_scores in zip(docs, scores):
|
|
doc_guesses = doc_scores.argmax(axis=1)
|
|
doc_guesses = self.numpy_ops.asarray(doc_guesses)
|
|
|
|
doc_compat_guesses = []
|
|
for i, token in enumerate(doc):
|
|
tree_id = self.cfg["labels"][doc_guesses[i]]
|
|
form: str = self._get_true_cased_form_of_token(token)
|
|
if self.trees.apply(tree_id, form) is not None:
|
|
doc_compat_guesses.append(tree_id)
|
|
else:
|
|
doc_compat_guesses.append(-1)
|
|
guesses.append(np.array(doc_compat_guesses))
|
|
|
|
return guesses
|
|
|
|
def _scores2guesses_top_k_greater_1(self, docs, scores):
|
|
guesses = []
|
|
top_k = min(self.top_k, len(self.labels))
|
|
for doc, doc_scores in zip(docs, scores):
|
|
doc_scores = self.numpy_ops.asarray(doc_scores)
|
|
doc_compat_guesses = []
|
|
for i, token in enumerate(doc):
|
|
for _ in range(top_k):
|
|
candidate = int(doc_scores[i].argmax())
|
|
candidate_tree_id = self.cfg["labels"][candidate]
|
|
form: str = self._get_true_cased_form_of_token(token)
|
|
if self.trees.apply(candidate_tree_id, form) is not None:
|
|
doc_compat_guesses.append(candidate_tree_id)
|
|
break
|
|
doc_scores[i, candidate] = np.finfo(np.float32).min
|
|
else:
|
|
doc_compat_guesses.append(-1)
|
|
guesses.append(np.array(doc_compat_guesses))
|
|
|
|
return guesses
|
|
|
|
def _scores2guesses_top_k_guardrail(self, docs, scores):
|
|
guesses = []
|
|
for doc, doc_scores in zip(docs, scores):
|
|
doc_guesses = np.argsort(doc_scores)[..., : -self.top_k - 1 : -1]
|
|
doc_guesses = self.numpy_ops.asarray(doc_guesses)
|
|
|
|
doc_compat_guesses = []
|
|
for token, candidates in zip(doc, doc_guesses):
|
|
tree_id = -1
|
|
for candidate in candidates:
|
|
candidate_tree_id = self.cfg["labels"][candidate]
|
|
|
|
form: str = self._get_true_cased_form_of_token(token)
|
|
|
|
if self.trees.apply(candidate_tree_id, form) is not None:
|
|
tree_id = candidate_tree_id
|
|
break
|
|
doc_compat_guesses.append(tree_id)
|
|
|
|
guesses.append(np.array(doc_compat_guesses))
|
|
|
|
return guesses
|
|
|
|
def set_annotations(self, docs: Iterable[Doc], batch_tree_ids):
|
|
for i, doc in enumerate(docs):
|
|
doc_tree_ids = batch_tree_ids[i]
|
|
if hasattr(doc_tree_ids, "get"):
|
|
doc_tree_ids = doc_tree_ids.get()
|
|
for j, tree_id in enumerate(doc_tree_ids):
|
|
if self.overwrite or doc[j].lemma == 0:
|
|
|
|
|
|
|
|
if tree_id == -1:
|
|
if self.backoff is not None:
|
|
doc[j].lemma = getattr(doc[j], self.backoff)
|
|
else:
|
|
form = self._get_true_cased_form_of_token(doc[j])
|
|
lemma = self.trees.apply(tree_id, form) or form
|
|
|
|
doc[j].lemma_ = lemma
|
|
|
|
@property
|
|
def labels(self) -> Tuple[int, ...]:
|
|
"""Returns the labels currently added to the component."""
|
|
return tuple(self.cfg["labels"])
|
|
|
|
@property
|
|
def hide_labels(self) -> bool:
|
|
return True
|
|
|
|
@property
|
|
def label_data(self) -> Dict:
|
|
trees = []
|
|
for tree_id in range(len(self.trees)):
|
|
tree = self.trees[tree_id]
|
|
if "orig" in tree:
|
|
tree["orig"] = self.vocab.strings[tree["orig"]]
|
|
if "subst" in tree:
|
|
tree["subst"] = self.vocab.strings[tree["subst"]]
|
|
trees.append(tree)
|
|
return dict(trees=trees, labels=tuple(self.cfg["labels"]))
|
|
|
|
def initialize(
|
|
self,
|
|
get_examples: Callable[[], Iterable[Example]],
|
|
*,
|
|
nlp: Optional[Language] = None,
|
|
labels: Optional[Dict] = None,
|
|
):
|
|
validate_get_examples(get_examples, "EditTreeLemmatizer.initialize")
|
|
|
|
if self.overwrite_labels:
|
|
if labels is None:
|
|
self._labels_from_data(get_examples)
|
|
else:
|
|
self._add_labels(labels)
|
|
|
|
|
|
doc_sample = []
|
|
label_sample = []
|
|
for example in islice(get_examples(), 10):
|
|
doc_sample.append(example.x)
|
|
gold_labels: List[List[float]] = []
|
|
for token in example.reference:
|
|
if token.lemma == 0:
|
|
gold_label = None
|
|
else:
|
|
gold_label = self._pair2label(token.text, token.lemma_)
|
|
|
|
gold_labels.append(
|
|
[
|
|
1.0 if label == gold_label else 0.0
|
|
for label in self.cfg["labels"]
|
|
]
|
|
)
|
|
|
|
gold_labels = cast(Floats2d, gold_labels)
|
|
label_sample.append(self.model.ops.asarray(gold_labels, dtype="float32"))
|
|
|
|
self._require_labels()
|
|
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
|
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
|
|
|
self.model.initialize(X=doc_sample, Y=label_sample)
|
|
|
|
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
|
deserializers = {
|
|
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
|
|
"model": lambda b: self.model.from_bytes(b),
|
|
"vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
|
|
"trees": lambda b: self.trees.from_bytes(b),
|
|
}
|
|
|
|
util.from_bytes(bytes_data, deserializers, exclude)
|
|
|
|
return self
|
|
|
|
def to_bytes(self, *, exclude=tuple()):
|
|
serializers = {
|
|
"cfg": lambda: srsly.json_dumps(self.cfg),
|
|
"model": lambda: self.model.to_bytes(),
|
|
"vocab": lambda: self.vocab.to_bytes(exclude=exclude),
|
|
"trees": lambda: self.trees.to_bytes(),
|
|
}
|
|
|
|
return util.to_bytes(serializers, exclude)
|
|
|
|
def to_disk(self, path, exclude=tuple()):
|
|
path = util.ensure_path(path)
|
|
serializers = {
|
|
"cfg": lambda p: srsly.write_json(p, self.cfg),
|
|
"model": lambda p: self.model.to_disk(p),
|
|
"vocab": lambda p: self.vocab.to_disk(p, exclude=exclude),
|
|
"trees": lambda p: self.trees.to_disk(p),
|
|
}
|
|
util.to_disk(path, serializers, exclude)
|
|
|
|
def from_disk(self, path, exclude=tuple()):
|
|
def load_model(p):
|
|
try:
|
|
with open(p, "rb") as mfile:
|
|
self.model.from_bytes(mfile.read())
|
|
except AttributeError:
|
|
raise ValueError(Errors.E149) from None
|
|
|
|
deserializers = {
|
|
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
|
|
"model": load_model,
|
|
"vocab": lambda p: self.vocab.from_disk(p, exclude=exclude),
|
|
"trees": lambda p: self.trees.from_disk(p),
|
|
}
|
|
|
|
util.from_disk(path, deserializers, exclude)
|
|
return self
|
|
|
|
def _add_labels(self, labels: Dict):
|
|
if "labels" not in labels:
|
|
raise ValueError(Errors.E857.format(name="labels"))
|
|
if "trees" not in labels:
|
|
raise ValueError(Errors.E857.format(name="trees"))
|
|
|
|
self.cfg["labels"] = list(labels["labels"])
|
|
trees = []
|
|
for tree in labels["trees"]:
|
|
errors = validate_edit_tree(tree)
|
|
if errors:
|
|
raise ValueError(Errors.E1026.format(errors="\n".join(errors)))
|
|
|
|
tree = dict(tree)
|
|
if "orig" in tree:
|
|
tree["orig"] = self.vocab.strings[tree["orig"]]
|
|
if "orig" in tree:
|
|
tree["subst"] = self.vocab.strings[tree["subst"]]
|
|
|
|
trees.append(tree)
|
|
|
|
self.trees.from_json(trees)
|
|
|
|
for label, tree in enumerate(self.labels):
|
|
self.tree2label[tree] = label
|
|
|
|
def _labels_from_data(self, get_examples: Callable[[], Iterable[Example]]):
|
|
|
|
|
|
vocab = Vocab()
|
|
trees = EditTrees(vocab.strings)
|
|
tree_freqs: Counter = Counter()
|
|
repr_pairs: Dict = {}
|
|
for example in get_examples():
|
|
for token in example.reference:
|
|
if token.lemma != 0:
|
|
form = self._get_true_cased_form_of_token(token)
|
|
|
|
tree_id = trees.add(form, token.lemma_)
|
|
tree_freqs[tree_id] += 1
|
|
repr_pairs[tree_id] = (form, token.lemma_)
|
|
|
|
|
|
|
|
for tree_id, freq in tree_freqs.items():
|
|
if freq >= self.min_tree_freq:
|
|
form, lemma = repr_pairs[tree_id]
|
|
self._pair2label(form, lemma, add_label=True)
|
|
|
|
@lru_cache()
|
|
def _get_true_cased_form(self, token: str, is_sent_start: bool, pos: str) -> str:
|
|
if is_sent_start and pos != "PROPN":
|
|
return token.lower()
|
|
else:
|
|
return token
|
|
|
|
def _get_true_cased_form_of_token(self, token: Token) -> str:
|
|
return self._get_true_cased_form(token.text, token.is_sent_start, token.pos_)
|
|
|
|
def _pair2label(self, form, lemma, add_label=False):
|
|
"""
|
|
Look up the edit tree identifier for a form/label pair. If the edit
|
|
tree is unknown and "add_label" is set, the edit tree will be added to
|
|
the labels.
|
|
"""
|
|
tree_id = self.trees.add(form, lemma)
|
|
if tree_id not in self.tree2label:
|
|
if not add_label:
|
|
return None
|
|
|
|
self.tree2label[tree_id] = len(self.cfg["labels"])
|
|
self.cfg["labels"].append(tree_id)
|
|
return self.tree2label[tree_id]
|
|
|