File size: 23,161 Bytes
386fb69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
"""
workflow:
    Document
        -> (InputEncoding, TargetEncoding) -> TaskEncoding -> TaskBatchEncoding
            -> ModelBatchEncoding -> ModelBatchOutput
        -> TaskOutput
    -> Document
"""

import logging
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, TypedDict, Union

import numpy as np
import torch
from pytorch_ie.annotations import BinaryRelation, LabeledSpan, MultiLabeledBinaryRelation, Span
from pytorch_ie.core import TaskEncoding, TaskModule
from pytorch_ie.documents import TextDocument
from pytorch_ie.models import (
    TransformerTextClassificationModelBatchOutput,
    TransformerTextClassificationModelStepBatchEncoding,
)
from pytorch_ie.utils.span import get_token_slice, is_contained_in
from pytorch_ie.utils.window import get_window_around_slice
from transformers import AutoTokenizer
from transformers.file_utils import PaddingStrategy
from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy
from typing_extensions import TypeAlias

TransformerReTextClassificationInputEncoding2: TypeAlias = Dict[str, Any]
TransformerReTextClassificationTargetEncoding2: TypeAlias = Sequence[int]

TransformerReTextClassificationTaskEncoding2: TypeAlias = TaskEncoding[
    TextDocument,
    TransformerReTextClassificationInputEncoding2,
    TransformerReTextClassificationTargetEncoding2,
]


class TransformerReTextClassificationTaskOutput2(TypedDict, total=False):
    labels: Sequence[str]
    probabilities: Sequence[float]


_TransformerReTextClassificationTaskModule2: TypeAlias = TaskModule[
    # _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput
    TextDocument,
    TransformerReTextClassificationInputEncoding2,
    TransformerReTextClassificationTargetEncoding2,
    TransformerTextClassificationModelStepBatchEncoding,
    TransformerTextClassificationModelBatchOutput,
    TransformerReTextClassificationTaskOutput2,
]


HEAD = "head"
TAIL = "tail"
START = "start"
END = "end"


logger = logging.getLogger(__name__)


class RelationArgument:
    def __init__(
        self,
        entity: LabeledSpan,
        role: str,
        offsets: Tuple[int, int],
        add_type_to_marker: bool,
    ) -> None:
        self.entity = entity
        self.role = role
        assert self.role in (HEAD, TAIL)
        self.offsets = offsets
        self.add_type_to_marker = add_type_to_marker

    @property
    def is_head(self) -> bool:
        return self.role == HEAD

    @property
    def is_tail(self) -> bool:
        return self.role == TAIL

    @property
    def as_start_marker(self) -> str:
        return self._get_marker(is_start=True)

    @property
    def as_end_marker(self) -> str:
        return self._get_marker(is_start=False)

    def _get_marker(self, is_start: bool = True) -> str:
        return f"[{'' if is_start else '/'}{'H' if self.is_head else 'T'}" + (
            f":{self.entity.label}]" if self.add_type_to_marker else "]"
        )

    @property
    def as_append_marker(self) -> str:
        return f"[{'H' if self.is_head else 'T'}={self.entity.label}]"


def _enumerate_entity_pairs(
    entities: Sequence[Span],
    partition: Optional[Span] = None,
    relations: Optional[Sequence[BinaryRelation]] = None,
):
    """Given a list of `entities` iterate all valid pairs of entities, including inverted pairs.

    If a `partition` is provided, restrict pairs to be contained in that. If `relations` are given,
    return only pairs for which a predefined relation exists (e.g. in the case of relation
    classification for train,val,test splits in supervised datasets).
    """
    existing_head_tail = {(relation.head, relation.tail) for relation in relations or []}
    for head in entities:
        if partition is not None and not is_contained_in(
            (head.start, head.end), (partition.start, partition.end)
        ):
            continue

        for tail in entities:
            if partition is not None and not is_contained_in(
                (tail.start, tail.end), (partition.start, partition.end)
            ):
                continue

            if head == tail:
                continue

            if relations is not None and (head, tail) not in existing_head_tail:
                continue

            yield head, tail


