File size: 2,408 Bytes
58bc514
 
f24f2e7
c825110
 
 
 
 
 
9772e97
c825110
 
f24f2e7
9772e97
 
 
 
 
f24f2e7
9772e97
c825110
f24f2e7
 
 
 
 
 
 
c825110
9772e97
 
 
 
 
 
 
 
 
 
 
 
f24f2e7
9772e97
c825110
58bc514
 
 
 
f24f2e7
9772e97
 
f24f2e7
9772e97
f24f2e7
c825110
9772e97
 
 
 
 
 
 
 
 
c825110
f24f2e7
 
c825110
9772e97
c825110
f24f2e7
9772e97
 
f24f2e7
 
9772e97
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import base64
import io
import sys
import torch
import onnx
import onnxruntime as rt
from torchvision import transforms as T
from tokenizer_base import Tokenizer
from PIL import Image
from huggingface_hub import hf_hub_download, try_to_load_from_cache


class DocumentParserModel:
    def __init__(self):
        charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
        img_size = (32, 128)

        self.tokenizer_base = Tokenizer(charset)
        self.transform = self.create_transform_pipeline(img_size)
        self.ort_session = self.initialize_onnx_model()

    def create_transform_pipeline(self, img_size):
        transforms = [
            T.Resize(img_size, T.InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(0.5, 0.5),
        ]
        return T.Compose(transforms)

    def initialize_onnx_model(self):
        repo_id = "stevenchang/captcha"
        filename = "captcha.onnx"

        filepath = try_to_load_from_cache(repo_id, filename)

        if isinstance(filepath, str):
            model_file = filepath
        else:
            model_file = result = hf_hub_download(repo_id, filename)

        onnx_model = onnx.load(model_file)
        onnx.checker.check_model(onnx_model)
        return rt.InferenceSession(model_file)

    def load_image_from_base64(self, base64_string):
        img_data = base64.b64decode(base64_string)
        image_buffer = io.BytesIO(img_data)

        try:
            image = Image.open(image_buffer)
            return image
        except IOError:
            print(f"Error: Cannot open image {image_blob}")
            return None

    def predict_text(self, image_blob):
        with self.load_image_from_base64(image_blob) as img_org:
            x = self.transform(img_org.convert("RGB")).unsqueeze(0)
            ort_inputs = {self.ort_session.get_inputs()[0].name: x.cpu().numpy()}
            logits = self.ort_session.run(None, ort_inputs)[0]
            probs = torch.tensor(logits).softmax(-1)
            preds, _ = self.tokenizer_base.decode(probs)
            return preds[0]


if __name__ == "__main__":
    import sys

    doc_parser = DocumentParserModel()

    if len(sys.argv) > 1:
        image_blob = sys.argv[1]
        result = doc_parser.predict_text(image_blob)
        print(result)
    else:
        print("Please provide an image blob.")