|
from transformers import Pipeline |
|
import torch |
|
|
|
class TBCP(Pipeline): |
|
def _sanitize_parameters(self, **kwargs): |
|
postprocess_kwargs = {} |
|
if "text_pair" in kwargs: |
|
postprocess_kwargs["top_k"] = kwargs["top_k"] |
|
return {}, {}, postprocess_kwargs |
|
|
|
def preprocess(self, text): |
|
return self.tokenizer(text, return_tensors="pt") |
|
|
|
def _forward(self, model_inputs): |
|
return self.model(**model_inputs) |
|
|
|
def postprocess(self, model_outputs,top_k = None): |
|
logits = model_outputs.logits |
|
probabilities = torch.nn.functional.softmax(logits, dim=-1) |
|
|
|
best_class = probabilities.argmax().item() |
|
label = self.model.config.id2label[best_class] |
|
score = probabilities.squeeze()[best_class].item() |
|
logits = logits.squeeze().tolist() |
|
return {"label": label, "score": score, "logits": logits} |