Steven C
Download model from HuggingFace to tmp folder for running on Lambda
9772e97 unverified
import base64
import io
import sys
import torch
import onnx
import onnxruntime as rt
from torchvision import transforms as T
from tokenizer_base import Tokenizer
from PIL import Image
from huggingface_hub import hf_hub_download, try_to_load_from_cache
class DocumentParserModel:
def __init__(self):
charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
img_size = (32, 128)
self.tokenizer_base = Tokenizer(charset)
self.transform = self.create_transform_pipeline(img_size)
self.ort_session = self.initialize_onnx_model()
def create_transform_pipeline(self, img_size):
transforms = [
T.Resize(img_size, T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(0.5, 0.5),
]
return T.Compose(transforms)
def initialize_onnx_model(self):
repo_id = "stevenchang/captcha"
filename = "captcha.onnx"
filepath = try_to_load_from_cache(repo_id, filename)
if isinstance(filepath, str):
model_file = filepath
else:
model_file = result = hf_hub_download(repo_id, filename)
onnx_model = onnx.load(model_file)
onnx.checker.check_model(onnx_model)
return rt.InferenceSession(model_file)
def load_image_from_base64(self, base64_string):
img_data = base64.b64decode(base64_string)
image_buffer = io.BytesIO(img_data)
try:
image = Image.open(image_buffer)
return image
except IOError:
print(f"Error: Cannot open image {image_blob}")
return None
def predict_text(self, image_blob):
with self.load_image_from_base64(image_blob) as img_org:
x = self.transform(img_org.convert("RGB")).unsqueeze(0)
ort_inputs = {self.ort_session.get_inputs()[0].name: x.cpu().numpy()}
logits = self.ort_session.run(None, ort_inputs)[0]
probs = torch.tensor(logits).softmax(-1)
preds, _ = self.tokenizer_base.decode(probs)
return preds[0]
if __name__ == "__main__":
import sys
doc_parser = DocumentParserModel()
if len(sys.argv) > 1:
image_blob = sys.argv[1]
result = doc_parser.predict_text(image_blob)
print(result)
else:
print("Please provide an image blob.")