|
import base64 |
|
import io |
|
import sys |
|
import torch |
|
import onnx |
|
import onnxruntime as rt |
|
from torchvision import transforms as T |
|
from tokenizer_base import Tokenizer |
|
from PIL import Image |
|
from huggingface_hub import hf_hub_download, try_to_load_from_cache |
|
|
|
|
|
class DocumentParserModel: |
|
def __init__(self): |
|
charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" |
|
img_size = (32, 128) |
|
|
|
self.tokenizer_base = Tokenizer(charset) |
|
self.transform = self.create_transform_pipeline(img_size) |
|
self.ort_session = self.initialize_onnx_model() |
|
|
|
def create_transform_pipeline(self, img_size): |
|
transforms = [ |
|
T.Resize(img_size, T.InterpolationMode.BICUBIC), |
|
T.ToTensor(), |
|
T.Normalize(0.5, 0.5), |
|
] |
|
return T.Compose(transforms) |
|
|
|
def initialize_onnx_model(self): |
|
repo_id = "stevenchang/captcha" |
|
filename = "captcha.onnx" |
|
|
|
filepath = try_to_load_from_cache(repo_id, filename) |
|
|
|
if isinstance(filepath, str): |
|
model_file = filepath |
|
else: |
|
model_file = result = hf_hub_download(repo_id, filename) |
|
|
|
onnx_model = onnx.load(model_file) |
|
onnx.checker.check_model(onnx_model) |
|
return rt.InferenceSession(model_file) |
|
|
|
def load_image_from_base64(self, base64_string): |
|
img_data = base64.b64decode(base64_string) |
|
image_buffer = io.BytesIO(img_data) |
|
|
|
try: |
|
image = Image.open(image_buffer) |
|
return image |
|
except IOError: |
|
print(f"Error: Cannot open image {image_blob}") |
|
return None |
|
|
|
def predict_text(self, image_blob): |
|
with self.load_image_from_base64(image_blob) as img_org: |
|
x = self.transform(img_org.convert("RGB")).unsqueeze(0) |
|
ort_inputs = {self.ort_session.get_inputs()[0].name: x.cpu().numpy()} |
|
logits = self.ort_session.run(None, ort_inputs)[0] |
|
probs = torch.tensor(logits).softmax(-1) |
|
preds, _ = self.tokenizer_base.decode(probs) |
|
return preds[0] |
|
|
|
|
|
if __name__ == "__main__": |
|
import sys |
|
|
|
doc_parser = DocumentParserModel() |
|
|
|
if len(sys.argv) > 1: |
|
image_blob = sys.argv[1] |
|
result = doc_parser.predict_text(image_blob) |
|
print(result) |
|
else: |
|
print("Please provide an image blob.") |
|
|