andrewrreed's picture
andrewrreed HF staff
update data dir
856f178
raw
history blame
1.96 kB
from typing import Dict, List, Any
from transformers import AutoTokenizer
from gector import GECToR, predict, load_verb_dict
class EndpointHandler:
def __init__(self, path=""):
self.model = GECToR.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.encode, self.decode = load_verb_dict("./data/verb-form-vocab.txt")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the input data and return the predicted results.
Args:
data (Dict[str, Any]): The input data dictionary containing the following keys:
- "inputs" (List[str]): A list of input strings to be processed.
- "n_iterations" (int, optional): The number of iterations for prediction. Defaults to 5.
- "batch_size" (int, optional): The batch size for prediction. Defaults to 2.
- "keep_confidence" (float, optional): The confidence threshold for keeping predictions. Defaults to 0.0.
- "min_error_prob" (float, optional): The minimum error probability for keeping predictions. Defaults to 0.0.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the predicted results for each input string.
"""
srcs = data["inputs"]
# Extract optional parameters from data, with defaults
n_iterations = data.get("n_iterations", 5)
batch_size = data.get("batch_size", 2)
keep_confidence = data.get("keep_confidence", 0.0)
min_error_prob = data.get("min_error_prob", 0.0)
return predict(
model=self.model,
tokenizer=self.tokenizer,
srcs=srcs,
encode=self.encode,
decode=self.decode,
keep_confidence=keep_confidence,
min_error_prob=min_error_prob,
n_iteration=n_iterations,
batch_size=batch_size,
)