|
from typing import Dict, List, Any |
|
from setfit import SetFitModel |
|
import numpy as np |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
self.model = SetFitModel.from_pretrained(path) |
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str` | `PIL.Image` | `np.array`) |
|
kwargs |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
|
|
inputs = data.pop("inputs", data) |
|
if isinstance(inputs, str): |
|
inputs = [inputs] |
|
|
|
|
|
exerciselabels = ['positive experience', |
|
'power posing', |
|
'worry vs rumination', |
|
'self-confidence', |
|
'negative emotions', |
|
'sleep', |
|
'loneliness', |
|
'imaginary friend', |
|
'perfectionism', |
|
'negative self-talk', |
|
'woop', |
|
'venting', |
|
'worry window', |
|
'act of kindness', |
|
'blowing balloons', |
|
'feeling on anger', |
|
'power of smile', |
|
'body scan', |
|
'stress enhancing thoughts', |
|
'anger ball of fire', |
|
'emotions', |
|
'lean against wall', |
|
'breathing', |
|
'crossed arms', |
|
'energy traffic light', |
|
'boundaries', |
|
'Inner strength'] |
|
|
|
|
|
preds = self.model.predict(inputs) |
|
scores = self.model.predict_proba(inputs) |
|
|
|
label = [[el for el, p in zip(exerciselabels, ps) if p] for ps in preds] |
|
|
|
|
|
modified_label = label[0] |
|
|
|
|
|
modified_proba = [[inner[0][1]] for item, inner in zip(scores, scores)] |
|
|
|
|
|
score = [[el for el, p in zip(modified_proba, ps) if p] for ps in preds] |
|
|
|
|
|
modified_score = score[0] |
|
|
|
|
|
combined_dict = {key: value for key, value in zip(modified_label, modified_score)} |
|
|
|
output_array = [combined_dict] |
|
|
|
|
|
return label |