|
import sys |
|
import torch |
|
import onnx |
|
import onnxruntime as rt |
|
from torchvision import transforms as T |
|
from tokenizer_base import Tokenizer |
|
from PIL import Image |
|
|
|
|
|
class DocumentParserModel: |
|
def __init__( |
|
self, |
|
model_path, |
|
img_size, |
|
charset |
|
): |
|
self.charset = charset |
|
self.tokenizer_base = Tokenizer(self.charset) |
|
self.transform = self.create_transform_pipeline(img_size) |
|
self.ort_session = self.initialize_onnx_model(str(model_path)) |
|
|
|
def create_transform_pipeline(self, img_size): |
|
transforms = [ |
|
T.Resize(img_size, T.InterpolationMode.BICUBIC), |
|
T.ToTensor(), |
|
T.Normalize(0.5, 0.5), |
|
] |
|
return T.Compose(transforms) |
|
|
|
def initialize_onnx_model(self, model_path): |
|
onnx_model = onnx.load(model_path) |
|
onnx.checker.check_model(onnx_model) |
|
return rt.InferenceSession(model_path) |
|
|
|
|
|
def predict_text(self, image_path): |
|
try: |
|
with Image.open(image_path) as img_org: |
|
x = self.transform(img_org.convert("RGB")).unsqueeze(0) |
|
ort_inputs = {self.ort_session.get_inputs()[0].name: x.cpu().numpy()} |
|
logits = self.ort_session.run(None, ort_inputs)[0] |
|
probs = torch.tensor(logits).softmax(-1) |
|
preds, _ = self.tokenizer_base.decode(probs) |
|
return preds[0] |
|
except IOError: |
|
print(f"Error: Cannot open image {image_path}") |
|
return None |
|
|
|
|
|
if __name__ == "__main__": |
|
import sys |
|
|
|
model_path = "captcha.onnx" |
|
img_size = (32, 128) |
|
charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" |
|
|
|
doc_parser = DocumentParserModel( |
|
model_path=model_path, |
|
img_size=img_size, |
|
charset=charset, |
|
) |
|
if len(sys.argv) > 1: |
|
image_path = sys.argv[1] |
|
result = doc_parser.predict_text(image_path) |
|
print(result) |
|
else: |
|
print("Please provide an image path.") |
|
|