Steven C commited on
Commit
9772e97
1 Parent(s): 58bc514

Download model from HuggingFace to tmp folder for running on Lambda

Browse files
Files changed (6) hide show
  1. README.md +1 -1
  2. app.py +36 -29
  3. captcha.onnx +0 -3
  4. handler.py +2 -9
  5. requirements.txt +1 -0
  6. serverless.yml +1 -1
README.md CHANGED
@@ -26,5 +26,5 @@ Before running this project, make sure you have the following prerequisites inst
26
  To use this project, run the following command:
27
 
28
  ```bash
29
- python3 app.py path/to/your_img
30
  ```
 
26
  To use this project, run the following command:
27
 
28
  ```bash
29
+ python3 app.py BASE_64_IMAGE_BLOB_STRING
30
  ```
app.py CHANGED
@@ -7,14 +7,17 @@ import onnxruntime as rt
7
  from torchvision import transforms as T
8
  from tokenizer_base import Tokenizer
9
  from PIL import Image
 
10
 
11
 
12
  class DocumentParserModel:
13
- def __init__(self, model_path, img_size, charset):
14
- self.charset = charset
15
- self.tokenizer_base = Tokenizer(self.charset)
 
 
16
  self.transform = self.create_transform_pipeline(img_size)
17
- self.ort_session = self.initialize_onnx_model(str(model_path))
18
 
19
  def create_transform_pipeline(self, img_size):
20
  transforms = [
@@ -24,46 +27,50 @@ class DocumentParserModel:
24
  ]
25
  return T.Compose(transforms)
26
 
27
- def initialize_onnx_model(self, model_path):
28
- onnx_model = onnx.load(model_path)
 
 
 
 
 
 
 
 
 
 
29
  onnx.checker.check_model(onnx_model)
30
- return rt.InferenceSession(model_path)
31
 
32
  def load_image_from_base64(self, base64_string):
33
  img_data = base64.b64decode(base64_string)
34
  image_buffer = io.BytesIO(img_data)
35
- image = Image.open(image_buffer)
36
- return image
37
 
38
- def predict_text(self, image_path):
39
  try:
40
- with self.load_image_from_base64(image_path) as img_org:
41
- x = self.transform(img_org.convert("RGB")).unsqueeze(0)
42
- ort_inputs = {self.ort_session.get_inputs()[0].name: x.cpu().numpy()}
43
- logits = self.ort_session.run(None, ort_inputs)[0]
44
- probs = torch.tensor(logits).softmax(-1)
45
- preds, _ = self.tokenizer_base.decode(probs)
46
- return preds[0]
47
  except IOError:
48
- print(f"Error: Cannot open image {image_path}")
49
  return None
50
 
 
 
 
 
 
 
 
 
 
51
 
52
  if __name__ == "__main__":
53
  import sys
54
 
55
- model_path = "captcha.onnx"
56
- img_size = (32, 128)
57
- charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
58
 
59
- doc_parser = DocumentParserModel(
60
- model_path=model_path,
61
- img_size=img_size,
62
- charset=charset,
63
- )
64
  if len(sys.argv) > 1:
65
- image_path = sys.argv[1]
66
- result = doc_parser.predict_text(image_path)
67
  print(result)
68
  else:
69
- print("Please provide an image path.")
 
7
  from torchvision import transforms as T
8
  from tokenizer_base import Tokenizer
9
  from PIL import Image
10
+ from huggingface_hub import hf_hub_download, try_to_load_from_cache
11
 
12
 
13
  class DocumentParserModel:
14
+ def __init__(self):
15
+ charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
16
+ img_size = (32, 128)
17
+
18
+ self.tokenizer_base = Tokenizer(charset)
19
  self.transform = self.create_transform_pipeline(img_size)
20
+ self.ort_session = self.initialize_onnx_model()
21
 
22
  def create_transform_pipeline(self, img_size):
23
  transforms = [
 
27
  ]
28
  return T.Compose(transforms)
29
 
30
+ def initialize_onnx_model(self):
31
+ repo_id = "stevenchang/captcha"
32
+ filename = "captcha.onnx"
33
+
34
+ filepath = try_to_load_from_cache(repo_id, filename)
35
+
36
+ if isinstance(filepath, str):
37
+ model_file = filepath
38
+ else:
39
+ model_file = result = hf_hub_download(repo_id, filename)
40
+
41
+ onnx_model = onnx.load(model_file)
42
  onnx.checker.check_model(onnx_model)
