Spaces:
Runtime error
Runtime error
File size: 23,161 Bytes
386fb69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 |
"""
workflow:
Document
-> (InputEncoding, TargetEncoding) -> TaskEncoding -> TaskBatchEncoding
-> ModelBatchEncoding -> ModelBatchOutput
-> TaskOutput
-> Document
"""
import logging
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, TypedDict, Union
import numpy as np
import torch
from pytorch_ie.annotations import BinaryRelation, LabeledSpan, MultiLabeledBinaryRelation, Span
from pytorch_ie.core import TaskEncoding, TaskModule
from pytorch_ie.documents import TextDocument
from pytorch_ie.models import (
TransformerTextClassificationModelBatchOutput,
TransformerTextClassificationModelStepBatchEncoding,
)
from pytorch_ie.utils.span import get_token_slice, is_contained_in
from pytorch_ie.utils.window import get_window_around_slice
from transformers import AutoTokenizer
from transformers.file_utils import PaddingStrategy
from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy
from typing_extensions import TypeAlias
TransformerReTextClassificationInputEncoding2: TypeAlias = Dict[str, Any]
TransformerReTextClassificationTargetEncoding2: TypeAlias = Sequence[int]
TransformerReTextClassificationTaskEncoding2: TypeAlias = TaskEncoding[
TextDocument,
TransformerReTextClassificationInputEncoding2,
TransformerReTextClassificationTargetEncoding2,
]
class TransformerReTextClassificationTaskOutput2(TypedDict, total=False):
labels: Sequence[str]
probabilities: Sequence[float]
_TransformerReTextClassificationTaskModule2: TypeAlias = TaskModule[
# _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput
TextDocument,
TransformerReTextClassificationInputEncoding2,
TransformerReTextClassificationTargetEncoding2,
TransformerTextClassificationModelStepBatchEncoding,
TransformerTextClassificationModelBatchOutput,
TransformerReTextClassificationTaskOutput2,
]
HEAD = "head"
TAIL = "tail"
START = "start"
END = "end"
logger = logging.getLogger(__name__)
class RelationArgument:
def __init__(
self,
entity: LabeledSpan,
role: str,
offsets: Tuple[int, int],
add_type_to_marker: bool,
) -> None:
self.entity = entity
self.role = role
assert self.role in (HEAD, TAIL)
self.offsets = offsets
self.add_type_to_marker = add_type_to_marker
@property
def is_head(self) -> bool:
return self.role == HEAD
@property
def is_tail(self) -> bool:
return self.role == TAIL
@property
def as_start_marker(self) -> str:
return self._get_marker(is_start=True)
@property
def as_end_marker(self) -> str:
return self._get_marker(is_start=False)
def _get_marker(self, is_start: bool = True) -> str:
return f"[{'' if is_start else '/'}{'H' if self.is_head else 'T'}" + (
f":{self.entity.label}]" if self.add_type_to_marker else "]"
)
@property
def as_append_marker(self) -> str:
return f"[{'H' if self.is_head else 'T'}={self.entity.label}]"
def _enumerate_entity_pairs(
entities: Sequence[Span],
partition: Optional[Span] = None,
relations: Optional[Sequence[BinaryRelation]] = None,
):
"""Given a list of `entities` iterate all valid pairs of entities, including inverted pairs.
If a `partition` is provided, restrict pairs to be contained in that. If `relations` are given,
return only pairs for which a predefined relation exists (e.g. in the case of relation
classification for train,val,test splits in supervised datasets).
"""
existing_head_tail = {(relation.head, relation.tail) for relation in relations or []}
for head in entities:
if partition is not None and not is_contained_in(
(head.start, head.end), (partition.start, partition.end)
):
continue
for tail in entities:
if partition is not None and not is_contained_in(
(tail.start, tail.end), (partition.start, partition.end)
):
continue
if head == tail:
continue
if relations is not None and (head, tail) not in existing_head_tail:
continue
yield head, tail
@TaskModule.register()
class TransformerRETextClassificationTaskModule2(_TransformerReTextClassificationTaskModule2):
"""Marker based relation extraction. This taskmodule prepares the input token ids in such a way
that before and after the candidate head and tail entities special marker tokens are inserted.
Then, the modified token ids can be simply passed into a transformer based text classifier
model.
parameters:
partition_annotation: str, optional. If specified, LabeledSpan annotations with this name are
expected to define partitions of the document that will be processed individually, e.g. sentences
or sections of the document text.
none_label: str, defaults to "no_relation". The relation label that indicate dummy/negative relations.
Predicted relations with that label will not be added to the document(s).
max_window: int, optional. If specified, use the tokens in a window of maximal this amount of tokens
around the center of head and tail entities and pass only that into the transformer.
"""
PREPARED_ATTRIBUTES = ["label_to_id", "entity_labels"]
def __init__(
self,
tokenizer_name_or_path: str,
entity_annotation: str = "entities",
relation_annotation: str = "relations",
partition_annotation: Optional[str] = None,
none_label: str = "no_relation",
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = True,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
multi_label: bool = False,
label_to_id: Optional[Dict[str, int]] = None,
add_type_to_marker: bool = False,
single_argument_pair: bool = True,
append_markers: bool = False,
entity_labels: Optional[List[str]] = None,
max_window: Optional[int] = None,
log_first_n_examples: Optional[int] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.save_hyperparameters()
self.entity_annotation = entity_annotation
self.relation_annotation = relation_annotation
self.padding = padding
self.truncation = truncation
self.label_to_id = label_to_id or {}
self.id_to_label = {v: k for k, v in self.label_to_id.items()}
self.max_length = max_length
self.pad_to_multiple_of = pad_to_multiple_of
self.multi_label = multi_label
self.add_type_to_marker = add_type_to_marker
self.single_argument_pair = single_argument_pair
self.append_markers = append_markers
self.entity_labels = entity_labels
self.partition_annotation = partition_annotation
self.none_label = none_label
self.max_window = max_window
self.log_first_n_examples = log_first_n_examples
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
self.argument_markers = None
self._logged_examples_counter = 0
def _prepare(self, documents: Sequence[TextDocument]) -> None:
entity_labels: Set[str] = set()
relation_labels: Set[str] = set()
for document in documents:
entities: Sequence[LabeledSpan] = document[self.entity_annotation]
relations: Sequence[BinaryRelation] = document[self.relation_annotation]
for entity in entities:
entity_labels.add(entity.label)
for relation in relations:
relation_labels.add(relation.label)
if self.none_label in relation_labels:
relation_labels.remove(self.none_label)
self.label_to_id = {label: i + 1 for i, label in enumerate(sorted(relation_labels))}
self.label_to_id[self.none_label] = 0
self.entity_labels = sorted(entity_labels)
def _post_prepare(self):
self.argument_markers = self._initialize_argument_markers()
self.tokenizer.add_tokens(self.argument_markers, special_tokens=True)
self.argument_markers_to_id = {
marker: self.tokenizer.vocab[marker] for marker in self.argument_markers
}
self.sep_token_id = self.tokenizer.vocab[self.tokenizer.sep_token]
self.id_to_label = {v: k for k, v in self.label_to_id.items()}
def _initialize_argument_markers(self) -> List[str]:
argument_markers: Set[str] = set()
for arg_type in [HEAD, TAIL]:
for arg_pos in [START, END]:
is_head = arg_type == HEAD
is_start = arg_pos == START
argument_markers.add(f"[{'' if is_start else '/'}{'H' if is_head else 'T'}]")
if self.add_type_to_marker:
for entity_type in self.entity_labels: # type: ignore
argument_markers.add(
f"[{'' if is_start else '/'}{'H' if is_head else 'T'}"
f"{':' + entity_type if self.add_type_to_marker else ''}]"
)
if self.append_markers:
for entity_type in self.entity_labels: # type: ignore
argument_markers.add(f"[{'H' if is_head else 'T'}={entity_type}]")
return sorted(list(argument_markers))
def _encode_text(
self,
document: TextDocument,
partition: Optional[Span] = None,
add_special_tokens: bool = True,
) -> BatchEncoding:
text = (
document.text[partition.start : partition.end]
if partition is not None
else document.text
)
encoding = self.tokenizer(
text,
padding=False,
truncation=self.truncation,
max_length=self.max_length,
is_split_into_words=False,
return_offsets_mapping=False,
add_special_tokens=add_special_tokens,
)
return encoding
def encode_input(
self,
document: TextDocument,
is_training: bool = False,
) -> Optional[
Union[
TransformerReTextClassificationTaskEncoding2,
Sequence[TransformerReTextClassificationTaskEncoding2],
]
]:
assert (
self.argument_markers is not None
), "No argument markers available, was `prepare` already called?"
entities: Sequence[Span] = document[self.entity_annotation]
relations: Sequence[BinaryRelation] = document[self.relation_annotation]
partitions: Sequence[Optional[Span]]
if self.partition_annotation is not None:
partitions = document[self.partition_annotation]
else:
# use single dummy partition
partitions = [None]
task_encodings: List[TransformerReTextClassificationTaskEncoding2] = []
for partition_idx, partition in enumerate(partitions):
partition_offset = 0 if partition is None else partition.start
add_special_tokens = self.max_window is None
encoding = self._encode_text(
document=document, partition=partition, add_special_tokens=add_special_tokens
)
for (head, tail,) in _enumerate_entity_pairs(
entities=entities,
partition=partition,
relations=relations,
):
head_token_slice = get_token_slice(
character_slice=(head.start, head.end),
char_to_token_mapper=encoding.char_to_token,
character_offset=partition_offset,
)
tail_token_slice = get_token_slice(
character_slice=(tail.start, tail.end),
char_to_token_mapper=encoding.char_to_token,
character_offset=partition_offset,
)
# this happens if the head/tail start/end does not match a token start/end
if head_token_slice is None or tail_token_slice is None:
# if statistics is not None:
# statistics["entity_token_alignment_error"][
# relation_mapping.get((head, tail), "TO_PREDICT")
# ] += 1
logger.warning(
f"Skipping invalid example {document.id}, cannot get token slice(s)"
)
continue
input_ids = encoding["input_ids"]
# not sure if this is the correct way to get the tokens corresponding to the input_ids
tokens = encoding.encodings[0].tokens
# windowing
if self.max_window is not None:
head_start, head_end = head_token_slice
tail_start, tail_end = tail_token_slice
# The actual number of tokens will be lower than max_window because we add the
# 4 marker tokens (before / after the head /tail) and the default special tokens
# (e.g. CLS and SEP).
num_added_special_tokens = len(
self.tokenizer.build_inputs_with_special_tokens([])
)
max_tokens = self.max_window - 4 - num_added_special_tokens
# the slice from the beginning of the first entity to the end of the second is required
slice_required = (min(head_start, tail_start), max(head_end, tail_end))
window_slice = get_window_around_slice(
slice=slice_required,
max_window_size=max_tokens,
available_input_length=len(input_ids),
)
# this happens if slice_required does not fit into max_tokens
if window_slice is None:
# if statistics is not None:
# statistics["out_of_token_window"][
# relation_mapping.get((head, tail), "TO_PREDICT")
# ] += 1
continue
window_start, window_end = window_slice
input_ids = input_ids[window_start:window_end]
head_token_slice = head_start - window_start, head_end - window_start
tail_token_slice = tail_start - window_start, tail_end - window_start
# maybe expand to n-ary relations?
head_arg = RelationArgument(head, HEAD, head_token_slice, self.add_type_to_marker)
tail_arg = RelationArgument(tail, TAIL, tail_token_slice, self.add_type_to_marker)
arg_list = [head_arg, tail_arg]
if head_token_slice[0] < tail_token_slice[0]:
assert (
head_token_slice[1] <= tail_token_slice[0]
), f"the head and tail entities are not allowed to overlap in {document.id}"
else:
assert (
tail_token_slice[1] <= head_token_slice[0]
), f"the head and tail entities are not allowed to overlap in {document.id}"
# expand to n-ary relations?
arg_list.reverse()
first_arg_start_id = self.argument_markers_to_id[arg_list[0].as_start_marker]
first_arg_end_id = self.argument_markers_to_id[arg_list[0].as_end_marker]
second_arg_start_id = self.argument_markers_to_id[arg_list[1].as_start_marker]
second_arg_end_id = self.argument_markers_to_id[arg_list[1].as_end_marker]
new_input_ids = (
input_ids[: arg_list[0].offsets[0]]
+ [first_arg_start_id]
+ input_ids[arg_list[0].offsets[0] : arg_list[0].offsets[1]]
+ [first_arg_end_id]
+ input_ids[arg_list[0].offsets[1] : arg_list[1].offsets[0]]
+ [second_arg_start_id]
+ input_ids[arg_list[1].offsets[0] : arg_list[1].offsets[1]]
+ [second_arg_end_id]
+ input_ids[arg_list[1].offsets[1] :]
)
if self.append_markers:
new_input_ids.extend(
[
self.argument_markers_to_id[head_arg.as_append_marker],
self.sep_token_id,
self.argument_markers_to_id[tail_arg.as_append_marker],
self.sep_token_id,
]
)
# when windowing is used, we have to add the special tokens manually
if not add_special_tokens:
new_input_ids = self.tokenizer.build_inputs_with_special_tokens(
token_ids_0=new_input_ids
)
# lots of logging from here on
log_this_example = (
self.log_first_n_examples is not None
and self._logged_examples_counter <= self.log_first_n_examples
)
if log_this_example:
self._log_example(document, arg_list, new_input_ids, relations, tokens)
task_encodings.append(
TaskEncoding(
document=document,
inputs={"input_ids": new_input_ids},
metadata={
HEAD: head,
TAIL: tail,
},
)
)
return task_encodings
def _log_example(
self,
document: TextDocument,
arg_list: List[RelationArgument],
input_ids: List[int],
relations: Sequence[BinaryRelation],
tokens: List[str],
):
first_arg_start = arg_list[0].as_start_marker
first_arg_end = arg_list[0].as_end_marker
second_arg_start = arg_list[1].as_start_marker
second_arg_end = arg_list[1].as_end_marker
new_tokens = (
tokens[: arg_list[0].offsets[0]]
+ [first_arg_start]
+ tokens[arg_list[0].offsets[0] : arg_list[0].offsets[1]]
+ [first_arg_end]
+ tokens[arg_list[0].offsets[1] : arg_list[1].offsets[0]]
+ [second_arg_start]
+ tokens[arg_list[1].offsets[0] : arg_list[1].offsets[1]]
+ [second_arg_end]
+ tokens[arg_list[1].offsets[1] :]
)
head_idx = 0 if arg_list[0].role == HEAD else 1
tail_idx = 0 if arg_list[0].role == TAIL else 1
if self.append_markers:
head_marker = arg_list[head_idx].as_append_marker
tail_marker = arg_list[tail_idx].as_append_marker
new_tokens.extend(
[head_marker, self.tokenizer.sep_token, tail_marker, self.tokenizer.sep_token]
)
logger.info("*** Example ***")
logger.info("doc id: %s", document.id)
logger.info("tokens: %s", " ".join([str(x) for x in new_tokens]))
logger.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
rel_labels = [relation.label for relation in relations]
rel_label_ids = [self.label_to_id[label] for label in rel_labels]
logger.info("Expected labels: %s (ids = %s)", rel_labels, rel_label_ids)
self._logged_examples_counter += 1
def encode_target(
self,
task_encoding: TransformerReTextClassificationTaskEncoding2,
) -> TransformerReTextClassificationTargetEncoding2:
metadata = task_encoding.metadata
document = task_encoding.document
relations: Sequence[BinaryRelation] = document[self.relation_annotation]
head_tail_to_labels = {
(relation.head, relation.tail): [relation.label] for relation in relations
}
labels = head_tail_to_labels.get((metadata[HEAD], metadata[TAIL]), [self.none_label])
target = [self.label_to_id[label] for label in labels]
return target
def unbatch_output(
self, model_output: TransformerTextClassificationModelBatchOutput
) -> Sequence[TransformerReTextClassificationTaskOutput2]:
logits = model_output["logits"]
output_label_probs = logits.sigmoid() if self.multi_label else logits.softmax(dim=-1)
output_label_probs = output_label_probs.detach().cpu().numpy()
unbatched_output = []
if self.multi_label:
raise NotImplementedError
else:
label_ids = np.argmax(output_label_probs, axis=-1)
for batch_idx, label_id in enumerate(label_ids):
label = self.id_to_label[label_id]
prob = float(output_label_probs[batch_idx, label_id])
result: TransformerReTextClassificationTaskOutput2 = {
"labels": [label],
"probabilities": [prob],
}
unbatched_output.append(result)
return unbatched_output
def create_annotations_from_output(
self,
task_encoding: TransformerReTextClassificationTaskEncoding2,
task_output: TransformerReTextClassificationTaskOutput2,
) -> Iterator[Tuple[str, Union[BinaryRelation, MultiLabeledBinaryRelation]]]:
labels = task_output["labels"]
probabilities = task_output["probabilities"]
if labels != [self.none_label]:
yield (
self.relation_annotation,
BinaryRelation(
head=task_encoding.metadata[HEAD],
tail=task_encoding.metadata[TAIL],
label=labels[0],
score=probabilities[0],
),
)
def collate(
self, task_encodings: Sequence[TransformerReTextClassificationTaskEncoding2]
) -> TransformerTextClassificationModelStepBatchEncoding:
input_features = [task_encoding.inputs for task_encoding in task_encodings]
inputs: Dict[str, torch.Tensor] = self.tokenizer.pad(
input_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
if not task_encodings[0].has_targets:
return inputs, None
target_list: List[TransformerReTextClassificationTargetEncoding2] = [
task_encoding.targets for task_encoding in task_encodings
]
targets = torch.tensor(target_list, dtype=torch.int64)
if not self.multi_label:
targets = targets.flatten()
return inputs, targets
|