File size: 809 Bytes
2424f5b d620c51 2424f5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
import transformers
from transformers import AutoModelForSequenceClassification, AutoTokenizer
class EndpointHandler:
def __init__(self, path=""):
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForSequenceClassification.from_pretrained(path)
model.eval()
self.pipeline = transformers.pipeline(
"text-classification", model=model, tokenizer=tokenizer
)
def __call__(self, data):
inputs = data.pop("inputs", data)
result = self.pipeline(inputs, truncation=True, padding=False, max_length=512)
for item in result:
if item['label'] == 'LABEL_0':
item['label'] = 'human-written'
else:
item['label'] = 'AI-generated'
return result |