Steven C
commited on
Commit
•
9772e97
1
Parent(s):
58bc514
Download model from HuggingFace to tmp folder for running on Lambda
Browse files- README.md +1 -1
- app.py +36 -29
- captcha.onnx +0 -3
- handler.py +2 -9
- requirements.txt +1 -0
- serverless.yml +1 -1
README.md
CHANGED
@@ -26,5 +26,5 @@ Before running this project, make sure you have the following prerequisites inst
|
|
26 |
To use this project, run the following command:
|
27 |
|
28 |
```bash
|
29 |
-
python3 app.py
|
30 |
```
|
|
|
26 |
To use this project, run the following command:
|
27 |
|
28 |
```bash
|
29 |
+
python3 app.py BASE_64_IMAGE_BLOB_STRING
|
30 |
```
|
app.py
CHANGED
@@ -7,14 +7,17 @@ import onnxruntime as rt
|
|
7 |
from torchvision import transforms as T
|
8 |
from tokenizer_base import Tokenizer
|
9 |
from PIL import Image
|
|
|
10 |
|
11 |
|
12 |
class DocumentParserModel:
|
13 |
-
def __init__(self
|
14 |
-
|
15 |
-
|
|
|
|
|
16 |
self.transform = self.create_transform_pipeline(img_size)
|
17 |
-
self.ort_session = self.initialize_onnx_model(
|
18 |
|
19 |
def create_transform_pipeline(self, img_size):
|
20 |
transforms = [
|
@@ -24,46 +27,50 @@ class DocumentParserModel:
|
|
24 |
]
|
25 |
return T.Compose(transforms)
|
26 |
|
27 |
-
def initialize_onnx_model(self
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
onnx.checker.check_model(onnx_model)
|
30 |
-
return rt.InferenceSession(
|
31 |
|
32 |
def load_image_from_base64(self, base64_string):
|
33 |
img_data = base64.b64decode(base64_string)
|
34 |
image_buffer = io.BytesIO(img_data)
|
35 |
-
image = Image.open(image_buffer)
|
36 |
-
return image
|
37 |
|
38 |
-
def predict_text(self, image_path):
|
39 |
try:
|
40 |
-
|
41 |
-
|
42 |
-
ort_inputs = {self.ort_session.get_inputs()[0].name: x.cpu().numpy()}
|
43 |
-
logits = self.ort_session.run(None, ort_inputs)[0]
|
44 |
-
probs = torch.tensor(logits).softmax(-1)
|
45 |
-
preds, _ = self.tokenizer_base.decode(probs)
|
46 |
-
return preds[0]
|
47 |
except IOError:
|
48 |
-
print(f"Error: Cannot open image {
|
49 |
return None
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
if __name__ == "__main__":
|
53 |
import sys
|
54 |
|
55 |
-
|
56 |
-
img_size = (32, 128)
|
57 |
-
charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
|
58 |
|
59 |
-
doc_parser = DocumentParserModel(
|
60 |
-
model_path=model_path,
|
61 |
-
img_size=img_size,
|
62 |
-
charset=charset,
|
63 |
-
)
|
64 |
if len(sys.argv) > 1:
|
65 |
-
|
66 |
-
result = doc_parser.predict_text(
|
67 |
print(result)
|
68 |
else:
|
69 |
-
print("Please provide an image
|
|
|
7 |
from torchvision import transforms as T
|
8 |
from tokenizer_base import Tokenizer
|
9 |
from PIL import Image
|
10 |
+
from huggingface_hub import hf_hub_download, try_to_load_from_cache
|
11 |
|
12 |
|
13 |
class DocumentParserModel:
|
14 |
+
def __init__(self):
|
15 |
+
charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
|
16 |
+
img_size = (32, 128)
|
17 |
+
|
18 |
+
self.tokenizer_base = Tokenizer(charset)
|
19 |
self.transform = self.create_transform_pipeline(img_size)
|
20 |
+
self.ort_session = self.initialize_onnx_model()
|
21 |
|
22 |
def create_transform_pipeline(self, img_size):
|
23 |
transforms = [
|
|
|
27 |
]
|
28 |
return T.Compose(transforms)
|
29 |
|
30 |
+
def initialize_onnx_model(self):
|
31 |
+
repo_id = "stevenchang/captcha"
|
32 |
+
filename = "captcha.onnx"
|
33 |
+
|
34 |
+
filepath = try_to_load_from_cache(repo_id, filename)
|
35 |
+
|
36 |
+
if isinstance(filepath, str):
|
37 |
+
model_file = filepath
|
38 |
+
else:
|
39 |
+
model_file = result = hf_hub_download(repo_id, filename)
|
40 |
+
|
41 |
+
onnx_model = onnx.load(model_file)
|
42 |
onnx.checker.check_model(onnx_model)
|
43 |
+
return rt.InferenceSession(model_file)
|
44 |
|
45 |
def load_image_from_base64(self, base64_string):
|
46 |
img_data = base64.b64decode(base64_string)
|
47 |
image_buffer = io.BytesIO(img_data)
|
|
|
|
|
48 |
|
|
|
49 |
try:
|
50 |
+
image = Image.open(image_buffer)
|
51 |
+
return image
|
|
|
|
|
|
|
|
|
|
|
52 |
except IOError:
|
53 |
+
print(f"Error: Cannot open image {image_blob}")
|
54 |
return None
|
55 |
|
56 |
+
def predict_text(self, image_blob):
|
57 |
+
with self.load_image_from_base64(image_blob) as img_org:
|
58 |
+
x = self.transform(img_org.convert("RGB")).unsqueeze(0)
|
59 |
+
ort_inputs = {self.ort_session.get_inputs()[0].name: x.cpu().numpy()}
|
60 |
+
logits = self.ort_session.run(None, ort_inputs)[0]
|
61 |
+
probs = torch.tensor(logits).softmax(-1)
|
62 |
+
preds, _ = self.tokenizer_base.decode(probs)
|
63 |
+
return preds[0]
|
64 |
+
|
65 |
|
66 |
if __name__ == "__main__":
|
67 |
import sys
|
68 |
|
69 |
+
doc_parser = DocumentParserModel()
|
|
|
|
|
70 |
|
|
|
|
|
|
|
|
|
|
|
71 |
if len(sys.argv) > 1:
|
72 |
+
image_blob = sys.argv[1]
|
73 |
+
result = doc_parser.predict_text(image_blob)
|
74 |
print(result)
|
75 |
else:
|
76 |
+
print("Please provide an image blob.")
|
captcha.onnx
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:2a672587ee82eb010dbef54dd0a38e99625293608ed4068c4bd20ebe467fede4
|
3 |
-
size 95401304
|
|
|
|
|
|
|
|
handler.py
CHANGED
@@ -6,14 +6,7 @@ from app import DocumentParserModel
|
|
6 |
LOGGER = logging.getLogger()
|
7 |
LOGGER.setLevel(logging.INFO)
|
8 |
|
9 |
-
|
10 |
-
img_size = (32, 128)
|
11 |
-
charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
|
12 |
-
model = DocumentParserModel(
|
13 |
-
model_path=model_path,
|
14 |
-
img_size=img_size,
|
15 |
-
charset=charset,
|
16 |
-
)
|
17 |
|
18 |
|
19 |
def lambda_handle(event, context):
|
@@ -22,4 +15,4 @@ def lambda_handle(event, context):
|
|
22 |
LOGGER.info("No ML work to do. Just staying warm...")
|
23 |
return "Keeping Lambda warm"
|
24 |
|
25 |
-
return {"statusCode": 200, "
|
|
|
6 |
LOGGER = logging.getLogger()
|
7 |
LOGGER.setLevel(logging.INFO)
|
8 |
|
9 |
+
model = DocumentParserModel()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
def lambda_handle(event, context):
|
|
|
15 |
LOGGER.info("No ML work to do. Just staying warm...")
|
16 |
return "Keeping Lambda warm"
|
17 |
|
18 |
+
return {"statusCode": 200, "result": model.predict_text(image_blob=event["image_blob"])}
|
requirements.txt
CHANGED
@@ -3,3 +3,4 @@ torchvision==0.12.0
|
|
3 |
onnx==1.16.0
|
4 |
onnxruntime==1.16.*
|
5 |
Pillow==10.0.0
|
|
|
|
3 |
onnx==1.16.0
|
4 |
onnxruntime==1.16.*
|
5 |
Pillow==10.0.0
|
6 |
+
huggingface_hub==0.21.4
|
serverless.yml
CHANGED
@@ -29,7 +29,7 @@ functions:
|
|
29 |
environment:
|
30 |
# On Lambda, the default location is not writable. Only the "/tmp" folder is writable. Therefore, we need to set the cache location inside "/tmp".
|
31 |
TORCH_HOME: /tmp/.ml_cache
|
32 |
-
|
33 |
custom:
|
34 |
warmup:
|
35 |
MLModelWarmer:
|
|
|
29 |
environment:
|
30 |
# On Lambda, the default location is not writable. Only the "/tmp" folder is writable. Therefore, we need to set the cache location inside "/tmp".
|
31 |
TORCH_HOME: /tmp/.ml_cache
|
32 |
+
HF_HUB_CACHE: /tmp/.ml_cache/huggingface
|
33 |
custom:
|
34 |
warmup:
|
35 |
MLModelWarmer:
|