File size: 2,408 Bytes
58bc514 f24f2e7 c825110 9772e97 c825110 f24f2e7 9772e97 f24f2e7 9772e97 c825110 f24f2e7 c825110 9772e97 f24f2e7 9772e97 c825110 58bc514 f24f2e7 9772e97 f24f2e7 9772e97 f24f2e7 c825110 9772e97 c825110 f24f2e7 c825110 9772e97 c825110 f24f2e7 9772e97 f24f2e7 9772e97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import base64
import io
import sys
import torch
import onnx
import onnxruntime as rt
from torchvision import transforms as T
from tokenizer_base import Tokenizer
from PIL import Image
from huggingface_hub import hf_hub_download, try_to_load_from_cache
class DocumentParserModel:
def __init__(self):
charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
img_size = (32, 128)
self.tokenizer_base = Tokenizer(charset)
self.transform = self.create_transform_pipeline(img_size)
self.ort_session = self.initialize_onnx_model()
def create_transform_pipeline(self, img_size):
transforms = [
T.Resize(img_size, T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(0.5, 0.5),
]
return T.Compose(transforms)
def initialize_onnx_model(self):
repo_id = "stevenchang/captcha"
filename = "captcha.onnx"
filepath = try_to_load_from_cache(repo_id, filename)
if isinstance(filepath, str):
model_file = filepath
else:
model_file = result = hf_hub_download(repo_id, filename)
onnx_model = onnx.load(model_file)
onnx.checker.check_model(onnx_model)
return rt.InferenceSession(model_file)
def load_image_from_base64(self, base64_string):
img_data = base64.b64decode(base64_string)
image_buffer = io.BytesIO(img_data)
try:
image = Image.open(image_buffer)
return image
except IOError:
print(f"Error: Cannot open image {image_blob}")
return None
def predict_text(self, image_blob):
with self.load_image_from_base64(image_blob) as img_org:
x = self.transform(img_org.convert("RGB")).unsqueeze(0)
ort_inputs = {self.ort_session.get_inputs()[0].name: x.cpu().numpy()}
logits = self.ort_session.run(None, ort_inputs)[0]
probs = torch.tensor(logits).softmax(-1)
preds, _ = self.tokenizer_base.decode(probs)
return preds[0]
if __name__ == "__main__":
import sys
doc_parser = DocumentParserModel()
if len(sys.argv) > 1:
image_blob = sys.argv[1]
result = doc_parser.predict_text(image_blob)
print(result)
else:
print("Please provide an image blob.")
|