re_on_tacred / transformer_re_text_classification2.py
dfki-nlp's picture
Update transformer_re_text_classification2.py
c81d27a
"""
workflow:
Document
-> (InputEncoding, TargetEncoding) -> TaskEncoding -> TaskBatchEncoding
-> ModelBatchEncoding -> ModelBatchOutput
-> TaskOutput
-> Document
"""
import logging
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, TypedDict, Union
import numpy as np
import torch
from pytorch_ie.annotations import BinaryRelation, LabeledSpan, MultiLabeledBinaryRelation, Span
from pytorch_ie.core import TaskEncoding, TaskModule
from pytorch_ie.documents import TextDocument
from pytorch_ie.models import (
TransformerTextClassificationModelBatchOutput,
TransformerTextClassificationModelStepBatchEncoding,
)
from pytorch_ie.utils.span import get_token_slice, is_contained_in
from pytorch_ie.utils.window import get_window_around_slice
from transformers import AutoTokenizer
from transformers.file_utils import PaddingStrategy
from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy
from typing_extensions import TypeAlias
TransformerReTextClassificationInputEncoding2: TypeAlias = Dict[str, Any]
TransformerReTextClassificationTargetEncoding2: TypeAlias = Sequence[int]
TransformerReTextClassificationTaskEncoding2: TypeAlias = TaskEncoding[
TextDocument,
TransformerReTextClassificationInputEncoding2,
TransformerReTextClassificationTargetEncoding2,
]
class TransformerReTextClassificationTaskOutput2(TypedDict, total=False):
labels: Sequence[str]
probabilities: Sequence[float]
_TransformerReTextClassificationTaskModule2: TypeAlias = TaskModule[
# _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput
TextDocument,
TransformerReTextClassificationInputEncoding2,
TransformerReTextClassificationTargetEncoding2,
TransformerTextClassificationModelStepBatchEncoding,
TransformerTextClassificationModelBatchOutput,
TransformerReTextClassificationTaskOutput2,
]
HEAD = "head"
TAIL = "tail"
START = "start"
END = "end"
logger = logging.getLogger(__name__)
class RelationArgument:
def __init__(
self,
entity: LabeledSpan,
role: str,
offsets: Tuple[int, int],
add_type_to_marker: bool,
) -> None:
self.entity = entity
self.role = role
assert self.role in (HEAD, TAIL)
self.offsets = offsets
self.add_type_to_marker = add_type_to_marker
@property
def is_head(self) -> bool:
return self.role == HEAD
@property
def is_tail(self) -> bool:
return self.role == TAIL
@property
def as_start_marker(self) -> str:
return self._get_marker(is_start=True)
@property
def as_end_marker(self) -> str:
return self._get_marker(is_start=False)
def _get_marker(self, is_start: bool = True) -> str:
return f"[{'' if is_start else '/'}{'H' if self.is_head else 'T'}" + (
f":{self.entity.label}]" if self.add_type_to_marker else "]"
)
@property
def as_append_marker(self) -> str:
return f"[{'H' if self.is_head else 'T'}={self.entity.label}]"
def _enumerate_entity_pairs(
entities: Sequence[Span],
partition: Optional[Span] = None,
relations: Optional[Sequence[BinaryRelation]] = None,
):
"""Given a list of `entities` iterate all valid pairs of entities, including inverted pairs.
If a `partition` is provided, restrict pairs to be contained in that. If `relations` are given,
return only pairs for which a predefined relation exists (e.g. in the case of relation
classification for train,val,test splits in supervised datasets).
"""
existing_head_tail = {(relation.head, relation.tail) for relation in relations or []}
for head in entities:
if partition is not None and not is_contained_in(
(head.start, head.end), (partition.start, partition.end)
):
continue
for tail in entities:
if partition is not None and not is_contained_in(
(tail.start, tail.end), (partition.start, partition.end)
):
continue
if head == tail:
continue
if relations is not None and (head, tail) not in existing_head_tail:
continue
yield head, tail
@TaskModule.register()
class TransformerRETextClassificationTaskModule2(_TransformerReTextClassificationTaskModule2):
"""Marker based relation extraction. This taskmodule prepares the input token ids in such a way
that before and after the candidate head and tail entities special marker tokens are inserted.
Then, the modified token ids can be simply passed into a transformer based text classifier
model.
parameters:
partition_annotation: str, optional. If specified, LabeledSpan annotations with this name are
expected to define partitions of the document that will be processed individually, e.g. sentences
or sections of the document text.
none_label: str, defaults to "no_relation". The relation label that indicate dummy/negative relations.
Predicted relations with that label will not be added to the document(s).
max_window: int, optional. If specified, use the tokens in a window of maximal this amount of tokens
around the center of head and tail entities and pass only that into the transformer.
"""
PREPARED_ATTRIBUTES = ["label_to_id", "entity_labels"]
def __init__(
self,
tokenizer_name_or_path: str,
entity_annotation: str = "entities",
relation_annotation: str = "relations",
partition_annotation: Optional[str] = None,
none_label: str = "no_relation",
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = True,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
multi_label: bool = False,
label_to_id: Optional[Dict[str, int]] = None,
add_type_to_marker: bool = False,
single_argument_pair: bool = True,
append_markers: bool = False,
entity_labels: Optional[List[str]] = None,
max_window: Optional[int] = None,
log_first_n_examples: Optional[int] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.save_hyperparameters()
self.entity_annotation = entity_annotation
self.relation_annotation = relation_annotation
self.padding = padding
self.truncation = truncation
self.label_to_id = label_to_id or {}
self.id_to_label = {v: k for k, v in self.label_to_id.items()}
self.max_length = max_length
self.pad_to_multiple_of = pad_to_multiple_of
self.multi_label = multi_label
self.add_type_to_marker = add_type_to_marker
self.single_argument_pair = single_argument_pair
self.append_markers = append_markers
self.entity_labels = entity_labels
self.partition_annotation = partition_annotation
self.none_label = none_label
self.max_window = max_window
self.log_first_n_examples = log_first_n_examples
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
self.argument_markers = None
self._logged_examples_counter = 0
def _prepare(self, documents: Sequence[TextDocument]) -> None:
entity_labels: Set[str] = set()
relation_labels: Set[str] = set()
for document in documents:
entities: Sequence[LabeledSpan] = document[self.entity_annotation]
relations: Sequence[BinaryRelation] = document[self.relation_annotation]
for entity in entities:
entity_labels.add(entity.label)
for relation in relations:
relation_labels.add(relation.label)
if self.none_label in relation_labels:
relation_labels.remove(self.none_label)
self.label_to_id = {label: i + 1 for i, label in enumerate(sorted(relation_labels))}
self.label_to_id[self.none_label] = 0
self.entity_labels = sorted(entity_labels)
def _post_prepare(self):
self.argument_markers = self._initialize_argument_markers()
self.tokenizer.add_tokens(self.argument_markers, special_tokens=True)
self.argument_markers_to_id = {
marker: self.tokenizer.vocab[marker] for marker in self.argument_markers
}
self.sep_token_id = self.tokenizer.vocab[self.tokenizer.sep_token]
self.id_to_label = {v: k for k, v in self.label_to_id.items()}
def _initialize_argument_markers(self) -> List[str]:
argument_markers: Set[str] = set()
for arg_type in [HEAD, TAIL]:
for arg_pos in [START, END]:
is_head = arg_type == HEAD
is_start = arg_pos == START
argument_markers.add(f"[{'' if is_start else '/'}{'H' if is_head else 'T'}]")
if self.add_type_to_marker:
for entity_type in self.entity_labels: # type: ignore
argument_markers.add(
f"[{'' if is_start else '/'}{'H' if is_head else 'T'}"
f"{':' + entity_type if self.add_type_to_marker else ''}]"
)
if self.append_markers:
for entity_type in self.entity_labels: # type: ignore
argument_markers.add(f"[{'H' if is_head else 'T'}={entity_type}]")
return sorted(list(argument_markers))
def _encode_text(
self,
document: TextDocument,
partition: Optional[Span] = None,
add_special_tokens: bool = True,
) -> BatchEncoding:
text = (
document.text[partition.start : partition.end]
if partition is not None
else document.text
)
encoding = self.tokenizer(
text,
padding=False,
truncation=self.truncation,
max_length=self.max_length,
is_split_into_words=False,
return_offsets_mapping=False,
add_special_tokens=add_special_tokens,
)
return encoding
def encode_input(
self,
document: TextDocument,
is_training: bool = False,
) -> Optional[
Union[
TransformerReTextClassificationTaskEncoding2,
Sequence[TransformerReTextClassificationTaskEncoding2],
]
]:
assert (
self.argument_markers is not None
), "No argument markers available, was `prepare` already called?"
entities: Sequence[Span] = document[self.entity_annotation]
relations: Sequence[BinaryRelation] = document[self.relation_annotation]
# if no relations are predefined, use None so that enumerate_entities yields all pairs
# should be fixed to a parameter "generate_all" or "restrict_to_existing_relations"
if len(relations) == 0:
relations = None
partitions: Sequence[Optional[Span]]
if self.partition_annotation is not None:
partitions = document[self.partition_annotation]
else:
# use single dummy partition
partitions = [None]
task_encodings: List[TransformerReTextClassificationTaskEncoding2] = []
for partition_idx, partition in enumerate(partitions):
partition_offset = 0 if partition is None else partition.start
add_special_tokens = self.max_window is None
encoding = self._encode_text(
document=document, partition=partition, add_special_tokens=add_special_tokens
)
for (head, tail,) in _enumerate_entity_pairs(
entities=entities,
partition=partition,
relations=relations,
):
head_token_slice = get_token_slice(
character_slice=(head.start, head.end),
char_to_token_mapper=encoding.char_to_token,
character_offset=partition_offset,
)
tail_token_slice = get_token_slice(
character_slice=(tail.start, tail.end),
char_to_token_mapper=encoding.char_to_token,
character_offset=partition_offset,
)
# this happens if the head/tail start/end does not match a token start/end
if head_token_slice is None or tail_token_slice is None:
# if statistics is not None:
# statistics["entity_token_alignment_error"][
# relation_mapping.get((head, tail), "TO_PREDICT")
# ] += 1
logger.warning(
f"Skipping invalid example {document.id}, cannot get token slice(s)"
)
continue
input_ids = encoding["input_ids"]
# not sure if this is the correct way to get the tokens corresponding to the input_ids
tokens = encoding.encodings[0].tokens
# windowing
if self.max_window is not None:
head_start, head_end = head_token_slice
tail_start, tail_end = tail_token_slice
# The actual number of tokens will be lower than max_window because we add the
# 4 marker tokens (before / after the head /tail) and the default special tokens
# (e.g. CLS and SEP).
num_added_special_tokens = len(
self.tokenizer.build_inputs_with_special_tokens([])
)
max_tokens = self.max_window - 4 - num_added_special_tokens
# the slice from the beginning of the first entity to the end of the second is required
slice_required = (min(head_start, tail_start), max(head_end, tail_end))
window_slice = get_window_around_slice(
slice=slice_required,
max_window_size=max_tokens,
available_input_length=len(input_ids),
)
# this happens if slice_required does not fit into max_tokens
if window_slice is None:
# if statistics is not None:
# statistics["out_of_token_window"][
# relation_mapping.get((head, tail), "TO_PREDICT")
# ] += 1
continue
window_start, window_end = window_slice
input_ids = input_ids[window_start:window_end]
head_token_slice = head_start - window_start, head_end - window_start
tail_token_slice = tail_start - window_start, tail_end - window_start
# maybe expand to n-ary relations?
head_arg = RelationArgument(head, HEAD, head_token_slice, self.add_type_to_marker)
tail_arg = RelationArgument(tail, TAIL, tail_token_slice, self.add_type_to_marker)
arg_list = [head_arg, tail_arg]
if head_token_slice[0] < tail_token_slice[0]:
assert (
head_token_slice[1] <= tail_token_slice[0]
), f"the head and tail entities are not allowed to overlap in {document.id}"
else:
assert (
tail_token_slice[1] <= head_token_slice[0]
), f"the head and tail entities are not allowed to overlap in {document.id}"
# expand to n-ary relations?
arg_list.reverse()
first_arg_start_id = self.argument_markers_to_id[arg_list[0].as_start_marker]
first_arg_end_id = self.argument_markers_to_id[arg_list[0].as_end_marker]
second_arg_start_id = self.argument_markers_to_id[arg_list[1].as_start_marker]
second_arg_end_id = self.argument_markers_to_id[arg_list[1].as_end_marker]
new_input_ids = (
input_ids[: arg_list[0].offsets[0]]
+ [first_arg_start_id]
+ input_ids[arg_list[0].offsets[0] : arg_list[0].offsets[1]]
+ [first_arg_end_id]
+ input_ids[arg_list[0].offsets[1] : arg_list[1].offsets[0]]
+ [second_arg_start_id]
+ input_ids[arg_list[1].offsets[0] : arg_list[1].offsets[1]]
+ [second_arg_end_id]
+ input_ids[arg_list[1].offsets[1] :]
)
if self.append_markers:
new_input_ids.extend(
[
self.argument_markers_to_id[head_arg.as_append_marker],
self.sep_token_id,
self.argument_markers_to_id[tail_arg.as_append_marker],
self.sep_token_id,
]
)
# when windowing is used, we have to add the special tokens manually
if not add_special_tokens:
new_input_ids = self.tokenizer.build_inputs_with_special_tokens(
token_ids_0=new_input_ids
)
# lots of logging from here on
log_this_example = (
relations is not None
and self.log_first_n_examples is not None
and self._logged_examples_counter <= self.log_first_n_examples
)
if log_this_example:
self._log_example(document, arg_list, new_input_ids, relations, tokens)
task_encodings.append(
TaskEncoding(
document=document,
inputs={"input_ids": new_input_ids},
metadata={
HEAD: head,
TAIL: tail,
},
)
)
return task_encodings
def _log_example(
self,
document: TextDocument,
arg_list: List[RelationArgument],
input_ids: List[int],
relations: Sequence[BinaryRelation],
tokens: List[str],
):
first_arg_start = arg_list[0].as_start_marker
first_arg_end = arg_list[0].as_end_marker
second_arg_start = arg_list[1].as_start_marker
second_arg_end = arg_list[1].as_end_marker
new_tokens = (
tokens[: arg_list[0].offsets[0]]
+ [first_arg_start]
+ tokens[arg_list[0].offsets[0] : arg_list[0].offsets[1]]
+ [first_arg_end]
+ tokens[arg_list[0].offsets[1] : arg_list[1].offsets[0]]
+ [second_arg_start]
+ tokens[arg_list[1].offsets[0] : arg_list[1].offsets[1]]
+ [second_arg_end]
+ tokens[arg_list[1].offsets[1] :]
)
head_idx = 0 if arg_list[0].role == HEAD else 1
tail_idx = 0 if arg_list[0].role == TAIL else 1
if self.append_markers:
head_marker = arg_list[head_idx].as_append_marker
tail_marker = arg_list[tail_idx].as_append_marker
new_tokens.extend(
[head_marker, self.tokenizer.sep_token, tail_marker, self.tokenizer.sep_token]
)
logger.info("*** Example ***")
logger.info("doc id: %s", document.id)
logger.info("tokens: %s", " ".join([str(x) for x in new_tokens]))
logger.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
rel_labels = [relation.label for relation in relations]
rel_label_ids = [self.label_to_id[label] for label in rel_labels]
logger.info("Expected labels: %s (ids = %s)", rel_labels, rel_label_ids)
self._logged_examples_counter += 1
def encode_target(
self,
task_encoding: TransformerReTextClassificationTaskEncoding2,
) -> TransformerReTextClassificationTargetEncoding2:
metadata = task_encoding.metadata
document = task_encoding.document
relations: Sequence[BinaryRelation] = document[self.relation_annotation]
head_tail_to_labels = {
(relation.head, relation.tail): [relation.label] for relation in relations
}
labels = head_tail_to_labels.get((metadata[HEAD], metadata[TAIL]), [self.none_label])
target = [self.label_to_id[label] for label in labels]
return target
def unbatch_output(
self, model_output: TransformerTextClassificationModelBatchOutput
) -> Sequence[TransformerReTextClassificationTaskOutput2]:
logits = model_output["logits"]
output_label_probs = logits.sigmoid() if self.multi_label else logits.softmax(dim=-1)
output_label_probs = output_label_probs.detach().cpu().numpy()
unbatched_output = []
if self.multi_label:
raise NotImplementedError
else:
label_ids = np.argmax(output_label_probs, axis=-1)
for batch_idx, label_id in enumerate(label_ids):
label = self.id_to_label[label_id]
prob = float(output_label_probs[batch_idx, label_id])
result: TransformerReTextClassificationTaskOutput2 = {
"labels": [label],
"probabilities": [prob],
}
unbatched_output.append(result)
return unbatched_output
def create_annotations_from_output(
self,
task_encoding: TransformerReTextClassificationTaskEncoding2,
task_output: TransformerReTextClassificationTaskOutput2,
) -> Iterator[Tuple[str, Union[BinaryRelation, MultiLabeledBinaryRelation]]]:
labels = task_output["labels"]
probabilities = task_output["probabilities"]
if labels != [self.none_label]:
yield (
self.relation_annotation,
BinaryRelation(
head=task_encoding.metadata[HEAD],
tail=task_encoding.metadata[TAIL],
label=labels[0],
score=probabilities[0],
),
)
def collate(
self, task_encodings: Sequence[TransformerReTextClassificationTaskEncoding2]
) -> TransformerTextClassificationModelStepBatchEncoding:
input_features = [task_encoding.inputs for task_encoding in task_encodings]
inputs: Dict[str, torch.Tensor] = self.tokenizer.pad(
input_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
if not task_encodings[0].has_targets:
return inputs, None
target_list: List[TransformerReTextClassificationTargetEncoding2] = [
task_encoding.targets for task_encoding in task_encodings
]
targets = torch.tensor(target_list, dtype=torch.int64)
if not self.multi_label:
targets = targets.flatten()
return inputs, targets