import sys import torch import onnx import onnxruntime as rt from torchvision import transforms as T from tokenizer_base import Tokenizer from PIL import Image class DocumentParserModel: def __init__( self, model_path, img_size, charset ): self.charset = charset self.tokenizer_base = Tokenizer(self.charset) self.transform = self.create_transform_pipeline(img_size) self.ort_session = self.initialize_onnx_model(str(model_path)) def create_transform_pipeline(self, img_size): transforms = [ T.Resize(img_size, T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(0.5, 0.5), ] return T.Compose(transforms) def initialize_onnx_model(self, model_path): onnx_model = onnx.load(model_path) onnx.checker.check_model(onnx_model) return rt.InferenceSession(model_path) # TODO: test with image blob def predict_text(self, image_path): try: with Image.open(image_path) as img_org: x = self.transform(img_org.convert("RGB")).unsqueeze(0) ort_inputs = {self.ort_session.get_inputs()[0].name: x.cpu().numpy()} logits = self.ort_session.run(None, ort_inputs)[0] probs = torch.tensor(logits).softmax(-1) preds, _ = self.tokenizer_base.decode(probs) return preds[0] except IOError: print(f"Error: Cannot open image {image_path}") return None if __name__ == "__main__": import sys model_path = "captcha.onnx" img_size = (32, 128) charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" doc_parser = DocumentParserModel( model_path=model_path, img_size=img_size, charset=charset, ) if len(sys.argv) > 1: image_path = sys.argv[1] result = doc_parser.predict_text(image_path) print(result) else: print("Please provide an image path.")