|
```python |
|
|
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
import torch |
|
import numpy as np |
|
import json |
|
|
|
class Prehibition: |
|
def __init__(self): |
|
model_name = 'wyluilipe/prehibiton-themes-clf' |
|
self.tokenizer = BertTokenizer.from_pretrained(model_name) |
|
self.model = BertForSequenceClassification.from_pretrained(model_name) |
|
|
|
def predict(self, text): |
|
tokenized = self.tokenizer.batch_encode_plus( |
|
[text], |
|
max_length = 512, |
|
pad_to_max_length=True, |
|
truncation=True, |
|
return_token_type_ids=False |
|
) |
|
tokens_ids, mask = torch.tensor(tokenized['input_ids']), torch.tensor(tokenized['attention_mask']) |
|
with torch.no_grad(): |
|
model_output = self.model(tokens_ids, mask) |
|
return np.argmax(model_output['logits']).item() |
|
|
|
``` |