File size: 29,218 Bytes
38b91ed
5915b56
4fd1faf
5915b56
61675e4
c549c79
 
 
dd55486
4fd1faf
 
 
a84fd08
5915b56
2447e1e
dd55486
1a4acf2
4348c1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a3c1b9
 
4348c1b
dd55486
 
143f535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc71c13
143f535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fd1faf
 
e94a65a
 
 
4fd1faf
 
e94a65a
 
 
 
4fd1faf
e94a65a
 
4fd1faf
e94a65a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed9f086
7d319f7
e94a65a
 
 
 
 
 
 
 
4fd1faf
 
 
 
 
7966679
 
 
 
 
4fd1faf
 
 
 
 
b7624fb
4fd1faf
 
 
 
 
 
 
 
 
08dfe16
a84fd08
 
 
08dfe16
 
b7624fb
 
08dfe16
 
a84fd08
 
 
 
 
08dfe16
 
d97f1a5
 
 
 
 
 
 
 
 
4fd1faf
 
9ccbbd3
a1fe592
c549c79
 
34f241e
9ccbbd3
c549c79
 
f401b99
 
4fd1faf
 
4bd2480
4fd1faf
 
 
 
 
 
 
 
 
 
 
 
18246b0
4fd1faf
720e26b
4fd1faf
 
720e26b
 
 
 
 
 
 
 
 
 
4fd1faf
 
 
 
0fa92dc
 
d15b7c4
 
0fa92dc
 
fdcd171
5c23370
2978d65
e9b77f7
ff295cd
9ccbbd3
 
fdcd171
 
 
 
 
862237e
fdcd171
862237e
 
 
 
 
 
 
 
 
ff295cd
 
9ccbbd3
ff295cd
862237e
 
5c23370
862237e
fdcd171
862237e
 
 
 
fdcd171
 
 
 
5c23370
 
4fc0c31
 
5c23370
862237e
4fc0c31
 
 
 
862237e
4fc0c31
 
5c23370
 
 
 
9e65353
 
 
3d5edd0
 
9e65353
5d335a2
143f535
5d335a2
9e65353
734888b
 
 
 
c164669
5d335a2
 
734888b
 
f773b12
5d335a2
 
eb52b2b
734888b
 
30e8c4c
5d335a2
 
 
 
 
734888b
 
9e65353
3d5edd0
f773b12
9e65353
 
 
 
f773b12
9e65353
 
 
 
 
 
9ccbbd3
20284f5
 
9e65353
 
 
 
9ccbbd3
9e65353
 
 
906dab9
 
9e65353
 
 
 
7a3c1b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af2b95e
c08a424
af2b95e
 
c08a424
5988bd9
 
 
 
e9b77f7
2978d65
 
 
 
 
af2b95e
bc35ffb
 
 
e9b77f7
2978d65
a5ad8a0
 
7c8b752
a5ad8a0
 
9e65353
7a3c1b9
0968912
7a3c1b9
 
 
bc35ffb
af2b95e
 
2afe88b
 
 
 
 
 
9ccbbd3
 
e8bd3ea
9ccbbd3
e8bd3ea
 
76f796b
 
 
 
 
 
 
2afe88b
 
 
 
 
ba34bb6
 
 
 
2d1ae5b
ba34bb6
 
 
 
2d1ae5b
ba34bb6
 
 
 
797e71f
 
 
 
ba34bb6
 
 
2d1ae5b
 
 
797e71f
 
 
ba0e518
2d1ae5b
ba0e518
 
2d1ae5b
ba34bb6
 
0bc4f0b
4378340
 
 
 
7a3c1b9
 
 
1f946e5
9ccbbd3
7a3c1b9
 
1f946e5
4378340
 
 
b48099f
4378340
97a98d1
7a3c1b9
 
 
4378340
7a3c1b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97a98d1
4378340
97a98d1
4378340
97a98d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a3c1b9
97a98d1
 
 
 
 
 
 
 
7a3c1b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4378340
7a3c1b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4378340
 
1f946e5
7a3c1b9
 
 
 
 
 
 
97a98d1
2978d65
 
