|
from typing import Dict |
|
from transformers import pipeline |
|
import requests |
|
import jwt |
|
import base64 |
|
|
|
SAMPLE_RATE = 16000 |
|
|
|
PUBLIC_KEY = b"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAu1SU1LfVLPHCozMxH2Mo\n4lgOEePzNm0tRgeLezV6ffAt0gunVTLw7onLRnrq0/IzW7yWR7QkrmBL7jTKEn5u\n+qKhbwKfBstIs+bMY2Zkp18gnTxKLxoS2tFczGkPLPgizskuemMghRniWaoLcyeh\nkd3qqGElvW/VDL5AaWTg0nLVkjRo9z+40RQzuVaE8AkAFmxZzow3x+VJYKdjykkJ\n0iT9wCS0DRTXu269V264Vf/3jvredZiKRkgwlL9xNAwxXFg0x/XFw005UWVRIkdg\ncKWTjpBP2dPwVZ4WWC+9aGVd+Gyn1o0CLelf4rEjGoXbAAEgAqeGUxrcIlbjXfbc\nmwIDAQAB\n-----END PUBLIC KEY-----" |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.pipeline = pipeline( |
|
"automatic-speech-recognition", model="openai/whisper-base" |
|
) |
|
|
|
def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]: |
|
""" |
|
Args: |
|
data (:obj:): |
|
includes the deserialized audio file as bytes |
|
Return: |
|
A :obj:`dict`:. base64 encoded image |
|
""" |
|
|
|
token = data.pop("token", None) |
|
if token is None: |
|
raise RuntimeError("missing token") |
|
decoded = jwt.decode(token, PUBLIC_KEY, algorithms=["RS512"]) |
|
print("received input from jti=", decoded["jti"]) |
|
inputs = data.pop("inputs", None) |
|
if isinstance(inputs, str): |
|
inputs = base64.b64decode(inputs) |
|
parameters = data.pop("parameters", {}) |
|
|
|
return self.pipeline(inputs, **parameters) |
|
|