model_length
Browse files- hive_token_classification.py +20 -0
hive_token_classification.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict
|
2 |
+
from transformers import Pipeline, AutoModel, AutoTokenizer
|
3 |
+
from transformers.pipelines.base import GenericTensor, ModelOutput
|
4 |
+
|
5 |
+
|
6 |
+
class HiveTokenClassification(Pipeline):
|
7 |
+
def _sanitize_parameters(self, **kwargs):
|
8 |
+
forward_parameters = {}
|
9 |
+
if "output_style" in kwargs:
|
10 |
+
forward_parameters["output_style"] = kwargs["output_style"]
|
11 |
+
return {}, forward_parameters, {}
|
12 |
+
|
13 |
+
def preprocess(self, input_: Any, **preprocess_parameters: Dict) -> Dict[str, GenericTensor]:
|
14 |
+
return input_
|
15 |
+
|
16 |
+
def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput:
|
17 |
+
return self.model.predict(input_tensors, self.tokenizer, **forward_parameters)
|
18 |
+
|
19 |
+
def postprocess(self, model_outputs: ModelOutput, **postprocess_parameters: Dict) -> Any:
|
20 |
+
return {"output": model_outputs, "model_length": len(model_outputs)}
|