7a3c1b9
1f946e5
4fd1faf
2709af2
 
 
 
 
 
4fd1faf
 
2709af2
4fd1faf
2709af2
4fd1faf
51adc94
5915b56
18246b0
 
 
 
 
 
5915b56
 
18246b0
 
 
 
 
 
 
8c3be27
18246b0
 
 
c549c79
 
f401b99
 
 
 
c549c79
471ce47
53b6bfb
18246b0
 
 
 
53b6bfb
18246b0
 
 
 
 
6ff0d20
 
18246b0
5915b56
50980d2
18246b0
 
50980d2
 
18246b0
 
 
 
 
 
fdcd171
7795f59
45597c2
31a2ec9
349c15b
e9b77f7
4b21d68
ba34bb6
4b21d68
ba34bb6
 
c30f539
4b21d68
734888b
 
2978d65
eea7054
7795f59
 
f7dfd44
7795f59
 
0968912
7795f59
 
ba34bb6
7795f59
 
f74635c
 
31a2ec9
f9d6dfb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
import logging
from transformers import Pipeline
import numpy as np
import torch
import nltk

nltk.download("averaged_perceptron_tagger")
nltk.download("averaged_perceptron_tagger_eng")
nltk.download("stopwords")
from nltk.chunk import conlltags2tree
from nltk import pos_tag
from nltk.tree import Tree
import torch.nn.functional as F
import re, string

stop_words = set(nltk.corpus.stopwords.words("english"))
DEBUG = False
punctuation = (
    string.punctuation
    + "«»—…“”"
    + "—."
    + "–"
    + "’"
    + "‘"
    + "´"
    + "•"
    + "°"
    + "»"
    + "“"
    + "”"
    + "–"
    + "—"
    + "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"
)

# List of additional "strange" punctuation marks
# additional_punctuation = "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"


WHITESPACE_RULES = {
    "fr": {
        "pct_no_ws_before": [".", ",", ")", "]", "}", "°", "...", ".-", "%"],
        "pct_no_ws_after": ["(", "[", "{"],
        "pct_no_ws_before_after": ["'", "-"],
        "pct_number": [".", ","],
    },
    "de": {
        "pct_no_ws_before": [
            ".",
            ",",
            ")",
            "]",
            "}",
            "°",
            "...",
            "?",
            "!",
            ":",
            ";",
            ".-",
            "%",
        ],
        "pct_no_ws_after": ["(", "[", "{"],
        "pct_no_ws_before_after": ["'", "-"],
        "pct_number": [".", ","],
    },
    "other": {
        "pct_no_ws_before": [
            ".",
            ",",
            ")",
            "]",
            "}",
            "°",
            "...",
            "?",
            "!",
            ":",
            ";",
            ".-",
            "%",
        ],
        "pct_no_ws_after": ["(", "[", "{"],
        "pct_no_ws_before_after": ["'", "-"],
        "pct_number": [".", ","],
    },
}


def tokenize(text: str, language: str = "other") -> list[str]:
    """Apply whitespace rules to the given text and language, separating it into tokens.

    Args:
        text (str): The input text to separate into a list of tokens.
        language (str): Language of the text.

    Returns:
        list[str]: List of tokens with punctuation as separate tokens.
    """
    # text = add_spaces_around_punctuation(text)
    if not text:
        return []

    if language not in WHITESPACE_RULES:
        # Default behavior for languages without specific rules:
        # tokenize using standard whitespace splitting
        language = "other"

    wsrules = WHITESPACE_RULES[language]
    tokenized_text = []
    current_token = ""

    for char in text:
        if char in wsrules["pct_no_ws_before_after"]:
            if current_token:
                tokenized_text.append(current_token)
            tokenized_text.append(char)
            current_token = ""
        elif char in wsrules["pct_no_ws_before"] or char in wsrules["pct_no_ws_after"]:
            if current_token:
                tokenized_text.append(current_token)
            tokenized_text.append(char)
            current_token = ""
        elif char.isspace():
            if current_token:
                tokenized_text.append(current_token)
                current_token = ""
        else:
            current_token += char

    if current_token:
        tokenized_text.append(current_token)

    return tokenized_text


