File size: 2,707 Bytes
fb4253c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
from huggingface_hub import Repository
from typing import List, Union
from transformers import pipeline
from transformers.pipelines import PIPELINE_REGISTRY
from transformers import AutoModelForSequenceClassification, Pipeline
import torch
# from loguru import logger
class MyPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "hypothesis" in kwargs:
preprocess_kwargs["hypothesis"] = kwargs["hypothesis"]
return preprocess_kwargs, {}, {}
def __call__(
self,
sequences: Union[str, List[str]],
*args,
**kwargs,
):
if len(args) == 0:
pass
elif len(args) == 1 and "hypothesis" not in kwargs:
kwargs["hypothesis"] = args[0]
else:
raise ValueError(f"Unable to understand extra arguments {args}")
return super().__call__(sequences, **kwargs)
def preprocess(self, premise, hypothesis=None):
encode_inputs = self.tokenizer(
premise,
hypothesis,
# max_length=self.toke,
# return_token_type_ids=True,
truncation=True,
return_tensors="pt"
)
return {"input_ids": encode_inputs['input_ids']}
def _forward(self, input_ids):
outputs = self.model(input_ids['input_ids'])
return outputs
def postprocess(self, model_outputs):
prediction = torch.softmax(model_outputs["logits"][0], -1).tolist()
print(prediction)
label_names = ["entailment", "neutral", "contradiction"]
prediction = {name: round(float(pred) * 100, 1)
for pred, name in zip(prediction, label_names)}
return prediction
# PIPELINE_REGISTRY.register_pipeline(
# "test",
# pipeline_class=MyPipeline,
# pt_model=AutoModelForSequenceClassification,
# # default={"pt": ("MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli", "retina")},
# # type="text",
# )
# classifier = pipeline("test",
# model="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
# # tokenizer="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
# )
# output = classifier(
# "Angela Merkel is a politician in Germany and leader of the CDU",
# hypothesis="this is a test"
# )
# # logger.info(output)
# # repo = Repository("entailment-classifier",
# # clone_from="Tverous/entailment-classifier")
# classifier.save_pretrained("entailment-classifier")
# # repo.push_to_hub()
# logger.info("Finished")
|