|
import sys |
|
import torch |
|
import onnx |
|
import onnxruntime as rt |
|
from torchvision import transforms as T |
|
from tokenizer_base import Tokenizer |
|
import pathlib |
|
from PIL import Image |
|
from huggingface_hub import Repository |
|
|
|
|
|
class DocumentParserModel: |
|
def __init__( |
|
self, |
|
repo_path, |
|
model_subpath, |
|
img_size, |
|
charset, |
|
repo_url="stevenchang/captcha", |
|
token=None, |
|
): |
|
self.repo_path = pathlib.Path(repo_path).resolve() |
|
self.model_path = self.repo_path / model_subpath |
|
self.charset = charset |
|
self.tokenizer_base = Tokenizer(self.charset) |
|
self.initialize_repository(repo_url, token) |
|
self.transform = self.create_transform_pipeline(img_size) |
|
self.ort_session = self.initialize_onnx_model(str(self.model_path)) |
|
|
|
def initialize_repository(self, repo_url, token): |
|
if not self.model_path.exists(): |
|
if not self.repo_path.exists(): |
|
print( |
|
f"Repository does not exist. Cloning from {repo_url} into {self.repo_path}" |
|
) |
|
repo = Repository( |
|
local_dir=str(self.repo_path), |
|
clone_from=repo_url, |
|
use_auth_token=token if token else True, |
|
) |
|
else: |
|
print( |
|
f"Model does not exist, but repository is already cloned. Pulling latest changes in {self.repo_path}" |
|
) |
|
repo = Repository( |
|
local_dir=str(self.repo_path), |
|
use_auth_token=token if token else True, |
|
) |
|
repo.git_pull() |
|
else: |
|
print( |
|
f"Model {self.model_path} already exists, skipping repository update." |
|
) |
|
|
|
def create_transform_pipeline(self, img_size): |
|
transforms = [ |
|
T.Resize(img_size, T.InterpolationMode.BICUBIC), |
|
T.ToTensor(), |
|
T.Normalize(0.5, 0.5), |
|
] |
|
return T.Compose(transforms) |
|
|
|
def initialize_onnx_model(self, model_path): |
|
onnx_model = onnx.load(model_path) |
|
onnx.checker.check_model(onnx_model) |
|
return rt.InferenceSession(model_path) |
|
|
|
def predict_text(self, image_path): |
|
try: |
|
with Image.open(image_path) as img_org: |
|
x = self.transform(img_org.convert("RGB")).unsqueeze(0) |
|
ort_inputs = {self.ort_session.get_inputs()[0].name: x.cpu().numpy()} |
|
logits = self.ort_session.run(None, ort_inputs)[0] |
|
probs = torch.tensor(logits).softmax(-1) |
|
preds, _ = self.tokenizer_base.decode(probs) |
|
return preds[0] |
|
except IOError: |
|
print(f"Error: Cannot open image {image_path}") |
|
return None |
|
|
|
|
|
if __name__ == "__main__": |
|
import sys |
|
|
|
repo_path = "secret_models" |
|
model_subpath = "captcha.onnx" |
|
img_size = (32, 128) |
|
charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" |
|
|
|
doc_parser = DocumentParserModel( |
|
repo_path=repo_path, |
|
model_subpath=model_subpath, |
|
img_size=img_size, |
|
charset=charset, |
|
) |
|
if len(sys.argv) > 1: |
|
image_path = sys.argv[1] |
|
result = doc_parser.predict_text(image_path) |
|
print(result) |
|
else: |
|
print("Please provide an image path.") |
|
|