def normalize_text(text):
    # Remove spaces and tabs for the search but keep newline characters
    return re.sub(r"[ \t]+", "", text)


def find_entity_indices(article_text, search_text):
    # Normalize texts by removing spaces and tabs
    normalized_article = normalize_text(article_text)
    normalized_search = normalize_text(search_text)

    # Initialize a list to hold all start and end indices
    indices = []

    # Find all occurrences of the search text in the normalized article text
    start_index = 0
    while True:
        start_index = normalized_article.find(normalized_search, start_index)
        if start_index == -1:
            break

        # Calculate the actual start and end indices in the original article text
        original_chars = 0
        original_start_index = 0
        for i in range(start_index):
            while article_text[original_start_index] in (" ", "\t"):
                original_start_index += 1
            if article_text[original_start_index] not in (" ", "\t", "\n"):
                original_chars += 1
            original_start_index += 1

        original_end_index = original_start_index
        search_chars = 0
        while search_chars < len(normalized_search):
            if article_text[original_end_index] not in (" ", "\t", "\n"):
                search_chars += 1
            original_end_index += 1  # Increment to include the last character

        # Append the found indices to the list
        if article_text[original_start_index] == " ":
            original_start_index += 1
        indices.append((original_start_index, original_end_index))

        # Move start_index to the next position to continue searching
        start_index += 1

    return indices


def get_entities(tokens, tags, confidences, text):

    tags = [tag.replace("S-", "B-").replace("E-", "I-") for tag in tags]
    pos_tags = [pos for token, pos in pos_tag(tokens)]

    for i in range(1, len(tags)):
        # If a 'B-' tag is followed by another 'B-' without an 'O' in between, change the second to 'I-'
        if tags[i].startswith("B-") and tags[i - 1].startswith("I-"):
            tags[i] = "I-" + tags[i][2:]  # Change 'B-' to 'I-' for the same entity type

    conlltags = [(token, pos, tg) for token, pos, tg in zip(tokens, pos_tags, tags)]
    ne_tree = conlltags2tree(conlltags)

    entities = []
    idx: int = 0
    already_done = []
    for subtree in ne_tree:
        # skipping 'O' tags
        if isinstance(subtree, Tree):
            original_label = subtree.label()
            original_string = " ".join([token for token, pos in subtree.leaves()])

            for indices in find_entity_indices(text, original_string):
                entity_start_position = indices[0]
                entity_end_position = indices[1]
                if (
                    "_".join(
                        [original_label, original_string, str(entity_start_position)]
                    )
                    in already_done
                ):
                    continue
                else:
                    already_done.append(
                        "_".join(
                            [
                                original_label,
                                original_string,
                                str(entity_start_position),
                            ]
                        )
                    )
                if len(text[entity_start_position:entity_end_position].strip()) < len(
                    text[entity_start_position:entity_end_position]
                ):
                    entity_start_position = (
                        entity_start_position
                        + len(text[entity_start_position:entity_end_position])
                        - len(text[entity_start_position:entity_end_position].strip())
                    )

                entities.append(
                    {
                        "type": original_label,
                        "confidence_ner": round(
                            np.average(confidences[idx : idx + len(subtree)]) * 100, 2
                        ),
                        "index": (idx, idx + len(subtree)),
                        "surface": text[
                            entity_start_position:entity_end_position
                        ],  # original_string,
                        "lOffset": entity_start_position,
                        "rOffset": entity_end_position,
                    }
                )

            idx += len(subtree)

            # Update the current character position
            # We add the length of the original string + 1 (for the space)
        else:
            token, pos = subtree
            # If it's not a named entity, we still need to update the character
            # position
            idx += 1

    return entities


def realign(
    text_sentence, out_label_preds, softmax_scores, tokenizer, reverted_label_map
):
    preds_list, words_list, confidence_list = [], [], []
    word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids()
    for idx, word in enumerate(text_sentence):
        beginning_index = word_ids.index(idx)
        try:
            preds_list.append(reverted_label_map[out_label_preds[beginning_index]])
            confidence_list.append(max(softmax_scores[beginning_index]))
        except Exception as ex:  # the sentence was longer then max_length
            preds_list.append("O")
            confidence_list.append(0.0)
        words_list.append(word)

    return words_list, preds_list, confidence_list


