File size: 872 Bytes
46f724b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
```python

from transformers import BertTokenizer, BertForSequenceClassification
import torch
import numpy as np
import json

class Prehibition:
    def __init__(self):
        model_name = 'wyluilipe/prehibiton-themes-clf'
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.model = BertForSequenceClassification.from_pretrained(model_name)

    def predict(self, text):
        tokenized = self.tokenizer.batch_encode_plus(
            [text],
            max_length = 512,
            pad_to_max_length=True,
            truncation=True,
            return_token_type_ids=False
        )
        tokens_ids, mask = torch.tensor(tokenized['input_ids']), torch.tensor(tokenized['attention_mask'])
        with torch.no_grad():
            model_output = self.model(tokens_ids, mask)
        return np.argmax(model_output['logits']).item()

```