File size: 1,882 Bytes
c825110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import onnx
import onnxruntime as rt
from torchvision import transforms as T
from tokenizer_base import Tokenizer
import pathlib
import os
import sys
from PIL import Image

from huggingface_hub import Repository

repo = Repository(
    local_dir="secret_models",
    repo_type="model",
    clone_from="docparser/captcha",
    token=True
)
repo.git_pull()

cwd = pathlib.Path(__file__).parent.resolve()
img_size = (32, 128)
charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
tokenizer_base = Tokenizer(charset)


def get_transform(img_size):
    transforms = []
    transforms.extend([
        T.Resize(img_size, T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(0.5, 0.5)
    ])
    return T.Compose(transforms)


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


def initialize_model(model_file):
    transform = get_transform(img_size)
    onnx_model = onnx.load(model_file)
    onnx.checker.check_model(onnx_model)
    ort_session = rt.InferenceSession(model_file)
    return transform, ort_session


def get_text(image_path):
    img_org = Image.open(image_path)
    # Preprocess. Model expects a batch of images with shape: (B, C, H, W)
    x = transform(img_org.convert('RGB')).unsqueeze(0)

    # compute ONNX Runtime output prediction
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
    logits = ort_session.run(None, ort_inputs)[0]
    probs = torch.tensor(logits).softmax(-1)
    preds, probs = tokenizer_base.decode(probs)
    preds = preds[0]
    return preds


model_file = os.path.join(cwd, "secret_models", "captcha.onnx")
transform, ort_session = initialize_model(model_file=model_file)

if __name__ == "__main__":
    image_path = sys.argv[1]
    res = get_text(image_path)
    print(res)