def add_spaces_around_punctuation(text):
    # Add a space before and after all punctuation
    all_punctuation = string.punctuation + punctuation
    return re.sub(r"([{}])".format(re.escape(all_punctuation)), r" \1 ", text)


def attach_comp_to_closest(entities):
    # Define valid entity types that can receive a "comp.function" or "comp.name" attachment
    valid_entity_types = {"org", "pers", "org.ent", "pers.ind"}

    # Separate "comp.function" and "comp.name" entities from other entities
    comp_entities = [ent for ent in entities if ent["type"].startswith("comp")]
    other_entities = [ent for ent in entities if not ent["type"].startswith("comp")]

    for comp_entity in comp_entities:
        closest_entity = None
        min_distance = float("inf")

        # Find the closest non-"comp" entity that is valid for attaching
        for other_entity in other_entities:
            # Calculate distance between the comp entity and the other entity
            if comp_entity["lOffset"] > other_entity["rOffset"]:
                distance = comp_entity["lOffset"] - other_entity["rOffset"]
            elif comp_entity["rOffset"] < other_entity["lOffset"]:
                distance = other_entity["lOffset"] - comp_entity["rOffset"]
            else:
                distance = 0  # They overlap or touch

            # Ensure the entity type is valid and check for minimal distance
            if (
                distance < min_distance
                and other_entity["type"].split(".")[0] in valid_entity_types
            ):
                min_distance = distance
                closest_entity = other_entity

        # Attach the "comp.function" or "comp.name" if a valid entity is found
        if closest_entity:
            suffix = comp_entity["type"].split(".")[
                -1
            ]  # Extract the suffix (e.g., 'name', 'function')
            closest_entity[suffix] = comp_entity["surface"]  # Attach the text

    return other_entities


def conflicting_context(comp_entity, target_entity):
    """
    Determines if there is a conflict between the comp_entity and the target entity.
    Prevents incorrect name and function attachments by using a rule-based approach.
    """
    # Case 1: Check for correct function attachment to person or organization entities
    if comp_entity["type"].startswith("comp.function"):
        if not ("pers" in target_entity["type"] or "org" in target_entity["type"]):
            return True  # Conflict: Function should only attach to persons or organizations

    # Case 2: Avoid attaching comp.* entities to non-person, non-organization types (like locations)
    if "loc" in target_entity["type"]:
        return True  # Conflict: comp.* entities should not attach to locations or similar types

    return False  # No conflict


def extract_name_from_text(text, partial_name):
    """
    Extracts the full name from the entity's text based on the partial name.
    This function assumes that the full name starts with capitalized letters and does not
    include any words that come after the partial name.
    """
    # Split the text and partial name into words
    words = tokenize(text)
    partial_words = partial_name.split()

    if DEBUG:
        print("text:", text)
    if DEBUG:
        print("partial_name:", partial_name)

    # Find the position of the partial name in the word list
    for i, word in enumerate(words):
        if DEBUG:
            print(words, "---", words[i : i + len(partial_words)])
        if words[i : i + len(partial_words)] == partial_words:
            # Initialize full name with the partial name
            full_name = partial_words[:]

            if DEBUG:
                print("full_name:", full_name)

            # Check previous words and only add capitalized words (skip lowercase words)
            j = i - 1
            while j >= 0 and words[j][0].isupper():
                full_name.insert(0, words[j])
                j -= 1
                if DEBUG:
                    print("full_name:", full_name)

            # Return only the full name up to the partial name (ignore words after the name)
            return " ".join(full_name).strip()  # Join the words to form the full name

    # If not found, return the original text (as a fallback)
    return text.strip()


def repair_names_in_entities(entities):
    """
    This function repairs the names in the entities by extracting the full name
    from the text of the entity if a partial name (e.g., 'Washington') is incorrectly attached.
    """
    for entity in entities:
        if "name" in entity and "pers" in entity["type"]:
            name = entity["name"]
            text = entity["surface"]

            # Check if the attached name is part of the entity's text
            if name in text:
                # Extract the full name from the text by splitting around the attached name
                full_name = extract_name_from_text(entity["surface"], name)
                entity["name"] = (
                    full_name  # Replace the partial name with the full name
                )
        # if "name" not in entity:
        #     entity["name"] = entity["surface"]

    return entities


