import fasttext | |
from transformers import pipeline | |
class FastTextClassifierPipeline(pipeline): | |
def __init__(self, model_path): | |
self.model = fasttext.load_model(model_path) | |
def __call__(self, texts): | |
if isinstance(texts, str): | |
texts = [texts] | |
results = [] | |
for text in texts: | |
prediction = self.model.predict(text) | |
label = prediction[0][0].replace("__label__", "") | |
score = prediction[1][0] | |
results.append({"label": label, "score": score}) | |
return results | |
classifier = FastTextClassifierPipeline("fasttext_news_classifier.bin") |