|
import torch |
|
import transformers |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
tokenizer = AutoTokenizer.from_pretrained(path) |
|
model = AutoModelForSequenceClassification.from_pretrained(path) |
|
model.eval() |
|
self.pipeline = transformers.pipeline( |
|
"text-classification", model=model, tokenizer=tokenizer |
|
) |
|
|
|
def __call__(self, data): |
|
inputs = data.pop("inputs", data) |
|
result = self.pipeline(inputs, truncation=True, padding=False, max_length=512) |
|
for item in result: |
|
if item['label'] == 'LABEL_0': |
|
item['label'] = 'human-written' |
|
else: |
|
item['label'] = 'AI-generated' |
|
return result |