Spaces:
Runtime error
Runtime error
""" | |
workflow: | |
Document | |
-> (InputEncoding, TargetEncoding) -> TaskEncoding -> TaskBatchEncoding | |
-> ModelBatchEncoding -> ModelBatchOutput | |
-> TaskOutput | |
-> Document | |
""" | |
import logging | |
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, TypedDict, Union | |
import numpy as np | |
import torch | |
from pytorch_ie.annotations import BinaryRelation, LabeledSpan, MultiLabeledBinaryRelation, Span | |
from pytorch_ie.core import TaskEncoding, TaskModule | |
from pytorch_ie.documents import TextDocument | |
from pytorch_ie.models import ( | |
TransformerTextClassificationModelBatchOutput, | |
TransformerTextClassificationModelStepBatchEncoding, | |
) | |
from pytorch_ie.utils.span import get_token_slice, is_contained_in | |
from pytorch_ie.utils.window import get_window_around_slice | |
from transformers import AutoTokenizer | |
from transformers.file_utils import PaddingStrategy | |
from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy | |
from typing_extensions import TypeAlias | |
TransformerReTextClassificationInputEncoding2: TypeAlias = Dict[str, Any] | |
TransformerReTextClassificationTargetEncoding2: TypeAlias = Sequence[int] | |
TransformerReTextClassificationTaskEncoding2: TypeAlias = TaskEncoding[ | |
TextDocument, | |
TransformerReTextClassificationInputEncoding2, | |
TransformerReTextClassificationTargetEncoding2, | |
] | |
class TransformerReTextClassificationTaskOutput2(TypedDict, total=False): | |
labels: Sequence[str] | |
probabilities: Sequence[float] | |
_TransformerReTextClassificationTaskModule2: TypeAlias = TaskModule[ | |
# _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput | |
TextDocument, | |
TransformerReTextClassificationInputEncoding2, | |
TransformerReTextClassificationTargetEncoding2, | |
TransformerTextClassificationModelStepBatchEncoding, | |
TransformerTextClassificationModelBatchOutput, | |
TransformerReTextClassificationTaskOutput2, | |
] | |
HEAD = "head" | |
TAIL = "tail" | |
START = "start" | |
END = "end" | |
logger = logging.getLogger(__name__) | |
class RelationArgument: | |
def __init__( | |
self, | |
entity: LabeledSpan, | |
role: str, | |
offsets: Tuple[int, int], | |
add_type_to_marker: bool, | |
) -> None: | |
self.entity = entity | |
self.role = role | |
assert self.role in (HEAD, TAIL) | |
self.offsets = offsets | |
self.add_type_to_marker = add_type_to_marker | |
def is_head(self) -> bool: | |
return self.role == HEAD | |
def is_tail(self) -> bool: | |
return self.role == TAIL | |
def as_start_marker(self) -> str: | |
return self._get_marker(is_start=True) | |
def as_end_marker(self) -> str: | |
return self._get_marker(is_start=False) | |
def _get_marker(self, is_start: bool = True) -> str: | |
return f"[{'' if is_start else '/'}{'H' if self.is_head else 'T'}" + ( | |
f":{self.entity.label}]" if self.add_type_to_marker else "]" | |
) | |
def as_append_marker(self) -> str: | |
return f"[{'H' if self.is_head else 'T'}={self.entity.label}]" | |
def _enumerate_entity_pairs( | |
entities: Sequence[Span], | |
partition: Optional[Span] = None, | |
relations: Optional[Sequence[BinaryRelation]] = None, | |
): | |
"""Given a list of `entities` iterate all valid pairs of entities, including inverted pairs. | |
If a `partition` is provided, restrict pairs to be contained in that. If `relations` are given, | |
return only pairs for which a predefined relation exists (e.g. in the case of relation | |
classification for train,val,test splits in supervised datasets). | |
""" | |
existing_head_tail = {(relation.head, relation.tail) for relation in relations or []} | |
for head in entities: | |
if partition is not None and not is_contained_in( | |
(head.start, head.end), (partition.start, partition.end) | |
): | |
continue | |
for tail in entities: | |
if partition is not None and not is_contained_in( | |
(tail.start, tail.end), (partition.start, partition.end) | |
): | |
continue | |
if head == tail: | |
continue | |
if relations is not None and (head, tail) not in existing_head_tail: | |
continue | |
yield head, tail | |
class TransformerRETextClassificationTaskModule2(_TransformerReTextClassificationTaskModule2): | |
"""Marker based relation extraction. This taskmodule prepares the input token ids in such a way | |
that before and after the candidate head and tail entities special marker tokens are inserted. | |
Then, the modified token ids can be simply passed into a transformer based text classifier | |
model. | |
parameters: | |
partition_annotation: str, optional. If specified, LabeledSpan annotations with this name are | |
expected to define partitions of the document that will be processed individually, e.g. sentences | |
or sections of the document text. | |
none_label: str, defaults to "no_relation". The relation label that indicate dummy/negative relations. | |
Predicted relations with that label will not be added to the document(s). | |
max_window: int, optional. If specified, use the tokens in a window of maximal this amount of tokens | |
around the center of head and tail entities and pass only that into the transformer. | |
""" | |
PREPARED_ATTRIBUTES = ["label_to_id", "entity_labels"] | |
def __init__( | |
self, | |
tokenizer_name_or_path: str, | |
entity_annotation: str = "entities", | |
relation_annotation: str = "relations", | |
partition_annotation: Optional[str] = None, | |
none_label: str = "no_relation", | |
padding: Union[bool, str, PaddingStrategy] = True, | |
truncation: Union[bool, str, TruncationStrategy] = True, | |
max_length: Optional[int] = None, | |
pad_to_multiple_of: Optional[int] = None, | |
multi_label: bool = False, | |
label_to_id: Optional[Dict[str, int]] = None, | |
add_type_to_marker: bool = False, | |
single_argument_pair: bool = True, | |
append_markers: bool = False, | |
entity_labels: Optional[List[str]] = None, | |
max_window: Optional[int] = None, | |
log_first_n_examples: Optional[int] = None, | |
**kwargs, | |
) -> None: | |
super().__init__(**kwargs) | |
self.save_hyperparameters() | |
self.entity_annotation = entity_annotation | |
self.relation_annotation = relation_annotation | |
self.padding = padding | |
self.truncation = truncation | |
self.label_to_id = label_to_id or {} | |
self.id_to_label = {v: k for k, v in self.label_to_id.items()} | |
self.max_length = max_length | |
self.pad_to_multiple_of = pad_to_multiple_of | |
self.multi_label = multi_label | |
self.add_type_to_marker = add_type_to_marker | |
self.single_argument_pair = single_argument_pair | |
self.append_markers = append_markers | |
self.entity_labels = entity_labels | |
self.partition_annotation = partition_annotation | |
self.none_label = none_label | |
self.max_window = max_window | |
self.log_first_n_examples = log_first_n_examples | |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) | |
self.argument_markers = None | |
self._logged_examples_counter = 0 | |
def _prepare(self, documents: Sequence[TextDocument]) -> None: | |
entity_labels: Set[str] = set() | |
relation_labels: Set[str] = set() | |
for document in documents: | |
entities: Sequence[LabeledSpan] = document[self.entity_annotation] | |
relations: Sequence[BinaryRelation] = document[self.relation_annotation] | |
for entity in entities: | |
entity_labels.add(entity.label) | |
for relation in relations: | |
relation_labels.add(relation.label) | |
if self.none_label in relation_labels: | |
relation_labels.remove(self.none_label) | |
self.label_to_id = {label: i + 1 for i, label in enumerate(sorted(relation_labels))} | |
self.label_to_id[self.none_label] = 0 | |
self.entity_labels = sorted(entity_labels) | |
def _post_prepare(self): | |
self.argument_markers = self._initialize_argument_markers() | |
self.tokenizer.add_tokens(self.argument_markers, special_tokens=True) | |
self.argument_markers_to_id = { | |
marker: self.tokenizer.vocab[marker] for marker in self.argument_markers | |
} | |
self.sep_token_id = self.tokenizer.vocab[self.tokenizer.sep_token] | |
self.id_to_label = {v: k for k, v in self.label_to_id.items()} | |
def _initialize_argument_markers(self) -> List[str]: | |
argument_markers: Set[str] = set() | |
for arg_type in [HEAD, TAIL]: | |
for arg_pos in [START, END]: | |
is_head = arg_type == HEAD | |
is_start = arg_pos == START | |
argument_markers.add(f"[{'' if is_start else '/'}{'H' if is_head else 'T'}]") | |
if self.add_type_to_marker: | |
for entity_type in self.entity_labels: # type: ignore | |
argument_markers.add( | |
f"[{'' if is_start else '/'}{'H' if is_head else 'T'}" | |
f"{':' + entity_type if self.add_type_to_marker else ''}]" | |
) | |
if self.append_markers: | |
for entity_type in self.entity_labels: # type: ignore | |
argument_markers.add(f"[{'H' if is_head else 'T'}={entity_type}]") | |
return sorted(list(argument_markers)) | |
def _encode_text( | |
self, | |
document: TextDocument, | |
partition: Optional[Span] = None, | |
add_special_tokens: bool = True, | |
) -> BatchEncoding: | |
text = ( | |
document.text[partition.start : partition.end] | |
if partition is not None | |
else document.text | |
) | |
encoding = self.tokenizer( | |
text, | |
padding=False, | |
truncation=self.truncation, | |
max_length=self.max_length, | |
is_split_into_words=False, | |
return_offsets_mapping=False, | |
add_special_tokens=add_special_tokens, | |
) | |
return encoding | |
def encode_input( | |
self, | |
document: TextDocument, | |
is_training: bool = False, | |
) -> Optional[ | |
Union[ | |
TransformerReTextClassificationTaskEncoding2, | |
Sequence[TransformerReTextClassificationTaskEncoding2], | |
] | |
]: | |
assert ( | |
self.argument_markers is not None | |
), "No argument markers available, was `prepare` already called?" | |
entities: Sequence[Span] = document[self.entity_annotation] | |
relations: Sequence[BinaryRelation] = document[self.relation_annotation] | |
partitions: Sequence[Optional[Span]] | |
if self.partition_annotation is not None: | |
partitions = document[self.partition_annotation] | |
else: | |
# use single dummy partition | |
partitions = [None] | |
task_encodings: List[TransformerReTextClassificationTaskEncoding2] = [] | |
for partition_idx, partition in enumerate(partitions): | |
partition_offset = 0 if partition is None else partition.start | |
add_special_tokens = self.max_window is None | |
encoding = self._encode_text( | |
document=document, partition=partition, add_special_tokens=add_special_tokens | |
) | |
for (head, tail,) in _enumerate_entity_pairs( | |
entities=entities, | |
partition=partition, | |
relations=relations, | |
): | |
head_token_slice = get_token_slice( | |
character_slice=(head.start, head.end), | |
char_to_token_mapper=encoding.char_to_token, | |
character_offset=partition_offset, | |
) | |
tail_token_slice = get_token_slice( | |
character_slice=(tail.start, tail.end), | |
char_to_token_mapper=encoding.char_to_token, | |
character_offset=partition_offset, | |
) | |
# this happens if the head/tail start/end does not match a token start/end | |
if head_token_slice is None or tail_token_slice is None: | |
# if statistics is not None: | |
# statistics["entity_token_alignment_error"][ | |
# relation_mapping.get((head, tail), "TO_PREDICT") | |
# ] += 1 | |
logger.warning( | |
f"Skipping invalid example {document.id}, cannot get token slice(s)" | |
) | |
continue | |
input_ids = encoding["input_ids"] | |
# not sure if this is the correct way to get the tokens corresponding to the input_ids | |
tokens = encoding.encodings[0].tokens | |
# windowing | |
if self.max_window is not None: | |
head_start, head_end = head_token_slice | |
tail_start, tail_end = tail_token_slice | |
# The actual number of tokens will be lower than max_window because we add the | |
# 4 marker tokens (before / after the head /tail) and the default special tokens | |
# (e.g. CLS and SEP). | |
num_added_special_tokens = len( | |
self.tokenizer.build_inputs_with_special_tokens([]) | |
) | |
max_tokens = self.max_window - 4 - num_added_special_tokens | |
# the slice from the beginning of the first entity to the end of the second is required | |
slice_required = (min(head_start, tail_start), max(head_end, tail_end)) | |
window_slice = get_window_around_slice( | |
slice=slice_required, | |
max_window_size=max_tokens, | |
available_input_length=len(input_ids), | |
) | |
# this happens if slice_required does not fit into max_tokens | |
if window_slice is None: | |
# if statistics is not None: | |
# statistics["out_of_token_window"][ | |
# relation_mapping.get((head, tail), "TO_PREDICT") | |
# ] += 1 | |
continue | |
window_start, window_end = window_slice | |
input_ids = input_ids[window_start:window_end] | |
head_token_slice = head_start - window_start, head_end - window_start | |
tail_token_slice = tail_start - window_start, tail_end - window_start | |
# maybe expand to n-ary relations? | |
head_arg = RelationArgument(head, HEAD, head_token_slice, self.add_type_to_marker) | |
tail_arg = RelationArgument(tail, TAIL, tail_token_slice, self.add_type_to_marker) | |
arg_list = [head_arg, tail_arg] | |
if head_token_slice[0] < tail_token_slice[0]: | |
assert ( | |
head_token_slice[1] <= tail_token_slice[0] | |
), f"the head and tail entities are not allowed to overlap in {document.id}" | |
else: | |
assert ( | |
tail_token_slice[1] <= head_token_slice[0] | |
), f"the head and tail entities are not allowed to overlap in {document.id}" | |
# expand to n-ary relations? | |
arg_list.reverse() | |
first_arg_start_id = self.argument_markers_to_id[arg_list[0].as_start_marker] | |
first_arg_end_id = self.argument_markers_to_id[arg_list[0].as_end_marker] | |
second_arg_start_id = self.argument_markers_to_id[arg_list[1].as_start_marker] | |
second_arg_end_id = self.argument_markers_to_id[arg_list[1].as_end_marker] | |
new_input_ids = ( | |
input_ids[: arg_list[0].offsets[0]] | |
+ [first_arg_start_id] | |
+ input_ids[arg_list[0].offsets[0] : arg_list[0].offsets[1]] | |
+ [first_arg_end_id] | |
+ input_ids[arg_list[0].offsets[1] : arg_list[1].offsets[0]] | |
+ [second_arg_start_id] | |
+ input_ids[arg_list[1].offsets[0] : arg_list[1].offsets[1]] | |
+ [second_arg_end_id] | |
+ input_ids[arg_list[1].offsets[1] :] | |
) | |
if self.append_markers: | |
new_input_ids.extend( | |
[ | |
self.argument_markers_to_id[head_arg.as_append_marker], | |
self.sep_token_id, | |
self.argument_markers_to_id[tail_arg.as_append_marker], | |
self.sep_token_id, | |
] | |
) | |
# when windowing is used, we have to add the special tokens manually | |
if not add_special_tokens: | |
new_input_ids = self.tokenizer.build_inputs_with_special_tokens( | |
token_ids_0=new_input_ids | |
) | |
# lots of logging from here on | |
log_this_example = ( | |
self.log_first_n_examples is not None | |
and self._logged_examples_counter <= self.log_first_n_examples | |
) | |
if log_this_example: | |
self._log_example(document, arg_list, new_input_ids, relations, tokens) | |
task_encodings.append( | |
TaskEncoding( | |
document=document, | |
inputs={"input_ids": new_input_ids}, | |
metadata={ | |
HEAD: head, | |
TAIL: tail, | |
}, | |
) | |
) | |
return task_encodings | |
def _log_example( | |
self, | |
document: TextDocument, | |
arg_list: List[RelationArgument], | |
input_ids: List[int], | |
relations: Sequence[BinaryRelation], | |
tokens: List[str], | |
): | |
first_arg_start = arg_list[0].as_start_marker | |
first_arg_end = arg_list[0].as_end_marker | |
second_arg_start = arg_list[1].as_start_marker | |
second_arg_end = arg_list[1].as_end_marker | |
new_tokens = ( | |
tokens[: arg_list[0].offsets[0]] | |
+ [first_arg_start] | |
+ tokens[arg_list[0].offsets[0] : arg_list[0].offsets[1]] | |
+ [first_arg_end] | |
+ tokens[arg_list[0].offsets[1] : arg_list[1].offsets[0]] | |
+ [second_arg_start] | |
+ tokens[arg_list[1].offsets[0] : arg_list[1].offsets[1]] | |
+ [second_arg_end] | |
+ tokens[arg_list[1].offsets[1] :] | |
) | |
head_idx = 0 if arg_list[0].role == HEAD else 1 | |
tail_idx = 0 if arg_list[0].role == TAIL else 1 | |
if self.append_markers: | |
head_marker = arg_list[head_idx].as_append_marker | |
tail_marker = arg_list[tail_idx].as_append_marker | |
new_tokens.extend( | |
[head_marker, self.tokenizer.sep_token, tail_marker, self.tokenizer.sep_token] | |
) | |
logger.info("*** Example ***") | |
logger.info("doc id: %s", document.id) | |
logger.info("tokens: %s", " ".join([str(x) for x in new_tokens])) | |
logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) | |
rel_labels = [relation.label for relation in relations] | |
rel_label_ids = [self.label_to_id[label] for label in rel_labels] | |
logger.info("Expected labels: %s (ids = %s)", rel_labels, rel_label_ids) | |
self._logged_examples_counter += 1 | |
def encode_target( | |
self, | |
task_encoding: TransformerReTextClassificationTaskEncoding2, | |
) -> TransformerReTextClassificationTargetEncoding2: | |
metadata = task_encoding.metadata | |
document = task_encoding.document | |
relations: Sequence[BinaryRelation] = document[self.relation_annotation] | |
head_tail_to_labels = { | |
(relation.head, relation.tail): [relation.label] for relation in relations | |
} | |
labels = head_tail_to_labels.get((metadata[HEAD], metadata[TAIL]), [self.none_label]) | |
target = [self.label_to_id[label] for label in labels] | |
return target | |
def unbatch_output( | |
self, model_output: TransformerTextClassificationModelBatchOutput | |
) -> Sequence[TransformerReTextClassificationTaskOutput2]: | |
logits = model_output["logits"] | |
output_label_probs = logits.sigmoid() if self.multi_label else logits.softmax(dim=-1) | |
output_label_probs = output_label_probs.detach().cpu().numpy() | |
unbatched_output = [] | |
if self.multi_label: | |
raise NotImplementedError | |
else: | |
label_ids = np.argmax(output_label_probs, axis=-1) | |
for batch_idx, label_id in enumerate(label_ids): | |
label = self.id_to_label[label_id] | |
prob = float(output_label_probs[batch_idx, label_id]) | |
result: TransformerReTextClassificationTaskOutput2 = { | |
"labels": [label], | |
"probabilities": [prob], | |
} | |
unbatched_output.append(result) | |
return unbatched_output | |
def create_annotations_from_output( | |
self, | |
task_encoding: TransformerReTextClassificationTaskEncoding2, | |
task_output: TransformerReTextClassificationTaskOutput2, | |
) -> Iterator[Tuple[str, Union[BinaryRelation, MultiLabeledBinaryRelation]]]: | |
labels = task_output["labels"] | |
probabilities = task_output["probabilities"] | |
if labels != [self.none_label]: | |
yield ( | |
self.relation_annotation, | |
BinaryRelation( | |
head=task_encoding.metadata[HEAD], | |
tail=task_encoding.metadata[TAIL], | |
label=labels[0], | |
score=probabilities[0], | |
), | |
) | |
def collate( | |
self, task_encodings: Sequence[TransformerReTextClassificationTaskEncoding2] | |
) -> TransformerTextClassificationModelStepBatchEncoding: | |
input_features = [task_encoding.inputs for task_encoding in task_encodings] | |
inputs: Dict[str, torch.Tensor] = self.tokenizer.pad( | |
input_features, | |
padding=self.padding, | |
max_length=self.max_length, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
return_tensors="pt", | |
) | |
if not task_encodings[0].has_targets: | |
return inputs, None | |
target_list: List[TransformerReTextClassificationTargetEncoding2] = [ | |
task_encoding.targets for task_encoding in task_encodings | |
] | |
targets = torch.tensor(target_list, dtype=torch.int64) | |
if not self.multi_label: | |
targets = targets.flatten() | |
return inputs, targets | |