MayZhou's picture
Update handler.py
d620c51 verified
raw
history blame contribute delete
809 Bytes
import torch
import transformers
from transformers import AutoModelForSequenceClassification, AutoTokenizer
class EndpointHandler:
def __init__(self, path=""):
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForSequenceClassification.from_pretrained(path)
model.eval()
self.pipeline = transformers.pipeline(
"text-classification", model=model, tokenizer=tokenizer
)
def __call__(self, data):
inputs = data.pop("inputs", data)
result = self.pipeline(inputs, truncation=True, padding=False, max_length=512)
for item in result:
if item['label'] == 'LABEL_0':
item['label'] = 'human-written'
else:
item['label'] = 'AI-generated'
return result