shamaayan commited on
Commit
ccdca60
1 Parent(s): f3c3179

model_length

Browse files
Files changed (1) hide show
  1. hive_token_classification.py +20 -0
hive_token_classification.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+ from transformers import Pipeline, AutoModel, AutoTokenizer
3
+ from transformers.pipelines.base import GenericTensor, ModelOutput
4
+
5
+
6
+ class HiveTokenClassification(Pipeline):
7
+ def _sanitize_parameters(self, **kwargs):
8
+ forward_parameters = {}
9
+ if "output_style" in kwargs:
10
+ forward_parameters["output_style"] = kwargs["output_style"]
11
+ return {}, forward_parameters, {}
12
+
13
+ def preprocess(self, input_: Any, **preprocess_parameters: Dict) -> Dict[str, GenericTensor]:
14
+ return input_
15
+
16
+ def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput:
17
+ return self.model.predict(input_tensors, self.tokenizer, **forward_parameters)
18
+
19
+ def postprocess(self, model_outputs: ModelOutput, **postprocess_parameters: Dict) -> Any:
20
+ return {"output": model_outputs, "model_length": len(model_outputs)}