def clean_coarse_entities(entities):
    """
    This function removes entities that are not useful for the NEL process.
    """
    # Define a set of entity types that are considered useful for NEL
    useful_types = {
        "pers",  # Person
        "loc",  # Location
        "org",  # Organization
        "date",  # Product
        "time",  # Time
    }

    # Filter out entities that are not in the useful_types set unless they are comp.* entities
    cleaned_entities = [
        entity
        for entity in entities
        if entity["type"] in useful_types or "comp" in entity["type"]
    ]

    return cleaned_entities


def postprocess_entities(entities):
    # Step 1: Filter entities with the same text, keeping the one with the most dots in the 'entity' field
    entity_map = {}

    # Loop over the entities and prioritize the one with the most dots
    for entity in entities:
        entity_text = entity["surface"]
        num_dots = entity["type"].count(".")

        # If the entity text is new, or this entity has more dots, update the map
        if (
            entity_text not in entity_map
            or entity_map[entity_text]["type"].count(".") < num_dots
        ):
            entity_map[entity_text] = entity

    # Collect the filtered entities from the map
    filtered_entities = list(entity_map.values())

    # Step 2: Attach "comp.function" entities to the closest other entities
    filtered_entities = attach_comp_to_closest(filtered_entities)
    if DEBUG:
        print("After attach_comp_to_closest:", filtered_entities, "\n")
    filtered_entities = repair_names_in_entities(filtered_entities)
    if DEBUG:
        print("After repair_names_in_entities:", filtered_entities, "\n")

    # Step 3: Remove entities that are not useful for NEL
    # filtered_entities = clean_coarse_entities(filtered_entities)

    # filtered_entities = remove_blacklisted_entities(filtered_entities)

    return filtered_entities


def remove_included_entities(entities):
    # Loop through entities and remove those whose text is included in another with the same label
    final_entities = []
    for i, entity in enumerate(entities):
        is_included = False
        for other_entity in entities:
            if entity["surface"] != other_entity["surface"]:
                if "comp" in other_entity["type"]:
                    # Check if entity's text is a substring of another entity's text
                    if entity["surface"] in other_entity["surface"]:
                        is_included = True
                        break
                elif (
                    entity["type"].split(".")[0] in other_entity["type"].split(".")[0]
                    or other_entity["type"].split(".")[0]
                    in entity["type"].split(".")[0]
                ):
                    if entity["surface"] in other_entity["surface"]:
                        is_included = True
        if not is_included:
            final_entities.append(entity)
    return final_entities


def refine_entities_with_coarse(all_entities, coarse_entities):
    """
    Looks through all entities and refines them based on the coarse entities.
    If a surface match is found in the coarse entities and the types match,
    the entity's confidence_ner and type are updated based on the coarse entity.
    """
    # Create a dictionary for coarse entities based on surface and type for quick lookup
    coarse_lookup = {}
    for coarse_entity in coarse_entities:
        key = (coarse_entity["surface"], coarse_entity["type"].split(".")[0])
        coarse_lookup[key] = coarse_entity

    # Iterate through all entities and compare with the coarse entities
    for entity in all_entities:
        key = (
            entity["surface"],
            entity["type"].split(".")[0],
        )  # Use the coarse type for comparison

        if key in coarse_lookup:
            coarse_entity = coarse_lookup[key]
            # If a match is found, update the confidence_ner and type in the entity
            if entity["confidence_ner"] < coarse_entity["confidence_ner"]:
                entity["confidence_ner"] = coarse_entity["confidence_ner"]
                entity["type"] = coarse_entity[
                    "type"
                ]  # Update the type if the confidence is higher

    # No need to append to refined_entities, we're modifying in place
    for entity in all_entities:
        entity["type"] = entity["type"].split(".")[0]
    return all_entities


