|
import torch |
|
import onnx |
|
import onnxruntime as rt |
|
from torchvision import transforms as T |
|
from tokenizer_base import Tokenizer |
|
import pathlib |
|
import os |
|
import sys |
|
from PIL import Image |
|
|
|
from huggingface_hub import Repository |
|
|
|
repo = Repository( |
|
local_dir="secret_models", |
|
repo_type="model", |
|
clone_from="docparser/captcha", |
|
token=True |
|
) |
|
repo.git_pull() |
|
|
|
cwd = pathlib.Path(__file__).parent.resolve() |
|
img_size = (32, 128) |
|
charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" |
|
tokenizer_base = Tokenizer(charset) |
|
|
|
|
|
def get_transform(img_size): |
|
transforms = [] |
|
transforms.extend([ |
|
T.Resize(img_size, T.InterpolationMode.BICUBIC), |
|
T.ToTensor(), |
|
T.Normalize(0.5, 0.5) |
|
]) |
|
return T.Compose(transforms) |
|
|
|
|
|
def to_numpy(tensor): |
|
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() |
|
|
|
|
|
def initialize_model(model_file): |
|
transform = get_transform(img_size) |
|
onnx_model = onnx.load(model_file) |
|
onnx.checker.check_model(onnx_model) |
|
ort_session = rt.InferenceSession(model_file) |
|
return transform, ort_session |
|
|
|
|
|
def get_text(image_path): |
|
img_org = Image.open(image_path) |
|
|
|
x = transform(img_org.convert('RGB')).unsqueeze(0) |
|
|
|
|
|
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} |
|
logits = ort_session.run(None, ort_inputs)[0] |
|
probs = torch.tensor(logits).softmax(-1) |
|
preds, probs = tokenizer_base.decode(probs) |
|
preds = preds[0] |
|
return preds |
|
|
|
|
|
model_file = os.path.join(cwd, "secret_models", "captcha.onnx") |
|
transform, ort_session = initialize_model(model_file=model_file) |
|
|
|
if __name__ == "__main__": |
|
image_path = sys.argv[1] |
|
res = get_text(image_path) |
|
print(res) |
|
|