data-silence
Add FastText model and inference code
851dcd9
raw
history blame
636 Bytes
import fasttext
from transformers import pipeline
class FastTextClassifierPipeline(pipeline):
def __init__(self, model_path):
self.model = fasttext.load_model(model_path)
def __call__(self, texts):
if isinstance(texts, str):
texts = [texts]
results = []
for text in texts:
prediction = self.model.predict(text)
label = prediction[0][0].replace("__label__", "")
score = prediction[1][0]
results.append({"label": label, "score": score})
return results
classifier = FastTextClassifierPipeline("fasttext_news_classifier.bin")