from typing import Dict, List, Any | |
from punctuators.models.punc_cap_seg_model import PunctCapSegConfigONNX, PunctCapSegModelONNX | |
class PreTrainedPipeline(): | |
def __init__(self, path: str): | |
cfg: PunctCapSegConfigONNX = PunctCapSegConfigONNX( | |
directory=path, | |
spe_filename="sp.model", | |
model_filename="model.onnx", | |
config_filename="config.yaml", | |
) | |
self._punctuator: PunctCapSegModelONNX = PunctCapSegModelONNX(cfg) | |
def __call__(self, data: str) -> List[Dict]: | |
# Use list to generate a batch of size 1 | |
pred_texts: List[List[str]] = self._punctuator.infer([data]) | |
# Can't figure out how to make the text gen widget print multiple lines; use a '\n' for now. | |
outputs: List[Dict] = [{"generated_text": " \\n ".join(pred_texts[0])}] | |
return outputs | |