def remove_trailing_stopwords(entities):
    """
    This function removes stopwords and punctuation from both the beginning and end of each entity's text
    and repairs the lOffset and rOffset accordingly.
    """
    if DEBUG:
        print(f"Initial entities: {len(entities)}")
    new_entities = []
    for entity in entities:
        if "comp" not in entity["type"]:
            entity_text = entity["surface"]
            original_len = len(entity_text)

            # Initial offsets
            lOffset = entity.get("lOffset", 0)
            rOffset = entity.get("rOffset", original_len)

            # Remove stopwords and punctuation from the beginning
            i = 0
            while entity_text and (
                entity_text.split()[0].lower() in stop_words
                or entity_text[0] in punctuation
            ):
                if entity_text.split()[0].lower() in stop_words:
                    stopword_len = (
                        len(entity_text.split()[0]) + 1
                    )  # Adjust length for stopword and following space
                    entity_text = entity_text[stopword_len:]  # Remove leading stopword
                    lOffset += stopword_len  # Adjust the left offset
                    if DEBUG:
                        print(
                            f"Removed leading stopword from entity: {entity['surface']} --> {entity_text} ({entity['type']}"
                        )
                elif entity_text[0] in punctuation:
                    entity_text = entity_text[1:]  # Remove leading punctuation
                    lOffset += 1  # Adjust the left offset
                    if DEBUG:
                        print(
                            f"Removed leading punctuation from entity: {entity['surface']} --> {entity_text} ({entity['type']}"
                        )
                i += 1

            i = 0
            # Remove stopwords and punctuation from the end
            iteration = 0
            max_iterations = len(entity_text)  # Prevent infinite loops

            while entity_text and iteration < max_iterations:
                # Check if the last word is a stopword or the last character is punctuation
                last_word = entity_text.split()[-1] if entity_text.split() else ""
                last_char = entity_text[-1]

                if last_word.lower() in stop_words:
                    # Remove trailing stopword and adjust rOffset
                    stopword_len = len(last_word) + 1  # Include space before stopword
                    entity_text = entity_text[:-stopword_len].rstrip()
                    rOffset -= stopword_len
                    if DEBUG:
                        print(
                            f"Removed trailing stopword from entity: {entity_text} (rOffset={rOffset})"
                        )

                elif last_char in punctuation:
                    # Remove trailing punctuation and adjust rOffset
                    entity_text = entity_text[:-1].rstrip()
                    rOffset -= 1
                    if DEBUG:
                        print(
                            f"Removed trailing punctuation from entity: {entity_text} (rOffset={rOffset})"
                        )
                else:
                    # Exit loop if neither stopwords nor punctuation are found
                    break

                iteration += 1
                # print(f"ITERATION: {iteration} [{entity['surface']}] for {entity_text}")

            if len(entity_text.strip()) == 1:
                entities.remove(entity)
                if DEBUG:
                    print(f"Skipping entity: {entity_text}")
                continue
            # Skip certain entities based on rules
            if entity_text in string.punctuation:
                if DEBUG:
                    print(f"Skipping entity: {entity_text}")
                entities.remove(entity)
                continue
            # check now if its in stopwords
            if entity_text.lower() in stop_words:
                if DEBUG:
                    print(f"Skipping entity: {entity_text}")
                entities.remove(entity)
                continue
            # check now if the entire entity is a list of stopwords:
            if all([word.lower() in stop_words for word in entity_text.split()]):
                if DEBUG:
                    print(f"Skipping entity: {entity_text}")
                entities.remove(entity)
                continue
            # Check if the entire entity is made up of stopwords characters
            if all(
                [char.lower() in stop_words for char in entity_text if char.isalpha()]
            ):
                if DEBUG:
                    print(
                        f"Skipping entity: {entity_text} (all characters are stopwords)"
                    )
                entities.remove(entity)
                continue
            # check now if all entity is in a list of punctuation
            if all([word in string.punctuation for word in entity_text.split()]):
                if DEBUG:
                    print(
                        f"Skipping entity: {entity_text} (all characters are punctuation)"
                    )
                entities.remove(entity)
                continue
            if all(
                [
                    char.lower() in string.punctuation
                    for char in entity_text
                    if char.isalpha()
                ]
            ):
                if DEBUG:
                    print(
                        f"Skipping entity: {entity_text} (all characters are punctuation)"
                    )
                entities.remove(entity)
                continue

            # if it's a number and "time" no in it, then continue
            if entity_text.isdigit() and "time" not in entity["type"]:
                if DEBUG:
                    print(f"Skipping entity: {entity_text}")
                entities.remove(entity)
                continue

            if entity_text.startswith(" "):
                entity_text = entity_text[1:]
                # update lOffset, rOffset
                lOffset += 1
            if entity_text.endswith(" "):
                entity_text = entity_text[:-1]
                # update lOffset, rOffset
                rOffset -= 1

            # Update the entity surface and offsets
            entity["surface"] = entity_text
            entity["lOffset"] = lOffset
            entity["rOffset"] = rOffset

            # Remove the entity if the surface is empty after cleaning
            if len(entity["surface"].strip()) == 0:
                if DEBUG:
                    print(f"Deleted entity: {entity['surface']}")
                entities.remove(entity)
            else:
                new_entities.append(entity)

    if DEBUG:
        print(f"Remained entities: {len(new_entities)}")
    return new_entities

class MultitaskTokenClassificationPipeline(Pipeline):

    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        if "text" in kwargs:
            preprocess_kwargs["text"] = kwargs["text"]
        self.label_map = self.model.config.label_map
        self.id2label = {
            task: {id_: label for label, id_ in labels.items()}
            for task, labels in self.label_map.items()
        }
        return preprocess_kwargs, {}, {}

    def preprocess(self, text, **kwargs):

        tokenized_inputs = self.tokenizer(
            text, padding="max_length", truncation=True, max_length=512
        )

        text_sentence = tokenize(add_spaces_around_punctuation(text))
        return tokenized_inputs, text_sentence, text

    def _forward(self, inputs):
        inputs, text_sentences, text = inputs
        input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
            self.model.device
        )
        attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to(
            self.model.device
        )
        with torch.no_grad():
            outputs = self.model(input_ids, attention_mask)
        return outputs, text_sentences, text

    def is_within(self, entity1, entity2):
        """Check if entity1 is fully within the bounds of entity2."""
        return (
            entity1["lOffset"] >= entity2["lOffset"]
            and entity1["rOffset"] <= entity2["rOffset"]
        )

    def postprocess(self, outputs, **kwargs):
        """
        Postprocess the outputs of the model
        :param outputs:
        :param kwargs:
        :return:
        """
        tokens_result, text_sentence, text = outputs

        predictions = {}
        confidence_scores = {}
        for task, logits in tokens_result.logits.items():
            predictions[task] = torch.argmax(logits, dim=-1).tolist()[0]
            confidence_scores[task] = F.softmax(logits, dim=-1).tolist()[0]

        entities = {}
        for task in predictions.keys():
            words_list, preds_list, confidence_list = realign(
                text_sentence,
                predictions[task],
                confidence_scores[task],
                self.tokenizer,
                self.id2label[task],
            )

            entities[task] = get_entities(words_list, preds_list, confidence_list, text)

        # add titles to comp entities
        # from pprint import pprint

        # print("Before:")
        # pprint(entities)

        all_entities = []
        coarse_entities = []
        for key in entities:
            if key in ["NE-COARSE-LIT"]:
                coarse_entities = entities[key]
            all_entities.extend(entities[key])

        if DEBUG:
            print(all_entities)
        # print("After remove_included_entities:")
        all_entities = remove_included_entities(all_entities)
        if DEBUG:
            print("After remove_included_entities:", all_entities)
        all_entities = remove_trailing_stopwords(all_entities)
        if DEBUG:
            print("After remove_trailing_stopwords:", all_entities)
        all_entities = postprocess_entities(all_entities)
        if DEBUG:
            print("After postprocess_entities:", all_entities)
        all_entities = refine_entities_with_coarse(all_entities, coarse_entities)
        if DEBUG:
            print("After refine_entities_with_coarse:", all_entities)
        # print("After attach_comp_to_closest:")
        # pprint(all_entities)
        # print("\n")
        return all_entities