@TaskModule.register()
class TransformerRETextClassificationTaskModule2(_TransformerReTextClassificationTaskModule2):
    """Marker based relation extraction. This taskmodule prepares the input token ids in such a way
    that before and after the candidate head and tail entities special marker tokens are inserted.
    Then, the modified token ids can be simply passed into a transformer based text classifier
    model.

    parameters:

        partition_annotation: str, optional. If specified, LabeledSpan annotations with this name are
            expected to define partitions of the document that will be processed individually, e.g. sentences
            or sections of the document text.
        none_label: str, defaults to "no_relation". The relation label that indicate dummy/negative relations.
            Predicted relations with that label will not be added to the document(s).
        max_window: int, optional. If specified, use the tokens in a window of maximal this amount of tokens
            around the center of head and tail entities and pass only that into the transformer.
    """

    PREPARED_ATTRIBUTES = ["label_to_id", "entity_labels"]

    def __init__(
        self,
        tokenizer_name_or_path: str,
        entity_annotation: str = "entities",
        relation_annotation: str = "relations",
        partition_annotation: Optional[str] = None,
        none_label: str = "no_relation",
        padding: Union[bool, str, PaddingStrategy] = True,
        truncation: Union[bool, str, TruncationStrategy] = True,
        max_length: Optional[int] = None,
        pad_to_multiple_of: Optional[int] = None,
        multi_label: bool = False,
        label_to_id: Optional[Dict[str, int]] = None,
        add_type_to_marker: bool = False,
        single_argument_pair: bool = True,
        append_markers: bool = False,
        entity_labels: Optional[List[str]] = None,
        max_window: Optional[int] = None,
        log_first_n_examples: Optional[int] = None,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.save_hyperparameters()

        self.entity_annotation = entity_annotation
        self.relation_annotation = relation_annotation
        self.padding = padding
        self.truncation = truncation
        self.label_to_id = label_to_id or {}
        self.id_to_label = {v: k for k, v in self.label_to_id.items()}
        self.max_length = max_length
        self.pad_to_multiple_of = pad_to_multiple_of
        self.multi_label = multi_label
        self.add_type_to_marker = add_type_to_marker
        self.single_argument_pair = single_argument_pair
        self.append_markers = append_markers
        self.entity_labels = entity_labels
        self.partition_annotation = partition_annotation
        self.none_label = none_label
        self.max_window = max_window
        self.log_first_n_examples = log_first_n_examples

        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)

        self.argument_markers = None

        self._logged_examples_counter = 0

    def _prepare(self, documents: Sequence[TextDocument]) -> None:
        entity_labels: Set[str] = set()
        relation_labels: Set[str] = set()
        for document in documents:
            entities: Sequence[LabeledSpan] = document[self.entity_annotation]
            relations: Sequence[BinaryRelation] = document[self.relation_annotation]

            for entity in entities:
                entity_labels.add(entity.label)

            for relation in relations:
                relation_labels.add(relation.label)

        if self.none_label in relation_labels:
            relation_labels.remove(self.none_label)

        self.label_to_id = {label: i + 1 for i, label in enumerate(sorted(relation_labels))}
        self.label_to_id[self.none_label] = 0

        self.entity_labels = sorted(entity_labels)

    def _post_prepare(self):
        self.argument_markers = self._initialize_argument_markers()
        self.tokenizer.add_tokens(self.argument_markers, special_tokens=True)

        self.argument_markers_to_id = {
            marker: self.tokenizer.vocab[marker] for marker in self.argument_markers
        }
        self.sep_token_id = self.tokenizer.vocab[self.tokenizer.sep_token]

        self.id_to_label = {v: k for k, v in self.label_to_id.items()}

    def _initialize_argument_markers(self) -> List[str]:
        argument_markers: Set[str] = set()
        for arg_type in [HEAD, TAIL]:
            for arg_pos in [START, END]:
                is_head = arg_type == HEAD
                is_start = arg_pos == START
                argument_markers.add(f"[{'' if is_start else '/'}{'H' if is_head else 'T'}]")
                if self.add_type_to_marker:
                    for entity_type in self.entity_labels:  # type: ignore
                        argument_markers.add(
                            f"[{'' if is_start else '/'}{'H' if is_head else 'T'}"
                            f"{':' + entity_type if self.add_type_to_marker else ''}]"
                        )
                if self.append_markers:
                    for entity_type in self.entity_labels:  # type: ignore
                        argument_markers.add(f"[{'H' if is_head else 'T'}={entity_type}]")

        return sorted(list(argument_markers))

    def _encode_text(
        self,
        document: TextDocument,
        partition: Optional[Span] = None,
        add_special_tokens: bool = True,
    ) -> BatchEncoding:
        text = (
            document.text[partition.start : partition.end]
            if partition is not None
            else document.text
        )
        encoding = self.tokenizer(
            text,
            padding=False,
            truncation=self.truncation,
            max_length=self.max_length,
            is_split_into_words=False,
            return_offsets_mapping=False,
            add_special_tokens=add_special_tokens,
        )
        return encoding

    def encode_input(
        self,
        document: TextDocument,
        is_training: bool = False,
    ) -> Optional[
        Union[
            TransformerReTextClassificationTaskEncoding2,
            Sequence[TransformerReTextClassificationTaskEncoding2],
        ]
    ]:

        assert (
            self.argument_markers is not None
        ), "No argument markers available, was `prepare` already called?"

        entities: Sequence[Span] = document[self.entity_annotation]

        relations: Sequence[BinaryRelation] = document[self.relation_annotation]

        partitions: Sequence[Optional[Span]]
        if self.partition_annotation is not None:
            partitions = document[self.partition_annotation]
        else:
            # use single dummy partition
            partitions = [None]

        task_encodings: List[TransformerReTextClassificationTaskEncoding2] = []
        for partition_idx, partition in enumerate(partitions):
            partition_offset = 0 if partition is None else partition.start
            add_special_tokens = self.max_window is None
            encoding = self._encode_text(
                document=document, partition=partition, add_special_tokens=add_special_tokens
            )

            for (head, tail,) in _enumerate_entity_pairs(
                entities=entities,
                partition=partition,
                relations=relations,
            ):
                head_token_slice = get_token_slice(
                    character_slice=(head.start, head.end),
                    char_to_token_mapper=encoding.char_to_token,
                    character_offset=partition_offset,
                )
                tail_token_slice = get_token_slice(
                    character_slice=(tail.start, tail.end),
                    char_to_token_mapper=encoding.char_to_token,
                    character_offset=partition_offset,
                )
                # this happens if the head/tail start/end does not match a token start/end
                if head_token_slice is None or tail_token_slice is None:
                    # if statistics is not None:
                    #     statistics["entity_token_alignment_error"][
                    #         relation_mapping.get((head, tail), "TO_PREDICT")
                    #     ] += 1
                    logger.warning(
                        f"Skipping invalid example {document.id}, cannot get token slice(s)"
                    )
                    continue

                input_ids = encoding["input_ids"]
                # not sure if this is the correct way to get the tokens corresponding to the input_ids
                tokens = encoding.encodings[0].tokens

                # windowing
                if self.max_window is not None:
                    head_start, head_end = head_token_slice
                    tail_start, tail_end = tail_token_slice
                    # The actual number of tokens will be lower than max_window because we add the
                    # 4 marker tokens (before / after the head /tail) and the default special tokens
                    # (e.g. CLS and SEP).
                    num_added_special_tokens = len(
                        self.tokenizer.build_inputs_with_special_tokens([])
                    )
                    max_tokens = self.max_window - 4 - num_added_special_tokens
                    # the slice from the beginning of the first entity to the end of the second is required
                    slice_required = (min(head_start, tail_start), max(head_end, tail_end))
                    window_slice = get_window_around_slice(
                        slice=slice_required,
                        max_window_size=max_tokens,
                        available_input_length=len(input_ids),
                    )
                    # this happens if slice_required does not fit into max_tokens
                    if window_slice is None:
                        # if statistics is not None:
                        #     statistics["out_of_token_window"][
                        #         relation_mapping.get((head, tail), "TO_PREDICT")
                        #     ] += 1
                        continue

                    window_start, window_end = window_slice
                    input_ids = input_ids[window_start:window_end]

                    head_token_slice = head_start - window_start, head_end - window_start
                    tail_token_slice = tail_start - window_start, tail_end - window_start

                # maybe expand to n-ary relations?
                head_arg = RelationArgument(head, HEAD, head_token_slice, self.add_type_to_marker)
                tail_arg = RelationArgument(tail, TAIL, tail_token_slice, self.add_type_to_marker)
                arg_list = [head_arg, tail_arg]

                if head_token_slice[0] < tail_token_slice[0]:
                    assert (
                        head_token_slice[1] <= tail_token_slice[0]
                    ), f"the head and tail entities are not allowed to overlap in {document.id}"

                else:
                    assert (
                        tail_token_slice[1] <= head_token_slice[0]
                    ), f"the head and tail entities are not allowed to overlap in {document.id}"
                    # expand to n-ary relations?
                    arg_list.reverse()

                first_arg_start_id = self.argument_markers_to_id[arg_list[0].as_start_marker]
                first_arg_end_id = self.argument_markers_to_id[arg_list[0].as_end_marker]
                second_arg_start_id = self.argument_markers_to_id[arg_list[1].as_start_marker]
                second_arg_end_id = self.argument_markers_to_id[arg_list[1].as_end_marker]

                new_input_ids = (
                    input_ids[: arg_list[0].offsets[0]]
                    + [first_arg_start_id]
                    + input_ids[arg_list[0].offsets[0] : arg_list[0].offsets[1]]
                    + [first_arg_end_id]
                    + input_ids[arg_list[0].offsets[1] : arg_list[1].offsets[0]]
                    + [second_arg_start_id]
                    + input_ids[arg_list[1].offsets[0] : arg_list[1].offsets[1]]
                    + [second_arg_end_id]
                    + input_ids[arg_list[1].offsets[1] :]
                )

                if self.append_markers:

                    new_input_ids.extend(
                        [
                            self.argument_markers_to_id[head_arg.as_append_marker],
                            self.sep_token_id,
                            self.argument_markers_to_id[tail_arg.as_append_marker],
                            self.sep_token_id,
                        ]
                    )

                # when windowing is used, we have to add the special tokens manually
                if not add_special_tokens:
                    new_input_ids = self.tokenizer.build_inputs_with_special_tokens(
                        token_ids_0=new_input_ids
                    )

                # lots of logging from here on
                log_this_example = (
                    self.log_first_n_examples is not None
                    and self._logged_examples_counter <= self.log_first_n_examples
                )
                if log_this_example:
                    self._log_example(document, arg_list, new_input_ids, relations, tokens)

                task_encodings.append(
                    TaskEncoding(
                        document=document,
                        inputs={"input_ids": new_input_ids},
                        metadata={
                            HEAD: head,
                            TAIL: tail,
                        },
                    )
                )

        return task_encodings

    def _log_example(
        self,
        document: TextDocument,
        arg_list: List[RelationArgument],
        input_ids: List[int],
        relations: Sequence[BinaryRelation],
        tokens: List[str],
    ):

        first_arg_start = arg_list[0].as_start_marker
        first_arg_end = arg_list[0].as_end_marker
        second_arg_start = arg_list[1].as_start_marker
        second_arg_end = arg_list[1].as_end_marker
        new_tokens = (
            tokens[: arg_list[0].offsets[0]]
            + [first_arg_start]
            + tokens[arg_list[0].offsets[0] : arg_list[0].offsets[1]]
            + [first_arg_end]
            + tokens[arg_list[0].offsets[1] : arg_list[1].offsets[0]]
            + [second_arg_start]
            + tokens[arg_list[1].offsets[0] : arg_list[1].offsets[1]]
            + [second_arg_end]
            + tokens[arg_list[1].offsets[1] :]
        )

        head_idx = 0 if arg_list[0].role == HEAD else 1
        tail_idx = 0 if arg_list[0].role == TAIL else 1

        if self.append_markers:
            head_marker = arg_list[head_idx].as_append_marker
            tail_marker = arg_list[tail_idx].as_append_marker
            new_tokens.extend(
                [head_marker, self.tokenizer.sep_token, tail_marker, self.tokenizer.sep_token]
            )
        logger.info("*** Example ***")
        logger.info("doc id: %s", document.id)
        logger.info("tokens: %s", " ".join([str(x) for x in new_tokens]))
        logger.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
        rel_labels = [relation.label for relation in relations]
        rel_label_ids = [self.label_to_id[label] for label in rel_labels]
        logger.info("Expected labels: %s (ids = %s)", rel_labels, rel_label_ids)

        self._logged_examples_counter += 1

    def encode_target(
        self,
        task_encoding: TransformerReTextClassificationTaskEncoding2,
    ) -> TransformerReTextClassificationTargetEncoding2:
        metadata = task_encoding.metadata
        document = task_encoding.document

        relations: Sequence[BinaryRelation] = document[self.relation_annotation]

        head_tail_to_labels = {
            (relation.head, relation.tail): [relation.label] for relation in relations
        }

        labels = head_tail_to_labels.get((metadata[HEAD], metadata[TAIL]), [self.none_label])
        target = [self.label_to_id[label] for label in labels]

        return target

    def unbatch_output(
        self, model_output: TransformerTextClassificationModelBatchOutput
    ) -> Sequence[TransformerReTextClassificationTaskOutput2]:
        logits = model_output["logits"]

        output_label_probs = logits.sigmoid() if self.multi_label else logits.softmax(dim=-1)
        output_label_probs = output_label_probs.detach().cpu().numpy()

        unbatched_output = []
        if self.multi_label:
            raise NotImplementedError
        else:
            label_ids = np.argmax(output_label_probs, axis=-1)
            for batch_idx, label_id in enumerate(label_ids):
                label = self.id_to_label[label_id]
                prob = float(output_label_probs[batch_idx, label_id])
                result: TransformerReTextClassificationTaskOutput2 = {
                    "labels": [label],
                    "probabilities": [prob],
                }
                unbatched_output.append(result)

        return unbatched_output

    def create_annotations_from_output(
        self,
        task_encoding: TransformerReTextClassificationTaskEncoding2,
        task_output: TransformerReTextClassificationTaskOutput2,
    ) -> Iterator[Tuple[str, Union[BinaryRelation, MultiLabeledBinaryRelation]]]:
        labels = task_output["labels"]
        probabilities = task_output["probabilities"]
        if labels != [self.none_label]:
            yield (
                self.relation_annotation,
                BinaryRelation(
                    head=task_encoding.metadata[HEAD],
                    tail=task_encoding.metadata[TAIL],
                    label=labels[0],
                    score=probabilities[0],
                ),
            )

    def collate(
        self, task_encodings: Sequence[TransformerReTextClassificationTaskEncoding2]
    ) -> TransformerTextClassificationModelStepBatchEncoding:
        input_features = [task_encoding.inputs for task_encoding in task_encodings]

        inputs: Dict[str, torch.Tensor] = self.tokenizer.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        if not task_encodings[0].has_targets:
            return inputs, None

        target_list: List[TransformerReTextClassificationTargetEncoding2] = [
            task_encoding.targets for task_encoding in task_encodings
        ]
        targets = torch.tensor(target_list, dtype=torch.int64)

        if not self.multi_label:
            targets = targets.flatten()

        return inputs, targets