Steven C commited on
Commit
f24f2e7
1 Parent(s): c825110

First step refactoring

Browse files
Files changed (1) hide show
  1. app.py +83 -51
app.py CHANGED
@@ -1,69 +1,101 @@
 
1
  import torch
2
  import onnx
3
  import onnxruntime as rt
4
  from torchvision import transforms as T
5
  from tokenizer_base import Tokenizer
6
  import pathlib
7
- import os
8
- import sys
9
  from PIL import Image
10
-
11
  from huggingface_hub import Repository
12
 
13
- repo = Repository(
14
- local_dir="secret_models",
15
- repo_type="model",
16
- clone_from="docparser/captcha",
17
- token=True
18
- )
19
- repo.git_pull()
20
-
21
- cwd = pathlib.Path(__file__).parent.resolve()
22
- img_size = (32, 128)
23
- charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
24
- tokenizer_base = Tokenizer(charset)
25
-
26
 
27
- def get_transform(img_size):
28
- transforms = []
29
- transforms.extend([
30
- T.Resize(img_size, T.InterpolationMode.BICUBIC),
31
- T.ToTensor(),
32
- T.Normalize(0.5, 0.5)
33
- ])
34
- return T.Compose(transforms)
 
 
 
 
 
 
 
 
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- def to_numpy(tensor):
38
- return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
 
 
 
 
 
39
 
 
 
 
 
40
 
41
- def initialize_model(model_file):
42
- transform = get_transform(img_size)
43
- onnx_model = onnx.load(model_file)
44
- onnx.checker.check_model(onnx_model)
45
- ort_session = rt.InferenceSession(model_file)
46
- return transform, ort_session
 
 
 
 
 
 
47
 
48
 
49
- def get_text(image_path):
50
- img_org = Image.open(image_path)
51
- # Preprocess. Model expects a batch of images with shape: (B, C, H, W)
52
- x = transform(img_org.convert('RGB')).unsqueeze(0)
53
-
54
- # compute ONNX Runtime output prediction
55
- ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
56
- logits = ort_session.run(None, ort_inputs)[0]
57
- probs = torch.tensor(logits).softmax(-1)
58
- preds, probs = tokenizer_base.decode(probs)
59
- preds = preds[0]
60
- return preds
61
-
62
 
63
- model_file = os.path.join(cwd, "secret_models", "captcha.onnx")
64
- transform, ort_session = initialize_model(model_file=model_file)
 
 
65
 
66
- if __name__ == "__main__":
67
- image_path = sys.argv[1]
68
- res = get_text(image_path)
69
- print(res)
 
 
 
 
 
 
 
 
 
1
+ import sys
2
  import torch
3
  import onnx
4
  import onnxruntime as rt
5
  from torchvision import transforms as T
6
  from tokenizer_base import Tokenizer
7
  import pathlib
 
 
8
  from PIL import Image
 
9
  from huggingface_hub import Repository
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ class DocumentParserModel:
13
+ def __init__(
14
+ self,
15
+ repo_path,
16
+ model_subpath,
17
+ img_size,
18
+ charset,
19
+ repo_url="stevenchang/captcha",
20
+ token=None,
21
+ ):
22
+ self.repo_path = pathlib.Path(repo_path).resolve()
23
+ self.model_path = self.repo_path / model_subpath
24
+ self.charset = charset
25
+ self.tokenizer_base = Tokenizer(self.charset)
26
+ self.initialize_repository(repo_url, token)
27
+ self.transform = self.create_transform_pipeline(img_size)
28
+ self.ort_session = self.initialize_onnx_model(str(self.model_path))
29
 
30
+ def initialize_repository(self, repo_url, token):
31
+ if not self.model_path.exists():
32
+ if not self.repo_path.exists():
33
+ print(
34
+ f"Repository does not exist. Cloning from {repo_url} into {self.repo_path}"
35
+ )
36
+ repo = Repository(
37
+ local_dir=str(self.repo_path),
38
+ clone_from=repo_url,
39
+ use_auth_token=token if token else True,
40
+ )
41
+ else:
42
+ print(
43
+ f"Model does not exist, but repository is already cloned. Pulling latest changes in {self.repo_path}"
44
+ )
45
+ repo = Repository(
46
+ local_dir=str(self.repo_path),
47
+ use_auth_token=token if token else True,
48
+ )
49
+ repo.git_pull()
50
+ else:
51
+ print(
52
+ f"Model {self.model_path} already exists, skipping repository update."
53
+ )
54
 
55
+ def create_transform_pipeline(self, img_size):
56
+ transforms = [
57
+ T.Resize(img_size, T.InterpolationMode.BICUBIC),
58
+ T.ToTensor(),
59
+ T.Normalize(0.5, 0.5),
60
+ ]
61
+ return T.Compose(transforms)
62
 
63
+ def initialize_onnx_model(self, model_path):
64
+ onnx_model = onnx.load(model_path)
65
+ onnx.checker.check_model(onnx_model)
66
+ return rt.InferenceSession(model_path)
67
 
68
+ def predict_text(self, image_path):
69
+ try:
70
+ with Image.open(image_path) as img_org:
71
+ x = self.transform(img_org.convert("RGB")).unsqueeze(0)
72
+ ort_inputs = {self.ort_session.get_inputs()[0].name: x.cpu().numpy()}
73
+ logits = self.ort_session.run(None, ort_inputs)[0]
74
+ probs = torch.tensor(logits).softmax(-1)
75
+ preds, _ = self.tokenizer_base.decode(probs)
76
+ return preds[0]
77
+ except IOError:
78
+ print(f"Error: Cannot open image {image_path}")
79
+ return None
80
 
81
 
82
+ if __name__ == "__main__":
83
+ import sys
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ repo_path = "secret_models"
86
+ model_subpath = "captcha.onnx"
87
+ img_size = (32, 128)
88
+ charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
89
 
90
+ doc_parser = DocumentParserModel(
91
+ repo_path=repo_path,
92
+ model_subpath=model_subpath,
93
+ img_size=img_size,
94
+ charset=charset,
95
+ )
96
+ if len(sys.argv) > 1:
97
+ image_path = sys.argv[1]
98
+ result = doc_parser.predict_text(image_path)
99
+ print(result)
100
+ else:
101
+ print("Please provide an image path.")