43
+ return rt.InferenceSession(model_file)
44
 
45
  def load_image_from_base64(self, base64_string):
46
  img_data = base64.b64decode(base64_string)
47
  image_buffer = io.BytesIO(img_data)
 
 
48
 
 
49
  try:
50
+ image = Image.open(image_buffer)
51
+ return image
 
 
 
 
 
52
  except IOError:
53
+ print(f"Error: Cannot open image {image_blob}")
54
  return None
55
 
56
+ def predict_text(self, image_blob):
57
+ with self.load_image_from_base64(image_blob) as img_org:
58
+ x = self.transform(img_org.convert("RGB")).unsqueeze(0)
59
+ ort_inputs = {self.ort_session.get_inputs()[0].name: x.cpu().numpy()}
60
+ logits = self.ort_session.run(None, ort_inputs)[0]
61
+ probs = torch.tensor(logits).softmax(-1)
62
+ preds, _ = self.tokenizer_base.decode(probs)
63
+ return preds[0]
64
+
65
 
66
  if __name__ == "__main__":
67
  import sys
68
 
69
+ doc_parser = DocumentParserModel()
 
 
70
 
 
 
 
 
 
71
  if len(sys.argv) > 1:
72
+ image_blob = sys.argv[1]
73
+ result = doc_parser.predict_text(image_blob)
74
  print(result)
75
  else:
76
+ print("Please provide an image blob.")
captcha.onnx DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2a672587ee82eb010dbef54dd0a38e99625293608ed4068c4bd20ebe467fede4
3
- size 95401304
 
 
 
 
handler.py CHANGED
@@ -6,14 +6,7 @@ from app import DocumentParserModel
6
  LOGGER = logging.getLogger()
7
  LOGGER.setLevel(logging.INFO)
8
 
9
- model_path = "captcha.onnx"
10
- img_size = (32, 128)
11
- charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
12
- model = DocumentParserModel(
13
- model_path=model_path,
14
- img_size=img_size,
15
- charset=charset,
16
- )
17
 
18
 
19
  def lambda_handle(event, context):
@@ -22,4 +15,4 @@ def lambda_handle(event, context):
22
  LOGGER.info("No ML work to do. Just staying warm...")
23
  return "Keeping Lambda warm"
24
 
25
- return {"statusCode": 200, "vc": model.predict_text(image_path=event["image_path"])}
 
6
  LOGGER = logging.getLogger()
7
  LOGGER.setLevel(logging.INFO)
8
 
9
+ model = DocumentParserModel()
 
 
 
 
 
 
 
10
 
11
 
12
  def lambda_handle(event, context):
 
15
  LOGGER.info("No ML work to do. Just staying warm...")
16
  return "Keeping Lambda warm"
17
 
18
+ return {"statusCode": 200, "result": model.predict_text(image_blob=event["image_blob"])}
requirements.txt CHANGED
@@ -3,3 +3,4 @@ torchvision==0.12.0
3
  onnx==1.16.0
4
  onnxruntime==1.16.*
5
  Pillow==10.0.0
 
 
3
  onnx==1.16.0
4
  onnxruntime==1.16.*
5
  Pillow==10.0.0
6
+ huggingface_hub==0.21.4
serverless.yml CHANGED
@@ -29,7 +29,7 @@ functions:
29
  environment:
30
  # On Lambda, the default location is not writable. Only the "/tmp" folder is writable. Therefore, we need to set the cache location inside "/tmp".
31
  TORCH_HOME: /tmp/.ml_cache
32
- # TRANSFORMERS_CACHE: /tmp/.ml_cache/huggingface
33
  custom:
34
  warmup:
35
  MLModelWarmer:
 
29
  environment:
30
  # On Lambda, the default location is not writable. Only the "/tmp" folder is writable. Therefore, we need to set the cache location inside "/tmp".
31
  TORCH_HOME: /tmp/.ml_cache
32
+ HF_HUB_CACHE: /tmp/.ml_cache/huggingface
33
  custom:
34
  warmup:
35
  MLModelWarmer: