File size: 2,607 Bytes
173d81c
f0c6f33
65c7bf6
 
 
 
173d81c
 
 
 
 
 
 
65c7bf6
 
173d81c
 
9a43936
 
 
 
 
 
 
 
 
 
 
65c7bf6
 
 
9a43936
 
 
 
 
 
 
 
65c7bf6
9a43936
 
 
65c7bf6
9a43936
 
65c7bf6
9a43936
 
 
 
 
 
 
 
 
 
65c7bf6
9a43936
 
 
 
65c7bf6
 
9a43936
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from transformers import BertForTokenClassification, BertTokenizer, AutoConfig
import torch
from typing import Dict, List, Any

class EndpointHandler:
    def __init__(self, path: str = "dejanseo/LinkBERT"):
        # Load the configuration from the saved model
        self.config = AutoConfig.from_pretrained(path)

        self.model = BertForTokenClassification.from_pretrained(
            path,
            config=self.config
        )
        self.model.eval()  # Set model to evaluation mode

        self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased")

    def split_into_chunks(self, text: str, max_length: int = 510) -> List[str]:
        """
        Splits the input text into manageable chunks for the tokenizer.
        """
        tokens = self.tokenizer.tokenize(text)
        chunk_texts = []
        for i in range(0, len(tokens), max_length):
            chunk = tokens[i:i+max_length]
            chunk_texts.append(self.tokenizer.convert_tokens_to_string(chunk))
        return chunk_texts

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        inputs = data.get("inputs", "")

        # Split input text into chunks
        chunks = self.split_into_chunks(inputs)

        all_results = []  # List to store results from each chunk

        for chunk in chunks:
            inputs_tensor = self.tokenizer(chunk, return_tensors="pt", add_special_tokens=True)
            input_ids = inputs_tensor["input_ids"]

            with torch.no_grad():
                outputs = self.model(input_ids)
                predictions = torch.argmax(outputs.logits, dim=-1)

            tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])[1:-1]  # Exclude CLS and SEP tokens
            predictions = predictions[0][1:-1].tolist()

            # Improved reconstruction to handle "##" artifacts
            reconstructed_text = ""
            for token, pred in zip(tokens, predictions):
                if not token.startswith("##"):
                    reconstructed_text += " " + token if reconstructed_text else token
                else:
                    reconstructed_text += token[2:]  # Remove "##" and append
                
                if pred == 1:  # Example condition, adjust as needed
                    reconstructed_text = reconstructed_text.strip() + "<u>" + token + "</u>"

            all_results.append(reconstructed_text.strip())

        # Join the results from each chunk
        final_text = " ".join(all_results)
        
        # Return the processed text in a structured format
        return [{"text": final_text}]