|
from transformers import ( |
|
AutoModelForTokenClassification, |
|
AutoTokenizer, |
|
TokenClassificationPipeline, |
|
) |
|
from transformers.pipelines import AggregationStrategy |
|
import numpy as np |
|
|
|
|
|
class KeyphraseExtractionPipeline(TokenClassificationPipeline): |
|
def __init__(self, model, *args, **kwargs): |
|
super().__init__( |
|
model=AutoModelForTokenClassification.from_pretrained(model), |
|
tokenizer=AutoTokenizer.from_pretrained(model), |
|
*args, |
|
**kwargs |
|
) |
|
|
|
def postprocess(self, model_outputs): |
|
results = super().postprocess( |
|
model_outputs=model_outputs, |
|
aggregation_strategy=AggregationStrategy.SIMPLE |
|
if self.model.config.model_type == "roberta" |
|
else AggregationStrategy.FIRST, |
|
) |
|
return np.unique([result.get("word").strip() for result in results]) |
|
|