""" workflow: Document -> (InputEncoding, TargetEncoding) -> TaskEncoding -> TaskBatchEncoding -> ModelBatchEncoding -> ModelBatchOutput -> TaskOutput -> Document """ import logging from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, TypedDict, Union import numpy as np import torch from pytorch_ie.annotations import BinaryRelation, LabeledSpan, MultiLabeledBinaryRelation, Span from pytorch_ie.core import TaskEncoding, TaskModule from pytorch_ie.documents import TextDocument from pytorch_ie.models import ( TransformerTextClassificationModelBatchOutput, TransformerTextClassificationModelStepBatchEncoding, ) from pytorch_ie.utils.span import get_token_slice, is_contained_in from pytorch_ie.utils.window import get_window_around_slice from transformers import AutoTokenizer from transformers.file_utils import PaddingStrategy from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy from typing_extensions import TypeAlias TransformerReTextClassificationInputEncoding2: TypeAlias = Dict[str, Any] TransformerReTextClassificationTargetEncoding2: TypeAlias = Sequence[int] TransformerReTextClassificationTaskEncoding2: TypeAlias = TaskEncoding[ TextDocument, TransformerReTextClassificationInputEncoding2, TransformerReTextClassificationTargetEncoding2, ] class TransformerReTextClassificationTaskOutput2(TypedDict, total=False): labels: Sequence[str] probabilities: Sequence[float] _TransformerReTextClassificationTaskModule2: TypeAlias = TaskModule[ # _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput TextDocument, TransformerReTextClassificationInputEncoding2, TransformerReTextClassificationTargetEncoding2, TransformerTextClassificationModelStepBatchEncoding, TransformerTextClassificationModelBatchOutput, TransformerReTextClassificationTaskOutput2, ] HEAD = "head" TAIL = "tail" START = "start" END = "end" logger = logging.getLogger(__name__) class RelationArgument: def __init__( self, entity: LabeledSpan, role: str, offsets: Tuple[int, int], add_type_to_marker: bool, ) -> None: self.entity = entity self.role = role assert self.role in (HEAD, TAIL) self.offsets = offsets self.add_type_to_marker = add_type_to_marker @property def is_head(self) -> bool: return self.role == HEAD @property def is_tail(self) -> bool: return self.role == TAIL @property def as_start_marker(self) -> str: return self._get_marker(is_start=True) @property def as_end_marker(self) -> str: return self._get_marker(is_start=False) def _get_marker(self, is_start: bool = True) -> str: return f"[{'' if is_start else '/'}{'H' if self.is_head else 'T'}" + ( f":{self.entity.label}]" if self.add_type_to_marker else "]" ) @property def as_append_marker(self) -> str: return f"[{'H' if self.is_head else 'T'}={self.entity.label}]" def _enumerate_entity_pairs( entities: Sequence[Span], partition: Optional[Span] = None, relations: Optional[Sequence[BinaryRelation]] = None, ): """Given a list of `entities` iterate all valid pairs of entities, including inverted pairs. If a `partition` is provided, restrict pairs to be contained in that. If `relations` are given, return only pairs for which a predefined relation exists (e.g. in the case of relation classification for train,val,test splits in supervised datasets). """ existing_head_tail = {(relation.head, relation.tail) for relation in relations or []} for head in entities: if partition is not None and not is_contained_in( (head.start, head.end), (partition.start, partition.end) ): continue for tail in entities: if partition is not None and not is_contained_in( (tail.start, tail.end), (partition.start, partition.end) ): continue if head == tail: continue if relations is not None and (head, tail) not in existing_head_tail: continue yield head, tail @TaskModule.register() class TransformerRETextClassificationTaskModule2(_TransformerReTextClassificationTaskModule2): """Marker based relation extraction. This taskmodule prepares the input token ids in such a way that before and after the candidate head and tail entities special marker tokens are inserted. Then, the modified token ids can be simply passed into a transformer based text classifier model. parameters: partition_annotation: str, optional. If specified, LabeledSpan annotations with this name are expected to define partitions of the document that will be processed individually, e.g. sentences or sections of the document text. none_label: str, defaults to "no_relation". The relation label that indicate dummy/negative relations. Predicted relations with that label will not be added to the document(s). max_window: int, optional. If specified, use the tokens in a window of maximal this amount of tokens around the center of head and tail entities and pass only that into the transformer. """ PREPARED_ATTRIBUTES = ["label_to_id", "entity_labels"] def __init__( self, tokenizer_name_or_path: str, entity_annotation: str = "entities", relation_annotation: str = "relations", partition_annotation: Optional[str] = None, none_label: str = "no_relation", padding: Union[bool, str, PaddingStrategy] = True, truncation: Union[bool, str, TruncationStrategy] = True, max_length: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, multi_label: bool = False, label_to_id: Optional[Dict[str, int]] = None, add_type_to_marker: bool = False, single_argument_pair: bool = True, append_markers: bool = False, entity_labels: Optional[List[str]] = None, max_window: Optional[int] = None, log_first_n_examples: Optional[int] = None, **kwargs, ) -> None: super().__init__(**kwargs) self.save_hyperparameters() self.entity_annotation = entity_annotation self.relation_annotation = relation_annotation self.padding = padding self.truncation = truncation self.label_to_id = label_to_id or {} self.id_to_label = {v: k for k, v in self.label_to_id.items()} self.max_length = max_length self.pad_to_multiple_of = pad_to_multiple_of self.multi_label = multi_label self.add_type_to_marker = add_type_to_marker self.single_argument_pair = single_argument_pair self.append_markers = append_markers self.entity_labels = entity_labels self.partition_annotation = partition_annotation self.none_label = none_label self.max_window = max_window self.log_first_n_examples = log_first_n_examples self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) self.argument_markers = None self._logged_examples_counter = 0 def _prepare(self, documents: Sequence[TextDocument]) -> None: entity_labels: Set[str] = set() relation_labels: Set[str] = set() for document in documents: entities: Sequence[LabeledSpan] = document[self.entity_annotation] relations: Sequence[BinaryRelation] = document[self.relation_annotation] for entity in entities: entity_labels.add(entity.label) for relation in relations: relation_labels.add(relation.label) if self.none_label in relation_labels: relation_labels.remove(self.none_label) self.label_to_id = {label: i + 1 for i, label in enumerate(sorted(relation_labels))} self.label_to_id[self.none_label] = 0 self.entity_labels = sorted(entity_labels) def _post_prepare(self): self.argument_markers = self._initialize_argument_markers() self.tokenizer.add_tokens(self.argument_markers, special_tokens=True) self.argument_markers_to_id = { marker: self.tokenizer.vocab[marker] for marker in self.argument_markers } self.sep_token_id = self.tokenizer.vocab[self.tokenizer.sep_token] self.id_to_label = {v: k for k, v in self.label_to_id.items()} def _initialize_argument_markers(self) -> List[str]: argument_markers: Set[str] = set() for arg_type in [HEAD, TAIL]: for arg_pos in [START, END]: is_head = arg_type == HEAD is_start = arg_pos == START argument_markers.add(f"[{'' if is_start else '/'}{'H' if is_head else 'T'}]") if self.add_type_to_marker: for entity_type in self.entity_labels: # type: ignore argument_markers.add( f"[{'' if is_start else '/'}{'H' if is_head else 'T'}" f"{':' + entity_type if self.add_type_to_marker else ''}]" ) if self.append_markers: for entity_type in self.entity_labels: # type: ignore argument_markers.add(f"[{'H' if is_head else 'T'}={entity_type}]") return sorted(list(argument_markers)) def _encode_text( self, document: TextDocument, partition: Optional[Span] = None, add_special_tokens: bool = True, ) -> BatchEncoding: text = ( document.text[partition.start : partition.end] if partition is not None else document.text ) encoding = self.tokenizer( text, padding=False, truncation=self.truncation, max_length=self.max_length, is_split_into_words=False, return_offsets_mapping=False, add_special_tokens=add_special_tokens, ) return encoding def encode_input( self, document: TextDocument, is_training: bool = False, ) -> Optional[ Union[ TransformerReTextClassificationTaskEncoding2, Sequence[TransformerReTextClassificationTaskEncoding2], ] ]: assert ( self.argument_markers is not None ), "No argument markers available, was `prepare` already called?" entities: Sequence[Span] = document[self.entity_annotation] relations: Sequence[BinaryRelation] = document[self.relation_annotation] # if no relations are predefined, use None so that enumerate_entities yields all pairs # should be fixed to a parameter "generate_all" or "restrict_to_existing_relations" if len(relations) == 0: relations = None partitions: Sequence[Optional[Span]] if self.partition_annotation is not None: partitions = document[self.partition_annotation] else: # use single dummy partition partitions = [None] task_encodings: List[TransformerReTextClassificationTaskEncoding2] = [] for partition_idx, partition in enumerate(partitions): partition_offset = 0 if partition is None else partition.start add_special_tokens = self.max_window is None encoding = self._encode_text( document=document, partition=partition, add_special_tokens=add_special_tokens ) for (head, tail,) in _enumerate_entity_pairs( entities=entities, partition=partition, relations=relations, ): head_token_slice = get_token_slice( character_slice=(head.start, head.end), char_to_token_mapper=encoding.char_to_token, character_offset=partition_offset, ) tail_token_slice = get_token_slice( character_slice=(tail.start, tail.end), char_to_token_mapper=encoding.char_to_token, character_offset=partition_offset, ) # this happens if the head/tail start/end does not match a token start/end if head_token_slice is None or tail_token_slice is None: # if statistics is not None: # statistics["entity_token_alignment_error"][ # relation_mapping.get((head, tail), "TO_PREDICT") # ] += 1 logger.warning( f"Skipping invalid example {document.id}, cannot get token slice(s)" ) continue input_ids = encoding["input_ids"] # not sure if this is the correct way to get the tokens corresponding to the input_ids tokens = encoding.encodings[0].tokens # windowing if self.max_window is not None: head_start, head_end = head_token_slice tail_start, tail_end = tail_token_slice # The actual number of tokens will be lower than max_window because we add the # 4 marker tokens (before / after the head /tail) and the default special tokens # (e.g. CLS and SEP). num_added_special_tokens = len( self.tokenizer.build_inputs_with_special_tokens([]) ) max_tokens = self.max_window - 4 - num_added_special_tokens # the slice from the beginning of the first entity to the end of the second is required slice_required = (min(head_start, tail_start), max(head_end, tail_end)) window_slice = get_window_around_slice( slice=slice_required, max_window_size=max_tokens, available_input_length=len(input_ids), ) # this happens if slice_required does not fit into max_tokens if window_slice is None: # if statistics is not None: # statistics["out_of_token_window"][ # relation_mapping.get((head, tail), "TO_PREDICT") # ] += 1 continue window_start, window_end = window_slice input_ids = input_ids[window_start:window_end] head_token_slice = head_start - window_start, head_end - window_start tail_token_slice = tail_start - window_start, tail_end - window_start # maybe expand to n-ary relations? head_arg = RelationArgument(head, HEAD, head_token_slice, self.add_type_to_marker) tail_arg = RelationArgument(tail, TAIL, tail_token_slice, self.add_type_to_marker) arg_list = [head_arg, tail_arg] if head_token_slice[0] < tail_token_slice[0]: assert ( head_token_slice[1] <= tail_token_slice[0] ), f"the head and tail entities are not allowed to overlap in {document.id}" else: assert ( tail_token_slice[1] <= head_token_slice[0] ), f"the head and tail entities are not allowed to overlap in {document.id}" # expand to n-ary relations? arg_list.reverse() first_arg_start_id = self.argument_markers_to_id[arg_list[0].as_start_marker] first_arg_end_id = self.argument_markers_to_id[arg_list[0].as_end_marker] second_arg_start_id = self.argument_markers_to_id[arg_list[1].as_start_marker] second_arg_end_id = self.argument_markers_to_id[arg_list[1].as_end_marker] new_input_ids = ( input_ids[: arg_list[0].offsets[0]] + [first_arg_start_id] + input_ids[arg_list[0].offsets[0] : arg_list[0].offsets[1]] + [first_arg_end_id] + input_ids[arg_list[0].offsets[1] : arg_list[1].offsets[0]] + [second_arg_start_id] + input_ids[arg_list[1].offsets[0] : arg_list[1].offsets[1]] + [second_arg_end_id] + input_ids[arg_list[1].offsets[1] :] ) if self.append_markers: new_input_ids.extend( [ self.argument_markers_to_id[head_arg.as_append_marker], self.sep_token_id, self.argument_markers_to_id[tail_arg.as_append_marker], self.sep_token_id, ] ) # when windowing is used, we have to add the special tokens manually if not add_special_tokens: new_input_ids = self.tokenizer.build_inputs_with_special_tokens( token_ids_0=new_input_ids ) # lots of logging from here on log_this_example = ( relations is not None and self.log_first_n_examples is not None and self._logged_examples_counter <= self.log_first_n_examples ) if log_this_example: self._log_example(document, arg_list, new_input_ids, relations, tokens) task_encodings.append( TaskEncoding( document=document, inputs={"input_ids": new_input_ids}, metadata={ HEAD: head, TAIL: tail, }, ) ) return task_encodings def _log_example( self, document: TextDocument, arg_list: List[RelationArgument], input_ids: List[int], relations: Sequence[BinaryRelation], tokens: List[str], ): first_arg_start = arg_list[0].as_start_marker first_arg_end = arg_list[0].as_end_marker second_arg_start = arg_list[1].as_start_marker second_arg_end = arg_list[1].as_end_marker new_tokens = ( tokens[: arg_list[0].offsets[0]] + [first_arg_start] + tokens[arg_list[0].offsets[0] : arg_list[0].offsets[1]] + [first_arg_end] + tokens[arg_list[0].offsets[1] : arg_list[1].offsets[0]] + [second_arg_start] + tokens[arg_list[1].offsets[0] : arg_list[1].offsets[1]] + [second_arg_end] + tokens[arg_list[1].offsets[1] :] ) head_idx = 0 if arg_list[0].role == HEAD else 1 tail_idx = 0 if arg_list[0].role == TAIL else 1 if self.append_markers: head_marker = arg_list[head_idx].as_append_marker tail_marker = arg_list[tail_idx].as_append_marker new_tokens.extend( [head_marker, self.tokenizer.sep_token, tail_marker, self.tokenizer.sep_token] ) logger.info("*** Example ***") logger.info("doc id: %s", document.id) logger.info("tokens: %s", " ".join([str(x) for x in new_tokens])) logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) rel_labels = [relation.label for relation in relations] rel_label_ids = [self.label_to_id[label] for label in rel_labels] logger.info("Expected labels: %s (ids = %s)", rel_labels, rel_label_ids) self._logged_examples_counter += 1 def encode_target( self, task_encoding: TransformerReTextClassificationTaskEncoding2, ) -> TransformerReTextClassificationTargetEncoding2: metadata = task_encoding.metadata document = task_encoding.document relations: Sequence[BinaryRelation] = document[self.relation_annotation] head_tail_to_labels = { (relation.head, relation.tail): [relation.label] for relation in relations } labels = head_tail_to_labels.get((metadata[HEAD], metadata[TAIL]), [self.none_label]) target = [self.label_to_id[label] for label in labels] return target def unbatch_output( self, model_output: TransformerTextClassificationModelBatchOutput ) -> Sequence[TransformerReTextClassificationTaskOutput2]: logits = model_output["logits"] output_label_probs = logits.sigmoid() if self.multi_label else logits.softmax(dim=-1) output_label_probs = output_label_probs.detach().cpu().numpy() unbatched_output = [] if self.multi_label: raise NotImplementedError else: label_ids = np.argmax(output_label_probs, axis=-1) for batch_idx, label_id in enumerate(label_ids): label = self.id_to_label[label_id] prob = float(output_label_probs[batch_idx, label_id]) result: TransformerReTextClassificationTaskOutput2 = { "labels": [label], "probabilities": [prob], } unbatched_output.append(result) return unbatched_output def create_annotations_from_output( self, task_encoding: TransformerReTextClassificationTaskEncoding2, task_output: TransformerReTextClassificationTaskOutput2, ) -> Iterator[Tuple[str, Union[BinaryRelation, MultiLabeledBinaryRelation]]]: labels = task_output["labels"] probabilities = task_output["probabilities"] if labels != [self.none_label]: yield ( self.relation_annotation, BinaryRelation( head=task_encoding.metadata[HEAD], tail=task_encoding.metadata[TAIL], label=labels[0], score=probabilities[0], ), ) def collate( self, task_encodings: Sequence[TransformerReTextClassificationTaskEncoding2] ) -> TransformerTextClassificationModelStepBatchEncoding: input_features = [task_encoding.inputs for task_encoding in task_encodings] inputs: Dict[str, torch.Tensor] = self.tokenizer.pad( input_features, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt", ) if not task_encodings[0].has_targets: return inputs, None target_list: List[TransformerReTextClassificationTargetEncoding2] = [ task_encoding.targets for task_encoding in task_encodings ] targets = torch.tensor(target_list, dtype=torch.int64) if not self.multi_label: targets = targets.flatten() return inputs, targets