import sys import torch import onnx import onnxruntime as rt from torchvision import transforms as T from tokenizer_base import Tokenizer import pathlib from PIL import Image from huggingface_hub import Repository class DocumentParserModel: def __init__( self, repo_path, model_subpath, img_size, charset, repo_url="stevenchang/captcha", token=None, ): self.repo_path = pathlib.Path(repo_path).resolve() self.model_path = self.repo_path / model_subpath self.charset = charset self.tokenizer_base = Tokenizer(self.charset) self.initialize_repository(repo_url, token) self.transform = self.create_transform_pipeline(img_size) self.ort_session = self.initialize_onnx_model(str(self.model_path)) def initialize_repository(self, repo_url, token): if not self.model_path.exists(): if not self.repo_path.exists(): print( f"Repository does not exist. Cloning from {repo_url} into {self.repo_path}" ) repo = Repository( local_dir=str(self.repo_path), clone_from=repo_url, use_auth_token=token if token else True, ) else: print( f"Model does not exist, but repository is already cloned. Pulling latest changes in {self.repo_path}" ) repo = Repository( local_dir=str(self.repo_path), use_auth_token=token if token else True, ) repo.git_pull() else: print( f"Model {self.model_path} already exists, skipping repository update." ) def create_transform_pipeline(self, img_size): transforms = [ T.Resize(img_size, T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(0.5, 0.5), ] return T.Compose(transforms) def initialize_onnx_model(self, model_path): onnx_model = onnx.load(model_path) onnx.checker.check_model(onnx_model) return rt.InferenceSession(model_path) def predict_text(self, image_path): try: with Image.open(image_path) as img_org: x = self.transform(img_org.convert("RGB")).unsqueeze(0) ort_inputs = {self.ort_session.get_inputs()[0].name: x.cpu().numpy()} logits = self.ort_session.run(None, ort_inputs)[0] probs = torch.tensor(logits).softmax(-1) preds, _ = self.tokenizer_base.decode(probs) return preds[0] except IOError: print(f"Error: Cannot open image {image_path}") return None if __name__ == "__main__": import sys repo_path = "secret_models" model_subpath = "captcha.onnx" img_size = (32, 128) charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" doc_parser = DocumentParserModel( repo_path=repo_path, model_subpath=model_subpath, img_size=img_size, charset=charset, ) if len(sys.argv) > 1: image_path = sys.argv[1] result = doc_parser.predict_text(image_path) print(result) else: print("Please provide an image path.")