Spaces:
Running
on
Zero
Running
on
Zero
Upload 1110 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- stf/.DS_Store +0 -0
- stf/089.npz +3 -0
- stf/089.pth +3 -0
- stf/stf-api-alternative/.gitignore +160 -0
- stf/stf-api-alternative/.ipynb_checkpoints/README-checkpoint.md +1 -0
- stf/stf-api-alternative/.ipynb_checkpoints/poetry-checkpoint.lock +0 -0
- stf/stf-api-alternative/.ipynb_checkpoints/pyproject-checkpoint.toml +35 -0
- stf/stf-api-alternative/README.md +1 -0
- stf/stf-api-alternative/poetry.lock +0 -0
- stf/stf-api-alternative/pyproject.toml +35 -0
- stf/stf-api-alternative/pytriton/.flake8 +19 -0
- stf/stf-api-alternative/pytriton/.github/ISSUE_TEMPLATE/bug_report.md +83 -0
- stf/stf-api-alternative/pytriton/.github/ISSUE_TEMPLATE/feature_request.md +20 -0
- stf/stf-api-alternative/pytriton/.github/workflows/stale.yaml +35 -0
- stf/stf-api-alternative/pytriton/.gitignore +330 -0
- stf/stf-api-alternative/pytriton/.pre-commit-config.yaml +76 -0
- stf/stf-api-alternative/pytriton/CHANGELOG.md +239 -0
- stf/stf-api-alternative/pytriton/CONTRIBUTING.md +203 -0
- stf/stf-api-alternative/pytriton/COPYRIGHT +13 -0
- stf/stf-api-alternative/pytriton/LICENSE +174 -0
- stf/stf-api-alternative/pytriton/Makefile +124 -0
- stf/stf-api-alternative/pytriton/README.md +343 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/__init__.py +27 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/__main__.py +218 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/check/__init__.py +14 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/check/add_sub.py +139 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/check/env_checks.py +201 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/check/utils.py +555 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/client/__init__.py +22 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/client/asyncio_utils.py +308 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/client/client.py +2033 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/client/exceptions.py +92 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/client/utils.py +384 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/client/warnings.py +26 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/constants.py +31 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/decorators.py +678 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/exceptions.py +80 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/__init__.py +17 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/common.py +93 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/generator.py +284 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/model_config.py +43 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/parser.py +258 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/tensor.py +57 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/triton_model_config.py +68 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/models/__init__.py +14 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/models/manager.py +147 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/models/model.py +335 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/proxy/__init__.py +14 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/proxy/communication.py +555 -0
- stf/stf-api-alternative/pytriton/build/lib/pytriton/proxy/data.py +1133 -0
stf/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
stf/089.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9ce3fb07d8d15495eab879b47413c6b86bce114ca9ecd375b45b54777cf0e175
|
3 |
+
size 522605028
|
stf/089.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ba4eb3437019d77abed141d60bcb5489b664f494cf965eec0bccf304c3d79b2a
|
3 |
+
size 1567401123
|
stf/stf-api-alternative/.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
stf/stf-api-alternative/.ipynb_checkpoints/README-checkpoint.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
stf_api와 동일한 기능을 수행하는 라이브러리
|
stf/stf-api-alternative/.ipynb_checkpoints/poetry-checkpoint.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
stf/stf-api-alternative/.ipynb_checkpoints/pyproject-checkpoint.toml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "stf-alternative"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "alternative version of stf-api"
|
5 |
+
authors = ["Kim Minjong <make.dirty.code@gmail.com>"]
|
6 |
+
readme = "README.md"
|
7 |
+
packages = [
|
8 |
+
{include = "stf_alternative", from="src"}
|
9 |
+
]
|
10 |
+
|
11 |
+
[tool.poetry.dependencies]
|
12 |
+
python = "^3.10"
|
13 |
+
librosa = "0.8.1"
|
14 |
+
imageio = "2.13.5"
|
15 |
+
imageio-ffmpeg = "0.4.5"
|
16 |
+
Pillow = "9.1.0"
|
17 |
+
tqdm = "4.64.0"
|
18 |
+
numpy = "1.22.4"
|
19 |
+
addict = "2.4.0"
|
20 |
+
scipy = "1.12.0"
|
21 |
+
pandas = "1.3.5"
|
22 |
+
face_alignment = "1.3.5"
|
23 |
+
moviepy = "1.0.3"
|
24 |
+
transformers = "4.29.2"
|
25 |
+
facenet_pytorch = "2.5.2"
|
26 |
+
ffmpeg-python = "^0.2"
|
27 |
+
pydub = "^0.25"
|
28 |
+
av = "^11.0.0"
|
29 |
+
nvidia-pytriton = {extras = ["client"], version = "^0.4.2"}
|
30 |
+
asyncstdlib = "^3.10.9"
|
31 |
+
|
32 |
+
|
33 |
+
[build-system]
|
34 |
+
requires = ["poetry-core"]
|
35 |
+
build-backend = "poetry.core.masonry.api"
|
stf/stf-api-alternative/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
stf_api와 동일한 기능을 수행하는 라이브러리
|
stf/stf-api-alternative/poetry.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
stf/stf-api-alternative/pyproject.toml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "stf-alternative"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "alternative version of stf-api"
|
5 |
+
authors = ["Kim Minjong <make.dirty.code@gmail.com>"]
|
6 |
+
readme = "README.md"
|
7 |
+
packages = [
|
8 |
+
{include = "stf_alternative", from="src"}
|
9 |
+
]
|
10 |
+
|
11 |
+
[tool.poetry.dependencies]
|
12 |
+
python = "^3.10"
|
13 |
+
librosa = "0.8.1"
|
14 |
+
imageio = "2.13.5"
|
15 |
+
imageio-ffmpeg = "0.4.5"
|
16 |
+
Pillow = "9.1.0"
|
17 |
+
tqdm = "4.64.0"
|
18 |
+
numpy = "1.24.4"
|
19 |
+
addict = "2.4.0"
|
20 |
+
scipy = "1.12.0"
|
21 |
+
pandas = "1.3.5"
|
22 |
+
face_alignment = "1.3.5"
|
23 |
+
moviepy = "1.0.3"
|
24 |
+
transformers = "4.29.2"
|
25 |
+
facenet_pytorch = "2.5.2"
|
26 |
+
ffmpeg-python = "^0.2"
|
27 |
+
pydub = "^0.25"
|
28 |
+
av = "^11.0.0"
|
29 |
+
nvidia-pytriton = {extras = ["client"], version = "^0.4.2"}
|
30 |
+
asyncstdlib = "^3.10.9"
|
31 |
+
|
32 |
+
|
33 |
+
[build-system]
|
34 |
+
requires = ["poetry-core"]
|
35 |
+
build-backend = "poetry.core.masonry.api"
|
stf/stf-api-alternative/pytriton/.flake8
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
[flake8]
|
15 |
+
exclude = docs,experiments,blueprints,pytriton/tritonserver,sandbox
|
16 |
+
ignore = E203, E266, E501, W503
|
17 |
+
max-line-length = 120
|
18 |
+
max-complexity = 18
|
19 |
+
select = B,C,D,E,F,W,T,N
|
stf/stf-api-alternative/pytriton/.github/ISSUE_TEMPLATE/bug_report.md
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: Bug report
|
3 |
+
about: Create a report to help us improve
|
4 |
+
title: ''
|
5 |
+
labels: ''
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
**Description**
|
11 |
+
|
12 |
+
A clear and concise description of the bug.
|
13 |
+
|
14 |
+
**To reproduce**
|
15 |
+
|
16 |
+
If relevant, add a minimal example so that we can reproduce the error, if necessary, by running the code. For example:
|
17 |
+
|
18 |
+
```python
|
19 |
+
# server
|
20 |
+
from pytriton.decorators import batch
|
21 |
+
from pytriton.model_config import ModelConfig, Tensor
|
22 |
+
from pytriton.triton import Triton
|
23 |
+
|
24 |
+
@batch
|
25 |
+
def _infer_fn(**inputs):
|
26 |
+
...
|
27 |
+
results_dict = model(**inputs) # ex note: the bug is here, we expect to receive ...
|
28 |
+
...
|
29 |
+
# note: observing results_dict as dictionary of numpy arrays
|
30 |
+
return results_dict
|
31 |
+
|
32 |
+
|
33 |
+
with Triton() as triton:
|
34 |
+
triton.bind(
|
35 |
+
model_name="MyModel",
|
36 |
+
infer_func=_infer_fn,
|
37 |
+
inputs=[
|
38 |
+
Tensor(name="in1", dtype=np.float32, shape=(-1,)),
|
39 |
+
Tensor(name="in2", dtype=np.float32, shape=(-1,)),
|
40 |
+
],
|
41 |
+
outputs=[
|
42 |
+
Tensor(name="out1", dtype=np.float32, shape=(-1,)),
|
43 |
+
Tensor(name="out2", dtype=np.float32, shape=(-1,)),
|
44 |
+
],
|
45 |
+
config=ModelConfig(max_batch_size=128),
|
46 |
+
)
|
47 |
+
triton.serve()
|
48 |
+
```
|
49 |
+
|
50 |
+
```python
|
51 |
+
# client
|
52 |
+
import numpy as np
|
53 |
+
from pytriton.client import ModelClient
|
54 |
+
|
55 |
+
batch_size = 2
|
56 |
+
in1_batch = np.ones((batch_size, 1), dtype=np.float32)
|
57 |
+
in2_batch = np.ones((batch_size, 1), dtype=np.float32)
|
58 |
+
|
59 |
+
with ModelClient("localhost", "MyModel") as client:
|
60 |
+
result_batch = client.infer_batch(in1_batch, in2_batch)
|
61 |
+
```
|
62 |
+
|
63 |
+
**Observed results and expected behavior**
|
64 |
+
|
65 |
+
Please describe the observed results as well as the expected results.
|
66 |
+
If possible, attach relevant log output to help analyze your problem.
|
67 |
+
If an error is raised, please paste the full traceback of the exception.
|
68 |
+
|
69 |
+
```
|
70 |
+
|
71 |
+
```
|
72 |
+
|
73 |
+
**Environment**
|
74 |
+
|
75 |
+
- OS/container version: [e.g., container nvcr.io/nvidia/pytorch:23.02-py3 / virtual machine with Ubuntu 22.04]
|
76 |
+
- glibc version: [e.g., 2.31; can be checked with `ldd --version`]
|
77 |
+
- Python interpreter distribution and version: [e.g., CPython 3.8 / conda 4.7.12 with Python 3.8 environment]
|
78 |
+
- pip version: [e.g., 23.1.2]
|
79 |
+
- PyTriton version: [e.g., 0.1.4 / custom build from source at commit ______]
|
80 |
+
- Deployment details: [e.g., multi-node multi-GPU setup on GKE / multi-GPU single-node setup in Jupyter Notebook]
|
81 |
+
|
82 |
+
**Additional context**
|
83 |
+
Add any other context about the problem here.
|
stf/stf-api-alternative/pytriton/.github/ISSUE_TEMPLATE/feature_request.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: Feature request
|
3 |
+
about: Suggest an idea for this project
|
4 |
+
title: ''
|
5 |
+
labels: ''
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
**Is your feature request related to a problem? Please describe.**
|
11 |
+
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
12 |
+
|
13 |
+
**Describe the solution you'd like**
|
14 |
+
A clear and concise description of what you want to happen.
|
15 |
+
|
16 |
+
**Describe alternatives you've considered**
|
17 |
+
A clear and concise description of any alternative solutions or features you've considered.
|
18 |
+
|
19 |
+
**Additional context**
|
20 |
+
Add any other context or screenshots about the feature request here.
|
stf/stf-api-alternative/pytriton/.github/workflows/stale.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
name: 'Close stale issues and PRs'
|
15 |
+
on:
|
16 |
+
schedule:
|
17 |
+
- cron: "30 1 * * *"
|
18 |
+
jobs:
|
19 |
+
stale:
|
20 |
+
if: github.repository_owner == 'triton-inference-server'
|
21 |
+
runs-on: ubuntu-latest
|
22 |
+
permissions:
|
23 |
+
issues: write
|
24 |
+
pull-requests: write
|
25 |
+
steps:
|
26 |
+
- uses: actions/stale@v8
|
27 |
+
with:
|
28 |
+
days-before-stale: 21
|
29 |
+
days-before-close: 7
|
30 |
+
stale-issue-message: 'This issue is stale because it has been open 21 days with no activity. Remove stale label or comment or this will be closed in 7 days.'
|
31 |
+
stale-pr-message: 'This PR is stale because it has been open 21 days with no activity. Remove stale label or comment or this will be closed in 7 days.'
|
32 |
+
close-issue-message: 'This issue was closed because it has been stalled for 7 days with no activity.'
|
33 |
+
close-pr-message: 'This PR was closed because it has been stalled for 7 days with no activity.'
|
34 |
+
exempt-issue-labels: 'non-stale'
|
35 |
+
exempt-pr-labels: 'non-stale'
|
stf/stf-api-alternative/pytriton/.gitignore
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# Created by https://www.toptal.com/developers/gitignore/api/pycharm+all,visualstudiocode,python,direnv,vim
|
15 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=pycharm+all,visualstudiocode,python,direnv,vim
|
16 |
+
|
17 |
+
### direnv ###
|
18 |
+
.direnv
|
19 |
+
.envrc
|
20 |
+
|
21 |
+
### PyCharm+all ###
|
22 |
+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
23 |
+
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
24 |
+
|
25 |
+
# User-specific stuff
|
26 |
+
.idea/**/workspace.xml
|
27 |
+
.idea/**/tasks.xml
|
28 |
+
.idea/**/usage.statistics.xml
|
29 |
+
.idea/**/dictionaries
|
30 |
+
.idea/**/shelf
|
31 |
+
|
32 |
+
# AWS User-specific
|
33 |
+
.idea/**/aws.xml
|
34 |
+
|
35 |
+
# Generated files
|
36 |
+
.idea/**/contentModel.xml
|
37 |
+
|
38 |
+
# Sensitive or high-churn files
|
39 |
+
.idea/**/dataSources/
|
40 |
+
.idea/**/dataSources.ids
|
41 |
+
.idea/**/dataSources.local.xml
|
42 |
+
.idea/**/sqlDataSources.xml
|
43 |
+
.idea/**/dynamic.xml
|
44 |
+
.idea/**/uiDesigner.xml
|
45 |
+
.idea/**/dbnavigator.xml
|
46 |
+
|
47 |
+
# Gradle
|
48 |
+
.idea/**/gradle.xml
|
49 |
+
.idea/**/libraries
|
50 |
+
|
51 |
+
# Gradle and Maven with auto-import
|
52 |
+
# When using Gradle or Maven with auto-import, you should exclude module files,
|
53 |
+
# since they will be recreated, and may cause churn. Uncomment if using
|
54 |
+
# auto-import.
|
55 |
+
# .idea/artifacts
|
56 |
+
# .idea/compiler.xml
|
57 |
+
# .idea/jarRepositories.xml
|
58 |
+
# .idea/modules.xml
|
59 |
+
# .idea/*.iml
|
60 |
+
# .idea/modules
|
61 |
+
# *.iml
|
62 |
+
# *.ipr
|
63 |
+
|
64 |
+
# CMake
|
65 |
+
cmake-build-*/
|
66 |
+
|
67 |
+
# Mongo Explorer plugin
|
68 |
+
.idea/**/mongoSettings.xml
|
69 |
+
|
70 |
+
# File-based project format
|
71 |
+
*.iws
|
72 |
+
|
73 |
+
# IntelliJ
|
74 |
+
out/
|
75 |
+
|
76 |
+
# mpeltonen/sbt-idea plugin
|
77 |
+
.idea_modules/
|
78 |
+
|
79 |
+
# JIRA plugin
|
80 |
+
atlassian-ide-plugin.xml
|
81 |
+
|
82 |
+
# Cursive Clojure plugin
|
83 |
+
.idea/replstate.xml
|
84 |
+
|
85 |
+
# SonarLint plugin
|
86 |
+
.idea/sonarlint/
|
87 |
+
|
88 |
+
# Crashlytics plugin (for Android Studio and IntelliJ)
|
89 |
+
com_crashlytics_export_strings.xml
|
90 |
+
crashlytics.properties
|
91 |
+
crashlytics-build.properties
|
92 |
+
fabric.properties
|
93 |
+
|
94 |
+
# Editor-based Rest Client
|
95 |
+
.idea/httpRequests
|
96 |
+
|
97 |
+
# Android studio 3.1+ serialized cache file
|
98 |
+
.idea/caches/build_file_checksums.ser
|
99 |
+
|
100 |
+
### PyCharm+all Patch ###
|
101 |
+
# Ignore everything but code style settings and run configurations
|
102 |
+
# that are supposed to be shared within teams.
|
103 |
+
|
104 |
+
.idea/*
|
105 |
+
|
106 |
+
!.idea/codeStyles
|
107 |
+
!.idea/runConfigurations
|
108 |
+
|
109 |
+
### Python ###
|
110 |
+
# Byte-compiled / optimized / DLL files
|
111 |
+
__pycache__/
|
112 |
+
*.py[cod]
|
113 |
+
*$py.class
|
114 |
+
|
115 |
+
# C extensions
|
116 |
+
*.so
|
117 |
+
|
118 |
+
# Distribution / packaging
|
119 |
+
.Python
|
120 |
+
build/
|
121 |
+
develop-eggs/
|
122 |
+
dist/
|
123 |
+
downloads/
|
124 |
+
eggs/
|
125 |
+
.eggs/
|
126 |
+
lib/
|
127 |
+
lib64/
|
128 |
+
parts/
|
129 |
+
sdist/
|
130 |
+
var/
|
131 |
+
wheels/
|
132 |
+
share/python-wheels/
|
133 |
+
*.egg-info/
|
134 |
+
.installed.cfg
|
135 |
+
*.egg
|
136 |
+
MANIFEST
|
137 |
+
|
138 |
+
# PyInstaller
|
139 |
+
# Usually these files are written by a python script from a template
|
140 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
141 |
+
*.manifest
|
142 |
+
*.spec
|
143 |
+
|
144 |
+
# Installer logs
|
145 |
+
pip-log.txt
|
146 |
+
pip-delete-this-directory.txt
|
147 |
+
|
148 |
+
# Unit test / coverage reports
|
149 |
+
htmlcov/
|
150 |
+
.tox/
|
151 |
+
.nox/
|
152 |
+
.coverage
|
153 |
+
.coverage.*
|
154 |
+
.cache
|
155 |
+
nosetests.xml
|
156 |
+
coverage.xml
|
157 |
+
*.cover
|
158 |
+
*.py,cover
|
159 |
+
.hypothesis/
|
160 |
+
.pytest_cache/
|
161 |
+
cover/
|
162 |
+
|
163 |
+
# Translations
|
164 |
+
*.mo
|
165 |
+
*.pot
|
166 |
+
|
167 |
+
# Django stuff:
|
168 |
+
*.log
|
169 |
+
local_settings.py
|
170 |
+
db.sqlite3
|
171 |
+
db.sqlite3-journal
|
172 |
+
|
173 |
+
# Flask stuff:
|
174 |
+
instance/
|
175 |
+
.webassets-cache
|
176 |
+
|
177 |
+
# Scrapy stuff:
|
178 |
+
.scrapy
|
179 |
+
|
180 |
+
# Sphinx documentation
|
181 |
+
docs/_build/
|
182 |
+
|
183 |
+
# PyBuilder
|
184 |
+
.pybuilder/
|
185 |
+
target/
|
186 |
+
|
187 |
+
# Jupyter Notebook
|
188 |
+
.ipynb_checkpoints
|
189 |
+
|
190 |
+
# IPython
|
191 |
+
profile_default/
|
192 |
+
ipython_config.py
|
193 |
+
|
194 |
+
# pyenv
|
195 |
+
# For a library or package, you might want to ignore these files since the code is
|
196 |
+
# intended to run in multiple environments; otherwise, check them in:
|
197 |
+
# .python-version
|
198 |
+
|
199 |
+
# pipenv
|
200 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
201 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
202 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
203 |
+
# install all needed dependencies.
|
204 |
+
#Pipfile.lock
|
205 |
+
|
206 |
+
# poetry
|
207 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
208 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
209 |
+
# commonly ignored for libraries.
|
210 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
211 |
+
#poetry.lock
|
212 |
+
|
213 |
+
# pdm
|
214 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
215 |
+
#pdm.lock
|
216 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
217 |
+
# in version control.
|
218 |
+
# https://pdm.fming.dev/#use-with-ide
|
219 |
+
.pdm.toml
|
220 |
+
|
221 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
222 |
+
__pypackages__/
|
223 |
+
|
224 |
+
# Celery stuff
|
225 |
+
celerybeat-schedule
|
226 |
+
celerybeat.pid
|
227 |
+
|
228 |
+
# SageMath parsed files
|
229 |
+
*.sage.py
|
230 |
+
|
231 |
+
# Environments
|
232 |
+
.env
|
233 |
+
.venv
|
234 |
+
env/
|
235 |
+
venv/
|
236 |
+
ENV/
|
237 |
+
env.bak/
|
238 |
+
venv.bak/
|
239 |
+
|
240 |
+
# Spyder project settings
|
241 |
+
.spyderproject
|
242 |
+
.spyproject
|
243 |
+
|
244 |
+
# Rope project settings
|
245 |
+
.ropeproject
|
246 |
+
|
247 |
+
# mkdocs documentation
|
248 |
+
/site
|
249 |
+
|
250 |
+
# mypy
|
251 |
+
.mypy_cache/
|
252 |
+
.dmypy.json
|
253 |
+
dmypy.json
|
254 |
+
|
255 |
+
# Pyre type checker
|
256 |
+
.pyre/
|
257 |
+
|
258 |
+
# pytype static type analyzer
|
259 |
+
.pytype/
|
260 |
+
|
261 |
+
# Cython debug symbols
|
262 |
+
cython_debug/
|
263 |
+
|
264 |
+
# PyCharm
|
265 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
266 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
267 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
268 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
269 |
+
#.idea/
|
270 |
+
|
271 |
+
### Python Patch ###
|
272 |
+
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
273 |
+
poetry.toml
|
274 |
+
|
275 |
+
# ruff
|
276 |
+
.ruff_cache/
|
277 |
+
|
278 |
+
# LSP config files
|
279 |
+
pyrightconfig.json
|
280 |
+
|
281 |
+
### Vim ###
|
282 |
+
# Swap
|
283 |
+
[._]*.s[a-v][a-z]
|
284 |
+
!*.svg # comment out if you don't need vector files
|
285 |
+
[._]*.sw[a-p]
|
286 |
+
[._]s[a-rt-v][a-z]
|
287 |
+
[._]ss[a-gi-z]
|
288 |
+
[._]sw[a-p]
|
289 |
+
|
290 |
+
# Session
|
291 |
+
Session.vim
|
292 |
+
Sessionx.vim
|
293 |
+
|
294 |
+
# Temporary
|
295 |
+
.netrwhist
|
296 |
+
*~
|
297 |
+
# Auto-generated tag files
|
298 |
+
tags
|
299 |
+
# Persistent undo
|
300 |
+
[._]*.un~
|
301 |
+
|
302 |
+
### VisualStudioCode ###
|
303 |
+
.vscode/*
|
304 |
+
!.vscode/settings.json
|
305 |
+
!.vscode/tasks.json
|
306 |
+
!.vscode/launch.json
|
307 |
+
!.vscode/extensions.json
|
308 |
+
!.vscode/*.code-snippets
|
309 |
+
|
310 |
+
# Local History for Visual Studio Code
|
311 |
+
.history/
|
312 |
+
|
313 |
+
# Built Visual Studio Code Extensions
|
314 |
+
*.vsix
|
315 |
+
|
316 |
+
### VisualStudioCode Patch ###
|
317 |
+
# Ignore all local history of files
|
318 |
+
.history
|
319 |
+
.ionide
|
320 |
+
|
321 |
+
# End of https://www.toptal.com/developers/gitignore/api/pycharm+all,visualstudiocode,python,direnv,vim
|
322 |
+
pytriton/tritonserver
|
323 |
+
docs/CHANGELOG.md
|
324 |
+
docs/CONTRIBUTING.md
|
325 |
+
docs/LICENSE.md
|
326 |
+
docs/examples.md
|
327 |
+
|
328 |
+
### VisualStudioCode+all ##
|
329 |
+
.vscode
|
330 |
+
.devcontainer
|
stf/stf-api-alternative/pytriton/.pre-commit-config.yaml
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
exclude: kubernetes
|
15 |
+
repos:
|
16 |
+
- repo: https://github.com/ambv/black
|
17 |
+
rev: 23.11.0
|
18 |
+
hooks:
|
19 |
+
- id: black
|
20 |
+
- repo: https://github.com/pycqa/isort
|
21 |
+
rev: 5.12.0
|
22 |
+
hooks:
|
23 |
+
- id: isort
|
24 |
+
name: isort (python)
|
25 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
26 |
+
rev: v4.5.0
|
27 |
+
hooks:
|
28 |
+
- id: check-docstring-first
|
29 |
+
- id: check-executables-have-shebangs
|
30 |
+
- id: check-json
|
31 |
+
- id: check-merge-conflict
|
32 |
+
- id: detect-private-key
|
33 |
+
- id: check-shebang-scripts-are-executable
|
34 |
+
- id: check-toml
|
35 |
+
- id: check-yaml
|
36 |
+
- id: debug-statements
|
37 |
+
- id: end-of-file-fixer
|
38 |
+
types: [python]
|
39 |
+
- id: fix-byte-order-marker
|
40 |
+
- id: no-commit-to-branch
|
41 |
+
- id: requirements-txt-fixer
|
42 |
+
- id: trailing-whitespace
|
43 |
+
exclude: setup.cfg
|
44 |
+
- id: mixed-line-ending
|
45 |
+
args: [--fix=lf]
|
46 |
+
- repo: https://github.com/asottile/pyupgrade
|
47 |
+
rev: v3.15.0
|
48 |
+
hooks:
|
49 |
+
- id: pyupgrade
|
50 |
+
args: [--py36-plus]
|
51 |
+
- repo: https://github.com/pycqa/flake8
|
52 |
+
rev: 6.1.0
|
53 |
+
hooks:
|
54 |
+
- id: flake8
|
55 |
+
additional_dependencies:
|
56 |
+
- flake8-bugbear
|
57 |
+
- flake8-comprehensions
|
58 |
+
- flake8-print
|
59 |
+
- mccabe
|
60 |
+
- pep8-naming
|
61 |
+
- pycodestyle
|
62 |
+
- pyflakes
|
63 |
+
- repo: https://github.com/pycqa/pydocstyle
|
64 |
+
rev: 6.3.0
|
65 |
+
hooks:
|
66 |
+
- id: pydocstyle
|
67 |
+
name: Run pydocstyle
|
68 |
+
args:
|
69 |
+
- --convention=google
|
70 |
+
exclude: '(?:tests|examples)\/.*'
|
71 |
+
additional_dependencies: ['toml']
|
72 |
+
- repo: https://github.com/thlorenz/doctoc
|
73 |
+
rev: v2.2.0
|
74 |
+
hooks:
|
75 |
+
- id: doctoc
|
76 |
+
args: [ --github, --update-only ]
|
stf/stf-api-alternative/pytriton/CHANGELOG.md
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!--
|
2 |
+
Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
|
4 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
you may not use this file except in compliance with the License.
|
6 |
+
You may obtain a copy of the License at
|
7 |
+
|
8 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
|
10 |
+
Unless required by applicable law or agreed to in writing, software
|
11 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
See the License for the specific language governing permissions and
|
14 |
+
limitations under the License.
|
15 |
+
-->
|
16 |
+
|
17 |
+
# Changelog
|
18 |
+
|
19 |
+
## 0.4.2 (2023-12-05)
|
20 |
+
|
21 |
+
- New: You can create client from existing client instance or model configuration to avoid loading model configuration from server.
|
22 |
+
- New: Introduced warning system using the `warnings` module.
|
23 |
+
- Fix: Experimental client for decoupled models prevents sending another request, when responses from previous request are not consumed, blocks close until stream is stopped.
|
24 |
+
- Fix: Leak of ModelClient during Triton creation
|
25 |
+
- Fix: Fixed non-declared project dependencies (removed from use in code or added to package dependencies)
|
26 |
+
- Fix: Remote model is being unloaded from Triton when RemoteTriton is closed.
|
27 |
+
|
28 |
+
[//]: <> (put here on external component update with short summary what change or link to changelog)
|
29 |
+
|
30 |
+
- Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.39.0](https://github.com/triton-inference-server/server/releases/tag/v2.39.0)
|
31 |
+
|
32 |
+
## 0.4.1 (2023-11-09)
|
33 |
+
|
34 |
+
- New: Place where workspaces with temporary Triton model repositories and communication file sockets can be configured by `$PYTRITON_HOME` environment variable
|
35 |
+
- Fix: Recover handling `KeyboardInterrupt` in `triton.serve()`
|
36 |
+
- Fix: Remove limit for handling bytes dtype tensors
|
37 |
+
- Build scripts update
|
38 |
+
- Added support for arm64 platform builds
|
39 |
+
|
40 |
+
[//]: <> (put here on external component update with short summary what change or link to changelog)
|
41 |
+
|
42 |
+
- Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.39.0](https://github.com/triton-inference-server/server/releases/tag/v2.39.0)
|
43 |
+
|
44 |
+
## 0.4.0 (2023-10-20)
|
45 |
+
|
46 |
+
- New: Remote Mode - PyTriton can be used to connect to a remote Triton Inference Server
|
47 |
+
- Introduced RemoteTriton class which can be used to connect to a remote Triton Inference Server
|
48 |
+
running on the same machine, by passing triton url.
|
49 |
+
- Changed Triton lifecycle - now the Triton Inference Server is started while entering the context.
|
50 |
+
This allows to load models dynamically to the running server while calling the bind method.
|
51 |
+
It is still allowed to create Triton instance without entering the context and bind models before starting
|
52 |
+
the server (in this case the models are lazy loaded when calling run or serve method like it worked before).
|
53 |
+
- In RemoteTriton class, calling __enter__ or connect method connects to triton server, so we can safely load models
|
54 |
+
while binding inference functions (if RemoteTriton is used without context manager, models are lazy loaded
|
55 |
+
when calling connect or serve method).
|
56 |
+
- Change: `@batch` decorator raises a `ValueError` if any of the outputs have a different batch size than expected.
|
57 |
+
- fix: gevent resources leak in ``FuturesModelClient``
|
58 |
+
|
59 |
+
[//]: <> (put here on external component update with short summary what change or link to changelog)
|
60 |
+
|
61 |
+
- Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.36.0](https://github.com/triton-inference-server/server/releases/tag/v2.36.0)
|
62 |
+
|
63 |
+
## 0.3.1 (2023-09-26)
|
64 |
+
|
65 |
+
- Change: `KeyboardInterrupt` is now handled in `triton.serve()`. PyTriton hosting scripts return an exit code of 0 instead of 130 when they receive a SIGINT signal.
|
66 |
+
- Fix: Addressed potential instability in shared memory management.
|
67 |
+
|
68 |
+
[//]: <> (put here on external component update with short summary what change or link to changelog)
|
69 |
+
|
70 |
+
- Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.36.0](https://github.com/triton-inference-server/server/releases/tag/v2.36.0)
|
71 |
+
|
72 |
+
## 0.3.0 (2023-09-05)
|
73 |
+
|
74 |
+
- new: Support for multiple Python versions starting from 3.8+
|
75 |
+
- new: Added support for [decoupled models](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/decoupled_models.md) enabling to support streaming models (alpha state)
|
76 |
+
- change: Upgraded Triton Inference Server binaries to version 2.36.0. Note that this Triton Inference Server requires glibc 2.35+ or a more recent version.
|
77 |
+
|
78 |
+
[//]: <> (put here on external component update with short summary what change or link to changelog)
|
79 |
+
|
80 |
+
- Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.36.0](https://github.com/triton-inference-server/server/releases/tag/v2.36.0)
|
81 |
+
|
82 |
+
|
83 |
+
## 0.2.5 (2023-08-24)
|
84 |
+
|
85 |
+
- new: Allow to execute multiple PyTriton instances in the same process and/or host
|
86 |
+
- fix: Invalid flags for Proxy Backend configuration passed to Triton
|
87 |
+
|
88 |
+
|
89 |
+
[//]: <> (put here on external component update with short summary what change or link to changelog)
|
90 |
+
|
91 |
+
- Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.33.0](https://github.com/triton-inference-server/server/releases/tag/v2.33.0)
|
92 |
+
|
93 |
+
## 0.2.4 (2023-08-10)
|
94 |
+
|
95 |
+
- new: Introduced `strict` flag in `Triton.bind` which enables data types and shapes validation of inference callable outputs
|
96 |
+
against model config
|
97 |
+
- new: `AsyncioModelClient` which works in FastAPI and other async frameworks
|
98 |
+
- fix: `FuturesModelClient` do not raise `gevent.exceptions.InvalidThreadUseError`
|
99 |
+
- fix: Do not throw TimeoutError if could not connect to server during model verification
|
100 |
+
|
101 |
+
[//]: <> (put here on external component update with short summary what change or link to changelog)
|
102 |
+
|
103 |
+
- Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.33.0](https://github.com/triton-inference-server/server/releases/tag/v2.33.0)
|
104 |
+
|
105 |
+
## 0.2.3 (2023-07-21)
|
106 |
+
|
107 |
+
- Improved verification of Proxy Backend environment when running under same Python interpreter
|
108 |
+
- Fixed pytriton.__version__ to represent currently installed version
|
109 |
+
|
110 |
+
[//]: <> (put here on external component update with short summary what change or link to changelog)
|
111 |
+
|
112 |
+
- Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.33.0](https://github.com/triton-inference-server/server/releases/tag/v2.33.0)
|
113 |
+
|
114 |
+
## 0.2.2 (2023-07-19)
|
115 |
+
|
116 |
+
- Added `inference_timeout_s` parameters to client classes
|
117 |
+
- Renamed `PyTritonClientUrlParseError` to `PyTritonClientInvalidUrlError`
|
118 |
+
- `ModelClient` and `FuturesModelClient` methods raise `PyTritonClientClosedError` when used after client is closed
|
119 |
+
- Pinned tritonclient dependency due to issues with tritonclient >= 2.34 on systems with glibc version lower than 2.34
|
120 |
+
- Added warning after Triton Server setup and teardown while using too verbose logging level as it may cause a significant performance drop in model inference
|
121 |
+
|
122 |
+
[//]: <> (put here on external component update with short summary what change or link to changelog)
|
123 |
+
|
124 |
+
- Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.33.0](https://github.com/triton-inference-server/server/releases/tag/v2.33.0)
|
125 |
+
|
126 |
+
## 0.2.1 (2023-06-28)
|
127 |
+
|
128 |
+
- Fixed handling `TritonConfig.cache_directory` option - the directory was always overwritten with the default value.
|
129 |
+
- Fixed tritonclient dependency - PyTriton need tritonclient supporting http headers and parameters
|
130 |
+
- Improved shared memory usage to match 64MB limit (default value for Docker, Kubernetes) reducing the initial size for PyTriton Proxy Backend.
|
131 |
+
|
132 |
+
[//]: <> (put here on external component update with short summary what change or link to changelog)
|
133 |
+
|
134 |
+
- Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.33.0](https://github.com/triton-inference-server/server/releases/tag/v2.33.0)
|
135 |
+
|
136 |
+
## 0.2.0 (2023-05-30)
|
137 |
+
|
138 |
+
- Added support for using custom HTTP/gRPC request headers and parameters.
|
139 |
+
|
140 |
+
This change breaks backward compatibility of the inference function signature.
|
141 |
+
The undecorated inference function now accepts a list of `Request` instances instead
|
142 |
+
of a list of dictionaries. The `Request` class contains data for inputs and parameters
|
143 |
+
for combined parameters and headers.
|
144 |
+
|
145 |
+
See [docs/custom_params.md](docs/custom_params.md) for further information
|
146 |
+
|
147 |
+
- Added `FuturesModelClient` which enables sending inference requests in a parallel manner.
|
148 |
+
- Added displaying documentation link after models are loaded.
|
149 |
+
|
150 |
+
[//]: <> (put here on external component update with short summary what change or link to changelog)
|
151 |
+
|
152 |
+
- Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.33.0](https://github.com/triton-inference-server/server/releases/tag/v2.33.0)
|
153 |
+
|
154 |
+
## 0.1.5 (2023-05-12)
|
155 |
+
|
156 |
+
- Improved `pytriton.decorators.group_by_values` function
|
157 |
+
- Modified the function to avoid calling the inference callable on each individual sample when grouping by string/bytes input
|
158 |
+
- Added `pad_fn` argument for easy padding and combining of the inference results
|
159 |
+
- Fixed Triton binaries search
|
160 |
+
- Improved Workspace management (remove workspace on shutdown)
|
161 |
+
|
162 |
+
[//]: <> (put here on external component update with short summary what change or link to changelog)
|
163 |
+
|
164 |
+
- Version of external components used during testing:
|
165 |
+
- [Triton Inference Server](https://github.com/triton-inference-server/): 2.29.0
|
166 |
+
- Other component versions depend on the used framework and Triton Inference Server containers versions.
|
167 |
+
Refer to its [support matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html)
|
168 |
+
for a detailed summary.
|
169 |
+
|
170 |
+
## 0.1.4 (2023-03-16)
|
171 |
+
|
172 |
+
- Add validation of the model name passed to Triton bind method.
|
173 |
+
- Add monkey patching of `InferenceServerClient.__del__` method to prevent unhandled exceptions.
|
174 |
+
|
175 |
+
[//]: <> (put here on external component update with short summary what change or link to changelog)
|
176 |
+
|
177 |
+
- Version of external components used during testing:
|
178 |
+
- [Triton Inference Server](https://github.com/triton-inference-server/): 2.29.0
|
179 |
+
- Other component versions depend on the used framework and Triton Inference Server containers versions.
|
180 |
+
Refer to its [support matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html)
|
181 |
+
for a detailed summary.
|
182 |
+
|
183 |
+
## 0.1.3 (2023-02-20)
|
184 |
+
|
185 |
+
- Fixed getting model config in `fill_optionals` decorator.
|
186 |
+
|
187 |
+
[//]: <> (put here on external component update with short summary what change or link to changelog)
|
188 |
+
|
189 |
+
- Version of external components used during testing:
|
190 |
+
- [Triton Inference Server](https://github.com/triton-inference-server/): 2.29.0
|
191 |
+
- Other component versions depend on the used framework and Triton Inference Server containers versions.
|
192 |
+
Refer to its [support matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html)
|
193 |
+
for a detailed summary.
|
194 |
+
|
195 |
+
## 0.1.2 (2023-02-14)
|
196 |
+
|
197 |
+
- Fixed wheel build to support installations on operating systems with glibc version 2.31 or higher.
|
198 |
+
- Updated the documentation on custom builds of the package.
|
199 |
+
- Change: TritonContext instance is shared across bound models and contains model_configs dictionary.
|
200 |
+
- Fixed support of binding multiple models that uses methods of the same class.
|
201 |
+
|
202 |
+
[//]: <> (put here on external component update with short summary what change or link to changelog)
|
203 |
+
|
204 |
+
- Version of external components used during testing:
|
205 |
+
- [Triton Inference Server](https://github.com/triton-inference-server/): 2.29.0
|
206 |
+
- Other component versions depend on the used framework and Triton Inference Server containers versions.
|
207 |
+
Refer to its [support matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html)
|
208 |
+
for a detailed summary.
|
209 |
+
|
210 |
+
## 0.1.1 (2023-01-31)
|
211 |
+
|
212 |
+
- Change: The `@first_value` decorator has been updated with new features:
|
213 |
+
- Renamed from `@first_values` to `@first_value`
|
214 |
+
- Added a `strict` flag to toggle the checking of equality of values on a single selected input of the request. Default is True
|
215 |
+
- Added a `squeeze_single_values` flag to toggle the squeezing of single value ND arrays to scalars. Default is True
|
216 |
+
- Fix: `@fill_optionals` now supports non-batching models
|
217 |
+
- Fix: `@first_value` fixed to work with optional inputs
|
218 |
+
- Fix: `@group_by_values` fixed to work with string inputs
|
219 |
+
- Fix: `@group_by_values` fixed to work per sample-wise
|
220 |
+
|
221 |
+
[//]: <> (put here on external component update with short summary what change or link to changelog)
|
222 |
+
|
223 |
+
- Version of external components used during testing:
|
224 |
+
- [Triton Inference Server](https://github.com/triton-inference-server/): 2.29.0
|
225 |
+
- Other component versions depend on the used framework and Triton Inference Server containers versions.
|
226 |
+
Refer to its [support matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html)
|
227 |
+
for a detailed summary.
|
228 |
+
|
229 |
+
## 0.1.0 (2023-01-12)
|
230 |
+
|
231 |
+
- Initial release of PyTriton
|
232 |
+
|
233 |
+
[//]: <> (put here on external component update with short summary what change or link to changelog)
|
234 |
+
|
235 |
+
- Version of external components used during testing:
|
236 |
+
- [Triton Inference Server](https://github.com/triton-inference-server/): 2.29.0
|
237 |
+
- Other component versions depend on the used framework and Triton Inference Server containers versions.
|
238 |
+
Refer to its [support matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html)
|
239 |
+
for a detailed summary.
|
stf/stf-api-alternative/pytriton/CONTRIBUTING.md
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!--
|
2 |
+
Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
|
4 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
you may not use this file except in compliance with the License.
|
6 |
+
You may obtain a copy of the License at
|
7 |
+
|
8 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
|
10 |
+
Unless required by applicable law or agreed to in writing, software
|
11 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
See the License for the specific language governing permissions and
|
14 |
+
limitations under the License.
|
15 |
+
-->
|
16 |
+
|
17 |
+
# Contributing
|
18 |
+
|
19 |
+
Contributions are welcome, and they are much appreciated! Every little
|
20 |
+
helps, and we will always give credit.
|
21 |
+
|
22 |
+
## Types of Contributions
|
23 |
+
|
24 |
+
### Report Bugs
|
25 |
+
|
26 |
+
Report bugs at [https://github.com/triton-inference-server/pytriton/issues](https://github.com/triton-inference-server/pytriton/issues).
|
27 |
+
|
28 |
+
When reporting a bug, please include the following information:
|
29 |
+
|
30 |
+
* Your operating system name and version.
|
31 |
+
* Any details about your local setup that might be helpful in troubleshooting.
|
32 |
+
* Detailed steps to reproduce the bug.
|
33 |
+
|
34 |
+
### Fix Bugs
|
35 |
+
|
36 |
+
Look through the GitHub issues for bugs. Anything tagged with "bug" and "help
|
37 |
+
wanted" is open to whoever wants to implement it.
|
38 |
+
|
39 |
+
### Implement Features
|
40 |
+
|
41 |
+
Browse through the GitHub issues for features. Anything tagged with "enhancement" and "help wanted" is open to whoever wants to implement it.
|
42 |
+
|
43 |
+
### Write Documentation
|
44 |
+
|
45 |
+
The PyTriton could always use more documentation, whether as part of
|
46 |
+
the official PyTriton docs, in docstrings, or even on the web in blog posts,
|
47 |
+
articles, and such.
|
48 |
+
|
49 |
+
### Submit Feedback
|
50 |
+
|
51 |
+
The best way to send feedback is to file an issue at [https://github.com/triton-inference-server/pytriton/issues](https://github.com/triton-inference-server/pytriton/issues).
|
52 |
+
|
53 |
+
If you are proposing a feature:
|
54 |
+
|
55 |
+
* Explain in detail how it would work.
|
56 |
+
* Keep the scope as narrow as possible to make it easier to implement.
|
57 |
+
|
58 |
+
## Sign your Work
|
59 |
+
|
60 |
+
We require that all contributors "sign-off" on their commits. This certifies that
|
61 |
+
the contribution is your original work, or you have the rights to submit it under
|
62 |
+
the same license or a compatible license.
|
63 |
+
|
64 |
+
Any contribution which contains commits that are not Signed-Off will not be accepted.
|
65 |
+
|
66 |
+
To sign off on a commit, simply use the `--signoff` (or `-s`) option when committing your changes:
|
67 |
+
|
68 |
+
```shell
|
69 |
+
$ git commit -s -m "Add a cool feature."
|
70 |
+
```
|
71 |
+
|
72 |
+
This will append the following to your commit message:
|
73 |
+
|
74 |
+
```
|
75 |
+
Signed-off-by: Your Name <your@email.com>
|
76 |
+
```
|
77 |
+
|
78 |
+
By doing this, you certify the following:
|
79 |
+
|
80 |
+
```
|
81 |
+
Developer Certificate of Origin
|
82 |
+
Version 1.1
|
83 |
+
|
84 |
+
Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
|
85 |
+
1 Letterman Drive
|
86 |
+
Suite D4700
|
87 |
+
San Francisco, CA, 94129
|
88 |
+
|
89 |
+
Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed.
|
90 |
+
|
91 |
+
|
92 |
+
Developer's Certificate of Origin 1.1
|
93 |
+
|
94 |
+
By making a contribution to this project, I certify that:
|
95 |
+
|
96 |
+
(a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or
|
97 |
+
|
98 |
+
(b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or
|
99 |
+
|
100 |
+
(c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it.
|
101 |
+
|
102 |
+
(d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved.
|
103 |
+
```
|
104 |
+
|
105 |
+
## Get Started!
|
106 |
+
|
107 |
+
### Local Development
|
108 |
+
|
109 |
+
Ready to contribute? Here's how to set up the `PyTriton` for local development.
|
110 |
+
|
111 |
+
1. Fork the `PyTriton` repo on GitHub.
|
112 |
+
2. Clone your fork locally:
|
113 |
+
|
114 |
+
```shell
|
115 |
+
$ git clone git@github.com:your_name_here/pytriton.git
|
116 |
+
```
|
117 |
+
|
118 |
+
3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, here's how you set up your fork for local development:
|
119 |
+
|
120 |
+
```shell
|
121 |
+
$ mkvirtualenv pytriton
|
122 |
+
$ cd pytriton/
|
123 |
+
```
|
124 |
+
|
125 |
+
If you do not use the virtualenvwrapper package, you can initialize a virtual environment using the pure Python command:
|
126 |
+
|
127 |
+
```shell
|
128 |
+
$ python -m venv pytriton
|
129 |
+
$ cd pytriton/
|
130 |
+
$ source bin/activate
|
131 |
+
```
|
132 |
+
|
133 |
+
Once the virtualenv is activated, install the development dependencies:
|
134 |
+
|
135 |
+
```shell
|
136 |
+
$ make install-dev
|
137 |
+
```
|
138 |
+
|
139 |
+
4. Extract Triton Server to your environment so you can debug PyTriton while serving some models on Triton:
|
140 |
+
|
141 |
+
```shell
|
142 |
+
$ make extract-triton
|
143 |
+
```
|
144 |
+
|
145 |
+
5. Install pre-commit hooks:
|
146 |
+
|
147 |
+
```shell
|
148 |
+
$ pre-commit install
|
149 |
+
```
|
150 |
+
|
151 |
+
6. Create a branch for local development:
|
152 |
+
|
153 |
+
```shell
|
154 |
+
$ git checkout -b name-of-your-bugfix-or-feature
|
155 |
+
```
|
156 |
+
|
157 |
+
Now you can make your changes locally.
|
158 |
+
|
159 |
+
7. When you're done making changes, check that your changes pass linters and the
|
160 |
+
tests, including testing other Python versions with tox:
|
161 |
+
|
162 |
+
```shell
|
163 |
+
$ make lint # will run, among others, flake8 and pytype linters
|
164 |
+
$ make test # will run a test on your current virtualenv
|
165 |
+
```
|
166 |
+
|
167 |
+
To run a subset of tests:
|
168 |
+
|
169 |
+
```shell
|
170 |
+
$ pytest tests.test_subset
|
171 |
+
```
|
172 |
+
|
173 |
+
8. Commit your changes and push your branch to GitHub:
|
174 |
+
|
175 |
+
```shell
|
176 |
+
$ git add .
|
177 |
+
$ git commit -s -m "Your detailed description of your changes."
|
178 |
+
$ git push origin name-of-your-bugfix-or-feature
|
179 |
+
```
|
180 |
+
|
181 |
+
9. Submit a pull request through the GitHub website.
|
182 |
+
|
183 |
+
### Pull Request Guidelines
|
184 |
+
|
185 |
+
Before you submit a pull request, check that it meets these guidelines:
|
186 |
+
|
187 |
+
1. The pull request should include tests.
|
188 |
+
2. If the pull request adds functionality, you should update the docs. Put your new functionality into a function with a docstring and add the feature to the list in README.md.
|
189 |
+
|
190 |
+
|
191 |
+
## Documentation
|
192 |
+
|
193 |
+
Add/update docstrings as defined in [Google Style Guide](https://github.com/google/styleguide/blob/gh-pages/pyguide.md#38-comments-and-docstrings).
|
194 |
+
|
195 |
+
## Contributor License Agreement (CLA)
|
196 |
+
|
197 |
+
PyTriton requires that all contributors (or their corporate entity) send
|
198 |
+
a signed copy of the [Contributor License
|
199 |
+
Agreement](https://github.com/NVIDIA/triton-inference-server/blob/master/Triton-CCLA-v1.pdf)
|
200 |
+
to triton-cla@nvidia.com.
|
201 |
+
|
202 |
+
*NOTE*: Contributors with no company affiliation can fill `N/A` in the
|
203 |
+
`Corporation Name` and `Corporation Address` fields.
|
stf/stf-api-alternative/pytriton/COPYRIGHT
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
stf/stf-api-alternative/pytriton/LICENSE
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
stf/stf-api-alternative/pytriton/Makefile
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
.PHONY: clean clean-build clean-tritonserver clean-pyc clean-docs clean-test docs lint test coverage release dist build-triton extract-triton install install-dev help
|
15 |
+
.DEFAULT_GOAL := help
|
16 |
+
|
17 |
+
define BROWSER_PYSCRIPT
|
18 |
+
import os, webbrowser, sys
|
19 |
+
|
20 |
+
from urllib.request import pathname2url
|
21 |
+
|
22 |
+
webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1])))
|
23 |
+
endef
|
24 |
+
export BROWSER_PYSCRIPT
|
25 |
+
|
26 |
+
define PRINT_HELP_PYSCRIPT
|
27 |
+
import re, sys
|
28 |
+
|
29 |
+
for line in sys.stdin:
|
30 |
+
match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line)
|
31 |
+
if match:
|
32 |
+
target, help = match.groups()
|
33 |
+
print("%-20s %s" % (target, help))
|
34 |
+
endef
|
35 |
+
export PRINT_HELP_PYSCRIPT
|
36 |
+
|
37 |
+
BROWSER := python -c "$$BROWSER_PYSCRIPT"
|
38 |
+
PIP_INSTALL := pip install --extra-index-url https://pypi.ngc.nvidia.com
|
39 |
+
TRITONSERVER_IMAGE_VERSION = 23.10
|
40 |
+
TRITONSERVER_IMAGE_NAME = nvcr.io/nvidia/tritonserver:$(TRITONSERVER_IMAGE_VERSION)-pyt-python-py3
|
41 |
+
TRITONSERVER_OUTPUT_DIR = ${PWD}/pytriton/tritonserver
|
42 |
+
TRITONSERVER_BASENAME = pytriton
|
43 |
+
PYTRITON_IMAGE_NAME = $(TRITONSERVER_BASENAME):$(TRITONSERVER_IMAGE_VERSION)
|
44 |
+
# to set PLATFORM from outside, use: make PLATFORM=linux/arm64;
|
45 |
+
# correct values are: linux/amd64 (default), linux/arm64
|
46 |
+
PLATFORM=linux/amd64
|
47 |
+
|
48 |
+
help:
|
49 |
+
@python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST)
|
50 |
+
|
51 |
+
clean: clean-build clean-pyc clean-test clean-tritonserver clean-docs ## remove all build, tritonserver, test, docs, coverage and Python artifacts
|
52 |
+
|
53 |
+
clean-build: ## remove build artifacts
|
54 |
+
rm -fr build/
|
55 |
+
rm -fr dist/
|
56 |
+
rm -fr .eggs/
|
57 |
+
find . -name '*.egg-info' -exec rm -fr {} +
|
58 |
+
find . -name '*.egg' -exec rm -f {} +
|
59 |
+
|
60 |
+
clean-tritonserver:
|
61 |
+
rm -fr pytriton/tritonserver
|
62 |
+
|
63 |
+
clean-pyc: ## remove Python file artifacts
|
64 |
+
find . -name '*.pyc' -exec rm -f {} +
|
65 |
+
find . -name '*.pyo' -exec rm -f {} +
|
66 |
+
find . -name '*~' -exec rm -f {} +
|
67 |
+
find . -name '__pycache__' -exec rm -fr {} +
|
68 |
+
|
69 |
+
clean-docs: ## remove test and coverage artifacts
|
70 |
+
rm -rf site
|
71 |
+
|
72 |
+
clean-test: ## remove test and coverage artifacts
|
73 |
+
rm -fr .tox/
|
74 |
+
rm -f .coverage
|
75 |
+
rm -fr htmlcov/
|
76 |
+
rm -fr .pytest_cache
|
77 |
+
rm -fr .pytype/
|
78 |
+
|
79 |
+
docs: clean-docs ## generate site
|
80 |
+
cp CHANGELOG.md docs
|
81 |
+
cp CONTRIBUTING.md docs
|
82 |
+
cp LICENSE docs/LICENSE.md
|
83 |
+
cp examples/README.md docs/examples.md
|
84 |
+
mkdocs build --clean
|
85 |
+
|
86 |
+
docs-serve: docs
|
87 |
+
mkdocs serve
|
88 |
+
|
89 |
+
lint: ## check style with pre-commit and pytype
|
90 |
+
tox -e pytype,pre-commit --develop
|
91 |
+
|
92 |
+
test: ## run tests on every Python version with tox
|
93 |
+
tox --develop --skip-missing-interpreters
|
94 |
+
|
95 |
+
coverage: ## check code coverage quickly with the default Python
|
96 |
+
coverage run --source pytriton -m pytest
|
97 |
+
coverage report -m
|
98 |
+
coverage html
|
99 |
+
$(BROWSER) htmlcov/index.html
|
100 |
+
|
101 |
+
dist: clean build-triton extract-triton ## builds source and wheel package
|
102 |
+
bash ./scripts/build_wheel.sh $(PLATFORM)
|
103 |
+
ls -lh dist
|
104 |
+
find ./dist -iname *-linux*.whl -type f -exec bash ./scripts/add_libs_to_wheel.sh $(PYTRITON_IMAGE_NAME) $(TRITONSERVER_OUTPUT_DIR) {} $(PLATFORM) \;
|
105 |
+
find ./dist -iname *-linux*.whl -type f -delete
|
106 |
+
ls -lh dist
|
107 |
+
twine check dist/*
|
108 |
+
|
109 |
+
build-triton: ## build Triton with Python Stubs
|
110 |
+
bash ./scripts/build_triton.sh $(TRITONSERVER_IMAGE_NAME) $(PYTRITON_IMAGE_NAME) $(PLATFORM)
|
111 |
+
echo "export PYTRITON_IMAGE_NAME=$(PYTRITON_IMAGE_NAME)" > .env
|
112 |
+
|
113 |
+
extract-triton: build-triton ## extract Triton binaries and libraries
|
114 |
+
# changing dst path, change also in clean-build and pyproject.toml
|
115 |
+
bash ./scripts/extract_triton.sh $(PYTRITON_IMAGE_NAME) $(TRITONSERVER_OUTPUT_DIR) $(PLATFORM)
|
116 |
+
|
117 |
+
|
118 |
+
install: clean extract-triton ## install the package to the active Python's site-packages
|
119 |
+
$(PIP_INSTALL) --upgrade pip
|
120 |
+
$(PIP_INSTALL) .
|
121 |
+
|
122 |
+
install-dev: clean-build clean-pyc
|
123 |
+
$(PIP_INSTALL) --upgrade pip
|
124 |
+
$(PIP_INSTALL) -e .[dev]
|
stf/stf-api-alternative/pytriton/README.md
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!--
|
2 |
+
Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
|
4 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
you may not use this file except in compliance with the License.
|
6 |
+
You may obtain a copy of the License at
|
7 |
+
|
8 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
|
10 |
+
Unless required by applicable law or agreed to in writing, software
|
11 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
See the License for the specific language governing permissions and
|
14 |
+
limitations under the License.
|
15 |
+
-->
|
16 |
+
|
17 |
+
# PyTriton
|
18 |
+
|
19 |
+
PyTriton is a Flask/FastAPI-like interface that simplifies Triton's deployment in Python environments.
|
20 |
+
The library allows serving Machine Learning models directly from Python through
|
21 |
+
NVIDIA's [Triton Inference Server](https://github.com/triton-inference-server).
|
22 |
+
|
23 |
+
<!-- START doctoc generated TOC please keep comment here to allow auto update -->
|
24 |
+
<!-- DON'T EDIT THIS SECTION, INSTEAD RE-RUN doctoc TO UPDATE -->
|
25 |
+
|
26 |
+
- [Documentation](#documentation)
|
27 |
+
- [Feature matrix](#feature-matrix)
|
28 |
+
- [How it works?](#how-it-works)
|
29 |
+
- [Installation](#installation)
|
30 |
+
- [Prerequisites](#prerequisites)
|
31 |
+
- [Install from `pypi`](#install-from-pypi)
|
32 |
+
- [Setting Up Python Environment](#setting-up-python-environment)
|
33 |
+
- [Building binaries from source](#building-binaries-from-source)
|
34 |
+
- [Quick Start](#quick-start)
|
35 |
+
- [Architecture](#architecture)
|
36 |
+
- [Examples](#examples)
|
37 |
+
- [Streaming (alpha)](#streaming-alpha)
|
38 |
+
- [Profiling model](#profiling-model)
|
39 |
+
- [Version management](#version-management)
|
40 |
+
- [Useful Links](#useful-links)
|
41 |
+
|
42 |
+
<!-- END doctoc generated TOC please keep comment here to allow auto update -->
|
43 |
+
|
44 |
+
## Documentation
|
45 |
+
|
46 |
+
Read how to customize the Triton Inference Server, load models, deploy on clusters, and the API reference
|
47 |
+
can be found in the [documentation](https://triton-inference-server.github.io/pytriton). The below sections provide
|
48 |
+
brief information about the product and quick start guide.
|
49 |
+
|
50 |
+
## Feature matrix
|
51 |
+
|
52 |
+
| Feature | Description |
|
53 |
+
| ------- | ----------- |
|
54 |
+
| Native Python support | You can create any Python function and expose it as an HTTP/gRPC API. |
|
55 |
+
| Framework-agnostic | You can run any Python code with any framework of your choice, such as: PyTorch, TensorFlow, or JAX. |
|
56 |
+
| Performance optimization | You can benefit from dynamic batching, response cache, model pipelining, and GPU/CPU inference. |
|
57 |
+
| Easy installation and setup | You can use a simple and familiar interface based on Flask/FastAPI for easy installation and setup. |
|
58 |
+
| Model clients | You can access high-level model clients for HTTP/gRPC requests with configurable options and both synchronous and asynchronous API. |
|
59 |
+
| Streaming (alpha) | You can stream partial responses from a model by serving it in a decoupled mode. |
|
60 |
+
|
61 |
+
## How it works?
|
62 |
+
|
63 |
+
In PyTriton, like in Flask or FastAPI, you can define any Python function that executes a Machine Learning model prediction and exposes
|
64 |
+
it through an HTTP/gRPC API. PyTriton installs Triton Inference Server in your environment and uses it for handling
|
65 |
+
HTTP/gRPC requests and responses. Our library provides a Python API that allows you to attach a Python function to Triton
|
66 |
+
and a communication layer to send/receive data between Triton and the function. The solution enables using the
|
67 |
+
performance features of Triton Inference Server, such as dynamic batching or response cache, without changing your model
|
68 |
+
environment. Thus, it improves the performance of running inference on GPU for models implemented in Python. The solution is
|
69 |
+
framework-agnostic and can be used along with frameworks like PyTorch, TensorFlow, or JAX.
|
70 |
+
|
71 |
+
## Installation
|
72 |
+
|
73 |
+
We assume that you are comfortable with the Python programming language and familiar with Machine Learning models.
|
74 |
+
Using [Docker](https://www.docker.com/) is an option, but not mandatory.
|
75 |
+
|
76 |
+
The library can be installed in:
|
77 |
+
|
78 |
+
- system environment
|
79 |
+
- virtualenv
|
80 |
+
- [Docker](https://www.docker.com/) image
|
81 |
+
|
82 |
+
NVIDIA optimized Docker images for Python frameworks can be obtained from the [NVIDIA NGC Catalog](https://catalog.ngc.nvidia.com/containers).
|
83 |
+
|
84 |
+
If you want to use the Docker runtime, we recommend that you install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/overview.html) to
|
85 |
+
enable running model inference on NVIDIA GPU.
|
86 |
+
|
87 |
+
### Prerequisites
|
88 |
+
|
89 |
+
Before installing the library, ensure that you meet the following requirements:
|
90 |
+
|
91 |
+
- An operating system with glibc >= `2.35`.
|
92 |
+
- Triton Inference Server and PyTriton have **only** been rigorously tested on Ubuntu 22.04.
|
93 |
+
- Other supported operating systems include Ubuntu Debian 11+, Rocky Linux 9+, and Red Hat Universal Base Image 9+.
|
94 |
+
- To check your glibc version, run `ldd --version`
|
95 |
+
- Python version >= `3.8`
|
96 |
+
- Use `pip >= `20.3`
|
97 |
+
- Install `libpython3.*.so` in the operating system (appropriate for Python version).
|
98 |
+
|
99 |
+
### Install from `pypi`
|
100 |
+
|
101 |
+
The PyTriton can be installed from [pypi.org](https://pypi.org/project/nvidia-pytriton/) by running the following command:
|
102 |
+
|
103 |
+
```shell
|
104 |
+
pip install -U nvidia-pytriton
|
105 |
+
```
|
106 |
+
|
107 |
+
**Important**: The Triton Inference Server binary is installed as part of the PyTriton package.
|
108 |
+
|
109 |
+
More details about installation can be found in the [documentation](https://triton-inference-server.github.io/pytriton/latest/installation/).
|
110 |
+
|
111 |
+
|
112 |
+
### Setting Up Python Environment
|
113 |
+
|
114 |
+
The PyTriton requires installation and linking `libpython3.*.so`. Read more in "[Setting Up Python Environment](https://triton-inference-server.github.io/pytriton/latest/installation#setting-up-python-environment)"
|
115 |
+
for additional information how to configure system for different Python versions.
|
116 |
+
|
117 |
+
### Building binaries from source
|
118 |
+
|
119 |
+
The binary package can be built from the source, allowing access to unreleased hotfixes, the ability to modify the PyTriton code, and compatibility with various Triton Inference Server versions, including custom server builds.
|
120 |
+
For further information on building the PyTriton binary, refer to the [Building](https://triton-inference-server.github.io/pytriton/latest/building/) page of documentation.
|
121 |
+
|
122 |
+
## Quick Start
|
123 |
+
|
124 |
+
The quick start presents how to run Python model in Triton Inference Server without need to change the current working
|
125 |
+
environment. In the example we are using a simple `Linear` PyTorch model.
|
126 |
+
|
127 |
+
The requirement for the example is to have installed PyTorch in your environment. You can do it running:
|
128 |
+
|
129 |
+
```shell
|
130 |
+
pip install torch
|
131 |
+
```
|
132 |
+
|
133 |
+
The integration of model requires to provide following elements:
|
134 |
+
|
135 |
+
- The model - framework or Python model or function that handle inference requests
|
136 |
+
- Inference callback - a lambda or function which handle the input data coming from Triton and return the result
|
137 |
+
- Python function connection with Triton Inference Server - a binding for communication between Triton and Python
|
138 |
+
callback
|
139 |
+
|
140 |
+
In the next step define the `Linear` model:
|
141 |
+
|
142 |
+
```python
|
143 |
+
import torch
|
144 |
+
|
145 |
+
model = torch.nn.Linear(2, 3).to("cuda").eval()
|
146 |
+
```
|
147 |
+
|
148 |
+
In the second step, create an inference callable as a function. The function obtains the HTTP/gRPC request data as an argument, which should be in the form of a NumPy array. The expected return object should also be a NumPy array. You can define an inference callable as a function that uses the `@batch` decorator from PyTriton. This decorator converts the input request into a more suitable format that can be directly passed to the model. You can read more about [decorators here](docs/decorators.md).
|
149 |
+
|
150 |
+
Example implementation:
|
151 |
+
|
152 |
+
<!--pytest-codeblocks:cont-->
|
153 |
+
|
154 |
+
```python
|
155 |
+
import numpy as np
|
156 |
+
from pytriton.decorators import batch
|
157 |
+
|
158 |
+
|
159 |
+
@batch
|
160 |
+
def infer_fn(**inputs: np.ndarray):
|
161 |
+
(input1_batch,) = inputs.values()
|
162 |
+
input1_batch_tensor = torch.from_numpy(input1_batch).to("cuda")
|
163 |
+
output1_batch_tensor = model(input1_batch_tensor) # Calling the Python model inference
|
164 |
+
output1_batch = output1_batch_tensor.cpu().detach().numpy()
|
165 |
+
return [output1_batch]
|
166 |
+
```
|
167 |
+
|
168 |
+
In the next step, you can create the binding between the inference callable and Triton Inference Server using the `bind` method from pyTriton. This method takes the model name, the inference callable, the inputs and outputs tensors, and an optional model configuration object.
|
169 |
+
|
170 |
+
<!--pytest-codeblocks:cont-->
|
171 |
+
|
172 |
+
```python
|
173 |
+
from pytriton.model_config import ModelConfig, Tensor
|
174 |
+
from pytriton.triton import Triton
|
175 |
+
|
176 |
+
# Connecting inference callable with Triton Inference Server
|
177 |
+
with Triton() as triton:
|
178 |
+
# Load model into Triton Inference Server
|
179 |
+
triton.bind(
|
180 |
+
model_name="Linear",
|
181 |
+
infer_func=infer_fn,
|
182 |
+
inputs=[
|
183 |
+
Tensor(dtype=np.float32, shape=(-1,)),
|
184 |
+
],
|
185 |
+
outputs=[
|
186 |
+
Tensor(dtype=np.float32, shape=(-1,)),
|
187 |
+
],
|
188 |
+
config=ModelConfig(max_batch_size=128)
|
189 |
+
)
|
190 |
+
...
|
191 |
+
```
|
192 |
+
|
193 |
+
Finally, serve the model with the Triton Inference Server:
|
194 |
+
|
195 |
+
<!--pytest.mark.skip-->
|
196 |
+
|
197 |
+
```python
|
198 |
+
from pytriton.triton import Triton
|
199 |
+
|
200 |
+
with Triton() as triton:
|
201 |
+
... # Load models here
|
202 |
+
triton.serve()
|
203 |
+
```
|
204 |
+
|
205 |
+
The `bind` method creates a connection between the Triton Inference Server and the `infer_fn`, which handles
|
206 |
+
the inference queries. The `inputs` and `outputs` describe the model inputs and outputs that are exposed in
|
207 |
+
Triton. The config field allows more parameters for model deployment.
|
208 |
+
|
209 |
+
The `serve` method is blocking, and at this point, the application waits for incoming HTTP/gRPC requests. From that
|
210 |
+
moment, the model is available under the name `Linear` in the Triton server. The inference queries can be sent to
|
211 |
+
`localhost:8000/v2/models/Linear/infer`, which are passed to the `infer_fn` function.
|
212 |
+
|
213 |
+
If you would like to use Triton in the background mode, use `run`. More about that can be found
|
214 |
+
in the [Deploying Models](https://triton-inference-server.github.io/pytriton/latest/initialization/) page.
|
215 |
+
|
216 |
+
Once the `serve` or `run` method is called on the `Triton` object, the server status can be obtained using:
|
217 |
+
|
218 |
+
<!--pytest.mark.skip-->
|
219 |
+
|
220 |
+
```shell
|
221 |
+
curl -v localhost:8000/v2/health/live
|
222 |
+
```
|
223 |
+
|
224 |
+
The model is loaded right after the server starts, and its status can be queried using:
|
225 |
+
|
226 |
+
<!--pytest.mark.skip-->
|
227 |
+
|
228 |
+
```shell
|
229 |
+
curl -v localhost:8000/v2/models/Linear/ready
|
230 |
+
```
|
231 |
+
|
232 |
+
Finally, you can send an inference query to the model:
|
233 |
+
|
234 |
+
<!--pytest.mark.skip-->
|
235 |
+
|
236 |
+
```shell
|
237 |
+
curl -X POST \
|
238 |
+
-H "Content-Type: application/json" \
|
239 |
+
-d @input.json \
|
240 |
+
localhost:8000/v2/models/Linear/infer
|
241 |
+
```
|
242 |
+
|
243 |
+
The `input.json` with sample query:
|
244 |
+
|
245 |
+
```json
|
246 |
+
{
|
247 |
+
"id": "0",
|
248 |
+
"inputs": [
|
249 |
+
{
|
250 |
+
"name": "INPUT_1",
|
251 |
+
"shape": [1, 2],
|
252 |
+
"datatype": "FP32",
|
253 |
+
"parameters": {},
|
254 |
+
"data": [[-0.04281254857778549, 0.6738349795341492]]
|
255 |
+
}
|
256 |
+
]
|
257 |
+
}
|
258 |
+
```
|
259 |
+
|
260 |
+
Read more about the HTTP/gRPC interface in the Triton Inference Server
|
261 |
+
[documentation](https://github.com/triton-inference-server/server/blob/main/docs/customization_guide/inference_protocols.md#httprest-and-grpc-protocols).
|
262 |
+
|
263 |
+
You can also validate the deployed model using a simple client that can perform inference requests:
|
264 |
+
|
265 |
+
<!--pytest.mark.skip-->
|
266 |
+
|
267 |
+
```python
|
268 |
+
import torch
|
269 |
+
from pytriton.client import ModelClient
|
270 |
+
|
271 |
+
input1_data = torch.randn(128, 2).cpu().detach().numpy()
|
272 |
+
|
273 |
+
with ModelClient("localhost:8000", "Linear") as client:
|
274 |
+
result_dict = client.infer_batch(input1_data)
|
275 |
+
|
276 |
+
print(result_dict)
|
277 |
+
```
|
278 |
+
|
279 |
+
The full example code can be found in [examples/linear_random_pytorch](examples/linear_random_pytorch).
|
280 |
+
|
281 |
+
You can learn more about client usage in the [Clients](https://triton-inference-server.github.io/pytriton/latest/clients/) document.
|
282 |
+
|
283 |
+
More information about running the server and models can be found
|
284 |
+
in [Deploying Models](https://triton-inference-server.github.io/pytriton/latest/initialization/) page of documentation.
|
285 |
+
|
286 |
+
## Architecture
|
287 |
+
|
288 |
+
The diagram below presents the schema of how the Python models are served through Triton Inference Server using
|
289 |
+
PyTriton. The solution consists of two main components:
|
290 |
+
|
291 |
+
- Triton Inference Server: for exposing the HTTP/gRPC API and benefiting from performance features like dynamic batching
|
292 |
+
or response cache.
|
293 |
+
- Python Model Environment: your environment where the Python model is executed.
|
294 |
+
|
295 |
+
The Triton Inference Server binaries are provided as part of the PyTriton installation. The Triton Server is
|
296 |
+
installed in your current environment (system or container). The PyTriton controls the Triton Server process
|
297 |
+
through the `Triton Controller`.
|
298 |
+
|
299 |
+
Exposing the model through PyTriton requires the definition of an `Inference Callable` - a Python function that is
|
300 |
+
connected to Triton Inference Server and executes the model or ensemble for predictions. The integration layer binds
|
301 |
+
the `Inference Callable` to Triton Server and exposes it through the Triton HTTP/gRPC API under a provided `<model name>`. Once
|
302 |
+
the integration is done, the defined `Inference Callable` receives data sent to the HTTP/gRPC API endpoint
|
303 |
+
`v2/models/<model name>/infer`. Read more about HTTP/gRPC interface in Triton Inference Server
|
304 |
+
[documentation](https://github.com/triton-inference-server/server/blob/main/docs/customization_guide/inference_protocols.md#httprest-and-grpc-protocols).
|
305 |
+
|
306 |
+
The HTTP/gRPC requests sent to `v2/models/<model name>/infer` are handled by Triton
|
307 |
+
Inference Server. The server batches requests and passes them to the `Proxy Backend`, which sends the batched requests to the appropriate
|
308 |
+
`Inference Callable`. The data is sent as a `numpy` array. Once the `Inference Callable` finishes execution of
|
309 |
+
the model prediction, the result is returned to the `Proxy Backend`, and a response is created by Triton Server.
|
310 |
+
|
311 |
+
![High Level Design](docs/assets/hld.svg)
|
312 |
+
|
313 |
+
|
314 |
+
|
315 |
+
|
316 |
+
## Examples
|
317 |
+
|
318 |
+
The [examples](examples) page presents various cases of serving models using PyTriton. You can find simple examples of
|
319 |
+
running PyTorch, TensorFlow2, JAX, and simple Python models. Additionally, we have prepared more advanced scenarios like online
|
320 |
+
learning, multi-node models, or deployment on Kubernetes using PyTriton. Each example contains instructions describing
|
321 |
+
how to build and run the example. Learn more about how to use PyTriton by reviewing our [examples](examples).
|
322 |
+
|
323 |
+
### Streaming (alpha)
|
324 |
+
|
325 |
+
We introduced new alpha feature to PyTriton that allows to stream partial responses from a model. It is based on NVIDIA Triton Inference deocoupled models feature. Look at example in [examples/huggingface_dialogpt_streaming_pytorch](examples/huggingface_dialogpt_streaming_pytorch).
|
326 |
+
|
327 |
+
### Profiling model
|
328 |
+
|
329 |
+
The [Perf Analyzer](https://github.com/triton-inference-server/client/blob/main/src/c++/perf_analyzer/README.md) can be
|
330 |
+
used to profile models served through PyTriton. We have prepared an example of
|
331 |
+
using the Perf Analyzer to profile the BART PyTorch model. The example code can be found
|
332 |
+
in [examples/perf_analyzer](examples/perf_analyzer).
|
333 |
+
|
334 |
+
## Version management
|
335 |
+
|
336 |
+
PyTriton follows the [Semantic Versioning](https://semver.org/) scheme for versioning. Official releases can be found on [PyPI](https://pypi.org/project/nvidia-pytriton/) and [GitHub releases](https://github.com/triton-inference-server/pytriton/releases). The most up-to-date development version is available on the `main` branch, which may include hotfixes that have not yet been released through the standard channels. To install the latest development version, refer to the instructions in the
|
337 |
+
[building binaries from source](#building-binaries-from-source) section.
|
338 |
+
|
339 |
+
## Useful Links
|
340 |
+
|
341 |
+
- [Changelog](CHANGELOG.md)
|
342 |
+
- [Known Issues](https://triton-inference-server.github.io/pytriton/latest/known_issues)
|
343 |
+
- [Contributing](CONTRIBUTING.md)
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# noqa: D104
|
15 |
+
from importlib.metadata import PackageNotFoundError, version
|
16 |
+
|
17 |
+
try:
|
18 |
+
__version__ = version("nvidia-pytriton")
|
19 |
+
except PackageNotFoundError:
|
20 |
+
# package is not installed
|
21 |
+
pass
|
22 |
+
|
23 |
+
from pytriton import (
|
24 |
+
client, # noqa: F401
|
25 |
+
model_config, # noqa: F401
|
26 |
+
triton, # noqa: F401
|
27 |
+
)
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/__main__.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Pytriton check module."""
|
15 |
+
|
16 |
+
import logging
|
17 |
+
import os
|
18 |
+
import pathlib
|
19 |
+
import shutil
|
20 |
+
import tempfile
|
21 |
+
from typing import Optional
|
22 |
+
|
23 |
+
import typer
|
24 |
+
from typing_extensions import Annotated
|
25 |
+
|
26 |
+
from pytriton.check.add_sub import add_sub_example, add_sub_example_thread
|
27 |
+
from pytriton.check.env_checks import env_checks
|
28 |
+
|
29 |
+
warning_message = """
|
30 |
+
+---------------------------------------------------------------+
|
31 |
+
| WARNING |
|
32 |
+
+---------------------------------------------------------------+
|
33 |
+
| Command may collect sensitive information, please review the |
|
34 |
+
| log and the ZIP before sharing. |
|
35 |
+
+---------------------------------------------------------------+
|
36 |
+
"""
|
37 |
+
|
38 |
+
|
39 |
+
app = typer.Typer(help="Pytriton check tool.\n\nThis tool is used to check the environment and run examples.")
|
40 |
+
|
41 |
+
|
42 |
+
class CheckEnvironment:
|
43 |
+
"""Check environment class.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
workspace_path: Path to workspace
|
47 |
+
name: Name of the sub_workspace
|
48 |
+
zip_results: Flag if results should be zipped
|
49 |
+
check_workspace_exist: Flag if workspace should be checked if exists
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
workspace_path: Optional[pathlib.Path],
|
55 |
+
name: str,
|
56 |
+
zip_results: bool = True,
|
57 |
+
check_workspace_exist: bool = True,
|
58 |
+
):
|
59 |
+
"""Initialize class."""
|
60 |
+
self.name = name
|
61 |
+
self._zip_results = zip_results
|
62 |
+
self._temp_workspace = None
|
63 |
+
|
64 |
+
self.logger = logging.getLogger(name)
|
65 |
+
if check_workspace_exist and workspace_path is not None and workspace_path.exists():
|
66 |
+
self.logger.error(f"Workspace path {workspace_path} already exists")
|
67 |
+
raise typer.Exit(code=1)
|
68 |
+
if workspace_path is None:
|
69 |
+
self._temp_workspace = tempfile.TemporaryDirectory(prefix="pytriton_workspace_")
|
70 |
+
workspace_path = pathlib.Path(self._temp_workspace.name)
|
71 |
+
else:
|
72 |
+
workspace_path.mkdir(parents=True, exist_ok=True)
|
73 |
+
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")
|
74 |
+
self.logger.addHandler(logging.FileHandler(workspace_path / (name + "_log.txt")))
|
75 |
+
self.workspace_path = workspace_path
|
76 |
+
self.sub_workspace = workspace_path / name
|
77 |
+
|
78 |
+
def __enter__(self):
|
79 |
+
"""Enter method."""
|
80 |
+
return self
|
81 |
+
|
82 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
83 |
+
"""Exit method zips results if required."""
|
84 |
+
self.zip_results()
|
85 |
+
|
86 |
+
def zip_results(self):
|
87 |
+
"""Zip results."""
|
88 |
+
if self._zip_results:
|
89 |
+
if self.workspace_path.exists():
|
90 |
+
if self._temp_workspace is not None:
|
91 |
+
output_file_base = pathlib.Path(os.getcwd()) / self.workspace_path.name
|
92 |
+
else:
|
93 |
+
output_file_base = self.workspace_path
|
94 |
+
self.logger.info(f"Zipping {self.workspace_path} to {output_file_base}.zip")
|
95 |
+
shutil.make_archive(str(output_file_base.resolve()), "zip", str(self.workspace_path.resolve()))
|
96 |
+
else:
|
97 |
+
self.logger.error(f"Workspace path {self.workspace_path} does not exist")
|
98 |
+
|
99 |
+
|
100 |
+
@app.command("example-add-sub-script")
|
101 |
+
def example_add_sub_script(
|
102 |
+
workspace: Annotated[Optional[pathlib.Path], typer.Option("--workspace", "-w")] = None,
|
103 |
+
zip_results: Annotated[bool, typer.Option("--zip")] = True,
|
104 |
+
):
|
105 |
+
"""Run example using external script.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
workspace: Workspace path that will be created to store testing output (should not exist)
|
109 |
+
zip_results: flag if output should be zipped
|
110 |
+
"""
|
111 |
+
with CheckEnvironment(workspace, "example_add_sub_script", zip_results) as ce:
|
112 |
+
try:
|
113 |
+
add_sub_example_thread(ce.sub_workspace, ce.logger)
|
114 |
+
except Exception as e:
|
115 |
+
ce.logger.error(f"Error occurred in command: {e}")
|
116 |
+
|
117 |
+
|
118 |
+
@app.command("example-add-sub")
|
119 |
+
def example_add_sub(
|
120 |
+
workspace: Annotated[Optional[pathlib.Path], typer.Option("--workspace", "-w")] = None,
|
121 |
+
zip_results: Annotated[bool, typer.Option("--zip")] = True,
|
122 |
+
):
|
123 |
+
"""Run example.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
workspace: Workspace path that will be created to store testing output (should not exist)
|
127 |
+
zip_results: flag if output should be zipped
|
128 |
+
"""
|
129 |
+
with CheckEnvironment(workspace, "example_add_sub", zip_results) as ce:
|
130 |
+
try:
|
131 |
+
add_sub_example(ce.sub_workspace, ce.logger)
|
132 |
+
except Exception as e:
|
133 |
+
ce.logger.error(f"Error occurred in command: {e}")
|
134 |
+
|
135 |
+
|
136 |
+
@app.command("examples")
|
137 |
+
def examples(
|
138 |
+
workspace: Annotated[Optional[pathlib.Path], typer.Option("--workspace", "-w")] = None,
|
139 |
+
zip_results: Annotated[bool, typer.Option("--zip")] = True,
|
140 |
+
):
|
141 |
+
"""Run example in the same process.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
workspace: Workspace path that will be created to store testing output (should not exist)
|
145 |
+
zip_results: flag if output should be zipped
|
146 |
+
"""
|
147 |
+
with CheckEnvironment(workspace, "example_add_sub", zip_results) as ce:
|
148 |
+
try:
|
149 |
+
add_sub_example(ce.sub_workspace, ce.logger)
|
150 |
+
except Exception as e:
|
151 |
+
ce.logger.error(f"Error occurred in command: {e}")
|
152 |
+
|
153 |
+
with CheckEnvironment(workspace, "example_add_sub_script", zip_results, check_workspace_exist=False) as ce:
|
154 |
+
try:
|
155 |
+
add_sub_example_thread(ce.sub_workspace, ce.logger)
|
156 |
+
except Exception as e:
|
157 |
+
ce.logger.error(f"Error occurred in command: {e}")
|
158 |
+
|
159 |
+
|
160 |
+
@app.command("env")
|
161 |
+
def env_check(
|
162 |
+
workspace: Annotated[Optional[pathlib.Path], typer.Option("--workspace", "-w")] = None,
|
163 |
+
zip_results: Annotated[bool, typer.Option("--zip")] = True,
|
164 |
+
):
|
165 |
+
"""Run all environment checks.
|
166 |
+
|
167 |
+
It may collect sensitive system information in the log. Please review the log before sharing.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
workspace: Workspace path that will be created to store testing output (should not exist)
|
171 |
+
zip_results: flag if output should be zipped
|
172 |
+
"""
|
173 |
+
with CheckEnvironment(workspace, "env_checks", zip_results) as ce:
|
174 |
+
try:
|
175 |
+
env_checks(ce.logger)
|
176 |
+
except Exception as e:
|
177 |
+
ce.logger.error(f"Error occurred in command: {e}")
|
178 |
+
|
179 |
+
|
180 |
+
@app.command("check")
|
181 |
+
def check(
|
182 |
+
workspace: Annotated[Optional[pathlib.Path], typer.Option("--workspace", "-w")] = None,
|
183 |
+
zip_results: Annotated[bool, typer.Option("--zip")] = True,
|
184 |
+
):
|
185 |
+
"""Run all checks.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
workspace: Workspace path that will be created to store testing output (should not exist)
|
189 |
+
zip_results: flag if output should be zipped
|
190 |
+
"""
|
191 |
+
with CheckEnvironment(workspace, "all_checks", zip_results) as ce:
|
192 |
+
try:
|
193 |
+
ce.logger.info("Running all common checks")
|
194 |
+
env_check(ce.workspace_path / "env", False)
|
195 |
+
examples(ce.workspace_path / "examples", False)
|
196 |
+
except Exception as e:
|
197 |
+
ce.logger.error(f"Error occurred in command: {e}")
|
198 |
+
|
199 |
+
|
200 |
+
@app.callback(invoke_without_command=True)
|
201 |
+
def default_command(ctx: typer.Context):
|
202 |
+
"""Default command."""
|
203 |
+
if ctx.invoked_subcommand is None:
|
204 |
+
check()
|
205 |
+
|
206 |
+
|
207 |
+
def main():
|
208 |
+
"""Main function."""
|
209 |
+
logger = logging.getLogger("PyTriton-Check")
|
210 |
+
try:
|
211 |
+
logger.warning(warning_message)
|
212 |
+
app()
|
213 |
+
finally:
|
214 |
+
logger.warning(warning_message)
|
215 |
+
|
216 |
+
|
217 |
+
if __name__ == "__main__":
|
218 |
+
main()
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/check/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# noqa: D104
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/check/add_sub.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Add_sub example model for checking corectness of triton environment."""
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import logging
|
19 |
+
import pathlib
|
20 |
+
import signal
|
21 |
+
import sys
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
from pytriton.check.utils import ScriptThread
|
26 |
+
from pytriton.client import ModelClient
|
27 |
+
from pytriton.decorators import batch
|
28 |
+
from pytriton.model_config import ModelConfig, Tensor
|
29 |
+
from pytriton.triton import Triton
|
30 |
+
|
31 |
+
logger = logging.getLogger("check.add_sub_example")
|
32 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")
|
33 |
+
add_script_path = [sys.executable, "pytriton/check/add_sub.py"]
|
34 |
+
|
35 |
+
|
36 |
+
@batch
|
37 |
+
def _add_sub(**inputs):
|
38 |
+
a_batch, b_batch = inputs.values()
|
39 |
+
add_batch = a_batch + b_batch
|
40 |
+
sub_batch = a_batch - b_batch
|
41 |
+
return {"add": add_batch, "sub": sub_batch}
|
42 |
+
|
43 |
+
|
44 |
+
def prepare_triton(workspace: pathlib.Path):
|
45 |
+
"""Prepare triton server with AddSub model."""
|
46 |
+
triton = Triton(workspace=str(workspace.resolve()))
|
47 |
+
triton.run()
|
48 |
+
logger.info("Loading AddSub model")
|
49 |
+
triton.bind(
|
50 |
+
model_name="AddSub",
|
51 |
+
infer_func=_add_sub,
|
52 |
+
inputs=[
|
53 |
+
Tensor(dtype=np.float32, shape=(-1,)),
|
54 |
+
Tensor(dtype=np.float32, shape=(-1,)),
|
55 |
+
],
|
56 |
+
outputs=[
|
57 |
+
Tensor(name="add", dtype=np.float32, shape=(-1,)),
|
58 |
+
Tensor(name="sub", dtype=np.float32, shape=(-1,)),
|
59 |
+
],
|
60 |
+
config=ModelConfig(max_batch_size=128),
|
61 |
+
strict=True,
|
62 |
+
)
|
63 |
+
return triton
|
64 |
+
|
65 |
+
|
66 |
+
def infer_add_sub_model():
|
67 |
+
"""Infer AddSub model."""
|
68 |
+
batch_size = 2
|
69 |
+
a_batch = np.ones((batch_size, 1), dtype=np.float32)
|
70 |
+
b_batch = np.ones((batch_size, 1), dtype=np.float32)
|
71 |
+
|
72 |
+
logger.info(f"a: {a_batch.tolist()}")
|
73 |
+
logger.info(f"b: {b_batch.tolist()}")
|
74 |
+
|
75 |
+
with ModelClient("localhost", "AddSub") as client:
|
76 |
+
logger.info("Sending inference request")
|
77 |
+
result_batch = client.infer_batch(a_batch, b_batch)
|
78 |
+
|
79 |
+
for output_name, data_batch in result_batch.items():
|
80 |
+
logger.info(f"{output_name}: {data_batch.tolist()}")
|
81 |
+
|
82 |
+
|
83 |
+
def serve_triton(workspace: pathlib.Path):
|
84 |
+
"""Serve triton server with AddSub model."""
|
85 |
+
triton = prepare_triton(workspace)
|
86 |
+
logger.info("Serving AddSub model")
|
87 |
+
triton.serve()
|
88 |
+
|
89 |
+
|
90 |
+
def add_sub_example_thread(workspace: pathlib.Path, logger: logging.Logger):
|
91 |
+
"""Run example using external script.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
workspace: Workspace path that will be created to store testing output (should not exist)
|
95 |
+
logger: logger instance
|
96 |
+
"""
|
97 |
+
logger.info("Running example model using external script")
|
98 |
+
|
99 |
+
with ScriptThread(add_script_path + ["--workspace", str(workspace.resolve())], name="server") as server_thread:
|
100 |
+
import time
|
101 |
+
|
102 |
+
time.sleep(3)
|
103 |
+
infer_add_sub_model()
|
104 |
+
|
105 |
+
if server_thread.process:
|
106 |
+
server_thread.process.send_signal(signal.SIGINT)
|
107 |
+
|
108 |
+
server_thread.join()
|
109 |
+
logger.error(server_thread.output)
|
110 |
+
if server_thread.returncode not in [
|
111 |
+
0,
|
112 |
+
-2,
|
113 |
+
]:
|
114 |
+
logger.error(f"Server failed - return code {server_thread.returncode}")
|
115 |
+
|
116 |
+
|
117 |
+
def add_sub_example(workspace: pathlib.Path, logger: logging.Logger):
|
118 |
+
"""Run example in the same process.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
workspace: Workspace path that will be created to store testing output (should not exist)
|
122 |
+
logger: logger instance
|
123 |
+
"""
|
124 |
+
logger.info("Running example model")
|
125 |
+
triton = prepare_triton(workspace)
|
126 |
+
infer_add_sub_model()
|
127 |
+
triton.stop()
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
parser = argparse.ArgumentParser()
|
132 |
+
parser.add_argument("--workspace", help="Workspace path", type=str)
|
133 |
+
parser.add_argument("--infer", default=False, help="Infer AddSub model", action="store_true")
|
134 |
+
args = parser.parse_args()
|
135 |
+
|
136 |
+
if args.infer:
|
137 |
+
infer_add_sub_model()
|
138 |
+
else:
|
139 |
+
serve_triton(pathlib.Path(args.workspace))
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/check/env_checks.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Environment checks."""
|
15 |
+
|
16 |
+
import logging
|
17 |
+
import os
|
18 |
+
import pathlib
|
19 |
+
import platform
|
20 |
+
import re
|
21 |
+
import sys
|
22 |
+
|
23 |
+
import psutil
|
24 |
+
|
25 |
+
from pytriton.check.utils import ScriptThread
|
26 |
+
|
27 |
+
|
28 |
+
def nvidia_smi(logger):
|
29 |
+
"""Run nvidia-smi.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
logger: logger instance
|
33 |
+
"""
|
34 |
+
logger.info("Running nvidia-smi")
|
35 |
+
with ScriptThread(["nvidia-smi"], name="nvidia-smi") as nvidia_smi_thread:
|
36 |
+
nvidia_smi_thread.join()
|
37 |
+
logger.info(nvidia_smi_thread.output)
|
38 |
+
if nvidia_smi_thread.returncode != 0:
|
39 |
+
logger.error("nvidia-smi failed - possible cause: no GPU available or driver not installed")
|
40 |
+
logger.error(
|
41 |
+
"If running in WSL wit sudo, make sure to add nvidia-smi folder (e.g. /usr/lib/wsl/lib) to sudoers file!"
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def get_platform_info(logger):
|
46 |
+
"""Get platform information (OS, python, etc.).
|
47 |
+
|
48 |
+
Args:
|
49 |
+
logger: logger instance
|
50 |
+
"""
|
51 |
+
logger.info("Checking OS version")
|
52 |
+
logger.info("Script is running in docker:" + str(pathlib.Path("/.dockerenv").exists()))
|
53 |
+
|
54 |
+
os_release_path = pathlib.Path("/etc/os-release")
|
55 |
+
if os_release_path.exists():
|
56 |
+
with os_release_path.open() as f:
|
57 |
+
os_release = f.read()
|
58 |
+
logger.info("OS release")
|
59 |
+
logger.info(os_release)
|
60 |
+
for line in os_release.split("\n"):
|
61 |
+
if "PRETTY_NAME" in line:
|
62 |
+
os_version = line.split("=")[1].strip()
|
63 |
+
logger.info(f"OS version: {os_version}")
|
64 |
+
else:
|
65 |
+
logger.warning("OS release file not found (not available on some systems")
|
66 |
+
|
67 |
+
logger.info("Get platform info")
|
68 |
+
logger.info(f"Platform: {platform.platform()}")
|
69 |
+
logger.info(f"System: {platform.system()}")
|
70 |
+
logger.info(f"Release: {platform.release()}")
|
71 |
+
logger.info(f"Version: {platform.version()}")
|
72 |
+
logger.info(f"Machine: {platform.machine()}")
|
73 |
+
logger.info(f"Processor: {platform.processor()}")
|
74 |
+
logger.info(f"Python version: {platform.python_version()}")
|
75 |
+
logger.info(f"Python implementation: {platform.python_implementation()}")
|
76 |
+
logger.info(f"Python compiler: {platform.python_compiler()}")
|
77 |
+
logger.info(f"Python build: {platform.python_build()}")
|
78 |
+
logger.info(f"libc_ver: {platform.libc_ver()}")
|
79 |
+
|
80 |
+
|
81 |
+
def check_psutil_stats(logger):
|
82 |
+
"""Check psutil stats.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
logger: logger instance
|
86 |
+
"""
|
87 |
+
logger.info("Checking psutil stats")
|
88 |
+
logger.info("Memory stats")
|
89 |
+
logger.info(psutil.virtual_memory())
|
90 |
+
logger.info("Swap stats")
|
91 |
+
logger.info(psutil.swap_memory())
|
92 |
+
logger.info("Disk stats")
|
93 |
+
logger.info(psutil.disk_usage("/"))
|
94 |
+
logger.info("Disk io countwers")
|
95 |
+
logger.info(psutil.disk_io_counters())
|
96 |
+
logger.info("CPU stats")
|
97 |
+
logger.info(psutil.cpu_times())
|
98 |
+
logger.info("Network stats")
|
99 |
+
logger.info(psutil.net_io_counters())
|
100 |
+
|
101 |
+
|
102 |
+
def get_listening_processes(logger):
|
103 |
+
"""Get listening processes.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
logger: logger instance
|
107 |
+
"""
|
108 |
+
logger.info("Listening processes")
|
109 |
+
processes = {proc.pid: proc.name for proc in psutil.process_iter(["pid", "name"])}
|
110 |
+
connections = psutil.net_connections()
|
111 |
+
listening_sockets = [conn for conn in connections if conn.status == "LISTEN"]
|
112 |
+
|
113 |
+
for listening_socket in listening_sockets:
|
114 |
+
process_name = None
|
115 |
+
if listening_socket.pid is not None and listening_socket.pid in processes:
|
116 |
+
process_name = processes[listening_socket.pid]
|
117 |
+
logger.info(
|
118 |
+
f"Process ID: {listening_socket.pid}, Name: {process_name}, Local Address: {listening_socket.laddr}, Remote Address: {listening_socket.raddr}, Status: {listening_socket.status}"
|
119 |
+
)
|
120 |
+
|
121 |
+
|
122 |
+
def installed_packages(logger):
|
123 |
+
"""Get installed packages.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
logger: logger instance
|
127 |
+
"""
|
128 |
+
logger.info("Checking installed packages")
|
129 |
+
import importlib_metadata
|
130 |
+
|
131 |
+
packages = importlib_metadata.distributions()
|
132 |
+
|
133 |
+
installed_pkg = sorted([f"{package.metadata['Name']}=={package.version} ({package._path})" for package in packages])
|
134 |
+
installed_pkg_str = "\n[\n\t" + ",\n\t".join(installed_pkg) + "\n]"
|
135 |
+
logger.info(installed_pkg_str)
|
136 |
+
|
137 |
+
|
138 |
+
def check_compiler_and_clib(logger):
|
139 |
+
"""Check compiler and C libraries.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
logger: logger instance
|
143 |
+
"""
|
144 |
+
logger.info("Checking compiler and C libraries")
|
145 |
+
with ScriptThread(["gcc", "--version"], name="gcc_version") as gcc_version_thread:
|
146 |
+
gcc_version_thread.join()
|
147 |
+
logger.info("GCC version:")
|
148 |
+
logger.info(gcc_version_thread.output)
|
149 |
+
if gcc_version_thread.returncode != 0:
|
150 |
+
logger.error("gcc failed")
|
151 |
+
|
152 |
+
logger.info("Python version:")
|
153 |
+
logger.info(sys.version)
|
154 |
+
|
155 |
+
try:
|
156 |
+
logger.info(os.confstr("CS_GNU_LIBC_VERSION"))
|
157 |
+
except AttributeError as e:
|
158 |
+
logger.error(f"Failed to get glibc version {e}")
|
159 |
+
|
160 |
+
|
161 |
+
def log_env_variables(logger):
|
162 |
+
"""Log environment variables.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
logger: logger instance
|
166 |
+
"""
|
167 |
+
logger.info("Environment variables")
|
168 |
+
|
169 |
+
env_vars = os.environ.items()
|
170 |
+
blacklist_patterns = [
|
171 |
+
r".*token.*",
|
172 |
+
r".*secret.*",
|
173 |
+
r".*key.*",
|
174 |
+
r".*password.*",
|
175 |
+
]
|
176 |
+
|
177 |
+
patterns = [re.compile(pattern, re.IGNORECASE) for pattern in blacklist_patterns]
|
178 |
+
filtered_env_vars = [
|
179 |
+
f"{key}={value}"
|
180 |
+
for key, value in env_vars
|
181 |
+
if not any(pattern.search(key) or pattern.search(value) for pattern in patterns)
|
182 |
+
]
|
183 |
+
|
184 |
+
env_vars_str = "\n".join(filtered_env_vars)
|
185 |
+
logger.info(env_vars_str)
|
186 |
+
|
187 |
+
|
188 |
+
def env_checks(logger: logging.Logger):
|
189 |
+
"""Run all environment checks.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
logger: logger instance
|
193 |
+
"""
|
194 |
+
logger.info("Running all environment checks")
|
195 |
+
get_platform_info(logger)
|
196 |
+
nvidia_smi(logger)
|
197 |
+
installed_packages(logger)
|
198 |
+
check_psutil_stats(logger)
|
199 |
+
get_listening_processes(logger)
|
200 |
+
check_compiler_and_clib(logger)
|
201 |
+
log_env_variables(logger)
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/check/utils.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Utils."""
|
15 |
+
|
16 |
+
import contextlib
|
17 |
+
import fcntl
|
18 |
+
import logging
|
19 |
+
import os
|
20 |
+
import pathlib
|
21 |
+
import re
|
22 |
+
import select
|
23 |
+
import socket
|
24 |
+
import subprocess
|
25 |
+
import threading
|
26 |
+
import typing
|
27 |
+
|
28 |
+
LOGGER = logging.getLogger(__name__)
|
29 |
+
DEFAULT_LOG_FORMAT = "%(asctime)s - %(levelname)8s - %(process)8d - %(threadName)s - %(name)s: %(message)s"
|
30 |
+
|
31 |
+
|
32 |
+
def _read_outputs(_process, _logger, _outputs):
|
33 |
+
# Set stdout and stderr file descriptors to non-blocking mode
|
34 |
+
try:
|
35 |
+
fcntl.fcntl(_process.stdout, fcntl.F_SETFL, os.O_NONBLOCK)
|
36 |
+
fcntl.fcntl(_process.stderr, fcntl.F_SETFL, os.O_NONBLOCK)
|
37 |
+
except ValueError: # when selecting on closed files
|
38 |
+
return
|
39 |
+
|
40 |
+
buffers = {_process.stdout: "", _process.stderr: ""}
|
41 |
+
rds = [_process.stdout, _process.stderr]
|
42 |
+
while rds:
|
43 |
+
try:
|
44 |
+
readable, _, _ = select.select(rds, [], [], 1)
|
45 |
+
except ValueError: # when selecting on closed files
|
46 |
+
break
|
47 |
+
|
48 |
+
for rd in readable:
|
49 |
+
try:
|
50 |
+
data = os.read(rd.fileno(), 4096)
|
51 |
+
if not data:
|
52 |
+
rds.remove(rd)
|
53 |
+
continue
|
54 |
+
|
55 |
+
decoded_data = data.decode("utf-8")
|
56 |
+
buffers[rd] += decoded_data
|
57 |
+
lines = buffers[rd].splitlines(keepends=True)
|
58 |
+
|
59 |
+
if buffers[rd].endswith("\n"):
|
60 |
+
complete_lines = lines
|
61 |
+
buffers[rd] = ""
|
62 |
+
else:
|
63 |
+
complete_lines = lines[:-1]
|
64 |
+
buffers[rd] = lines[-1]
|
65 |
+
|
66 |
+
for line in complete_lines:
|
67 |
+
line = line.rstrip()
|
68 |
+
_logger.info(line)
|
69 |
+
_outputs.append(line)
|
70 |
+
except OSError: # Reading from an empty non-blocking file
|
71 |
+
pass
|
72 |
+
|
73 |
+
|
74 |
+
class ScriptThread(threading.Thread):
|
75 |
+
"""A class that runs external script in a separate thread."""
|
76 |
+
|
77 |
+
def __init__(self, cmd, workdir=None, group=None, target=None, name=None, args=(), kwargs=None) -> None:
|
78 |
+
"""Initializes the ScriptThread object."""
|
79 |
+
super().__init__(group, target, name, args, kwargs, daemon=True)
|
80 |
+
self.cmd = cmd
|
81 |
+
self.workdir = workdir
|
82 |
+
self._process_spawned_or_spawn_error_flag = None
|
83 |
+
self.active = False
|
84 |
+
self._process = None
|
85 |
+
self.returncode = None
|
86 |
+
self._output = []
|
87 |
+
self._logger = logging.getLogger(self.name)
|
88 |
+
|
89 |
+
def __enter__(self):
|
90 |
+
"""Starts the script thread."""
|
91 |
+
self.start(threading.Event())
|
92 |
+
self._process_spawned_or_spawn_error_flag.wait()
|
93 |
+
return self
|
94 |
+
|
95 |
+
def __exit__(self, *args):
|
96 |
+
"""Stops the script thread and waits for it to join."""
|
97 |
+
self.stop()
|
98 |
+
self.join()
|
99 |
+
self._process_spawned_or_spawn_error_flag = None
|
100 |
+
|
101 |
+
def start(self, flag: typing.Optional[threading.Event] = None) -> None:
|
102 |
+
"""Starts the script thread."""
|
103 |
+
if flag is None:
|
104 |
+
flag = threading.Event()
|
105 |
+
self._logger.info(f"Starting {self.name} script with \"{' '.join(self.cmd)}\" cmd")
|
106 |
+
self._process_spawned_or_spawn_error_flag = flag
|
107 |
+
super().start()
|
108 |
+
|
109 |
+
def stop(self):
|
110 |
+
"""Sets the active flag to False to stop the script thread."""
|
111 |
+
self._logger.info(f"Stopping {self.name} script")
|
112 |
+
self.active = False
|
113 |
+
|
114 |
+
def run(self):
|
115 |
+
"""Runs the script in a separate process."""
|
116 |
+
import psutil
|
117 |
+
|
118 |
+
self.returncode = None
|
119 |
+
self._output = []
|
120 |
+
self._process = None
|
121 |
+
|
122 |
+
os.environ.setdefault("PYTHONUNBUFFERED", "1") # to not buffer logs
|
123 |
+
try:
|
124 |
+
with psutil.Popen(
|
125 |
+
self.cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0, cwd=self.workdir
|
126 |
+
) as process:
|
127 |
+
self._process = process
|
128 |
+
self.active = True
|
129 |
+
if self._process_spawned_or_spawn_error_flag:
|
130 |
+
self._process_spawned_or_spawn_error_flag.set()
|
131 |
+
while self.active and process.poll() is None and process.returncode is None:
|
132 |
+
try:
|
133 |
+
_read_outputs(process, self._logger, self._output)
|
134 |
+
except KeyboardInterrupt:
|
135 |
+
self.stop()
|
136 |
+
|
137 |
+
finally:
|
138 |
+
if self._process_spawned_or_spawn_error_flag:
|
139 |
+
self._process_spawned_or_spawn_error_flag.set()
|
140 |
+
if self.process:
|
141 |
+
while self.process.poll() is None:
|
142 |
+
_read_outputs(self.process, self._logger, self._output)
|
143 |
+
_read_outputs(self.process, self._logger, self._output)
|
144 |
+
self.returncode = process.wait() # pytype: disable=name-error
|
145 |
+
self._logger.info(f"{self.name} process finished with {self.returncode}")
|
146 |
+
|
147 |
+
self.active = False
|
148 |
+
self._process = None
|
149 |
+
|
150 |
+
@property
|
151 |
+
def output(self):
|
152 |
+
"""Return process stream output."""
|
153 |
+
return "\n".join(self._output)
|
154 |
+
|
155 |
+
@property
|
156 |
+
def process(self):
|
157 |
+
"""Return process object."""
|
158 |
+
return self._process
|
159 |
+
|
160 |
+
|
161 |
+
def find_free_port() -> int:
|
162 |
+
"""Finds a free port on the local machine."""
|
163 |
+
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
164 |
+
s.bind(("", 0))
|
165 |
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
166 |
+
return s.getsockname()[1]
|
167 |
+
|
168 |
+
|
169 |
+
class ProcessMonitoring:
|
170 |
+
"""A class that dumps the state of a process and its children.
|
171 |
+
|
172 |
+
This class uses the py-spy tool to dump the stack trace of a process and its
|
173 |
+
children recursively. It also dumps the process information such as the parent
|
174 |
+
and the command line. It allows registering custom monitors that can perform
|
175 |
+
additional actions on the process.
|
176 |
+
|
177 |
+
Attributes:
|
178 |
+
_logger (logging.Logger): The logger object to write messages.
|
179 |
+
_process (psutil.Process): The process object to monitor.
|
180 |
+
_children_processes (list[psutil.Process]): The list of child processes to monitor.
|
181 |
+
_log (logging.Logger.method): The logging method to use for messages.
|
182 |
+
_remove_color (bool): Whether to remove ANSI escape sequences from the output.
|
183 |
+
_ansi_escape (re.Pattern): The regular expression object to match ANSI escape sequences.
|
184 |
+
_custom_monitors (list[typing.Callable[[int], None]]): The list of custom monitor functions to execute on each dump cycle.
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(
|
188 |
+
self,
|
189 |
+
pid: int,
|
190 |
+
logger: typing.Optional[logging.Logger] = None,
|
191 |
+
loglevel: int = logging.INFO,
|
192 |
+
remove_color: bool = False,
|
193 |
+
):
|
194 |
+
"""Initializes the ProcessMonitoring object.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
pid (int): The process ID of the process to monitor.
|
198 |
+
logger (typing.Optional[logging.Logger], optional): The logger object to write messages. Defaults to None.
|
199 |
+
loglevel (int, optional): The logging level to use for messages. Defaults to logging.INFO.
|
200 |
+
remove_color (bool, optional): Whether to remove ANSI escape sequences from the output. Defaults to False.
|
201 |
+
"""
|
202 |
+
import re
|
203 |
+
|
204 |
+
import psutil
|
205 |
+
|
206 |
+
self._logger = logger or logging.getLogger("monitoring")
|
207 |
+
self._process = psutil.Process(pid)
|
208 |
+
self._children_processes = list(self._process.children(recursive=True))
|
209 |
+
self._log = {
|
210 |
+
logging.DEBUG: self._logger.debug,
|
211 |
+
logging.INFO: self._logger.info,
|
212 |
+
logging.WARNING: self._logger.warning,
|
213 |
+
logging.ERROR: self._logger.error,
|
214 |
+
}[loglevel]
|
215 |
+
self._log(f"Initial list of children processes: {self._children_processes}")
|
216 |
+
self._remove_color = remove_color
|
217 |
+
pattern = r"\x1b\[.*?m"
|
218 |
+
self._ansi_escape = re.compile(pattern)
|
219 |
+
self._custom_monitors = []
|
220 |
+
|
221 |
+
def register_custom_monitor(self, custom_monitor: typing.Callable[[int], None]) -> None:
|
222 |
+
"""Registers a custom monitor for the process.
|
223 |
+
|
224 |
+
This method adds a custom monitor function to the list of monitors that are
|
225 |
+
executed on each dump cycle. A custom monitor function should take an integer
|
226 |
+
as an argument (the process ID) and return None.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
custom_monitor (typing.Callable[[int], None]): The custom monitor function to register.
|
230 |
+
"""
|
231 |
+
self._custom_monitors.append(custom_monitor)
|
232 |
+
|
233 |
+
def dump_state(self) -> None:
|
234 |
+
"""Dumps the state of the process and its children.
|
235 |
+
|
236 |
+
This method calls the _dump_processes_stacktrace and _dump_child_processes
|
237 |
+
methods to dump the stack trace and the process information of the process
|
238 |
+
and its children recursively.
|
239 |
+
"""
|
240 |
+
self._dump_processes_stacktrace()
|
241 |
+
self._dump_child_processes()
|
242 |
+
|
243 |
+
def _dump_processes_stacktrace(self):
|
244 |
+
import psutil
|
245 |
+
import sh
|
246 |
+
|
247 |
+
self._log("==== Dump process stacktrace")
|
248 |
+
pyspy_cmd = sh.Command("py-spy")
|
249 |
+
|
250 |
+
for process in [self._process] + self.children:
|
251 |
+
try:
|
252 |
+
result = pyspy_cmd("dump", "-ll", "--nonblocking", "-p", str(process.pid))
|
253 |
+
if self._remove_color:
|
254 |
+
result = self._ansi_escape.sub("", str(result))
|
255 |
+
self._log(f"Dump stack trace for process (pid={process.pid}) with cmd {process.cmdline()}")
|
256 |
+
for custom_monitor in self._custom_monitors:
|
257 |
+
custom_monitor(process.pid)
|
258 |
+
self._log(result)
|
259 |
+
except psutil.NoSuchProcess as e:
|
260 |
+
self._log(f"Error during handling process: {e}")
|
261 |
+
except sh.ErrorReturnCode_1 as e:
|
262 |
+
self._log(f"Error during calling py-spy process: {e}")
|
263 |
+
|
264 |
+
def _dump_child_processes(self):
|
265 |
+
import psutil
|
266 |
+
|
267 |
+
self._log("==== Dump process info (with its children)")
|
268 |
+
for process in [self._process] + self.children:
|
269 |
+
try:
|
270 |
+
self._log(f"{process} parent={process.parent()} ")
|
271 |
+
except psutil.NoSuchProcess:
|
272 |
+
self._log(f"{process} is missing in process table")
|
273 |
+
|
274 |
+
@property
|
275 |
+
def children(self):
|
276 |
+
"""Returns the list of child processes to monitor.
|
277 |
+
|
278 |
+
This property returns the list of child processes to monitor, and updates it
|
279 |
+
with any new children that are created by the process.
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
list[psutil.Process]: The list of child processes to monitor.
|
283 |
+
"""
|
284 |
+
import psutil
|
285 |
+
|
286 |
+
try:
|
287 |
+
children = list(self._process.children(recursive=True))
|
288 |
+
self._children_processes = list(set(self._children_processes + children))
|
289 |
+
except psutil.NoSuchProcess:
|
290 |
+
pass
|
291 |
+
return self._children_processes
|
292 |
+
|
293 |
+
|
294 |
+
def get_current_container_version():
|
295 |
+
"""Returns the version of the current container."""
|
296 |
+
container_version = os.environ.get("NVIDIA_PYTORCH_VERSION") or os.environ.get("NVIDIA_TENSORFLOW_VERSION")
|
297 |
+
if container_version and "-" in container_version:
|
298 |
+
container_version = container_version.split("-")[0] # TF version has format <year_month_version>-<tf_version>
|
299 |
+
return container_version
|
300 |
+
|
301 |
+
|
302 |
+
def verify_docker_image_in_readme_same_as_tested(readme_path, image_name_with_version):
|
303 |
+
"""Verify that the docker image is the same as described in the readme file."""
|
304 |
+
image_name, _image_version = image_name_with_version.split(":")
|
305 |
+
framework_name = image_name.split("/")[-1]
|
306 |
+
readme_payload = pathlib.Path(readme_path).read_text()
|
307 |
+
match_iterator = re.finditer(
|
308 |
+
rf"(?P<container_registry>[\w/.\-:]+)/{framework_name}:(?P<image_version_with_python_version>[\w.-]+)",
|
309 |
+
readme_payload,
|
310 |
+
)
|
311 |
+
for entry in match_iterator:
|
312 |
+
assert entry.group() == image_name_with_version, f"{entry.group()} != {image_name_with_version}"
|
313 |
+
|
314 |
+
|
315 |
+
def search_warning_on_too_verbose_log_level(logs: str):
|
316 |
+
"""Search warnings."""
|
317 |
+
pattern = r"Triton Inference Server is running with enabled verbose logs.*It may affect inference performance."
|
318 |
+
return re.search(pattern, logs)
|
319 |
+
|
320 |
+
|
321 |
+
class ProcessMonitoringThread:
|
322 |
+
"""A class that creates a thread to monitor a process.
|
323 |
+
|
324 |
+
This class uses the ProcessMonitoring class to dump the state of a process
|
325 |
+
and its children periodically. It also allows registering custom monitors
|
326 |
+
that can perform additional actions on the process.
|
327 |
+
|
328 |
+
Attributes:
|
329 |
+
_monitoring (ProcessMonitoring): The ProcessMonitoring object that handles the dumping logic.
|
330 |
+
_stop_event (threading.Event): The event object that signals the thread to stop its loop.
|
331 |
+
_thread (threading.Thread): The thread object that runs the _run method in a loop.
|
332 |
+
_interval (float): The interval in seconds between each dump cycle.
|
333 |
+
"""
|
334 |
+
|
335 |
+
def __init__(self, monitoring: ProcessMonitoring, interval: float = 60):
|
336 |
+
"""Initializes the ProcessMonitoringThread object.
|
337 |
+
|
338 |
+
Args:
|
339 |
+
monitoring (ProcessMonitoring): The ProcessMonitoring object that handles the dumping logic.
|
340 |
+
interval (float, optional): The interval in seconds between each dump cycle. Defaults to 60.
|
341 |
+
"""
|
342 |
+
self._monitoring = monitoring
|
343 |
+
self._interval = interval
|
344 |
+
|
345 |
+
def start(self) -> None:
|
346 |
+
"""Starts the monitoring thread.
|
347 |
+
|
348 |
+
This method creates a new thread that runs the _run method in a loop until
|
349 |
+
the stop method is called or an exception occurs. It also sets the stop event
|
350 |
+
object that can be used to signal the thread to stop gracefully.
|
351 |
+
"""
|
352 |
+
self._stop_event = threading.Event()
|
353 |
+
self._thread = threading.Thread(target=self._run, daemon=True)
|
354 |
+
self._thread.start()
|
355 |
+
|
356 |
+
def stop(self) -> None:
|
357 |
+
"""Stops the monitoring thread.
|
358 |
+
|
359 |
+
This method sets the stop event object that signals the thread to stop its loop.
|
360 |
+
It also waits for the thread to join before returning.
|
361 |
+
"""
|
362 |
+
self._stop_event.set()
|
363 |
+
self._thread.join()
|
364 |
+
|
365 |
+
def __enter__(self):
|
366 |
+
"""Enters the context manager for the monitoring thread."""
|
367 |
+
self.start()
|
368 |
+
return self
|
369 |
+
|
370 |
+
def __exit__(self, *args):
|
371 |
+
"""Exits the context manager for the monitoring thread."""
|
372 |
+
self.stop()
|
373 |
+
|
374 |
+
def _run(self):
|
375 |
+
logging.info("Monitoring process")
|
376 |
+
self._monitoring.dump_state()
|
377 |
+
while not self._stop_event.wait(self._interval):
|
378 |
+
logging.info("Monitoring process")
|
379 |
+
self._monitoring.dump_state()
|
380 |
+
|
381 |
+
|
382 |
+
class TestMonitoringContext:
|
383 |
+
"""A context manager that monitors test processes.
|
384 |
+
|
385 |
+
This context manager creates threads to monitor the test processes and dumps
|
386 |
+
their state periodically. It can extend argparse args with additional arguments.
|
387 |
+
It supports splitting log into different files. The standard output log can have one level
|
388 |
+
and the file log can have another level. It uses log rotation.
|
389 |
+
"""
|
390 |
+
|
391 |
+
@staticmethod
|
392 |
+
def extend_args(parser):
|
393 |
+
"""Extends argparse args with additional arguments."""
|
394 |
+
parser.add_argument(
|
395 |
+
"--verbose",
|
396 |
+
action="store_true",
|
397 |
+
help="Provide verbose logs",
|
398 |
+
)
|
399 |
+
parser.add_argument(
|
400 |
+
"--log-path",
|
401 |
+
type=str,
|
402 |
+
default=None,
|
403 |
+
help="Provide the path of external log for rotation",
|
404 |
+
)
|
405 |
+
parser.add_argument(
|
406 |
+
"--compress-logs",
|
407 |
+
action="store_true",
|
408 |
+
help="Enable logs compression",
|
409 |
+
)
|
410 |
+
parser.add_argument(
|
411 |
+
"--maximum-log-file",
|
412 |
+
type=int,
|
413 |
+
default=10 * 1024 * 1024,
|
414 |
+
help="Maximum logfile size before rotation is started",
|
415 |
+
required=False,
|
416 |
+
)
|
417 |
+
parser.add_argument(
|
418 |
+
"--enable-fault-handler",
|
419 |
+
action="store_true",
|
420 |
+
help="Enable faulthandler",
|
421 |
+
)
|
422 |
+
parser.add_argument(
|
423 |
+
"--faulthandler-interval",
|
424 |
+
type=float,
|
425 |
+
default=None,
|
426 |
+
help="Enable faulthandler after specified number of seconds with repeat",
|
427 |
+
required=False,
|
428 |
+
)
|
429 |
+
parser.add_argument(
|
430 |
+
"--process-monitoring-interval",
|
431 |
+
type=float,
|
432 |
+
default=None,
|
433 |
+
help="Enable process monitoring after specified number of seconds with repeat",
|
434 |
+
required=False,
|
435 |
+
)
|
436 |
+
|
437 |
+
def __init__(self, args):
|
438 |
+
"""Initializes the TestMonitoringContext object.
|
439 |
+
|
440 |
+
Args:
|
441 |
+
args (argparse.Namespace): The argparse args object to extend with additional arguments.
|
442 |
+
"""
|
443 |
+
self._args = args
|
444 |
+
|
445 |
+
def __enter__(self):
|
446 |
+
"""Enters the context manager for the test monitoring."""
|
447 |
+
import faulthandler
|
448 |
+
import logging.handlers
|
449 |
+
|
450 |
+
args = self._args
|
451 |
+
self._loglevel = log_level = logging.DEBUG if args.verbose else logging.INFO
|
452 |
+
logging.basicConfig(level=logging.DEBUG, format=DEFAULT_LOG_FORMAT)
|
453 |
+
logger = logging.getLogger()
|
454 |
+
|
455 |
+
if args.log_path is not None:
|
456 |
+
# Create a rotating file handler for the file output logger
|
457 |
+
# The file name is based on the log path argument, the maximum size is 10 MB, and the maximum number of files is 500
|
458 |
+
file_handler = logging.handlers.RotatingFileHandler(
|
459 |
+
args.log_path, maxBytes=args.maximum_log_file, backupCount=500
|
460 |
+
)
|
461 |
+
file_handler.setFormatter(logging.Formatter(DEFAULT_LOG_FORMAT))
|
462 |
+
file_handler.setLevel(logging.DEBUG)
|
463 |
+
if args.compress_logs:
|
464 |
+
file_handler.namer = lambda name: name + ".gz"
|
465 |
+
|
466 |
+
def gzip_rotation(source, dest):
|
467 |
+
import gzip
|
468 |
+
import os
|
469 |
+
|
470 |
+
with open(source, "rb") as f_in:
|
471 |
+
with gzip.open(dest, "wb") as f_out:
|
472 |
+
f_out.writelines(f_in)
|
473 |
+
os.remove(source)
|
474 |
+
|
475 |
+
file_handler.rotator = gzip_rotation
|
476 |
+
|
477 |
+
# Add the file handler to the default logger
|
478 |
+
logger.addHandler(file_handler)
|
479 |
+
# Get the stream handler that was created by basicConfig
|
480 |
+
|
481 |
+
# Get the stream handler that was created by basicConfig
|
482 |
+
stream_handler = logger.handlers[0]
|
483 |
+
# Set the stream handler's level to match the log level argument
|
484 |
+
stream_handler.setLevel(log_level)
|
485 |
+
|
486 |
+
if args.enable_fault_handler:
|
487 |
+
faulthandler.enable()
|
488 |
+
|
489 |
+
if args.faulthandler_interval is not None:
|
490 |
+
faulthandler.dump_traceback_later(args.faulthandler_interval, repeat=True, exit=False)
|
491 |
+
|
492 |
+
custom_monitors = []
|
493 |
+
|
494 |
+
import os
|
495 |
+
|
496 |
+
import psutil
|
497 |
+
|
498 |
+
def monitor_ram_usage(pid=None):
|
499 |
+
if pid is None:
|
500 |
+
pid = os.getpid()
|
501 |
+
|
502 |
+
process = psutil.Process(pid)
|
503 |
+
logger.debug(f"MONITOR RAM USAGE ({pid}): {process.memory_info()}")
|
504 |
+
|
505 |
+
custom_monitors.append(monitor_ram_usage)
|
506 |
+
|
507 |
+
def monitor_file_descriptors(pid=None):
|
508 |
+
if pid is None:
|
509 |
+
pid = os.getpid()
|
510 |
+
|
511 |
+
process = psutil.Process(pid)
|
512 |
+
logger.debug(f"MONITOR FILE DESCRIPTORS ({pid}): {process.num_fds()}")
|
513 |
+
|
514 |
+
custom_monitors.append(monitor_file_descriptors)
|
515 |
+
|
516 |
+
def monitor_cpu_usage(pid=None):
|
517 |
+
if pid is None:
|
518 |
+
pid = os.getpid()
|
519 |
+
|
520 |
+
process = psutil.Process(pid)
|
521 |
+
logger.debug(f"MONITOR CPU USAGE ({pid}): {process.cpu_percent()}")
|
522 |
+
|
523 |
+
custom_monitors.append(monitor_cpu_usage)
|
524 |
+
|
525 |
+
def monitor_threads(pid=None):
|
526 |
+
if pid is None:
|
527 |
+
pid = os.getpid()
|
528 |
+
|
529 |
+
process = psutil.Process(pid)
|
530 |
+
logger.debug(f"MONITOR THREADS ({pid}): {process.num_threads()}")
|
531 |
+
|
532 |
+
custom_monitors.append(monitor_threads)
|
533 |
+
|
534 |
+
def monitor_process_dict(pid=None):
|
535 |
+
if pid is None:
|
536 |
+
pid = os.getpid()
|
537 |
+
|
538 |
+
process = psutil.Process(pid)
|
539 |
+
logger.debug(f"MONITOR PROCESS DICT ({pid}): {process.as_dict()}")
|
540 |
+
|
541 |
+
custom_monitors.append(monitor_process_dict)
|
542 |
+
if args.process_monitoring_interval is not None:
|
543 |
+
monitoring = ProcessMonitoring(os.getpid(), logger, loglevel=logging.DEBUG, remove_color=True)
|
544 |
+
for monitor in custom_monitors:
|
545 |
+
monitoring.register_custom_monitor(monitor)
|
546 |
+
|
547 |
+
self._monitor = ProcessMonitoringThread(monitoring, interval=args.process_monitoring_interval)
|
548 |
+
self._monitor.start()
|
549 |
+
return self
|
550 |
+
|
551 |
+
def __exit__(self, *args):
|
552 |
+
"""Stops the monitor thread."""
|
553 |
+
if hasattr(self, "_monitor"):
|
554 |
+
self._monitor.stop()
|
555 |
+
self._monitor = None
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/client/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# noqa: D104
|
15 |
+
|
16 |
+
from .client import (
|
17 |
+
AsyncioDecoupledModelClient, # noqa: F401
|
18 |
+
AsyncioModelClient, # noqa: F401
|
19 |
+
DecoupledModelClient, # noqa: F401
|
20 |
+
FuturesModelClient, # noqa: F401
|
21 |
+
ModelClient, # noqa: F401
|
22 |
+
)
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/client/asyncio_utils.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Utility module supporting model clients."""
|
15 |
+
|
16 |
+
import asyncio
|
17 |
+
import logging
|
18 |
+
import time
|
19 |
+
from typing import Optional, Union
|
20 |
+
|
21 |
+
import aiohttp
|
22 |
+
import grpc
|
23 |
+
import tritonclient.grpc
|
24 |
+
import tritonclient.http
|
25 |
+
|
26 |
+
from pytriton.client.exceptions import PyTritonClientModelUnavailableError, PyTritonClientTimeoutError
|
27 |
+
from pytriton.client.utils import LATEST_MODEL_VERSION, ModelState, parse_grpc_response, parse_http_response
|
28 |
+
from pytriton.model_config.parser import ModelConfigParser
|
29 |
+
|
30 |
+
aio_clients = Union[tritonclient.grpc.aio.InferenceServerClient, tritonclient.http.aio.InferenceServerClient]
|
31 |
+
|
32 |
+
_LOGGER = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
_DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S = 60.0 # 60 seconds
|
35 |
+
_DEFAULT_ASYNC_SLEEP_FACTOR_S = 0.1 # 10% of timeout
|
36 |
+
|
37 |
+
|
38 |
+
async def asyncio_get_model_state(
|
39 |
+
client: aio_clients,
|
40 |
+
model_name: str,
|
41 |
+
model_version: Optional[str] = None,
|
42 |
+
) -> ModelState:
|
43 |
+
"""Obtains state of the model deployed in Triton Inference Server.
|
44 |
+
|
45 |
+
Typical use:
|
46 |
+
|
47 |
+
>>> import tritonclient.http.aio
|
48 |
+
... client = tritonclient.http.aio.InferenceServerClient("localhost:8000")
|
49 |
+
... model_state = await get_model_state(client, "MyModel", "1")
|
50 |
+
|
51 |
+
Args:
|
52 |
+
client: Triton Inference Server client to use for communication
|
53 |
+
model_name: name of the model which state we're requesting.
|
54 |
+
model_version:
|
55 |
+
version of the model which state we're requesting.
|
56 |
+
If model_version is None state of latest model is returned.
|
57 |
+
The latest versions of the model are the numerically greatest version numbers.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
Model state. ModelState.UNAVAILABLE is returned in case if model with given name and version is not found.
|
61 |
+
|
62 |
+
"""
|
63 |
+
_LOGGER.debug(f"Obtaining model {model_name} state")
|
64 |
+
repository_index = await client.get_model_repository_index()
|
65 |
+
_LOGGER.debug("Model repository index obtained")
|
66 |
+
if isinstance(repository_index, list):
|
67 |
+
models_states = parse_http_response(models=repository_index)
|
68 |
+
else:
|
69 |
+
models_states = parse_grpc_response(models=repository_index.models)
|
70 |
+
|
71 |
+
if model_version is None:
|
72 |
+
requested_model_states = {
|
73 |
+
version: state for (name, version), state in models_states.items() if name == model_name
|
74 |
+
}
|
75 |
+
if not requested_model_states:
|
76 |
+
return ModelState.UNAVAILABLE
|
77 |
+
else:
|
78 |
+
requested_model_states = sorted(requested_model_states.items(), key=lambda item: int(item[0]))
|
79 |
+
latest_version, latest_version_state = requested_model_states[-1]
|
80 |
+
_LOGGER.debug(f"Model {model_name} latest version: {latest_version} state: {latest_version_state}")
|
81 |
+
return latest_version_state
|
82 |
+
else:
|
83 |
+
key = (model_name, model_version)
|
84 |
+
if key not in models_states:
|
85 |
+
return ModelState.UNAVAILABLE
|
86 |
+
else:
|
87 |
+
model_state = models_states[key]
|
88 |
+
_LOGGER.debug(f"Model {model_name} version {model_version} state: {model_state}")
|
89 |
+
return model_state
|
90 |
+
|
91 |
+
|
92 |
+
async def asyncio_get_model_config(
|
93 |
+
client: aio_clients,
|
94 |
+
model_name: str,
|
95 |
+
model_version: Optional[str] = None,
|
96 |
+
timeout_s: float = _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S,
|
97 |
+
):
|
98 |
+
"""Obtain configuration of model deployed on the Triton Inference Server.
|
99 |
+
|
100 |
+
Function waits for server readiness.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
client: Triton Inference Server client to use for communication
|
104 |
+
model_name: name of the model which configuration we're requesting.
|
105 |
+
model_version:
|
106 |
+
version of the model which configuration we're requesting.
|
107 |
+
If model_version is None configuration of the latest model is returned.
|
108 |
+
The latest versions of the model are the numerically greatest version numbers.
|
109 |
+
timeout_s: timeout to finish model configuration obtain.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
Configuration of requested model.
|
113 |
+
|
114 |
+
Raises:
|
115 |
+
PyTritonClientTimeoutError: If obtain of model configuration didn't finish before given timeout.
|
116 |
+
PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
|
117 |
+
"""
|
118 |
+
should_finish_before = time.time() + timeout_s
|
119 |
+
_LOGGER.debug(f"Obtaining model {model_name} config (timeout={timeout_s:0.2f})")
|
120 |
+
try:
|
121 |
+
_LOGGER.debug(f"Waiting for model {model_name} to be ready")
|
122 |
+
await asyncio.wait_for(
|
123 |
+
asyncio_wait_for_model_ready(
|
124 |
+
client, model_name=model_name, model_version=model_version, timeout_s=timeout_s
|
125 |
+
),
|
126 |
+
timeout_s,
|
127 |
+
)
|
128 |
+
|
129 |
+
model_version = model_version or ""
|
130 |
+
|
131 |
+
timeout_s = max(0, should_finish_before - time.time())
|
132 |
+
if isinstance(client, tritonclient.grpc.aio.InferenceServerClient):
|
133 |
+
_LOGGER.debug(f"Obtaining model {model_name} config as_json=True")
|
134 |
+
response = await asyncio.wait_for(
|
135 |
+
client.get_model_config(model_name, model_version, as_json=True), timeout_s
|
136 |
+
)
|
137 |
+
model_config = response["config"]
|
138 |
+
else:
|
139 |
+
_LOGGER.debug(f"Obtaining model {model_name} config")
|
140 |
+
model_config = await asyncio.wait_for(client.get_model_config(model_name, model_version), timeout_s)
|
141 |
+
_LOGGER.debug("Model config obtained")
|
142 |
+
model_config = ModelConfigParser.from_dict(model_config)
|
143 |
+
_LOGGER.debug(f"Model config: {model_config}")
|
144 |
+
return model_config
|
145 |
+
except asyncio.TimeoutError as e:
|
146 |
+
message = f"Timeout while waiting for model {model_name} config (timeout={timeout_s:0.2f})"
|
147 |
+
_LOGGER.error(message)
|
148 |
+
raise PyTritonClientTimeoutError(message) from e
|
149 |
+
|
150 |
+
|
151 |
+
async def asyncio_wait_for_server_ready(
|
152 |
+
asyncio_client: aio_clients,
|
153 |
+
sleep_time_s: float,
|
154 |
+
):
|
155 |
+
"""Wait for Triton Inference Server readiness.
|
156 |
+
|
157 |
+
There are two functions, which check server status:
|
158 |
+
* asyncio_client.is_server_ready()
|
159 |
+
* asyncio_client.is_server_live()
|
160 |
+
Both must return true to consider server accessible to read model status.
|
161 |
+
|
162 |
+
Function contains while loop with sleep to check server status periodically.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
asyncio_client: Triton Inference Server client to use for communication
|
166 |
+
sleep_time_s: time to sleep between server status checks
|
167 |
+
|
168 |
+
Raises:
|
169 |
+
PyTritonClientTimeoutError: If obtain of model configuration didn't finish before given timeout.
|
170 |
+
"""
|
171 |
+
_LOGGER.debug("Waiting for server to be ready")
|
172 |
+
try:
|
173 |
+
while True:
|
174 |
+
try:
|
175 |
+
_LOGGER.debug("Waiting for server to be ready")
|
176 |
+
server_ready = await asyncio_client.is_server_ready()
|
177 |
+
_LOGGER.debug("Waiting for server to be live")
|
178 |
+
server_live = await asyncio_client.is_server_live()
|
179 |
+
except tritonclient.utils.InferenceServerException:
|
180 |
+
# Raised by tritonclient/grpc/__init__.py:75
|
181 |
+
server_live = False
|
182 |
+
server_ready = False
|
183 |
+
except aiohttp.client_exceptions.ClientConnectorError:
|
184 |
+
# This exception is raised by aiohttp/connector.py:901 in _create_direct_connection
|
185 |
+
# and it is not translated to any other error by tritonclient/http/aio/__init__.py:132 in _get method.
|
186 |
+
# res = await self._stub.get(url=req_url,
|
187 |
+
# and tritonclient/http/aio/__init__.py:242 in is_server_ready method.
|
188 |
+
# response = await self._get(request_uri=request_uri,
|
189 |
+
server_live = False
|
190 |
+
server_ready = False
|
191 |
+
except RuntimeError:
|
192 |
+
# This exception is raised by aiohttp/client.py:400 in _request
|
193 |
+
# and it is not translated to any other error by tritonclient/grpc/aio/__init__.py:151: in is_server_ready method.
|
194 |
+
# response = await self._client_stub.ServerReady(request=request,
|
195 |
+
server_live = False
|
196 |
+
server_ready = False
|
197 |
+
except grpc._cython.cygrpc.UsageError:
|
198 |
+
# This exception is raised by grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi:124
|
199 |
+
# and it is not translated to any other error by tritonclient/grpc/aio/__init__.py", line 151, in is_server_ready
|
200 |
+
# response = await self._client_stub.ServerReady(request=request,
|
201 |
+
server_live = False
|
202 |
+
server_ready = False
|
203 |
+
if server_ready and server_live:
|
204 |
+
break
|
205 |
+
_LOGGER.debug(f"Sleeping for {sleep_time_s:0.2f} seconds")
|
206 |
+
await asyncio.sleep(sleep_time_s)
|
207 |
+
except asyncio.TimeoutError as e:
|
208 |
+
# This error is caused by our timeout, not by Triton Inference Server client.
|
209 |
+
message = "Timeout while waiting for model"
|
210 |
+
_LOGGER.error(message)
|
211 |
+
raise PyTritonClientTimeoutError(message) from e
|
212 |
+
_LOGGER.debug("Server is ready")
|
213 |
+
|
214 |
+
|
215 |
+
async def asyncio_wait_for_model_status_loaded(
|
216 |
+
asyncio_client: aio_clients,
|
217 |
+
model_name: str,
|
218 |
+
sleep_time_s: float,
|
219 |
+
model_version: Optional[str] = None,
|
220 |
+
):
|
221 |
+
"""Wait for model status loaded.
|
222 |
+
|
223 |
+
Function runs the following async function to check model status:
|
224 |
+
```python
|
225 |
+
asyncio_get_model_state(asyncio_client, model_name, model_version)
|
226 |
+
```
|
227 |
+
If it return _ModelState.READY, then another async function can check if model is really ready:
|
228 |
+
```python
|
229 |
+
asyncio_client.is_model_ready(model_name)
|
230 |
+
```
|
231 |
+
This function uses the above functions to check if model is ready together
|
232 |
+
with asyncio.wait_for(...) to limit the time of waiting.
|
233 |
+
|
234 |
+
Function contains while loop with sleep to check model status periodically.
|
235 |
+
|
236 |
+
Args:
|
237 |
+
asyncio_client: Triton Inference Server client to use for communication
|
238 |
+
model_name: name of the model which configuration we're requesting.
|
239 |
+
model_version:
|
240 |
+
version of the model which configuration we're requesting.
|
241 |
+
If model_version is None configuration of the latest model is returned.
|
242 |
+
The latest versions of the model are the numerically greatest version numbers.
|
243 |
+
sleep_time_s: time interval, in seconds, between successive checks to determine if the model configuration has been completed.
|
244 |
+
|
245 |
+
Raises:
|
246 |
+
PyTritonClientTimeoutError: If obtain of model configuration didn't finish before given timeout.
|
247 |
+
"""
|
248 |
+
model_version = model_version or ""
|
249 |
+
model_version_msg = model_version or LATEST_MODEL_VERSION
|
250 |
+
_LOGGER.debug(f"Waiting for model {model_name}, {model_version_msg} to be ready")
|
251 |
+
try:
|
252 |
+
while True:
|
253 |
+
_LOGGER.debug(f"Checking if model {model_name} is ready")
|
254 |
+
is_model_ready = await asyncio_client.is_model_ready(model_name, model_version)
|
255 |
+
if is_model_ready:
|
256 |
+
break
|
257 |
+
_LOGGER.debug(f"Sleeping for {sleep_time_s} seconds")
|
258 |
+
await asyncio.sleep(sleep_time_s)
|
259 |
+
except asyncio.TimeoutError as e:
|
260 |
+
message = f"Timeout while waiting for model {model_name} state (timeout={sleep_time_s:0.2f})"
|
261 |
+
_LOGGER.error(message)
|
262 |
+
raise PyTritonClientTimeoutError(message) from e
|
263 |
+
_LOGGER.debug(f"Model {model_name}, {model_version_msg} is ready")
|
264 |
+
|
265 |
+
|
266 |
+
async def asyncio_wait_for_model_ready(
|
267 |
+
asyncio_client: aio_clients,
|
268 |
+
model_name: str,
|
269 |
+
model_version: Optional[str] = None,
|
270 |
+
timeout_s: float = _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S,
|
271 |
+
):
|
272 |
+
"""Wait for Triton Inference Server and deployed on it model readiness.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
asyncio_client: Triton Inference Server client to use for communication
|
276 |
+
model_name: name of the model which configuration we're requesting.
|
277 |
+
model_version:
|
278 |
+
version of the model which configuration we're requesting.
|
279 |
+
If model_version is None configuration of the latest model is returned.
|
280 |
+
The latest versions of the model are the numerically greatest version numbers.
|
281 |
+
timeout_s: timeout to finish model configuration obtain.
|
282 |
+
|
283 |
+
Raises:
|
284 |
+
PyTritonClientTimeoutError: If obtain of model configuration didn't finish before given timeout.
|
285 |
+
|
286 |
+
"""
|
287 |
+
_LOGGER.debug(f"Waiting for model {model_name} to be ready (timeout={timeout_s:0.2f})")
|
288 |
+
sleep_time_s = timeout_s * _DEFAULT_ASYNC_SLEEP_FACTOR_S
|
289 |
+
try:
|
290 |
+
should_finish_before = time.time() + timeout_s
|
291 |
+
await asyncio.wait_for(asyncio_wait_for_server_ready(asyncio_client, sleep_time_s), timeout_s)
|
292 |
+
_LOGGER.debug(f"Waiting for model {model_name} to be ready")
|
293 |
+
timeout_s = max(0, should_finish_before - time.time())
|
294 |
+
await asyncio.wait_for(
|
295 |
+
asyncio_wait_for_model_status_loaded(
|
296 |
+
asyncio_client, model_name=model_name, model_version=model_version, sleep_time_s=sleep_time_s
|
297 |
+
),
|
298 |
+
timeout_s,
|
299 |
+
)
|
300 |
+
except PyTritonClientModelUnavailableError as e:
|
301 |
+
_LOGGER.error(f"Failed to obtain model {model_name} config error {e}")
|
302 |
+
raise e
|
303 |
+
except asyncio.TimeoutError as e:
|
304 |
+
_LOGGER.error(f"Failed to obtain model {model_name} config error {e}")
|
305 |
+
raise PyTritonClientTimeoutError(
|
306 |
+
f"Timeout while waiting for model {model_name} to be ready (timeout={timeout_s:0.2f})"
|
307 |
+
) from e
|
308 |
+
_LOGGER.debug(f"Model {model_name} is ready")
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/client/client.py
ADDED
@@ -0,0 +1,2033 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Clients for easy interaction with models deployed on the Triton Inference Server.
|
16 |
+
|
17 |
+
Typical usage example:
|
18 |
+
|
19 |
+
```python
|
20 |
+
client = ModelClient("localhost", "MyModel")
|
21 |
+
result_dict = client.infer_sample(input_a=a, input_b=b)
|
22 |
+
client.close()
|
23 |
+
```
|
24 |
+
|
25 |
+
Inference inputs can be provided either as positional or keyword arguments:
|
26 |
+
|
27 |
+
```python
|
28 |
+
result_dict = client.infer_sample(input1, input2)
|
29 |
+
result_dict = client.infer_sample(a=input1, b=input2)
|
30 |
+
```
|
31 |
+
|
32 |
+
Mixing of argument passing conventions is not supported and will raise PyTritonClientValueError.
|
33 |
+
"""
|
34 |
+
|
35 |
+
import asyncio
|
36 |
+
import contextlib
|
37 |
+
import itertools
|
38 |
+
import logging
|
39 |
+
import socket
|
40 |
+
import time
|
41 |
+
import warnings
|
42 |
+
from concurrent.futures import Future
|
43 |
+
from queue import Empty, Full, Queue
|
44 |
+
from threading import Lock, Thread
|
45 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
46 |
+
|
47 |
+
import gevent
|
48 |
+
import numpy as np
|
49 |
+
import tritonclient.grpc
|
50 |
+
import tritonclient.grpc.aio
|
51 |
+
import tritonclient.http
|
52 |
+
import tritonclient.http.aio
|
53 |
+
import tritonclient.utils
|
54 |
+
|
55 |
+
from pytriton.client.asyncio_utils import asyncio_get_model_config, asyncio_wait_for_model_ready
|
56 |
+
from pytriton.client.exceptions import (
|
57 |
+
PyTritonClientClosedError,
|
58 |
+
PyTritonClientInferenceServerError,
|
59 |
+
PyTritonClientModelDoesntSupportBatchingError,
|
60 |
+
PyTritonClientQueueFullError,
|
61 |
+
PyTritonClientTimeoutError,
|
62 |
+
PyTritonClientValueError,
|
63 |
+
)
|
64 |
+
from pytriton.client.utils import (
|
65 |
+
_DEFAULT_NETWORK_TIMEOUT_S,
|
66 |
+
_DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S,
|
67 |
+
TritonUrl,
|
68 |
+
get_model_config,
|
69 |
+
wait_for_model_ready,
|
70 |
+
wait_for_server_ready,
|
71 |
+
)
|
72 |
+
from pytriton.client.warnings import NotSupportedTimeoutWarning
|
73 |
+
from pytriton.model_config.triton_model_config import TritonModelConfig
|
74 |
+
|
75 |
+
_LOGGER = logging.getLogger(__name__)
|
76 |
+
|
77 |
+
_DEFAULT_SYNC_INIT_TIMEOUT_S = _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S
|
78 |
+
_DEFAULT_FUTURES_INIT_TIMEOUT_S = _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S
|
79 |
+
DEFAULT_INFERENCE_TIMEOUT_S = 60.0
|
80 |
+
|
81 |
+
|
82 |
+
_IOType = Union[Tuple[np.ndarray, ...], Dict[str, np.ndarray]]
|
83 |
+
|
84 |
+
|
85 |
+
def _verify_inputs_args(inputs, named_inputs):
|
86 |
+
if not inputs and not named_inputs:
|
87 |
+
raise PyTritonClientValueError("Provide input data")
|
88 |
+
if not bool(inputs) ^ bool(named_inputs):
|
89 |
+
raise PyTritonClientValueError("Use either positional either keyword method arguments convention")
|
90 |
+
|
91 |
+
|
92 |
+
def _verify_parameters(parameters_or_headers: Optional[Dict[str, Union[str, int, bool]]] = None):
|
93 |
+
if parameters_or_headers is None:
|
94 |
+
return
|
95 |
+
if not isinstance(parameters_or_headers, dict):
|
96 |
+
raise PyTritonClientValueError("Parameters and headers must be a dictionary")
|
97 |
+
for key, value in parameters_or_headers.items():
|
98 |
+
if not isinstance(key, str):
|
99 |
+
raise PyTritonClientValueError("Parameter/header key must be a string")
|
100 |
+
if not isinstance(value, (str, int, bool)):
|
101 |
+
raise PyTritonClientValueError("Parameter/header value must be a string, integer or boolean")
|
102 |
+
|
103 |
+
|
104 |
+
class BaseModelClient:
|
105 |
+
"""Base client for model deployed on the Triton Inference Server."""
|
106 |
+
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
url: str,
|
110 |
+
model_name: str,
|
111 |
+
model_version: Optional[str] = None,
|
112 |
+
*,
|
113 |
+
lazy_init: bool = True,
|
114 |
+
init_timeout_s: Optional[float] = None,
|
115 |
+
inference_timeout_s: Optional[float] = None,
|
116 |
+
model_config: Optional[TritonModelConfig] = None,
|
117 |
+
ensure_model_is_ready: bool = True,
|
118 |
+
):
|
119 |
+
"""Inits BaseModelClient for given model deployed on the Triton Inference Server.
|
120 |
+
|
121 |
+
Common usage:
|
122 |
+
|
123 |
+
```python
|
124 |
+
client = ModelClient("localhost", "BERT")
|
125 |
+
result_dict = client.infer_sample(input1_sample, input2_sample)
|
126 |
+
client.close()
|
127 |
+
```
|
128 |
+
|
129 |
+
Args:
|
130 |
+
url: The Triton Inference Server url, e.g. `grpc://localhost:8001`.
|
131 |
+
In case no scheme is provided http scheme will be used as default.
|
132 |
+
In case no port is provided default port for given scheme will be used -
|
133 |
+
8001 for grpc scheme, 8000 for http scheme.
|
134 |
+
model_name: name of the model to interact with.
|
135 |
+
model_version: version of the model to interact with.
|
136 |
+
If model_version is None inference on latest model will be performed.
|
137 |
+
The latest versions of the model are numerically the greatest version numbers.
|
138 |
+
lazy_init: if initialization should be performed just before sending first request to inference server.
|
139 |
+
init_timeout_s: timeout in seconds for the server and model to be ready. If not passed, the default timeout of 300 seconds will be used.
|
140 |
+
inference_timeout_s: timeout in seconds for a single model inference request. If not passed, the default timeout of 60 seconds will be used.
|
141 |
+
model_config: model configuration. If not passed, it will be read from inference server during initialization.
|
142 |
+
ensure_model_is_ready: if model should be checked if it is ready before first inference request.
|
143 |
+
|
144 |
+
Raises:
|
145 |
+
PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
|
146 |
+
PyTritonClientTimeoutError:
|
147 |
+
if `lazy_init` argument is False and wait time for server and model being ready exceeds `init_timeout_s`.
|
148 |
+
PyTritonClientInvalidUrlError: If provided Triton Inference Server url is invalid.
|
149 |
+
"""
|
150 |
+
self._init_timeout_s = _DEFAULT_SYNC_INIT_TIMEOUT_S if init_timeout_s is None else init_timeout_s
|
151 |
+
self._inference_timeout_s = DEFAULT_INFERENCE_TIMEOUT_S if inference_timeout_s is None else inference_timeout_s
|
152 |
+
self._network_timeout_s = min(_DEFAULT_NETWORK_TIMEOUT_S, self._init_timeout_s)
|
153 |
+
|
154 |
+
self._general_client = self.create_client_from_url(url, network_timeout_s=self._network_timeout_s)
|
155 |
+
self._infer_client = self.create_client_from_url(url, network_timeout_s=self._inference_timeout_s)
|
156 |
+
|
157 |
+
self._model_name = model_name
|
158 |
+
self._model_version = model_version
|
159 |
+
|
160 |
+
self._request_id_generator = itertools.count(0)
|
161 |
+
|
162 |
+
# Monkey patch __del__ method from client to catch error in client when instance is garbage collected.
|
163 |
+
# This is needed because we are closing client in __exit__ method or in close method.
|
164 |
+
# (InferenceClient uses gevent library which does not support closing twice from different threads)
|
165 |
+
self._monkey_patch_client()
|
166 |
+
|
167 |
+
if model_config is not None:
|
168 |
+
self._model_config = model_config
|
169 |
+
self._model_ready = None if ensure_model_is_ready else True
|
170 |
+
|
171 |
+
else:
|
172 |
+
self._model_config = None
|
173 |
+
self._model_ready = None
|
174 |
+
self._lazy_init: bool = lazy_init
|
175 |
+
|
176 |
+
self._handle_lazy_init()
|
177 |
+
|
178 |
+
@classmethod
|
179 |
+
def from_existing_client(cls, existing_client: "BaseModelClient"):
|
180 |
+
"""Create a new instance from an existing client using the same class.
|
181 |
+
|
182 |
+
Common usage:
|
183 |
+
```python
|
184 |
+
client = BaseModelClient.from_existing_client(existing_client)
|
185 |
+
```
|
186 |
+
|
187 |
+
Args:
|
188 |
+
existing_client: An instance of an already initialized subclass.
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
A new instance of the same subclass with shared configuration and readiness state.
|
192 |
+
"""
|
193 |
+
kwargs = {}
|
194 |
+
# Copy model configuration and readiness state if present
|
195 |
+
if hasattr(existing_client, "_model_config"):
|
196 |
+
kwargs["model_config"] = existing_client._model_config
|
197 |
+
kwargs["ensure_model_is_ready"] = False
|
198 |
+
|
199 |
+
new_client = cls(
|
200 |
+
url=existing_client._url,
|
201 |
+
model_name=existing_client._model_name,
|
202 |
+
model_version=existing_client._model_version,
|
203 |
+
init_timeout_s=existing_client._init_timeout_s,
|
204 |
+
inference_timeout_s=existing_client._inference_timeout_s,
|
205 |
+
**kwargs,
|
206 |
+
)
|
207 |
+
|
208 |
+
return new_client
|
209 |
+
|
210 |
+
def create_client_from_url(self, url: str, network_timeout_s: Optional[float] = None):
|
211 |
+
"""Create Triton Inference Server client.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
url: url of the server to connect to.
|
215 |
+
If url doesn't contain scheme (e.g. "localhost:8001") http scheme is added.
|
216 |
+
If url doesn't contain port (e.g. "localhost") default port for given scheme is added.
|
217 |
+
network_timeout_s: timeout for client commands. Default value is 60.0 s.
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
Triton Inference Server client.
|
221 |
+
|
222 |
+
Raises:
|
223 |
+
PyTritonClientInvalidUrlError: If provided Triton Inference Server url is invalid.
|
224 |
+
"""
|
225 |
+
self._triton_url = TritonUrl.from_url(url)
|
226 |
+
self._url = self._triton_url.without_scheme
|
227 |
+
self._triton_client_lib = self.get_lib()
|
228 |
+
self._monkey_patch_client()
|
229 |
+
|
230 |
+
if self._triton_url.scheme == "grpc":
|
231 |
+
# by default grpc client has very large number of timeout, thus we want to make it equal to http client timeout
|
232 |
+
network_timeout_s = _DEFAULT_NETWORK_TIMEOUT_S if network_timeout_s is None else network_timeout_s
|
233 |
+
warnings.warn(
|
234 |
+
f"tritonclient.grpc doesn't support timeout for other commands than infer. Ignoring network_timeout: {network_timeout_s}.",
|
235 |
+
NotSupportedTimeoutWarning,
|
236 |
+
stacklevel=1,
|
237 |
+
)
|
238 |
+
|
239 |
+
triton_client_init_kwargs = self._get_init_extra_args()
|
240 |
+
|
241 |
+
_LOGGER.debug(
|
242 |
+
f"Creating InferenceServerClient for {self._triton_url.with_scheme} with {triton_client_init_kwargs}"
|
243 |
+
)
|
244 |
+
return self._triton_client_lib.InferenceServerClient(self._url, **triton_client_init_kwargs)
|
245 |
+
|
246 |
+
def get_lib(self):
|
247 |
+
"""Returns tritonclient library for given scheme."""
|
248 |
+
raise NotImplementedError
|
249 |
+
|
250 |
+
@property
|
251 |
+
def _next_request_id(self) -> str:
|
252 |
+
# pytype complained about creating generator in __init__ method
|
253 |
+
# so we create it lazily
|
254 |
+
if getattr(self, "_request_id_generator", None) is None:
|
255 |
+
self._request_id_generator = itertools.count(0)
|
256 |
+
return str(next(self._request_id_generator))
|
257 |
+
|
258 |
+
def _get_init_extra_args(self):
|
259 |
+
timeout = self._inference_timeout_s # pytype: disable=attribute-error
|
260 |
+
# The inference timeout is used for both the HTTP and the GRPC protocols. However,
|
261 |
+
# the way the timeout is passed to the client differs depending on the protocol.
|
262 |
+
# For the HTTP protocol, the timeout is set in the ``__init__`` method as ``network_timeout``
|
263 |
+
# and ``connection_timeout``. For the GRPC protocol, the timeout
|
264 |
+
# is passed to the infer method as ``client_timeout``.
|
265 |
+
# Both protocols support timeouts correctly and will raise an exception
|
266 |
+
# if the network request or the inference process takes longer than the timeout.
|
267 |
+
# This is a design choice of the underlying tritonclient library.
|
268 |
+
|
269 |
+
if self._triton_url.scheme != "http":
|
270 |
+
return {}
|
271 |
+
|
272 |
+
kwargs = {
|
273 |
+
# This value sets the maximum time allowed for each network request in both model loading and inference process
|
274 |
+
"network_timeout": timeout,
|
275 |
+
# This value sets the maximum time allowed for establishing a connection to the server.
|
276 |
+
# We use the inference timeout here instead of the init timeout because the init timeout
|
277 |
+
# is meant for waiting for the model to be ready. The connection timeout should be shorter
|
278 |
+
# than the init timeout because it only checks if connection is established (e.g. correct port)
|
279 |
+
"connection_timeout": timeout,
|
280 |
+
}
|
281 |
+
return kwargs
|
282 |
+
|
283 |
+
def _monkey_patch_client(self):
|
284 |
+
pass
|
285 |
+
|
286 |
+
def _get_model_config_extra_args(self):
|
287 |
+
# For the GRPC protocol, the timeout must be passed to the each request as client_timeout
|
288 |
+
# model_config doesn't yet support timeout but it is planned for the future
|
289 |
+
# grpc_network_timeout_s will be used for model_config
|
290 |
+
return {}
|
291 |
+
|
292 |
+
def _handle_lazy_init(self):
|
293 |
+
raise NotImplementedError
|
294 |
+
|
295 |
+
|
296 |
+
def _run_once_per_lib(f):
|
297 |
+
def wrapper(_self):
|
298 |
+
if _self._triton_client_lib not in wrapper.patched:
|
299 |
+
wrapper.patched.add(_self._triton_client_lib)
|
300 |
+
return f(_self)
|
301 |
+
|
302 |
+
wrapper.patched = set()
|
303 |
+
return wrapper
|
304 |
+
|
305 |
+
|
306 |
+
class ModelClient(BaseModelClient):
|
307 |
+
"""Synchronous client for model deployed on the Triton Inference Server."""
|
308 |
+
|
309 |
+
def __init__(
|
310 |
+
self,
|
311 |
+
url: str,
|
312 |
+
model_name: str,
|
313 |
+
model_version: Optional[str] = None,
|
314 |
+
*,
|
315 |
+
lazy_init: bool = True,
|
316 |
+
init_timeout_s: Optional[float] = None,
|
317 |
+
inference_timeout_s: Optional[float] = None,
|
318 |
+
model_config: Optional[TritonModelConfig] = None,
|
319 |
+
ensure_model_is_ready: bool = True,
|
320 |
+
):
|
321 |
+
"""Inits ModelClient for given model deployed on the Triton Inference Server.
|
322 |
+
|
323 |
+
If `lazy_init` argument is False, model configuration will be read
|
324 |
+
from inference server during initialization.
|
325 |
+
|
326 |
+
Common usage:
|
327 |
+
|
328 |
+
```python
|
329 |
+
client = ModelClient("localhost", "BERT")
|
330 |
+
result_dict = client.infer_sample(input1_sample, input2_sample)
|
331 |
+
client.close()
|
332 |
+
```
|
333 |
+
|
334 |
+
Client supports also context manager protocol:
|
335 |
+
|
336 |
+
```python
|
337 |
+
with ModelClient("localhost", "BERT") as client:
|
338 |
+
result_dict = client.infer_sample(input1_sample, input2_sample)
|
339 |
+
```
|
340 |
+
|
341 |
+
The creation of client requires connection to the server and downloading model configuration. You can create client from existing client using the same class:
|
342 |
+
|
343 |
+
```python
|
344 |
+
client = ModelClient.from_existing_client(existing_client)
|
345 |
+
```
|
346 |
+
|
347 |
+
Args:
|
348 |
+
url: The Triton Inference Server url, e.g. 'grpc://localhost:8001'.
|
349 |
+
In case no scheme is provided http scheme will be used as default.
|
350 |
+
In case no port is provided default port for given scheme will be used -
|
351 |
+
8001 for grpc scheme, 8000 for http scheme.
|
352 |
+
model_name: name of the model to interact with.
|
353 |
+
model_version: version of the model to interact with.
|
354 |
+
If model_version is None inference on latest model will be performed.
|
355 |
+
The latest versions of the model are numerically the greatest version numbers.
|
356 |
+
lazy_init: if initialization should be performed just before sending first request to inference server.
|
357 |
+
init_timeout_s: timeout for maximum waiting time in loop, which sends retry requests ask if model is ready. It is applied at initialization time only when `lazy_init` argument is False. Default is to do retry loop at first inference.
|
358 |
+
inference_timeout_s: timeout in seconds for the model inference process.
|
359 |
+
If non passed default 60 seconds timeout will be used.
|
360 |
+
For HTTP client it is not only inference timeout but any client request timeout
|
361 |
+
- get model config, is model loaded. For GRPC client it is only inference timeout.
|
362 |
+
model_config: model configuration. If not passed, it will be read from inference server during initialization.
|
363 |
+
ensure_model_is_ready: if model should be checked if it is ready before first inference request.
|
364 |
+
|
365 |
+
Raises:
|
366 |
+
PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
|
367 |
+
PyTritonClientTimeoutError:
|
368 |
+
if `lazy_init` argument is False and wait time for server and model being ready exceeds `init_timeout_s`.
|
369 |
+
PyTritonClientUrlParseError: In case of problems with parsing url.
|
370 |
+
"""
|
371 |
+
super().__init__(
|
372 |
+
url=url,
|
373 |
+
model_name=model_name,
|
374 |
+
model_version=model_version,
|
375 |
+
lazy_init=lazy_init,
|
376 |
+
init_timeout_s=init_timeout_s,
|
377 |
+
inference_timeout_s=inference_timeout_s,
|
378 |
+
model_config=model_config,
|
379 |
+
ensure_model_is_ready=ensure_model_is_ready,
|
380 |
+
)
|
381 |
+
|
382 |
+
def get_lib(self):
|
383 |
+
"""Returns tritonclient library for given scheme."""
|
384 |
+
return {"grpc": tritonclient.grpc, "http": tritonclient.http}[self._triton_url.scheme.lower()]
|
385 |
+
|
386 |
+
def __enter__(self):
|
387 |
+
"""Create context for using ModelClient as a context manager."""
|
388 |
+
return self
|
389 |
+
|
390 |
+
def __exit__(self, *_):
|
391 |
+
"""Close resources used by ModelClient instance when exiting from the context."""
|
392 |
+
self.close()
|
393 |
+
|
394 |
+
def load_model(self, config: Optional[str] = None, files: Optional[dict] = None):
|
395 |
+
"""Load model on the Triton Inference Server.
|
396 |
+
|
397 |
+
Args:
|
398 |
+
config: str - Optional JSON representation of a model config provided for
|
399 |
+
the load request, if provided, this config will be used for
|
400 |
+
loading the model.
|
401 |
+
files: dict - Optional dictionary specifying file path (with "file:" prefix) in
|
402 |
+
the override model directory to the file content as bytes.
|
403 |
+
The files will form the model directory that the model will be
|
404 |
+
loaded from. If specified, 'config' must be provided to be
|
405 |
+
the model configuration of the override model directory.
|
406 |
+
"""
|
407 |
+
self._general_client.load_model(self._model_name, config=config, files=files)
|
408 |
+
|
409 |
+
def unload_model(self):
|
410 |
+
"""Unload model from the Triton Inference Server."""
|
411 |
+
self._general_client.unload_model(self._model_name)
|
412 |
+
|
413 |
+
def close(self):
|
414 |
+
"""Close resources used by ModelClient.
|
415 |
+
|
416 |
+
This method closes the resources used by the ModelClient instance,
|
417 |
+
including the Triton Inference Server connections.
|
418 |
+
Once this method is called, the ModelClient instance should not be used again.
|
419 |
+
"""
|
420 |
+
_LOGGER.debug("Closing ModelClient")
|
421 |
+
try:
|
422 |
+
if self._general_client is not None:
|
423 |
+
self._general_client.close()
|
424 |
+
if self._infer_client is not None:
|
425 |
+
self._infer_client.close()
|
426 |
+
self._general_client = None
|
427 |
+
self._infer_client = None
|
428 |
+
except Exception as e:
|
429 |
+
_LOGGER.error(f"Error while closing ModelClient resources: {e}")
|
430 |
+
raise e
|
431 |
+
|
432 |
+
def wait_for_model(self, timeout_s: float):
|
433 |
+
"""Wait for the Triton Inference Server and the deployed model to be ready.
|
434 |
+
|
435 |
+
Args:
|
436 |
+
timeout_s: timeout in seconds to wait for the server and model to be ready.
|
437 |
+
|
438 |
+
Raises:
|
439 |
+
PyTritonClientTimeoutError: If the server and model are not ready before the given timeout.
|
440 |
+
PyTritonClientModelUnavailableError: If the model with the given name (and version) is unavailable.
|
441 |
+
KeyboardInterrupt: If the hosting process receives SIGINT.
|
442 |
+
PyTritonClientClosedError: If the ModelClient is closed.
|
443 |
+
"""
|
444 |
+
if self._general_client is None:
|
445 |
+
raise PyTritonClientClosedError("ModelClient is closed")
|
446 |
+
wait_for_model_ready(self._general_client, self._model_name, self._model_version, timeout_s=timeout_s)
|
447 |
+
|
448 |
+
@property
|
449 |
+
def is_batching_supported(self):
|
450 |
+
"""Checks if model supports batching.
|
451 |
+
|
452 |
+
Also waits for server to get into readiness state.
|
453 |
+
"""
|
454 |
+
return self.model_config.max_batch_size > 0
|
455 |
+
|
456 |
+
def wait_for_server(self, timeout_s: float):
|
457 |
+
"""Wait for Triton Inference Server readiness.
|
458 |
+
|
459 |
+
Args:
|
460 |
+
timeout_s: timeout to server get into readiness state.
|
461 |
+
|
462 |
+
Raises:
|
463 |
+
PyTritonClientTimeoutError: If server is not in readiness state before given timeout.
|
464 |
+
KeyboardInterrupt: If hosting process receives SIGINT
|
465 |
+
"""
|
466 |
+
wait_for_server_ready(self._general_client, timeout_s=timeout_s)
|
467 |
+
|
468 |
+
@property
|
469 |
+
def model_config(self) -> TritonModelConfig:
|
470 |
+
"""Obtain the configuration of the model deployed on the Triton Inference Server.
|
471 |
+
|
472 |
+
This method waits for the server to get into readiness state before obtaining the model configuration.
|
473 |
+
|
474 |
+
Returns:
|
475 |
+
TritonModelConfig: configuration of the model deployed on the Triton Inference Server.
|
476 |
+
|
477 |
+
Raises:
|
478 |
+
PyTritonClientTimeoutError: If the server and model are not in readiness state before the given timeout.
|
479 |
+
PyTritonClientModelUnavailableError: If the model with the given name (and version) is unavailable.
|
480 |
+
KeyboardInterrupt: If the hosting process receives SIGINT.
|
481 |
+
PyTritonClientClosedError: If the ModelClient is closed.
|
482 |
+
"""
|
483 |
+
if not self._model_config:
|
484 |
+
if self._general_client is None:
|
485 |
+
raise PyTritonClientClosedError("ModelClient is closed")
|
486 |
+
|
487 |
+
self._model_config = get_model_config(
|
488 |
+
self._general_client, self._model_name, self._model_version, timeout_s=self._init_timeout_s
|
489 |
+
)
|
490 |
+
return self._model_config
|
491 |
+
|
492 |
+
def infer_sample(
|
493 |
+
self,
|
494 |
+
*inputs,
|
495 |
+
parameters: Optional[Dict[str, Union[str, int, bool]]] = None,
|
496 |
+
headers: Optional[Dict[str, Union[str, int, bool]]] = None,
|
497 |
+
**named_inputs,
|
498 |
+
) -> Dict[str, np.ndarray]:
|
499 |
+
"""Run synchronous inference on a single data sample.
|
500 |
+
|
501 |
+
Typical usage:
|
502 |
+
|
503 |
+
```python
|
504 |
+
client = ModelClient("localhost", "MyModel")
|
505 |
+
result_dict = client.infer_sample(input1, input2)
|
506 |
+
client.close()
|
507 |
+
```
|
508 |
+
|
509 |
+
Inference inputs can be provided either as positional or keyword arguments:
|
510 |
+
|
511 |
+
```python
|
512 |
+
result_dict = client.infer_sample(input1, input2)
|
513 |
+
result_dict = client.infer_sample(a=input1, b=input2)
|
514 |
+
```
|
515 |
+
|
516 |
+
Args:
|
517 |
+
*inputs: Inference inputs provided as positional arguments.
|
518 |
+
parameters: Custom inference parameters.
|
519 |
+
headers: Custom inference headers.
|
520 |
+
**named_inputs: Inference inputs provided as named arguments.
|
521 |
+
|
522 |
+
Returns:
|
523 |
+
Dictionary with inference results, where dictionary keys are output names.
|
524 |
+
|
525 |
+
Raises:
|
526 |
+
PyTritonClientValueError: If mixing of positional and named arguments passing detected.
|
527 |
+
PyTritonClientTimeoutError: If the wait time for the server and model being ready exceeds `init_timeout_s` or
|
528 |
+
inference request time exceeds `inference_timeout_s`.
|
529 |
+
PyTritonClientModelUnavailableError: If the model with the given name (and version) is unavailable.
|
530 |
+
PyTritonClientInferenceServerError: If an error occurred on the inference callable or Triton Inference Server side.
|
531 |
+
"""
|
532 |
+
_verify_inputs_args(inputs, named_inputs)
|
533 |
+
_verify_parameters(parameters)
|
534 |
+
_verify_parameters(headers)
|
535 |
+
|
536 |
+
if self.is_batching_supported:
|
537 |
+
if inputs:
|
538 |
+
inputs = tuple(data[np.newaxis, ...] for data in inputs)
|
539 |
+
elif named_inputs:
|
540 |
+
named_inputs = {name: data[np.newaxis, ...] for name, data in named_inputs.items()}
|
541 |
+
|
542 |
+
result = self._infer(inputs or named_inputs, parameters, headers)
|
543 |
+
|
544 |
+
return self._debatch_result(result)
|
545 |
+
|
546 |
+
def infer_batch(
|
547 |
+
self,
|
548 |
+
*inputs,
|
549 |
+
parameters: Optional[Dict[str, Union[str, int, bool]]] = None,
|
550 |
+
headers: Optional[Dict[str, Union[str, int, bool]]] = None,
|
551 |
+
**named_inputs,
|
552 |
+
) -> Dict[str, np.ndarray]:
|
553 |
+
"""Run synchronous inference on batched data.
|
554 |
+
|
555 |
+
Typical usage:
|
556 |
+
|
557 |
+
```python
|
558 |
+
client = ModelClient("localhost", "MyModel")
|
559 |
+
result_dict = client.infer_batch(input1, input2)
|
560 |
+
client.close()
|
561 |
+
```
|
562 |
+
|
563 |
+
Inference inputs can be provided either as positional or keyword arguments:
|
564 |
+
|
565 |
+
```python
|
566 |
+
result_dict = client.infer_batch(input1, input2)
|
567 |
+
result_dict = client.infer_batch(a=input1, b=input2)
|
568 |
+
```
|
569 |
+
|
570 |
+
Args:
|
571 |
+
*inputs: Inference inputs provided as positional arguments.
|
572 |
+
parameters: Custom inference parameters.
|
573 |
+
headers: Custom inference headers.
|
574 |
+
**named_inputs: Inference inputs provided as named arguments.
|
575 |
+
|
576 |
+
Returns:
|
577 |
+
Dictionary with inference results, where dictionary keys are output names.
|
578 |
+
|
579 |
+
Raises:
|
580 |
+
PyTritonClientValueError: If mixing of positional and named arguments passing detected.
|
581 |
+
PyTritonClientTimeoutError: If the wait time for the server and model being ready exceeds `init_timeout_s` or
|
582 |
+
inference request time exceeds `inference_timeout_s`.
|
583 |
+
PyTritonClientModelUnavailableError: If the model with the given name (and version) is unavailable.
|
584 |
+
PyTritonClientInferenceServerError: If an error occurred on the inference callable or Triton Inference Server side.
|
585 |
+
PyTritonClientModelDoesntSupportBatchingError: If the model doesn't support batching.
|
586 |
+
PyTritonClientValueError: if mixing of positional and named arguments passing detected.
|
587 |
+
PyTritonClientTimeoutError:
|
588 |
+
in case of first method call, `lazy_init` argument is False
|
589 |
+
and wait time for server and model being ready exceeds `init_timeout_s` or
|
590 |
+
inference time exceeds `inference_timeout_s` passed to `__init__`.
|
591 |
+
PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
|
592 |
+
PyTritonClientInferenceServerError: If error occurred on inference callable or Triton Inference Server side,
|
593 |
+
"""
|
594 |
+
_verify_inputs_args(inputs, named_inputs)
|
595 |
+
_verify_parameters(parameters)
|
596 |
+
_verify_parameters(headers)
|
597 |
+
|
598 |
+
if not self.is_batching_supported:
|
599 |
+
raise PyTritonClientModelDoesntSupportBatchingError(
|
600 |
+
f"Model {self.model_config.model_name} doesn't support batching - use infer_sample method instead"
|
601 |
+
)
|
602 |
+
|
603 |
+
return self._infer(inputs or named_inputs, parameters, headers)
|
604 |
+
|
605 |
+
def _wait_and_init_model_config(self, init_timeout_s: float):
|
606 |
+
if self._general_client is None:
|
607 |
+
raise PyTritonClientClosedError("ModelClient is closed")
|
608 |
+
|
609 |
+
should_finish_before_s = time.time() + init_timeout_s
|
610 |
+
self.wait_for_model(init_timeout_s)
|
611 |
+
self._model_ready = True
|
612 |
+
timeout_s = max(0.0, should_finish_before_s - time.time())
|
613 |
+
self._model_config = get_model_config(
|
614 |
+
self._general_client, self._model_name, self._model_version, timeout_s=timeout_s
|
615 |
+
)
|
616 |
+
|
617 |
+
def _create_request(self, inputs: _IOType):
|
618 |
+
if self._infer_client is None:
|
619 |
+
raise PyTritonClientClosedError("ModelClient is closed")
|
620 |
+
|
621 |
+
if not self._model_ready:
|
622 |
+
self._wait_and_init_model_config(self._init_timeout_s)
|
623 |
+
|
624 |
+
if isinstance(inputs, Tuple):
|
625 |
+
inputs = {input_spec.name: input_data for input_spec, input_data in zip(self.model_config.inputs, inputs)}
|
626 |
+
|
627 |
+
inputs_wrapped = []
|
628 |
+
|
629 |
+
# to help pytype to obtain variable type
|
630 |
+
inputs: Dict[str, np.ndarray]
|
631 |
+
|
632 |
+
for input_name, input_data in inputs.items():
|
633 |
+
if input_data.dtype == object and not isinstance(input_data.reshape(-1)[0], bytes):
|
634 |
+
raise RuntimeError(
|
635 |
+
f"Numpy array for {input_name!r} input with dtype=object should contain encoded strings \
|
636 |
+
\\(e.g. into utf-8\\). Element type: {type(input_data.reshape(-1)[0])}"
|
637 |
+
)
|
638 |
+
if input_data.dtype.type == np.str_:
|
639 |
+
raise RuntimeError(
|
640 |
+
"Unicode inputs are not supported. "
|
641 |
+
f"Encode numpy array for {input_name!r} input (ex. with np.char.encode(array, 'utf-8'))."
|
642 |
+
)
|
643 |
+
triton_dtype = tritonclient.utils.np_to_triton_dtype(input_data.dtype)
|
644 |
+
infer_input = self._triton_client_lib.InferInput(input_name, input_data.shape, triton_dtype)
|
645 |
+
infer_input.set_data_from_numpy(input_data)
|
646 |
+
inputs_wrapped.append(infer_input)
|
647 |
+
|
648 |
+
outputs_wrapped = [
|
649 |
+
self._triton_client_lib.InferRequestedOutput(output_spec.name) for output_spec in self.model_config.outputs
|
650 |
+
]
|
651 |
+
return inputs_wrapped, outputs_wrapped
|
652 |
+
|
653 |
+
def _infer(self, inputs: _IOType, parameters, headers) -> Dict[str, np.ndarray]:
|
654 |
+
if self.model_config.decoupled:
|
655 |
+
raise PyTritonClientInferenceServerError("Model config is decoupled. Use DecoupledModelClient instead.")
|
656 |
+
|
657 |
+
inputs_wrapped, outputs_wrapped = self._create_request(inputs)
|
658 |
+
|
659 |
+
try:
|
660 |
+
_LOGGER.debug("Sending inference request to Triton Inference Server")
|
661 |
+
response = self._infer_client.infer(
|
662 |
+
model_name=self._model_name,
|
663 |
+
model_version=self._model_version or "",
|
664 |
+
inputs=inputs_wrapped,
|
665 |
+
headers=headers,
|
666 |
+
outputs=outputs_wrapped,
|
667 |
+
request_id=self._next_request_id,
|
668 |
+
parameters=parameters,
|
669 |
+
**self._get_infer_extra_args(),
|
670 |
+
)
|
671 |
+
except tritonclient.utils.InferenceServerException as e:
|
672 |
+
# tritonclient.grpc raises execption with message containing "Deadline Exceeded" for timeout
|
673 |
+
if "Deadline Exceeded" in e.message():
|
674 |
+
raise PyTritonClientTimeoutError(
|
675 |
+
f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s. Message: {e.message()}"
|
676 |
+
) from e
|
677 |
+
|
678 |
+
raise PyTritonClientInferenceServerError(
|
679 |
+
f"Error occurred during inference request. Message: {e.message()}"
|
680 |
+
) from e
|
681 |
+
except socket.timeout as e: # tritonclient.http raises socket.timeout for timeout
|
682 |
+
message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}"
|
683 |
+
_LOGGER.error(message)
|
684 |
+
raise PyTritonClientTimeoutError(message) from e
|
685 |
+
except OSError as e: # tritonclient.http raises socket.error for connection error
|
686 |
+
message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}"
|
687 |
+
_LOGGER.error(message)
|
688 |
+
raise PyTritonClientTimeoutError(message) from e
|
689 |
+
|
690 |
+
if isinstance(response, tritonclient.http.InferResult):
|
691 |
+
outputs = {
|
692 |
+
output["name"]: response.as_numpy(output["name"]) for output in response.get_response()["outputs"]
|
693 |
+
}
|
694 |
+
else:
|
695 |
+
outputs = {output.name: response.as_numpy(output.name) for output in response.get_response().outputs}
|
696 |
+
|
697 |
+
return outputs
|
698 |
+
|
699 |
+
def _get_numpy_result(self, result):
|
700 |
+
if isinstance(result, tritonclient.grpc.InferResult):
|
701 |
+
result = {output.name: result.as_numpy(output.name) for output in result.get_response().outputs}
|
702 |
+
else:
|
703 |
+
result = {output["name"]: result.as_numpy(output["name"]) for output in result.get_response()["outputs"]}
|
704 |
+
return result
|
705 |
+
|
706 |
+
def _debatch_result(self, result):
|
707 |
+
if self.is_batching_supported:
|
708 |
+
result = {name: data[0] for name, data in result.items()}
|
709 |
+
return result
|
710 |
+
|
711 |
+
def _handle_lazy_init(self):
|
712 |
+
if not self._lazy_init:
|
713 |
+
self._wait_and_init_model_config(self._init_timeout_s)
|
714 |
+
|
715 |
+
def _get_infer_extra_args(self):
|
716 |
+
if self._triton_url.scheme == "http":
|
717 |
+
return {}
|
718 |
+
# For the GRPC protocol, the timeout is passed to the infer method as client_timeout
|
719 |
+
# This timeout applies to the whole inference process and each network request
|
720 |
+
|
721 |
+
# The ``infer`` supports also timeout argument for both GRPC and HTTP.
|
722 |
+
# It is applied at server side and supported only for dynamic batching.
|
723 |
+
# However, it is not used here yet and planned for future release
|
724 |
+
kwargs = {"client_timeout": self._inference_timeout_s}
|
725 |
+
return kwargs
|
726 |
+
|
727 |
+
@_run_once_per_lib
|
728 |
+
def _monkey_patch_client(self):
|
729 |
+
"""Monkey patch InferenceServerClient to catch error in __del__."""
|
730 |
+
_LOGGER.info(f"Patch ModelClient {self._triton_url.scheme}")
|
731 |
+
if not hasattr(self._triton_client_lib.InferenceServerClient, "__del__"):
|
732 |
+
return
|
733 |
+
|
734 |
+
old_del = self._triton_client_lib.InferenceServerClient.__del__
|
735 |
+
|
736 |
+
def _monkey_patched_del(self):
|
737 |
+
"""Monkey patched del."""
|
738 |
+
try:
|
739 |
+
old_del(self)
|
740 |
+
except gevent.exceptions.InvalidThreadUseError:
|
741 |
+
_LOGGER.info("gevent.exceptions.InvalidThreadUseError in __del__ of InferenceServerClient")
|
742 |
+
except Exception as e:
|
743 |
+
_LOGGER.error("Exception in __del__ of InferenceServerClient: %s", e)
|
744 |
+
|
745 |
+
self._triton_client_lib.InferenceServerClient.__del__ = _monkey_patched_del
|
746 |
+
|
747 |
+
|
748 |
+
class DecoupledModelClient(ModelClient):
|
749 |
+
"""Synchronous client for decoupled model deployed on the Triton Inference Server."""
|
750 |
+
|
751 |
+
def __init__(
|
752 |
+
self,
|
753 |
+
url: str,
|
754 |
+
model_name: str,
|
755 |
+
model_version: Optional[str] = None,
|
756 |
+
*,
|
757 |
+
lazy_init: bool = True,
|
758 |
+
init_timeout_s: Optional[float] = None,
|
759 |
+
inference_timeout_s: Optional[float] = None,
|
760 |
+
model_config: Optional[TritonModelConfig] = None,
|
761 |
+
ensure_model_is_ready: bool = True,
|
762 |
+
):
|
763 |
+
"""Inits DecoupledModelClient for given decoupled model deployed on the Triton Inference Server.
|
764 |
+
|
765 |
+
Common usage:
|
766 |
+
|
767 |
+
```python
|
768 |
+
client = DecoupledModelClient("localhost", "BERT")
|
769 |
+
for response in client.infer_sample(input1_sample, input2_sample):
|
770 |
+
print(response)
|
771 |
+
client.close()
|
772 |
+
```
|
773 |
+
|
774 |
+
Args:
|
775 |
+
url: The Triton Inference Server url, e.g. `grpc://localhost:8001`.
|
776 |
+
In case no scheme is provided http scheme will be used as default.
|
777 |
+
In case no port is provided default port for given scheme will be used -
|
778 |
+
8001 for grpc scheme, 8000 for http scheme.
|
779 |
+
model_name: name of the model to interact with.
|
780 |
+
model_version: version of the model to interact with.
|
781 |
+
If model_version is None inference on latest model will be performed.
|
782 |
+
The latest versions of the model are numerically the greatest version numbers.
|
783 |
+
lazy_init: if initialization should be performed just before sending first request to inference server.
|
784 |
+
init_timeout_s: timeout in seconds for the server and model to be ready. If not passed, the default timeout of 300 seconds will be used.
|
785 |
+
inference_timeout_s: timeout in seconds for a single model inference request. If not passed, the default timeout of 60 seconds will be used.
|
786 |
+
model_config: model configuration. If not passed, it will be read from inference server during initialization.
|
787 |
+
ensure_model_is_ready: if model should be checked if it is ready before first inference request.
|
788 |
+
|
789 |
+
Raises:
|
790 |
+
PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
|
791 |
+
PyTritonClientTimeoutError:
|
792 |
+
if `lazy_init` argument is False and wait time for server and model being ready exceeds `init_timeout_s`.
|
793 |
+
PyTritonClientInvalidUrlError: If provided Triton Inference Server url is invalid.
|
794 |
+
"""
|
795 |
+
super().__init__(
|
796 |
+
url,
|
797 |
+
model_name,
|
798 |
+
model_version,
|
799 |
+
lazy_init=lazy_init,
|
800 |
+
init_timeout_s=init_timeout_s,
|
801 |
+
inference_timeout_s=inference_timeout_s,
|
802 |
+
model_config=model_config,
|
803 |
+
ensure_model_is_ready=ensure_model_is_ready,
|
804 |
+
)
|
805 |
+
if self._triton_url.scheme == "http":
|
806 |
+
raise PyTritonClientValueError("DecoupledModelClient is only supported for grpc protocol")
|
807 |
+
self._queue = Queue()
|
808 |
+
self._lock = Lock()
|
809 |
+
|
810 |
+
def close(self):
|
811 |
+
"""Close resources used by DecoupledModelClient."""
|
812 |
+
_LOGGER.debug("Closing DecoupledModelClient")
|
813 |
+
if self._lock.acquire(blocking=False):
|
814 |
+
try:
|
815 |
+
super().close()
|
816 |
+
finally:
|
817 |
+
self._lock.release()
|
818 |
+
else:
|
819 |
+
_LOGGER.warning("DecoupledModelClient is stil streaming answers")
|
820 |
+
self._infer_client.stop_stream(False)
|
821 |
+
super().close()
|
822 |
+
|
823 |
+
def _infer(self, inputs: _IOType, parameters, headers):
|
824 |
+
if not self._lock.acquire(blocking=False):
|
825 |
+
raise PyTritonClientInferenceServerError("Inference is already in progress")
|
826 |
+
if not self.model_config.decoupled:
|
827 |
+
raise PyTritonClientInferenceServerError("Model config is coupled. Use ModelClient instead.")
|
828 |
+
|
829 |
+
inputs_wrapped, outputs_wrapped = self._create_request(inputs)
|
830 |
+
if parameters is not None:
|
831 |
+
raise PyTritonClientValueError("DecoupledModelClient does not support parameters")
|
832 |
+
if headers is not None:
|
833 |
+
raise PyTritonClientValueError("DecoupledModelClient does not support headers")
|
834 |
+
try:
|
835 |
+
_LOGGER.debug("Sending inference request to Triton Inference Server")
|
836 |
+
if self._infer_client._stream is None:
|
837 |
+
self._infer_client.start_stream(callback=lambda result, error: self._response_callback(result, error))
|
838 |
+
|
839 |
+
self._infer_client.async_stream_infer(
|
840 |
+
model_name=self._model_name,
|
841 |
+
model_version=self._model_version or "",
|
842 |
+
inputs=inputs_wrapped,
|
843 |
+
outputs=outputs_wrapped,
|
844 |
+
request_id=self._next_request_id,
|
845 |
+
enable_empty_final_response=True,
|
846 |
+
**self._get_infer_extra_args(),
|
847 |
+
)
|
848 |
+
except tritonclient.utils.InferenceServerException as e:
|
849 |
+
# tritonclient.grpc raises execption with message containing "Deadline Exceeded" for timeout
|
850 |
+
if "Deadline Exceeded" in e.message():
|
851 |
+
raise PyTritonClientTimeoutError(
|
852 |
+
f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s. Message: {e.message()}"
|
853 |
+
) from e
|
854 |
+
|
855 |
+
raise PyTritonClientInferenceServerError(
|
856 |
+
f"Error occurred during inference request. Message: {e.message()}"
|
857 |
+
) from e
|
858 |
+
except socket.timeout as e: # tritonclient.http raises socket.timeout for timeout
|
859 |
+
message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}"
|
860 |
+
_LOGGER.error(message)
|
861 |
+
raise PyTritonClientTimeoutError(message) from e
|
862 |
+
except OSError as e: # tritonclient.http raises socket.error for connection error
|
863 |
+
message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}"
|
864 |
+
_LOGGER.error(message)
|
865 |
+
raise PyTritonClientTimeoutError(message) from e
|
866 |
+
_LOGGER.debug("Returning response iterator")
|
867 |
+
return self._create_response_iterator()
|
868 |
+
|
869 |
+
def _response_callback(self, response, error):
|
870 |
+
_LOGGER.debug(f"Received response from Triton Inference Server: {response}")
|
871 |
+
if error:
|
872 |
+
_LOGGER.error(f"Error occurred during inference request. Message: {error}")
|
873 |
+
self._queue.put(error)
|
874 |
+
else:
|
875 |
+
actual_response = response.get_response()
|
876 |
+
# Check if the object is not None
|
877 |
+
triton_final_response = actual_response.parameters.get("triton_final_response")
|
878 |
+
if triton_final_response and triton_final_response.bool_param:
|
879 |
+
self._queue.put(None)
|
880 |
+
else:
|
881 |
+
result = self._get_numpy_result(response)
|
882 |
+
self._queue.put(result)
|
883 |
+
|
884 |
+
def _create_response_iterator(self):
|
885 |
+
try:
|
886 |
+
while True:
|
887 |
+
try:
|
888 |
+
item = self._queue.get(self._inference_timeout_s)
|
889 |
+
except Empty as e:
|
890 |
+
message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s"
|
891 |
+
_LOGGER.error(message)
|
892 |
+
raise PyTritonClientTimeoutError(message) from e
|
893 |
+
if isinstance(item, Exception):
|
894 |
+
message = f"Error occurred during inference request. Message: {item.message()}"
|
895 |
+
_LOGGER.error(message)
|
896 |
+
raise PyTritonClientInferenceServerError(message) from item
|
897 |
+
|
898 |
+
if item is None:
|
899 |
+
break
|
900 |
+
yield item
|
901 |
+
finally:
|
902 |
+
self._lock.release()
|
903 |
+
|
904 |
+
def _debatch_result(self, result):
|
905 |
+
if self.is_batching_supported:
|
906 |
+
result = ({name: data[0] for name, data in result_.items()} for result_ in result)
|
907 |
+
return result
|
908 |
+
|
909 |
+
def _get_infer_extra_args(self):
|
910 |
+
# kwargs = super()._get_infer_extra_args()
|
911 |
+
kwargs = {}
|
912 |
+
# kwargs["enable_empty_final_response"] = True
|
913 |
+
return kwargs
|
914 |
+
|
915 |
+
|
916 |
+
class AsyncioModelClient(BaseModelClient):
|
917 |
+
"""Asyncio client for model deployed on the Triton Inference Server.
|
918 |
+
|
919 |
+
This client is based on Triton Inference Server Python clients and GRPC library:
|
920 |
+
- ``tritonclient.http.aio.InferenceServerClient``
|
921 |
+
- ``tritonclient.grpc.aio.InferenceServerClient``
|
922 |
+
|
923 |
+
It can wait for server to be ready with model loaded and then perform inference on it.
|
924 |
+
``AsyncioModelClient`` supports asyncio context manager protocol.
|
925 |
+
|
926 |
+
Typical usage:
|
927 |
+
|
928 |
+
```python
|
929 |
+
from pytriton.client import AsyncioModelClient
|
930 |
+
import numpy as np
|
931 |
+
|
932 |
+
input1_sample = np.random.rand(1, 3, 224, 224).astype(np.float32)
|
933 |
+
input2_sample = np.random.rand(1, 3, 224, 224).astype(np.float32)
|
934 |
+
|
935 |
+
client = AsyncioModelClient("localhost", "MyModel")
|
936 |
+
result_dict = await client.infer_sample(input1_sample, input2_sample)
|
937 |
+
print(result_dict["output_name"])
|
938 |
+
await client.close()
|
939 |
+
```
|
940 |
+
"""
|
941 |
+
|
942 |
+
def __init__(
|
943 |
+
self,
|
944 |
+
url: str,
|
945 |
+
model_name: str,
|
946 |
+
model_version: Optional[str] = None,
|
947 |
+
*,
|
948 |
+
lazy_init: bool = True,
|
949 |
+
init_timeout_s: Optional[float] = None,
|
950 |
+
inference_timeout_s: Optional[float] = None,
|
951 |
+
model_config: Optional[TritonModelConfig] = None,
|
952 |
+
ensure_model_is_ready: bool = True,
|
953 |
+
):
|
954 |
+
"""Inits ModelClient for given model deployed on the Triton Inference Server.
|
955 |
+
|
956 |
+
If `lazy_init` argument is False, model configuration will be read
|
957 |
+
from inference server during initialization.
|
958 |
+
|
959 |
+
Args:
|
960 |
+
url: The Triton Inference Server url, e.g. 'grpc://localhost:8001'.
|
961 |
+
In case no scheme is provided http scheme will be used as default.
|
962 |
+
In case no port is provided default port for given scheme will be used -
|
963 |
+
8001 for grpc scheme, 8000 for http scheme.
|
964 |
+
model_name: name of the model to interact with.
|
965 |
+
model_version: version of the model to interact with.
|
966 |
+
If model_version is None inference on latest model will be performed.
|
967 |
+
The latest versions of the model are numerically the greatest version numbers.
|
968 |
+
lazy_init: if initialization should be performed just before sending first request to inference server.
|
969 |
+
init_timeout_s: timeout for server and model being ready.
|
970 |
+
inference_timeout_s: timeout in seconds for a single model inference request. If not passed, the default timeout of 60 seconds will be used.
|
971 |
+
model_config: model configuration. If not passed, it will be read from inference server during initialization.
|
972 |
+
ensure_model_is_ready: if model should be checked if it is ready before first inference request.
|
973 |
+
|
974 |
+
Raises:
|
975 |
+
PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
|
976 |
+
PyTritonClientTimeoutError: if `lazy_init` argument is False and wait time for server and model being ready exceeds `init_timeout_s`.
|
977 |
+
PyTritonClientUrlParseError: In case of problems with parsing url.
|
978 |
+
"""
|
979 |
+
super().__init__(
|
980 |
+
url=url,
|
981 |
+
model_name=model_name,
|
982 |
+
model_version=model_version,
|
983 |
+
lazy_init=lazy_init,
|
984 |
+
init_timeout_s=init_timeout_s,
|
985 |
+
inference_timeout_s=inference_timeout_s,
|
986 |
+
model_config=model_config,
|
987 |
+
ensure_model_is_ready=ensure_model_is_ready,
|
988 |
+
)
|
989 |
+
|
990 |
+
def get_lib(self):
|
991 |
+
"""Get Triton Inference Server Python client library."""
|
992 |
+
return {"grpc": tritonclient.grpc.aio, "http": tritonclient.http.aio}[self._triton_url.scheme.lower()]
|
993 |
+
|
994 |
+
async def __aenter__(self):
|
995 |
+
"""Create context for use AsyncioModelClient as a context manager."""
|
996 |
+
_LOGGER.debug("Entering AsyncioModelClient context")
|
997 |
+
try:
|
998 |
+
if not self._lazy_init:
|
999 |
+
_LOGGER.debug("Waiting in AsyncioModelClient context for model to be ready")
|
1000 |
+
await self._wait_and_init_model_config(self._init_timeout_s)
|
1001 |
+
_LOGGER.debug("Model is ready in AsyncioModelClient context")
|
1002 |
+
return self
|
1003 |
+
except Exception as e:
|
1004 |
+
_LOGGER.error("Error occurred during AsyncioModelClient context initialization")
|
1005 |
+
await self.close()
|
1006 |
+
raise e
|
1007 |
+
|
1008 |
+
async def __aexit__(self, *_):
|
1009 |
+
"""Close resources used by AsyncioModelClient when exiting from context."""
|
1010 |
+
await self.close()
|
1011 |
+
_LOGGER.debug("Exiting AsyncioModelClient context")
|
1012 |
+
|
1013 |
+
async def close(self):
|
1014 |
+
"""Close resources used by _ModelClientBase."""
|
1015 |
+
_LOGGER.debug("Closing InferenceServerClient")
|
1016 |
+
await self._general_client.close()
|
1017 |
+
await self._infer_client.close()
|
1018 |
+
_LOGGER.debug("InferenceServerClient closed")
|
1019 |
+
|
1020 |
+
async def wait_for_model(self, timeout_s: float):
|
1021 |
+
"""Asynchronous wait for Triton Inference Server and deployed on it model readiness.
|
1022 |
+
|
1023 |
+
Args:
|
1024 |
+
timeout_s: timeout to server and model get into readiness state.
|
1025 |
+
|
1026 |
+
Raises:
|
1027 |
+
PyTritonClientTimeoutError: If server and model are not in readiness state before given timeout.
|
1028 |
+
PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
|
1029 |
+
KeyboardInterrupt: If hosting process receives SIGINT
|
1030 |
+
"""
|
1031 |
+
_LOGGER.debug(f"Waiting for model {self._model_name} to be ready")
|
1032 |
+
try:
|
1033 |
+
await asyncio.wait_for(
|
1034 |
+
asyncio_wait_for_model_ready(
|
1035 |
+
self._general_client, self._model_name, self._model_version, timeout_s=timeout_s
|
1036 |
+
),
|
1037 |
+
self._init_timeout_s,
|
1038 |
+
)
|
1039 |
+
except asyncio.TimeoutError as e:
|
1040 |
+
message = f"Timeout while waiting for model {self._model_name} to be ready for {self._init_timeout_s}s"
|
1041 |
+
_LOGGER.error(message)
|
1042 |
+
raise PyTritonClientTimeoutError(message) from e
|
1043 |
+
|
1044 |
+
@property
|
1045 |
+
async def model_config(self):
|
1046 |
+
"""Obtain configuration of model deployed on the Triton Inference Server.
|
1047 |
+
|
1048 |
+
Also waits for server to get into readiness state.
|
1049 |
+
"""
|
1050 |
+
try:
|
1051 |
+
if not self._model_config:
|
1052 |
+
kwargs = self._get_model_config_extra_args()
|
1053 |
+
_LOGGER.debug(f"Obtaining model config for {self._model_name}")
|
1054 |
+
|
1055 |
+
self._model_config = await asyncio.wait_for(
|
1056 |
+
asyncio_get_model_config(
|
1057 |
+
self._general_client,
|
1058 |
+
self._model_name,
|
1059 |
+
self._model_version,
|
1060 |
+
timeout_s=self._init_timeout_s,
|
1061 |
+
**kwargs,
|
1062 |
+
),
|
1063 |
+
self._init_timeout_s,
|
1064 |
+
)
|
1065 |
+
_LOGGER.debug(f"Obtained model config for {self._model_name}")
|
1066 |
+
return self._model_config
|
1067 |
+
except asyncio.TimeoutError as e:
|
1068 |
+
message = f"Timeout while waiting for model {self._model_name} to be ready for {self._init_timeout_s}s"
|
1069 |
+
_LOGGER.error(message)
|
1070 |
+
raise PyTritonClientTimeoutError(message) from e
|
1071 |
+
|
1072 |
+
async def infer_sample(
|
1073 |
+
self,
|
1074 |
+
*inputs,
|
1075 |
+
parameters: Optional[Dict[str, Union[str, int, bool]]] = None,
|
1076 |
+
headers: Optional[Dict[str, Union[str, int, bool]]] = None,
|
1077 |
+
**named_inputs,
|
1078 |
+
):
|
1079 |
+
"""Run asynchronous inference on single data sample.
|
1080 |
+
|
1081 |
+
Typical usage:
|
1082 |
+
|
1083 |
+
```python
|
1084 |
+
client = AsyncioModelClient("localhost", "MyModel")
|
1085 |
+
result_dict = await client.infer_sample(input1, input2)
|
1086 |
+
await client.close()
|
1087 |
+
```
|
1088 |
+
|
1089 |
+
Inference inputs can be provided either as positional or keyword arguments:
|
1090 |
+
|
1091 |
+
```python
|
1092 |
+
result_dict = await client.infer_sample(input1, input2)
|
1093 |
+
result_dict = await client.infer_sample(a=input1, b=input2)
|
1094 |
+
```
|
1095 |
+
|
1096 |
+
Mixing of argument passing conventions is not supported and will raise PyTritonClientRuntimeError.
|
1097 |
+
|
1098 |
+
Args:
|
1099 |
+
*inputs: inference inputs provided as positional arguments.
|
1100 |
+
parameters: custom inference parameters.
|
1101 |
+
headers: custom inference headers.
|
1102 |
+
**named_inputs: inference inputs provided as named arguments.
|
1103 |
+
|
1104 |
+
Returns:
|
1105 |
+
dictionary with inference results, where dictionary keys are output names.
|
1106 |
+
|
1107 |
+
Raises:
|
1108 |
+
PyTritonClientValueError: if mixing of positional and named arguments passing detected.
|
1109 |
+
PyTritonClientTimeoutError:
|
1110 |
+
in case of first method call, `lazy_init` argument is False
|
1111 |
+
and wait time for server and model being ready exceeds `init_timeout_s`
|
1112 |
+
or inference time exceeds `timeout_s`.
|
1113 |
+
PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
|
1114 |
+
PyTritonClientInferenceServerError: If error occurred on inference callable or Triton Inference Server side.
|
1115 |
+
"""
|
1116 |
+
_verify_inputs_args(inputs, named_inputs)
|
1117 |
+
_verify_parameters(parameters)
|
1118 |
+
_verify_parameters(headers)
|
1119 |
+
|
1120 |
+
_LOGGER.debug(f"Running inference for {self._model_name}")
|
1121 |
+
model_config = await self.model_config
|
1122 |
+
_LOGGER.debug(f"Model config for {self._model_name} obtained")
|
1123 |
+
|
1124 |
+
model_supports_batching = model_config.max_batch_size > 0
|
1125 |
+
if model_supports_batching:
|
1126 |
+
if inputs:
|
1127 |
+
inputs = tuple(data[np.newaxis, ...] for data in inputs)
|
1128 |
+
elif named_inputs:
|
1129 |
+
named_inputs = {name: data[np.newaxis, ...] for name, data in named_inputs.items()}
|
1130 |
+
|
1131 |
+
_LOGGER.debug(f"Running _infer for {self._model_name}")
|
1132 |
+
result = await self._infer(inputs or named_inputs, parameters, headers)
|
1133 |
+
_LOGGER.debug(f"_infer for {self._model_name} finished")
|
1134 |
+
if model_supports_batching:
|
1135 |
+
result = {name: data[0] for name, data in result.items()}
|
1136 |
+
|
1137 |
+
return result
|
1138 |
+
|
1139 |
+
async def infer_batch(
|
1140 |
+
self,
|
1141 |
+
*inputs,
|
1142 |
+
parameters: Optional[Dict[str, Union[str, int, bool]]] = None,
|
1143 |
+
headers: Optional[Dict[str, Union[str, int, bool]]] = None,
|
1144 |
+
**named_inputs,
|
1145 |
+
):
|
1146 |
+
"""Run asynchronous inference on batched data.
|
1147 |
+
|
1148 |
+
Typical usage:
|
1149 |
+
|
1150 |
+
```python
|
1151 |
+
client = AsyncioModelClient("localhost", "MyModel")
|
1152 |
+
result_dict = await client.infer_batch(input1, input2)
|
1153 |
+
await client.close()
|
1154 |
+
```
|
1155 |
+
|
1156 |
+
Inference inputs can be provided either as positional or keyword arguments:
|
1157 |
+
|
1158 |
+
```python
|
1159 |
+
result_dict = await client.infer_batch(input1, input2)
|
1160 |
+
result_dict = await client.infer_batch(a=input1, b=input2)
|
1161 |
+
```
|
1162 |
+
|
1163 |
+
Mixing of argument passing conventions is not supported and will raise PyTritonClientValueError.
|
1164 |
+
|
1165 |
+
Args:
|
1166 |
+
*inputs: inference inputs provided as positional arguments.
|
1167 |
+
parameters: custom inference parameters.
|
1168 |
+
headers: custom inference headers.
|
1169 |
+
**named_inputs: inference inputs provided as named arguments.
|
1170 |
+
|
1171 |
+
Returns:
|
1172 |
+
dictionary with inference results, where dictionary keys are output names.
|
1173 |
+
|
1174 |
+
Raises:
|
1175 |
+
PyTritonClientValueError: if mixing of positional and named arguments passing detected.
|
1176 |
+
PyTritonClientTimeoutError:
|
1177 |
+
in case of first method call, `lazy_init` argument is False
|
1178 |
+
and wait time for server and model being ready exceeds `init_timeout_s`
|
1179 |
+
or inference time exceeds `timeout_s`.
|
1180 |
+
PyTritonClientModelDoesntSupportBatchingError: if model doesn't support batching.
|
1181 |
+
PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
|
1182 |
+
PyTritonClientInferenceServerError: If error occurred on inference callable or Triton Inference Server side.
|
1183 |
+
"""
|
1184 |
+
_verify_inputs_args(inputs, named_inputs)
|
1185 |
+
_verify_parameters(parameters)
|
1186 |
+
_verify_parameters(headers)
|
1187 |
+
|
1188 |
+
_LOGGER.debug(f"Running inference for {self._model_name}")
|
1189 |
+
model_config = await self.model_config
|
1190 |
+
_LOGGER.debug(f"Model config for {self._model_name} obtained")
|
1191 |
+
|
1192 |
+
model_supports_batching = model_config.max_batch_size > 0
|
1193 |
+
if not model_supports_batching:
|
1194 |
+
_LOGGER.error(f"Model {model_config.model_name} doesn't support batching")
|
1195 |
+
raise PyTritonClientModelDoesntSupportBatchingError(
|
1196 |
+
f"Model {model_config.model_name} doesn't support batching - use infer_sample method instead"
|
1197 |
+
)
|
1198 |
+
|
1199 |
+
_LOGGER.debug(f"Running _infer for {self._model_name}")
|
1200 |
+
result = await self._infer(inputs or named_inputs, parameters, headers)
|
1201 |
+
_LOGGER.debug(f"_infer for {self._model_name} finished")
|
1202 |
+
return result
|
1203 |
+
|
1204 |
+
async def _wait_and_init_model_config(self, init_timeout_s: float):
|
1205 |
+
"""Asynchronous wait for model and obtain model configuration.
|
1206 |
+
|
1207 |
+
Args:
|
1208 |
+
init_timeout_s: timeout for server and model being ready.
|
1209 |
+
|
1210 |
+
Raises:
|
1211 |
+
PyTritonClientTimeoutError: if wait time for server and model being ready exceeds `init_timeout_s`
|
1212 |
+
PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
|
1213 |
+
"""
|
1214 |
+
try:
|
1215 |
+
should_finish_before_s = time.time() + init_timeout_s
|
1216 |
+
_LOGGER.debug(f"Waiting for model {self._model_name} to be ready")
|
1217 |
+
|
1218 |
+
await asyncio.wait_for(self.wait_for_model(init_timeout_s), init_timeout_s)
|
1219 |
+
_LOGGER.debug(f"Model {self._model_name} is ready")
|
1220 |
+
self._model_ready = True
|
1221 |
+
|
1222 |
+
timeout_s = max(0.0, should_finish_before_s - time.time())
|
1223 |
+
_LOGGER.debug(f"Obtaining model config for {self._model_name}")
|
1224 |
+
self._model_config = await asyncio.wait_for(
|
1225 |
+
asyncio_get_model_config(
|
1226 |
+
self._general_client, self._model_name, self._model_version, timeout_s=timeout_s
|
1227 |
+
),
|
1228 |
+
timeout_s,
|
1229 |
+
)
|
1230 |
+
_LOGGER.debug(f"Model config for {self._model_name} obtained")
|
1231 |
+
except asyncio.TimeoutError as e:
|
1232 |
+
_LOGGER.error(f"Timeout exceeded while waiting for model {self._model_name} to be ready")
|
1233 |
+
raise PyTritonClientTimeoutError(
|
1234 |
+
f"Timeout exceeded while waiting for model {self._model_name} to be ready"
|
1235 |
+
) from e
|
1236 |
+
|
1237 |
+
def _validate_input(self, input_name, input_data):
|
1238 |
+
if input_data.dtype == object and not isinstance(input_data.reshape(-1)[0], bytes):
|
1239 |
+
raise RuntimeError(
|
1240 |
+
f"Numpy array for {input_name!r} input with dtype=object should contain encoded strings \
|
1241 |
+
\\(e.g. into utf-8\\). Element type: {type(input_data.reshape(-1)[0])}"
|
1242 |
+
)
|
1243 |
+
if input_data.dtype.type == np.str_:
|
1244 |
+
raise RuntimeError(
|
1245 |
+
"Unicode inputs are not supported. "
|
1246 |
+
f"Encode numpy array for {input_name!r} input (ex. with np.char.encode(array, 'utf-8'))."
|
1247 |
+
)
|
1248 |
+
|
1249 |
+
async def _execute_infer(self, model_config, inputs_wrapped, outputs_wrapped, parameters, headers) -> Any:
|
1250 |
+
try:
|
1251 |
+
_LOGGER.debug(f"Sending InferRequest for {self._model_name}")
|
1252 |
+
kwargs = self._get_infer_extra_args()
|
1253 |
+
response = await self._infer_client.infer(
|
1254 |
+
model_name=self._model_name,
|
1255 |
+
model_version=self._model_version or "",
|
1256 |
+
inputs=inputs_wrapped,
|
1257 |
+
headers=headers,
|
1258 |
+
outputs=outputs_wrapped,
|
1259 |
+
request_id=self._next_request_id,
|
1260 |
+
parameters=parameters,
|
1261 |
+
**kwargs,
|
1262 |
+
)
|
1263 |
+
except asyncio.exceptions.TimeoutError as e:
|
1264 |
+
# HTTP aio client raises asyncio.exceptions.TimeoutError for timeout errors
|
1265 |
+
message = f"Timeout exceeded while running inference for {self._model_name}"
|
1266 |
+
_LOGGER.error(message)
|
1267 |
+
raise PyTritonClientTimeoutError(message) from e
|
1268 |
+
except tritonclient.utils.InferenceServerException as e:
|
1269 |
+
message = f"Error occurred on Triton Inference Server side:\n {e.message()}"
|
1270 |
+
_LOGGER.error(message)
|
1271 |
+
if "Deadline Exceeded" in e.message():
|
1272 |
+
# GRPC aio client raises InferenceServerException with message "Deadline Exceeded"
|
1273 |
+
# for timeout errors
|
1274 |
+
raise PyTritonClientTimeoutError(message) from e
|
1275 |
+
else:
|
1276 |
+
raise PyTritonClientInferenceServerError(message) from e
|
1277 |
+
_LOGGER.debug(f"Received InferResponse for {self._model_name}")
|
1278 |
+
outputs = {output_spec.name: response.as_numpy(output_spec.name) for output_spec in model_config.outputs}
|
1279 |
+
return outputs
|
1280 |
+
|
1281 |
+
async def _infer(self, inputs: _IOType, parameters, headers):
|
1282 |
+
if self._model_ready:
|
1283 |
+
_LOGGER.debug(f"Waiting for model {self._model_name} config")
|
1284 |
+
await self._wait_and_init_model_config(self._init_timeout_s)
|
1285 |
+
_LOGGER.debug(f"Model wait finished for {self._model_name}")
|
1286 |
+
|
1287 |
+
_LOGGER.debug(f"Obtaining config for {self._model_name}")
|
1288 |
+
model_config = await self.model_config
|
1289 |
+
_LOGGER.debug(f"Model config for {self._model_name} obtained")
|
1290 |
+
if model_config.decoupled:
|
1291 |
+
raise PyTritonClientInferenceServerError(
|
1292 |
+
"Model config is decoupled. Use DecouploedAsyncioModelClient instead."
|
1293 |
+
)
|
1294 |
+
|
1295 |
+
if isinstance(inputs, Tuple):
|
1296 |
+
inputs = {input_spec.name: input_data for input_spec, input_data in zip(model_config.inputs, inputs)}
|
1297 |
+
|
1298 |
+
inputs_wrapped = []
|
1299 |
+
for input_name, input_data in inputs.items():
|
1300 |
+
if isinstance(input_data, np.ndarray):
|
1301 |
+
self._validate_input(input_name, input_data)
|
1302 |
+
triton_dtype = tritonclient.utils.np_to_triton_dtype(input_data.dtype)
|
1303 |
+
infer_input = self._triton_client_lib.InferInput(input_name, input_data.shape, triton_dtype)
|
1304 |
+
infer_input.set_data_from_numpy(input_data)
|
1305 |
+
input_wrapped = infer_input
|
1306 |
+
inputs_wrapped.append(input_wrapped)
|
1307 |
+
else:
|
1308 |
+
raise PyTritonClientValueError(
|
1309 |
+
f"Input {input_name} is not a numpy array. Got {type(input_data)} instead."
|
1310 |
+
)
|
1311 |
+
|
1312 |
+
outputs_wrapped = [
|
1313 |
+
self._triton_client_lib.InferRequestedOutput(output_spec.name) for output_spec in model_config.outputs
|
1314 |
+
]
|
1315 |
+
return await self._execute_infer(model_config, inputs_wrapped, outputs_wrapped, parameters, headers)
|
1316 |
+
|
1317 |
+
def _handle_lazy_init(self):
|
1318 |
+
# Asynchronous lazy initialization is done in __aenter__ method
|
1319 |
+
pass
|
1320 |
+
|
1321 |
+
def _get_init_extra_args(self):
|
1322 |
+
# The inference timeout is used for both the HTTP and the GRPC protocols. However,
|
1323 |
+
# the way the timeout is passed to the client differs depending on the protocol.
|
1324 |
+
# For the HTTP protocol, the timeout is set in the ``__init__`` method as ``conn_timeout`` for both connection and request timeouts.
|
1325 |
+
# For the GRPC protocol, the timeout
|
1326 |
+
# is passed to the infer method as ``client_timeout``.
|
1327 |
+
# Both protocols support timeouts correctly and will raise an exception
|
1328 |
+
# if the network request or the inference process takes longer than the timeout.
|
1329 |
+
# This is a design choice of the underlying tritonclient library.
|
1330 |
+
|
1331 |
+
if self._triton_url.scheme != "http":
|
1332 |
+
return {}
|
1333 |
+
|
1334 |
+
kwargs = {
|
1335 |
+
# This value sets the maximum time allowed for both connection and network requests in both model loading and inference process
|
1336 |
+
"conn_timeout": self._inference_timeout_s,
|
1337 |
+
}
|
1338 |
+
return kwargs
|
1339 |
+
|
1340 |
+
def _get_infer_extra_args(self):
|
1341 |
+
if self._triton_url.scheme == "http":
|
1342 |
+
return {}
|
1343 |
+
# For the GRPC protocol, the timeout is passed to the infer method as client_timeout
|
1344 |
+
# This timeout applies to the whole inference process and each network request
|
1345 |
+
|
1346 |
+
# The ``infer`` supports also timeout argument for both GRPC and HTTP.
|
1347 |
+
# It is applied at server side and supported only for dynamic batching.
|
1348 |
+
# However, it is not used here yet and planned for future release
|
1349 |
+
kwargs = {"client_timeout": self._inference_timeout_s}
|
1350 |
+
return kwargs
|
1351 |
+
|
1352 |
+
|
1353 |
+
class AsyncioDecoupledModelClient(AsyncioModelClient):
|
1354 |
+
"""Asyncio client for model deployed on the Triton Inference Server.
|
1355 |
+
|
1356 |
+
This client is based on Triton Inference Server Python clients and GRPC library:
|
1357 |
+
* ``tritonclient.grpc.aio.InferenceServerClient``
|
1358 |
+
|
1359 |
+
It can wait for server to be ready with model loaded and then perform inference on it.
|
1360 |
+
``AsyncioDecoupledModelClient`` supports asyncio context manager protocol.
|
1361 |
+
|
1362 |
+
The client is intended to be used with decoupled models and will raise an error if model is coupled.
|
1363 |
+
|
1364 |
+
Typical usage:
|
1365 |
+
```python
|
1366 |
+
from pytriton.client import AsyncioDecoupledModelClient
|
1367 |
+
import numpy as np
|
1368 |
+
|
1369 |
+
input1_sample = np.random.rand(1, 3, 224, 224).astype(np.float32)
|
1370 |
+
input2_sample = np.random.rand(1, 3, 224, 224).astype(np.float32)
|
1371 |
+
|
1372 |
+
async with AsyncioDecoupledModelClient("grpc://localhost", "MyModel") as client:
|
1373 |
+
async for result_dict in client.infer_sample(input1_sample, input2_sample):
|
1374 |
+
print(result_dict["output_name"])
|
1375 |
+
```
|
1376 |
+
"""
|
1377 |
+
|
1378 |
+
async def infer_sample(
|
1379 |
+
self,
|
1380 |
+
*inputs,
|
1381 |
+
parameters: Optional[Dict[str, Union[str, int, bool]]] = None,
|
1382 |
+
headers: Optional[Dict[str, Union[str, int, bool]]] = None,
|
1383 |
+
**named_inputs,
|
1384 |
+
):
|
1385 |
+
"""Run asynchronous inference on single data sample.
|
1386 |
+
|
1387 |
+
Typical usage:
|
1388 |
+
|
1389 |
+
```python
|
1390 |
+
async with AsyncioDecoupledModelClient("grpc://localhost", "MyModel") as client:
|
1391 |
+
async for result_dict in client.infer_sample(input1_sample, input2_sample):
|
1392 |
+
print(result_dict["output_name"])
|
1393 |
+
```
|
1394 |
+
|
1395 |
+
Inference inputs can be provided either as positional or keyword arguments:
|
1396 |
+
|
1397 |
+
```python
|
1398 |
+
results_iterator = client.infer_sample(input1, input2)
|
1399 |
+
results_iterator = client.infer_sample(a=input1, b=input2)
|
1400 |
+
```
|
1401 |
+
|
1402 |
+
Mixing of argument passing conventions is not supported and will raise PyTritonClientRuntimeError.
|
1403 |
+
|
1404 |
+
Args:
|
1405 |
+
*inputs: inference inputs provided as positional arguments.
|
1406 |
+
parameters: custom inference parameters.
|
1407 |
+
headers: custom inference headers.
|
1408 |
+
**named_inputs: inference inputs provided as named arguments.
|
1409 |
+
|
1410 |
+
Returns:
|
1411 |
+
Asynchronous generator, which generates dictionaries with partial inference results, where dictionary keys are output names.
|
1412 |
+
|
1413 |
+
Raises:
|
1414 |
+
PyTritonClientValueError: if mixing of positional and named arguments passing detected.
|
1415 |
+
PyTritonClientTimeoutError:
|
1416 |
+
in case of first method call, `lazy_init` argument is False
|
1417 |
+
and wait time for server and model being ready exceeds `init_timeout_s`
|
1418 |
+
or inference time exceeds `timeout_s`.
|
1419 |
+
PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
|
1420 |
+
PyTritonClientInferenceServerError: If error occurred on inference callable or Triton Inference Server side.
|
1421 |
+
"""
|
1422 |
+
_verify_inputs_args(inputs, named_inputs)
|
1423 |
+
_verify_parameters(parameters)
|
1424 |
+
_verify_parameters(headers)
|
1425 |
+
|
1426 |
+
_LOGGER.debug(f"Running inference for {self._model_name}")
|
1427 |
+
model_config = await self.model_config
|
1428 |
+
_LOGGER.debug(f"Model config for {self._model_name} obtained")
|
1429 |
+
|
1430 |
+
model_supports_batching = model_config.max_batch_size > 0
|
1431 |
+
if model_supports_batching:
|
1432 |
+
if inputs:
|
1433 |
+
inputs = tuple(data[np.newaxis, ...] for data in inputs)
|
1434 |
+
elif named_inputs:
|
1435 |
+
named_inputs = {name: data[np.newaxis, ...] for name, data in named_inputs.items()}
|
1436 |
+
|
1437 |
+
_LOGGER.debug(f"Running _infer for {self._model_name}")
|
1438 |
+
result = self._infer(inputs or named_inputs, parameters, headers)
|
1439 |
+
_LOGGER.debug(f"_infer for {self._model_name} finished")
|
1440 |
+
|
1441 |
+
async for item in result:
|
1442 |
+
if model_supports_batching:
|
1443 |
+
debatched_item = {name: data[0] for name, data in item.items()}
|
1444 |
+
yield debatched_item
|
1445 |
+
else:
|
1446 |
+
yield item
|
1447 |
+
|
1448 |
+
async def infer_batch(
|
1449 |
+
self,
|
1450 |
+
*inputs,
|
1451 |
+
parameters: Optional[Dict[str, Union[str, int, bool]]] = None,
|
1452 |
+
headers: Optional[Dict[str, Union[str, int, bool]]] = None,
|
1453 |
+
**named_inputs,
|
1454 |
+
):
|
1455 |
+
"""Run asynchronous inference on batched data.
|
1456 |
+
|
1457 |
+
Typical usage:
|
1458 |
+
|
1459 |
+
```python
|
1460 |
+
async with AsyncioDecoupledModelClient("grpc://localhost", "MyModel") as client:
|
1461 |
+
async for result_dict in client.infer_batch(input1_sample, input2_sample):
|
1462 |
+
print(result_dict["output_name"])
|
1463 |
+
```
|
1464 |
+
|
1465 |
+
Inference inputs can be provided either as positional or keyword arguments:
|
1466 |
+
|
1467 |
+
```python
|
1468 |
+
results_iterator = client.infer_batch(input1, input2)
|
1469 |
+
results_iterator = client.infer_batch(a=input1, b=input2)
|
1470 |
+
```
|
1471 |
+
|
1472 |
+
Mixing of argument passing conventions is not supported and will raise PyTritonClientRuntimeError.
|
1473 |
+
|
1474 |
+
Args:
|
1475 |
+
*inputs: inference inputs provided as positional arguments.
|
1476 |
+
parameters: custom inference parameters.
|
1477 |
+
headers: custom inference headers.
|
1478 |
+
**named_inputs: inference inputs provided as named arguments.
|
1479 |
+
|
1480 |
+
Returns:
|
1481 |
+
Asynchronous generator, which generates dictionaries with partial inference results, where dictionary keys are output names.
|
1482 |
+
|
1483 |
+
Raises:
|
1484 |
+
PyTritonClientValueError: if mixing of positional and named arguments passing detected.
|
1485 |
+
PyTritonClientTimeoutError:
|
1486 |
+
in case of first method call, `lazy_init` argument is False
|
1487 |
+
and wait time for server and model being ready exceeds `init_timeout_s`
|
1488 |
+
or inference time exceeds `timeout_s`.
|
1489 |
+
PyTritonClientModelDoesntSupportBatchingError: if model doesn't support batching.
|
1490 |
+
PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
|
1491 |
+
PyTritonClientInferenceServerError: If error occurred on inference callable or Triton Inference Server side.
|
1492 |
+
"""
|
1493 |
+
_verify_inputs_args(inputs, named_inputs)
|
1494 |
+
_verify_parameters(parameters)
|
1495 |
+
_verify_parameters(headers)
|
1496 |
+
|
1497 |
+
_LOGGER.debug(f"Running inference for {self._model_name}")
|
1498 |
+
model_config = await self.model_config
|
1499 |
+
_LOGGER.debug(f"Model config for {self._model_name} obtained")
|
1500 |
+
|
1501 |
+
model_supports_batching = model_config.max_batch_size > 0
|
1502 |
+
if not model_supports_batching:
|
1503 |
+
_LOGGER.error(f"Model {model_config.model_name} doesn't support batching")
|
1504 |
+
raise PyTritonClientModelDoesntSupportBatchingError(
|
1505 |
+
f"Model {model_config.model_name} doesn't support batching - use infer_sample method instead"
|
1506 |
+
)
|
1507 |
+
|
1508 |
+
_LOGGER.debug(f"Running _infer for {self._model_name}")
|
1509 |
+
result = self._infer(inputs or named_inputs, parameters, headers)
|
1510 |
+
_LOGGER.debug(f"_infer for {self._model_name} finished")
|
1511 |
+
async for item in result:
|
1512 |
+
yield item
|
1513 |
+
|
1514 |
+
async def _execute_infer(self, model_config, inputs_wrapped, outputs_wrapped, parameters, headers) -> Any:
|
1515 |
+
# stream_infer siletly consumes all errors raised inside async_request_iterator and raises CancelledError
|
1516 |
+
error_raised_inside_async_request_iterator = set()
|
1517 |
+
try:
|
1518 |
+
_LOGGER.debug(f"Sending InferRequest for {self._model_name}")
|
1519 |
+
kwargs = self._get_infer_extra_args()
|
1520 |
+
|
1521 |
+
async def async_request_iterator(errors):
|
1522 |
+
_LOGGER.debug(f"Begin creating InferRequestHeader for {self._model_name}")
|
1523 |
+
try:
|
1524 |
+
yield {
|
1525 |
+
"model_name": self._model_name,
|
1526 |
+
"inputs": inputs_wrapped,
|
1527 |
+
"outputs": outputs_wrapped,
|
1528 |
+
"request_id": self._next_request_id,
|
1529 |
+
"sequence_id": 0,
|
1530 |
+
"sequence_start": True,
|
1531 |
+
"sequence_end": True,
|
1532 |
+
}
|
1533 |
+
except Exception as e:
|
1534 |
+
_LOGGER.error(f"Error occurred while creating InferRequestHeader for {self._model_name}")
|
1535 |
+
errors.add(e)
|
1536 |
+
raise e
|
1537 |
+
_LOGGER.debug(f"End creating InferRequestHeader for {self._model_name}")
|
1538 |
+
|
1539 |
+
response_iterator = self._infer_client.stream_infer(
|
1540 |
+
inputs_iterator=async_request_iterator(error_raised_inside_async_request_iterator),
|
1541 |
+
headers=headers,
|
1542 |
+
**kwargs,
|
1543 |
+
)
|
1544 |
+
_LOGGER.debug(f"End preparing InferRequest for {self._model_name}")
|
1545 |
+
while True:
|
1546 |
+
try:
|
1547 |
+
try:
|
1548 |
+
response = await asyncio.wait_for(
|
1549 |
+
response_iterator.__anext__(),
|
1550 |
+
self._inference_timeout_s,
|
1551 |
+
)
|
1552 |
+
except asyncio.TimeoutError as e:
|
1553 |
+
message = f"Timeout while waiting for model {self._model_name} to return next response {self._inference_timeout_s}s"
|
1554 |
+
_LOGGER.error(message)
|
1555 |
+
raise PyTritonClientTimeoutError(message) from e
|
1556 |
+
result, error = response
|
1557 |
+
_LOGGER.debug(f"Received InferResponse for {self._model_name}")
|
1558 |
+
if error is not None:
|
1559 |
+
raise error
|
1560 |
+
else:
|
1561 |
+
partial_output = {
|
1562 |
+
output_spec.name: result.as_numpy(output_spec.name) for output_spec in model_config.outputs
|
1563 |
+
}
|
1564 |
+
yield partial_output
|
1565 |
+
except StopAsyncIteration:
|
1566 |
+
break
|
1567 |
+
_LOGGER.debug(f"End receiving InferResponse for {self._model_name}")
|
1568 |
+
|
1569 |
+
except asyncio.exceptions.TimeoutError as e:
|
1570 |
+
# HTTP aio client raises asyncio.exceptions.TimeoutError for timeout errors
|
1571 |
+
message = f"Timeout exceeded while running inference for {self._model_name}"
|
1572 |
+
_LOGGER.error(message)
|
1573 |
+
raise PyTritonClientTimeoutError(message) from e
|
1574 |
+
except tritonclient.utils.InferenceServerException as e:
|
1575 |
+
message = f"Error occurred on Triton Inference Server side:\n {e.message()}"
|
1576 |
+
_LOGGER.error(message)
|
1577 |
+
if "Deadline Exceeded" in e.message():
|
1578 |
+
# GRPC aio client raises InferenceServerException with message "Deadline Exceeded"
|
1579 |
+
# for timeout errors
|
1580 |
+
raise PyTritonClientTimeoutError(message) from e
|
1581 |
+
else:
|
1582 |
+
raise PyTritonClientInferenceServerError(message) from e
|
1583 |
+
except asyncio.exceptions.CancelledError as e:
|
1584 |
+
_LOGGER.error(f"CancelledError occurred while streaming inference for {self._model_name}")
|
1585 |
+
# stream_infer siletly consumes all errors raised inside async_request_iterator and raises CancelledError
|
1586 |
+
if len(error_raised_inside_async_request_iterator) > 0:
|
1587 |
+
_LOGGER.error(f"Re-raising error raised inside async_request_iterator for {self._model_name} ")
|
1588 |
+
raise error_raised_inside_async_request_iterator.pop() from None
|
1589 |
+
else:
|
1590 |
+
raise e
|
1591 |
+
|
1592 |
+
async def _infer(self, inputs: _IOType, parameters, headers):
|
1593 |
+
if self._model_ready:
|
1594 |
+
_LOGGER.debug(f"Waiting for model {self._model_name} config")
|
1595 |
+
await self._wait_and_init_model_config(self._init_timeout_s)
|
1596 |
+
_LOGGER.debug(f"Model wait finished for {self._model_name}")
|
1597 |
+
|
1598 |
+
_LOGGER.debug(f"Obtaining config for {self._model_name}")
|
1599 |
+
model_config = await self.model_config
|
1600 |
+
_LOGGER.debug(f"Model config for {self._model_name} obtained")
|
1601 |
+
if not model_config.decoupled:
|
1602 |
+
raise PyTritonClientInferenceServerError("Model config is coupled. Use AsyncioModelClient instead.")
|
1603 |
+
|
1604 |
+
if isinstance(inputs, Tuple):
|
1605 |
+
inputs = {input_spec.name: input_data for input_spec, input_data in zip(model_config.inputs, inputs)}
|
1606 |
+
|
1607 |
+
inputs_wrapped = []
|
1608 |
+
for input_name, input_data in inputs.items():
|
1609 |
+
if isinstance(input_data, np.ndarray):
|
1610 |
+
self._validate_input(input_name, input_data)
|
1611 |
+
triton_dtype = tritonclient.utils.np_to_triton_dtype(input_data.dtype)
|
1612 |
+
infer_input = self._triton_client_lib.InferInput(input_name, input_data.shape, triton_dtype)
|
1613 |
+
infer_input.set_data_from_numpy(input_data)
|
1614 |
+
input_wrapped = infer_input
|
1615 |
+
inputs_wrapped.append(input_wrapped)
|
1616 |
+
else:
|
1617 |
+
raise PyTritonClientValueError(
|
1618 |
+
f"Input {input_name} is not a numpy array. Got {type(input_data)} instead."
|
1619 |
+
)
|
1620 |
+
|
1621 |
+
outputs_wrapped = [
|
1622 |
+
self._triton_client_lib.InferRequestedOutput(output_spec.name) for output_spec in model_config.outputs
|
1623 |
+
]
|
1624 |
+
result = self._execute_infer(model_config, inputs_wrapped, outputs_wrapped, parameters, headers)
|
1625 |
+
async for item in result:
|
1626 |
+
yield item
|
1627 |
+
|
1628 |
+
def _get_infer_extra_args(self):
|
1629 |
+
if self._triton_url.scheme == "http":
|
1630 |
+
raise PyTritonClientValueError("AsyncioDecoupledModelClient is only supported for grpc protocol")
|
1631 |
+
warnings.warn(
|
1632 |
+
f"tritonclient.aio.grpc doesn't support client_timeout parameter {self._inference_timeout_s} for infer_stream",
|
1633 |
+
NotSupportedTimeoutWarning,
|
1634 |
+
stacklevel=1,
|
1635 |
+
)
|
1636 |
+
return {}
|
1637 |
+
|
1638 |
+
|
1639 |
+
@contextlib.contextmanager
|
1640 |
+
def _hub_context():
|
1641 |
+
hub = gevent.get_hub()
|
1642 |
+
try:
|
1643 |
+
yield hub
|
1644 |
+
finally:
|
1645 |
+
hub.destroy()
|
1646 |
+
|
1647 |
+
|
1648 |
+
_INIT = "init"
|
1649 |
+
_WAIT_FOR_MODEL = "wait_for_model"
|
1650 |
+
_MODEL_CONFIG = "model_config"
|
1651 |
+
_INFER_BATCH = "infer_batch"
|
1652 |
+
_INFER_SAMPLE = "infer_sample"
|
1653 |
+
_CLOSE = "close"
|
1654 |
+
|
1655 |
+
|
1656 |
+
class FuturesModelClient:
|
1657 |
+
"""A client for interacting with a model deployed on the Triton Inference Server using concurrent.futures.
|
1658 |
+
|
1659 |
+
This client allows asynchronous inference requests using a thread pool executor. It can be used to perform inference
|
1660 |
+
on a model by providing input data and receiving the corresponding output data. The client can be used in a `with`
|
1661 |
+
statement to ensure proper resource management.
|
1662 |
+
|
1663 |
+
Example usage with context manager:
|
1664 |
+
|
1665 |
+
```python
|
1666 |
+
with FuturesModelClient("localhost", "MyModel") as client:
|
1667 |
+
result_future = client.infer_sample(input1=input1_data, input2=input2_data)
|
1668 |
+
# do something else
|
1669 |
+
print(result_future.result())
|
1670 |
+
```
|
1671 |
+
|
1672 |
+
Usage without context manager:
|
1673 |
+
|
1674 |
+
```python
|
1675 |
+
client = FuturesModelClient("localhost", "MyModel")
|
1676 |
+
result_future = client.infer_sample(input1=input1_data, input2=input2_data)
|
1677 |
+
# do something else
|
1678 |
+
print(result_future.result())
|
1679 |
+
client.close()
|
1680 |
+
```
|
1681 |
+
"""
|
1682 |
+
|
1683 |
+
def __init__(
|
1684 |
+
self,
|
1685 |
+
url: str,
|
1686 |
+
model_name: str,
|
1687 |
+
model_version: Optional[str] = None,
|
1688 |
+
*,
|
1689 |
+
max_workers: int = 128,
|
1690 |
+
max_queue_size: int = 128,
|
1691 |
+
non_blocking: bool = False,
|
1692 |
+
init_timeout_s: Optional[float] = None,
|
1693 |
+
inference_timeout_s: Optional[float] = None,
|
1694 |
+
):
|
1695 |
+
"""Initializes the FuturesModelClient for a given model.
|
1696 |
+
|
1697 |
+
Args:
|
1698 |
+
url: The Triton Inference Server url, e.g. `grpc://localhost:8001`.
|
1699 |
+
model_name: The name of the model to interact with.
|
1700 |
+
model_version: The version of the model to interact with. If None, the latest version will be used.
|
1701 |
+
max_workers: The maximum number of threads that can be used to execute the given calls. If None, there is not limit on the number of threads.
|
1702 |
+
max_queue_size: The maximum number of requests that can be queued. If None, there is not limit on the number of requests.
|
1703 |
+
non_blocking: If True, the client will raise a PyTritonClientQueueFullError if the queue is full. If False, the client will block until the queue is not full.
|
1704 |
+
init_timeout_s: Timeout in seconds for server and model being ready. If non passed default 60 seconds timeout will be used.
|
1705 |
+
inference_timeout_s: Timeout in seconds for the single model inference request. If non passed default 60 seconds timeout will be used.
|
1706 |
+
"""
|
1707 |
+
self._url = url
|
1708 |
+
self._model_name = model_name
|
1709 |
+
self._model_version = model_version
|
1710 |
+
self._threads = []
|
1711 |
+
self._max_workers = max_workers
|
1712 |
+
self._max_queue_size = max_queue_size
|
1713 |
+
self._non_blocking = non_blocking
|
1714 |
+
|
1715 |
+
if self._max_workers is not None and self._max_workers <= 0:
|
1716 |
+
raise ValueError("max_workers must be greater than 0")
|
1717 |
+
if self._max_queue_size is not None and self._max_queue_size <= 0:
|
1718 |
+
raise ValueError("max_queue_size must be greater than 0")
|
1719 |
+
|
1720 |
+
kwargs = {}
|
1721 |
+
if self._max_queue_size is not None:
|
1722 |
+
kwargs["maxsize"] = self._max_queue_size
|
1723 |
+
self._queue = Queue(**kwargs)
|
1724 |
+
self._queue.put((_INIT, None, None))
|
1725 |
+
self._init_timeout_s = _DEFAULT_FUTURES_INIT_TIMEOUT_S if init_timeout_s is None else init_timeout_s
|
1726 |
+
self._inference_timeout_s = inference_timeout_s
|
1727 |
+
self._closed = False
|
1728 |
+
self._lock = Lock()
|
1729 |
+
self._existing_client = None
|
1730 |
+
|
1731 |
+
def __enter__(self):
|
1732 |
+
"""Create context for using FuturesModelClient as a context manager."""
|
1733 |
+
return self
|
1734 |
+
|
1735 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
1736 |
+
"""Close resources used by FuturesModelClient instance when exiting from the context."""
|
1737 |
+
self.close()
|
1738 |
+
|
1739 |
+
def close(self, wait=True):
|
1740 |
+
"""Close resources used by FuturesModelClient.
|
1741 |
+
|
1742 |
+
This method closes the resources used by the FuturesModelClient instance, including the Triton Inference Server connections.
|
1743 |
+
Once this method is called, the FuturesModelClient instance should not be used again.
|
1744 |
+
|
1745 |
+
Args:
|
1746 |
+
wait: If True, then shutdown will not return until all running futures have finished executing.
|
1747 |
+
"""
|
1748 |
+
if self._closed:
|
1749 |
+
return
|
1750 |
+
_LOGGER.debug("Closing FuturesModelClient.")
|
1751 |
+
|
1752 |
+
self._closed = True
|
1753 |
+
for _ in range(len(self._threads)):
|
1754 |
+
self._queue.put((_CLOSE, None, None))
|
1755 |
+
|
1756 |
+
if wait:
|
1757 |
+
_LOGGER.debug("Waiting for futures to finish.")
|
1758 |
+
for thread in self._threads:
|
1759 |
+
thread.join()
|
1760 |
+
|
1761 |
+
def wait_for_model(self, timeout_s: float) -> Future:
|
1762 |
+
"""Returns a Future object which result will be None when the model is ready.
|
1763 |
+
|
1764 |
+
Typical usage:
|
1765 |
+
|
1766 |
+
```python
|
1767 |
+
with FuturesModelClient("localhost", "BERT") as client
|
1768 |
+
future = client.wait_for_model(300.)
|
1769 |
+
# do something else
|
1770 |
+
future.result() # wait rest of timeout_s time
|
1771 |
+
# till return None if model is ready
|
1772 |
+
# or raise PyTritonClientTimeutError
|
1773 |
+
```
|
1774 |
+
|
1775 |
+
Args:
|
1776 |
+
timeout_s: The maximum amount of time to wait for the model to be ready, in seconds.
|
1777 |
+
|
1778 |
+
Returns:
|
1779 |
+
A Future object which result is None when the model is ready.
|
1780 |
+
"""
|
1781 |
+
return self._execute(
|
1782 |
+
name=_WAIT_FOR_MODEL,
|
1783 |
+
request=timeout_s,
|
1784 |
+
)
|
1785 |
+
|
1786 |
+
def model_config(self) -> Future:
|
1787 |
+
"""Obtain the configuration of the model deployed on the Triton Inference Server.
|
1788 |
+
|
1789 |
+
This method returns a Future object that will contain the TritonModelConfig object when it is ready.
|
1790 |
+
Client will wait init_timeout_s for the server to get into readiness state before obtaining the model configuration.
|
1791 |
+
|
1792 |
+
Returns:
|
1793 |
+
A Future object that will contain the TritonModelConfig object when it is ready.
|
1794 |
+
|
1795 |
+
Raises:
|
1796 |
+
PyTritonClientClosedError: If the FuturesModelClient is closed.
|
1797 |
+
"""
|
1798 |
+
return self._execute(name=_MODEL_CONFIG)
|
1799 |
+
|
1800 |
+
def infer_sample(
|
1801 |
+
self,
|
1802 |
+
*inputs,
|
1803 |
+
parameters: Optional[Dict[str, Union[str, int, bool]]] = None,
|
1804 |
+
headers: Optional[Dict[str, Union[str, int, bool]]] = None,
|
1805 |
+
**named_inputs,
|
1806 |
+
) -> Future:
|
1807 |
+
"""Run asynchronous inference on a single data sample and return a Future object.
|
1808 |
+
|
1809 |
+
This method allows the user to perform inference on a single data sample by providing input data and receiving the
|
1810 |
+
corresponding output data. The method returns a Future object that wraps a dictionary of inference results, where dictionary keys are output names.
|
1811 |
+
|
1812 |
+
Example usage:
|
1813 |
+
|
1814 |
+
```python
|
1815 |
+
with FuturesModelClient("localhost", "BERT") as client:
|
1816 |
+
result_future = client.infer_sample(input1=input1_data, input2=input2_data)
|
1817 |
+
# do something else
|
1818 |
+
print(result_future.result())
|
1819 |
+
```
|
1820 |
+
|
1821 |
+
Inference inputs can be provided either as positional or keyword arguments:
|
1822 |
+
|
1823 |
+
```python
|
1824 |
+
future = client.infer_sample(input1, input2)
|
1825 |
+
future = client.infer_sample(a=input1, b=input2)
|
1826 |
+
```
|
1827 |
+
|
1828 |
+
Args:
|
1829 |
+
*inputs: Inference inputs provided as positional arguments.
|
1830 |
+
parameters: Optional dictionary of inference parameters.
|
1831 |
+
headers: Optional dictionary of HTTP headers for the inference request.
|
1832 |
+
**named_inputs: Inference inputs provided as named arguments.
|
1833 |
+
|
1834 |
+
Returns:
|
1835 |
+
A Future object wrapping a dictionary of inference results, where dictionary keys are output names.
|
1836 |
+
|
1837 |
+
Raises:
|
1838 |
+
PyTritonClientClosedError: If the FuturesModelClient is closed.
|
1839 |
+
"""
|
1840 |
+
return self._execute(
|
1841 |
+
name=_INFER_SAMPLE,
|
1842 |
+
request=(inputs, parameters, headers, named_inputs),
|
1843 |
+
)
|
1844 |
+
|
1845 |
+
def infer_batch(
|
1846 |
+
self,
|
1847 |
+
*inputs,
|
1848 |
+
parameters: Optional[Dict[str, Union[str, int, bool]]] = None,
|
1849 |
+
headers: Optional[Dict[str, Union[str, int, bool]]] = None,
|
1850 |
+
**named_inputs,
|
1851 |
+
) -> Future:
|
1852 |
+
"""Run asynchronous inference on batched data and return a Future object.
|
1853 |
+
|
1854 |
+
This method allows the user to perform inference on batched data by providing input data and receiving the corresponding output data.
|
1855 |
+
The method returns a Future object that wraps a dictionary of inference results, where dictionary keys are output names.
|
1856 |
+
|
1857 |
+
Example usage:
|
1858 |
+
|
1859 |
+
```python
|
1860 |
+
with FuturesModelClient("localhost", "BERT") as client:
|
1861 |
+
future = client.infer_batch(input1_sample, input2_sample)
|
1862 |
+
# do something else
|
1863 |
+
print(future.result())
|
1864 |
+
```
|
1865 |
+
|
1866 |
+
Inference inputs can be provided either as positional or keyword arguments:
|
1867 |
+
|
1868 |
+
```python
|
1869 |
+
future = client.infer_batch(input1, input2)
|
1870 |
+
future = client.infer_batch(a=input1, b=input2)
|
1871 |
+
```
|
1872 |
+
|
1873 |
+
Mixing of argument passing conventions is not supported and will raise PyTritonClientValueError.
|
1874 |
+
|
1875 |
+
Args:
|
1876 |
+
*inputs: Inference inputs provided as positional arguments.
|
1877 |
+
parameters: Optional dictionary of inference parameters.
|
1878 |
+
headers: Optional dictionary of HTTP headers for the inference request.
|
1879 |
+
**named_inputs: Inference inputs provided as named arguments.
|
1880 |
+
|
1881 |
+
Returns:
|
1882 |
+
A Future object wrapping a dictionary of inference results, where dictionary keys are output names.
|
1883 |
+
|
1884 |
+
Raises:
|
1885 |
+
PyTritonClientClosedError: If the FuturesModelClient is closed.
|
1886 |
+
"""
|
1887 |
+
return self._execute(name=_INFER_BATCH, request=(inputs, parameters, headers, named_inputs))
|
1888 |
+
|
1889 |
+
def _execute(self, name, request=None):
|
1890 |
+
if self._closed:
|
1891 |
+
raise PyTritonClientClosedError("FutureModelClient is already closed")
|
1892 |
+
self._extend_thread_pool()
|
1893 |
+
future = Future()
|
1894 |
+
if self._non_blocking:
|
1895 |
+
try:
|
1896 |
+
self._queue.put_nowait((future, request, name))
|
1897 |
+
except Full as e:
|
1898 |
+
raise PyTritonClientQueueFullError("Queue is full") from e
|
1899 |
+
else:
|
1900 |
+
kwargs = {}
|
1901 |
+
if self._inference_timeout_s is not None:
|
1902 |
+
kwargs["timeout"] = self._inference_timeout_s
|
1903 |
+
try:
|
1904 |
+
self._queue.put((future, request, name), **kwargs)
|
1905 |
+
except Full as e:
|
1906 |
+
raise PyTritonClientQueueFullError("Queue is full") from e
|
1907 |
+
return future
|
1908 |
+
|
1909 |
+
def _extend_thread_pool(self):
|
1910 |
+
if self._closed:
|
1911 |
+
return
|
1912 |
+
|
1913 |
+
with self._lock:
|
1914 |
+
if not self._queue.empty() and (self._max_workers is None or len(self._threads) < self._max_workers):
|
1915 |
+
_LOGGER.debug("Create new thread")
|
1916 |
+
thread = Thread(target=self._worker)
|
1917 |
+
self._threads.append(thread)
|
1918 |
+
thread.start()
|
1919 |
+
else:
|
1920 |
+
_LOGGER.debug("No need to create new thread")
|
1921 |
+
|
1922 |
+
def _client_request_executor(self, client, request, name):
|
1923 |
+
_LOGGER.debug(f"Running {name} for {self._model_name}")
|
1924 |
+
if name == _INFER_SAMPLE:
|
1925 |
+
inputs, parameters, headers, named_inputs = request
|
1926 |
+
result = client.infer_sample(
|
1927 |
+
*inputs,
|
1928 |
+
parameters=parameters,
|
1929 |
+
headers=headers,
|
1930 |
+
**named_inputs,
|
1931 |
+
)
|
1932 |
+
elif name == _INFER_BATCH:
|
1933 |
+
inputs, parameters, headers, named_inputs = request
|
1934 |
+
result = client.infer_batch(
|
1935 |
+
*inputs,
|
1936 |
+
parameters=parameters,
|
1937 |
+
headers=headers,
|
1938 |
+
**named_inputs,
|
1939 |
+
)
|
1940 |
+
elif name == _MODEL_CONFIG:
|
1941 |
+
result = client.model_config
|
1942 |
+
elif name == _WAIT_FOR_MODEL:
|
1943 |
+
timeout_s = request
|
1944 |
+
result = client.wait_for_model(timeout_s)
|
1945 |
+
else:
|
1946 |
+
raise PyTritonClientValueError(f"Unknown request name {name}")
|
1947 |
+
self._set_existing_client(client)
|
1948 |
+
return result
|
1949 |
+
|
1950 |
+
def _create_client(self, lazy_init):
|
1951 |
+
_LOGGER.debug(f"Creating ModelClient lazy_init={lazy_init}")
|
1952 |
+
return ModelClient(
|
1953 |
+
self._url,
|
1954 |
+
self._model_name,
|
1955 |
+
self._model_version,
|
1956 |
+
lazy_init=lazy_init,
|
1957 |
+
init_timeout_s=self._init_timeout_s,
|
1958 |
+
inference_timeout_s=self._inference_timeout_s,
|
1959 |
+
)
|
1960 |
+
|
1961 |
+
def _set_existing_client(self, client):
|
1962 |
+
if client._model_config is not None:
|
1963 |
+
with self._lock:
|
1964 |
+
if self._existing_client is None:
|
1965 |
+
_LOGGER.debug("Setting existing client")
|
1966 |
+
self._existing_client = client
|
1967 |
+
|
1968 |
+
def _remove_existing_client(self, client):
|
1969 |
+
if client is not None:
|
1970 |
+
with self._lock:
|
1971 |
+
if self._existing_client is not None:
|
1972 |
+
if self._existing_client is client:
|
1973 |
+
_LOGGER.debug("Resetting existing client")
|
1974 |
+
self._existing_client = None
|
1975 |
+
|
1976 |
+
def _worker(self):
|
1977 |
+
_LOGGER.debug("Starting worker thread")
|
1978 |
+
client = None
|
1979 |
+
# Work around for AttributeError: '_Threadlocal' object has no attribute 'hub'
|
1980 |
+
# gevent/_hub_local.py", line 77, in gevent._gevent_c_hub_local.get_hub_noargs
|
1981 |
+
with _hub_context():
|
1982 |
+
while True:
|
1983 |
+
future, request, name = self._queue.get()
|
1984 |
+
if future == _CLOSE:
|
1985 |
+
_LOGGER.debug("Closing thread")
|
1986 |
+
self._queue.task_done()
|
1987 |
+
break
|
1988 |
+
if future == _INIT:
|
1989 |
+
with self._lock:
|
1990 |
+
if self._existing_client is None:
|
1991 |
+
try:
|
1992 |
+
_LOGGER.debug("Initial client creation")
|
1993 |
+
client = self._create_client(False)
|
1994 |
+
_LOGGER.debug("Setting existing client")
|
1995 |
+
self._existing_client = client
|
1996 |
+
except Exception as e:
|
1997 |
+
_LOGGER.warning(f"Error {e} occurred during init for {self._model_name}")
|
1998 |
+
continue
|
1999 |
+
try:
|
2000 |
+
if client is None:
|
2001 |
+
with self._lock:
|
2002 |
+
if self._existing_client is not None:
|
2003 |
+
_LOGGER.debug("Creating new client from existing client")
|
2004 |
+
client = ModelClient.from_existing_client(self._existing_client)
|
2005 |
+
if client is None:
|
2006 |
+
_LOGGER.debug("Creating new client")
|
2007 |
+
client = self._create_client(name == _WAIT_FOR_MODEL)
|
2008 |
+
with client:
|
2009 |
+
self._set_existing_client(client)
|
2010 |
+
while True:
|
2011 |
+
try:
|
2012 |
+
result = self._client_request_executor(client, request, name)
|
2013 |
+
_LOGGER.debug(f"Finished {name} for {self._model_name}")
|
2014 |
+
future.set_result(result)
|
2015 |
+
self._queue.task_done()
|
2016 |
+
except Exception as e:
|
2017 |
+
_LOGGER.error(f"Error {e} occurred during {name} for {self._model_name}")
|
2018 |
+
future.set_exception(e)
|
2019 |
+
self._queue.task_done()
|
2020 |
+
break
|
2021 |
+
future, request, name = self._queue.get()
|
2022 |
+
if future == _CLOSE:
|
2023 |
+
_LOGGER.debug("Closing thread")
|
2024 |
+
self._queue.task_done()
|
2025 |
+
return
|
2026 |
+
except Exception as e:
|
2027 |
+
_LOGGER.error(f"Error {e} occurred during {name} for {self._model_name}")
|
2028 |
+
future.set_exception(e)
|
2029 |
+
self._queue.task_done()
|
2030 |
+
finally:
|
2031 |
+
self._remove_existing_client(client)
|
2032 |
+
client = None
|
2033 |
+
_LOGGER.debug("Finishing worker thread")
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/client/exceptions.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Exceptions thrown in pytriton.client module."""
|
15 |
+
|
16 |
+
|
17 |
+
class PyTritonClientError(Exception):
|
18 |
+
"""Generic pytriton client exception."""
|
19 |
+
|
20 |
+
def __init__(self, message: str):
|
21 |
+
"""Initialize exception with message.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
message: Error message
|
25 |
+
"""
|
26 |
+
self._message = message
|
27 |
+
|
28 |
+
def __str__(self) -> str:
|
29 |
+
"""String representation of error.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
Message content
|
33 |
+
"""
|
34 |
+
return self._message
|
35 |
+
|
36 |
+
@property
|
37 |
+
def message(self):
|
38 |
+
"""Get the exception message.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
The message associated with this exception, or None if no message.
|
42 |
+
|
43 |
+
"""
|
44 |
+
return self._message
|
45 |
+
|
46 |
+
|
47 |
+
class PyTritonClientValueError(PyTritonClientError):
|
48 |
+
"""Generic error raised in case of incorrect values are provided into API."""
|
49 |
+
|
50 |
+
pass
|
51 |
+
|
52 |
+
|
53 |
+
class PyTritonClientInvalidUrlError(PyTritonClientValueError):
|
54 |
+
"""Error raised when provided Triton Inference Server url is invalid."""
|
55 |
+
|
56 |
+
pass
|
57 |
+
|
58 |
+
|
59 |
+
class PyTritonClientTimeoutError(PyTritonClientError):
|
60 |
+
"""Timeout occurred during communication with the Triton Inference Server."""
|
61 |
+
|
62 |
+
pass
|
63 |
+
|
64 |
+
|
65 |
+
class PyTritonClientModelUnavailableError(PyTritonClientError):
|
66 |
+
"""Model with given name and version is unavailable on the given Triton Inference Server."""
|
67 |
+
|
68 |
+
pass
|
69 |
+
|
70 |
+
|
71 |
+
class PyTritonClientClosedError(PyTritonClientError):
|
72 |
+
"""Error raised in case of trying to use closed client."""
|
73 |
+
|
74 |
+
pass
|
75 |
+
|
76 |
+
|
77 |
+
class PyTritonClientModelDoesntSupportBatchingError(PyTritonClientError):
|
78 |
+
"""Error raised in case of trying to infer batch on model not supporting batching."""
|
79 |
+
|
80 |
+
pass
|
81 |
+
|
82 |
+
|
83 |
+
class PyTritonClientInferenceServerError(PyTritonClientError):
|
84 |
+
"""Error raised in case of error on inference callable or Triton Inference Server side."""
|
85 |
+
|
86 |
+
pass
|
87 |
+
|
88 |
+
|
89 |
+
class PyTritonClientQueueFullError(PyTritonClientError):
|
90 |
+
"""Error raised in case of trying to push request to full queue."""
|
91 |
+
|
92 |
+
pass
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/client/utils.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Utility module supporting model clients."""
|
15 |
+
|
16 |
+
import dataclasses
|
17 |
+
import enum
|
18 |
+
import logging
|
19 |
+
import socket
|
20 |
+
import sys
|
21 |
+
import time
|
22 |
+
import urllib
|
23 |
+
import warnings
|
24 |
+
from typing import Optional, Union
|
25 |
+
|
26 |
+
import tritonclient.grpc
|
27 |
+
import tritonclient.http
|
28 |
+
import tritonclient.http.aio
|
29 |
+
from grpc import RpcError
|
30 |
+
from tritonclient.utils import InferenceServerException
|
31 |
+
|
32 |
+
from pytriton.client.exceptions import PyTritonClientInvalidUrlError, PyTritonClientTimeoutError
|
33 |
+
from pytriton.client.warnings import NotSupportedTimeoutWarning
|
34 |
+
from pytriton.constants import DEFAULT_GRPC_PORT, DEFAULT_HTTP_PORT
|
35 |
+
from pytriton.model_config.parser import ModelConfigParser
|
36 |
+
|
37 |
+
_LOGGER = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
_TritonSyncClientType = Union[tritonclient.grpc.InferenceServerClient, tritonclient.http.InferenceServerClient]
|
40 |
+
|
41 |
+
_DEFAULT_NETWORK_TIMEOUT_S = 60.0 # 1min
|
42 |
+
_DEFAULT_WAIT_FOR_SERVER_READY_TIMEOUT_S = 60.0 # 1min
|
43 |
+
_DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S = 300.0 # 5min
|
44 |
+
|
45 |
+
LATEST_MODEL_VERSION = "<latest>"
|
46 |
+
|
47 |
+
|
48 |
+
# Special value for model_version argument. If model_version is None, the latest version of the model is returned.
|
49 |
+
|
50 |
+
|
51 |
+
class ModelState(enum.Enum):
|
52 |
+
"""Describe model state in Triton.
|
53 |
+
|
54 |
+
Attributes:
|
55 |
+
LOADING: Loading of model
|
56 |
+
UNLOADING: Unloading of model
|
57 |
+
UNAVAILABLE: Model is missing or could not be loaded
|
58 |
+
READY: Model is ready for inference
|
59 |
+
"""
|
60 |
+
|
61 |
+
LOADING = "LOADING"
|
62 |
+
UNLOADING = "UNLOADING"
|
63 |
+
UNAVAILABLE = "UNAVAILABLE"
|
64 |
+
READY = "READY"
|
65 |
+
|
66 |
+
|
67 |
+
def parse_http_response(models):
|
68 |
+
"""Parse model repository index response from Triton Inference Server for HTTP."""
|
69 |
+
models_states = {}
|
70 |
+
_LOGGER.debug("Parsing model repository index entries:")
|
71 |
+
for model in models:
|
72 |
+
_LOGGER.debug(f" name={model.get('name')} version={model.get('version')} state={model.get('state')}")
|
73 |
+
if not model.get("version"):
|
74 |
+
continue
|
75 |
+
|
76 |
+
model_state = ModelState(model["state"]) if model.get("state") else ModelState.LOADING
|
77 |
+
models_states[(model["name"], model["version"])] = model_state
|
78 |
+
|
79 |
+
return models_states
|
80 |
+
|
81 |
+
|
82 |
+
def parse_grpc_response(models):
|
83 |
+
"""Parse model repository index response from Triton Inference Server for GRCP."""
|
84 |
+
models_states = {}
|
85 |
+
_LOGGER.debug("Parsing model repository index entries:")
|
86 |
+
for model in models:
|
87 |
+
_LOGGER.debug(f" name={model.name} version={model.version} state={model.state}")
|
88 |
+
if not model.version:
|
89 |
+
continue
|
90 |
+
|
91 |
+
model_state = ModelState(model.state) if model.state else ModelState.LOADING
|
92 |
+
models_states[(model.name, model.version)] = model_state
|
93 |
+
|
94 |
+
return models_states
|
95 |
+
|
96 |
+
|
97 |
+
def get_model_state(
|
98 |
+
client: _TritonSyncClientType,
|
99 |
+
model_name: str,
|
100 |
+
model_version: Optional[str] = None,
|
101 |
+
) -> ModelState:
|
102 |
+
"""Obtains state of the model deployed in Triton Inference Server.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
client: Triton Inference Server client to use for communication
|
106 |
+
model_name: name of the model which state we're requesting.
|
107 |
+
model_version:
|
108 |
+
version of the model which state we're requesting.
|
109 |
+
If model_version is None state of latest model is returned.
|
110 |
+
The latest versions of the model are the numerically greatest version numbers.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
Model state. _ModelState.UNAVAILABLE is returned in case if model with given name and version is not found.
|
114 |
+
|
115 |
+
"""
|
116 |
+
repository_index = client.get_model_repository_index()
|
117 |
+
if isinstance(repository_index, list):
|
118 |
+
models_states = parse_http_response(models=repository_index)
|
119 |
+
else:
|
120 |
+
models_states = parse_grpc_response(models=repository_index.models)
|
121 |
+
|
122 |
+
if model_version is None:
|
123 |
+
requested_model_states = {
|
124 |
+
version: state for (name, version), state in models_states.items() if name == model_name
|
125 |
+
}
|
126 |
+
if not requested_model_states:
|
127 |
+
return ModelState.UNAVAILABLE
|
128 |
+
else:
|
129 |
+
requested_model_states = sorted(requested_model_states.items(), key=lambda item: int(item[0]))
|
130 |
+
_latest_version, latest_version_state = requested_model_states[-1]
|
131 |
+
return latest_version_state
|
132 |
+
else:
|
133 |
+
state = models_states.get((model_name, model_version), ModelState.UNAVAILABLE)
|
134 |
+
return state
|
135 |
+
|
136 |
+
|
137 |
+
def get_model_config(
|
138 |
+
client: _TritonSyncClientType,
|
139 |
+
model_name: str,
|
140 |
+
model_version: Optional[str] = None,
|
141 |
+
timeout_s: Optional[float] = None,
|
142 |
+
):
|
143 |
+
"""Obtain configuration of model deployed on the Triton Inference Server.
|
144 |
+
|
145 |
+
Function waits for server readiness.
|
146 |
+
|
147 |
+
Typical use:
|
148 |
+
|
149 |
+
client = tritonclient.grpc.Client("localhost:8001")
|
150 |
+
model_config = get_model_config(client, "MyModel", "1", 60.0)
|
151 |
+
model_config = get_model_config(client, "MyModel")
|
152 |
+
|
153 |
+
Args:
|
154 |
+
client: Triton Inference Server client to use for communication
|
155 |
+
model_name: name of the model which configuration we're requesting.
|
156 |
+
model_version:
|
157 |
+
version of the model which configuration we're requesting.
|
158 |
+
If model_version is None configuration of the latest model is returned.
|
159 |
+
The latest versions of the model are the numerically greatest version numbers.
|
160 |
+
timeout_s: timeout to finish model configuration obtain. Default value is 300.0 s.
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
Configuration of requested model.
|
164 |
+
|
165 |
+
Raises:
|
166 |
+
PyTritonClientTimeoutError: If obtain of model configuration didn't finish before given timeout.
|
167 |
+
PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
|
168 |
+
"""
|
169 |
+
wait_for_model_ready(client, model_name=model_name, model_version=model_version, timeout_s=timeout_s)
|
170 |
+
|
171 |
+
model_version = model_version or ""
|
172 |
+
|
173 |
+
_LOGGER.debug(f"Obtaining model {model_name} config")
|
174 |
+
if isinstance(client, tritonclient.grpc.InferenceServerClient):
|
175 |
+
response = client.get_model_config(model_name, model_version, as_json=True)
|
176 |
+
model_config = response["config"]
|
177 |
+
else:
|
178 |
+
model_config = client.get_model_config(model_name, model_version)
|
179 |
+
model_config = ModelConfigParser.from_dict(model_config)
|
180 |
+
_LOGGER.debug(f"Model config: {model_config}")
|
181 |
+
return model_config
|
182 |
+
|
183 |
+
|
184 |
+
def _warn_on_too_big_network_timeout(client: _TritonSyncClientType, timeout_s: float):
|
185 |
+
if isinstance(client, tritonclient.http.InferenceServerClient):
|
186 |
+
connection_pool = client._client_stub._connection_pool
|
187 |
+
network_reldiff_s = (connection_pool.network_timeout - timeout_s) / timeout_s
|
188 |
+
connection_reldiff_s = (connection_pool.connection_timeout - timeout_s) / timeout_s
|
189 |
+
rtol = 0.001
|
190 |
+
if network_reldiff_s > rtol or connection_reldiff_s > rtol:
|
191 |
+
warnings.warn(
|
192 |
+
"Client network and/or connection timeout is smaller than requested timeout_s. This may cause unexpected behavior. "
|
193 |
+
f"network_timeout={connection_pool.network_timeout} "
|
194 |
+
f"connection_timeout={connection_pool.connection_timeout} "
|
195 |
+
f"timeout_s={timeout_s}",
|
196 |
+
NotSupportedTimeoutWarning,
|
197 |
+
stacklevel=1,
|
198 |
+
)
|
199 |
+
|
200 |
+
|
201 |
+
def wait_for_server_ready(
|
202 |
+
client: _TritonSyncClientType,
|
203 |
+
timeout_s: Optional[float] = None,
|
204 |
+
):
|
205 |
+
"""Waits for Triton Inference Server to be ready.
|
206 |
+
|
207 |
+
Typical use:
|
208 |
+
|
209 |
+
client = tritonclient.http.Client("localhost:8001")
|
210 |
+
wait_for_server_ready(client, timeout_s=600.0)
|
211 |
+
|
212 |
+
Args:
|
213 |
+
client: Triton Inference Server client to use for communication
|
214 |
+
timeout_s: timeout to server get into readiness state. Default value is 60.0 s.
|
215 |
+
|
216 |
+
Raises:
|
217 |
+
PyTritonClientTimeoutError: If obtain of model configuration didn't finish before given timeout.
|
218 |
+
"""
|
219 |
+
timeout_s = timeout_s if timeout_s is not None else _DEFAULT_WAIT_FOR_SERVER_READY_TIMEOUT_S
|
220 |
+
should_finish_before_s = time.time() + timeout_s
|
221 |
+
_warn_on_too_big_network_timeout(client, timeout_s)
|
222 |
+
|
223 |
+
def _is_server_ready():
|
224 |
+
try:
|
225 |
+
return client.is_server_ready() and client.is_server_live()
|
226 |
+
except InferenceServerException:
|
227 |
+
return False
|
228 |
+
except (RpcError, ConnectionError, socket.gaierror): # GRPC and HTTP clients raises these errors
|
229 |
+
return False
|
230 |
+
except Exception as e:
|
231 |
+
_LOGGER.exception(f"Exception while checking server readiness: {e}")
|
232 |
+
raise e
|
233 |
+
|
234 |
+
timeout_s = max(0.0, should_finish_before_s - time.time())
|
235 |
+
_LOGGER.debug(f"Waiting for server to be ready (timeout={timeout_s})")
|
236 |
+
is_server_ready = _is_server_ready()
|
237 |
+
while not is_server_ready:
|
238 |
+
time.sleep(min(1.0, timeout_s))
|
239 |
+
is_server_ready = _is_server_ready()
|
240 |
+
if not is_server_ready and time.time() >= should_finish_before_s:
|
241 |
+
raise PyTritonClientTimeoutError("Waiting for server to be ready timed out.")
|
242 |
+
|
243 |
+
|
244 |
+
def wait_for_model_ready(
|
245 |
+
client: _TritonSyncClientType,
|
246 |
+
model_name: str,
|
247 |
+
model_version: Optional[str] = None,
|
248 |
+
timeout_s: Optional[float] = None,
|
249 |
+
):
|
250 |
+
"""Wait for Triton Inference Server to be ready.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
client: Triton Inference Server client to use for communication.
|
254 |
+
model_name: name of the model to wait for readiness.
|
255 |
+
model_version:
|
256 |
+
version of the model to wait for readiness.
|
257 |
+
If model_version is None waiting for latest version of the model.
|
258 |
+
The latest versions of the model are the numerically greatest version numbers.
|
259 |
+
timeout_s: timeout to server and model get into readiness state. Default value is 300.0 s.
|
260 |
+
|
261 |
+
Raises:
|
262 |
+
PyTritonClientTimeoutError: If server readiness didn't finish before given timeout.
|
263 |
+
"""
|
264 |
+
model_version = model_version or ""
|
265 |
+
model_version_msg = model_version or LATEST_MODEL_VERSION
|
266 |
+
timeout_s = timeout_s if timeout_s is not None else _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S
|
267 |
+
should_finish_before_s = time.time() + timeout_s
|
268 |
+
|
269 |
+
wait_for_server_ready(client, timeout_s=timeout_s)
|
270 |
+
timeout_s = max(0.0, should_finish_before_s - time.time())
|
271 |
+
_LOGGER.debug(f"Waiting for model {model_name}/{model_version_msg} to be ready (timeout={timeout_s})")
|
272 |
+
is_model_ready = client.is_model_ready(model_name, model_version)
|
273 |
+
while not is_model_ready:
|
274 |
+
time.sleep(min(1.0, timeout_s))
|
275 |
+
is_model_ready = client.is_model_ready(model_name, model_version)
|
276 |
+
|
277 |
+
if not is_model_ready and time.time() >= should_finish_before_s:
|
278 |
+
raise PyTritonClientTimeoutError(
|
279 |
+
f"Waiting for model {model_name}/{model_version_msg} to be ready timed out."
|
280 |
+
)
|
281 |
+
|
282 |
+
|
283 |
+
def create_client_from_url(url: str, network_timeout_s: Optional[float] = None) -> _TritonSyncClientType: # type: ignore
|
284 |
+
"""Create Triton Inference Server client.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
url: url of the server to connect to.
|
288 |
+
If url doesn't contain scheme (e.g. "localhost:8001") http scheme is added.
|
289 |
+
If url doesn't contain port (e.g. "localhost") default port for given scheme is added.
|
290 |
+
network_timeout_s: timeout for client commands. Default value is 60.0 s.
|
291 |
+
|
292 |
+
Returns:
|
293 |
+
Triton Inference Server client.
|
294 |
+
|
295 |
+
Raises:
|
296 |
+
PyTritonClientInvalidUrlError: If provided Triton Inference Server url is invalid.
|
297 |
+
"""
|
298 |
+
url = TritonUrl.from_url(url)
|
299 |
+
triton_client_lib = {"grpc": tritonclient.grpc, "http": tritonclient.http}[url.scheme]
|
300 |
+
|
301 |
+
if url.scheme == "grpc":
|
302 |
+
# by default grpc client has very large number of timeout, thus we want to make it equal to http client timeout
|
303 |
+
network_timeout_s = _DEFAULT_NETWORK_TIMEOUT_S if network_timeout_s is None else network_timeout_s
|
304 |
+
warnings.warn(
|
305 |
+
f"tritonclient.grpc doesn't support timeout for other commands than infer. Ignoring network_timeout: {network_timeout_s}.",
|
306 |
+
NotSupportedTimeoutWarning,
|
307 |
+
stacklevel=1,
|
308 |
+
)
|
309 |
+
|
310 |
+
triton_client_init_kwargs = {}
|
311 |
+
if network_timeout_s is not None:
|
312 |
+
triton_client_init_kwargs.update(
|
313 |
+
**{
|
314 |
+
"grpc": {},
|
315 |
+
"http": {"connection_timeout": network_timeout_s, "network_timeout": network_timeout_s},
|
316 |
+
}[url.scheme]
|
317 |
+
)
|
318 |
+
|
319 |
+
_LOGGER.debug(f"Creating InferenceServerClient for {url.with_scheme} with {triton_client_init_kwargs}")
|
320 |
+
return triton_client_lib.InferenceServerClient(url.without_scheme, **triton_client_init_kwargs)
|
321 |
+
|
322 |
+
|
323 |
+
@dataclasses.dataclass
|
324 |
+
class TritonUrl:
|
325 |
+
"""TritonUrl class for parsing Triton Inference Server url.
|
326 |
+
|
327 |
+
Attributes:
|
328 |
+
scheme: scheme of the url (http or grpc)
|
329 |
+
hostname: hostname of the url
|
330 |
+
port: port of the url
|
331 |
+
|
332 |
+
Examples:
|
333 |
+
triton_url = TritonUrl.from_url("localhost:8000")
|
334 |
+
triton_url.with_scheme
|
335 |
+
>>> "http://localhost:8000"
|
336 |
+
triton_url.without_scheme
|
337 |
+
>>> "localhost:8000"
|
338 |
+
triton_url.scheme, triton_url.hostname, triton_url.port
|
339 |
+
>>> ("http", "localhost", 8000)
|
340 |
+
"""
|
341 |
+
|
342 |
+
scheme: str
|
343 |
+
hostname: str
|
344 |
+
port: int
|
345 |
+
|
346 |
+
@classmethod
|
347 |
+
def from_url(cls, url):
|
348 |
+
"""Parse triton url and create TritonUrl instance.
|
349 |
+
|
350 |
+
Returns:
|
351 |
+
TritonUrl object with scheme, hostname and port.
|
352 |
+
"""
|
353 |
+
if not isinstance(url, str):
|
354 |
+
raise PyTritonClientInvalidUrlError(f"Invalid url {url}. Url must be a string.")
|
355 |
+
try:
|
356 |
+
parsed_url = urllib.parse.urlparse(url)
|
357 |
+
# change in py3.9+
|
358 |
+
# https://github.com/python/cpython/commit/5a88d50ff013a64fbdb25b877c87644a9034c969
|
359 |
+
if sys.version_info < (3, 9) and not parsed_url.scheme and "://" in parsed_url.path:
|
360 |
+
raise ValueError(f"Invalid url {url}. Only grpc and http are supported.")
|
361 |
+
if (not parsed_url.scheme and "://" not in parsed_url.path) or (
|
362 |
+
sys.version_info >= (3, 9) and parsed_url.scheme and not parsed_url.netloc
|
363 |
+
):
|
364 |
+
_LOGGER.debug(f"Adding http scheme to {url}")
|
365 |
+
parsed_url = urllib.parse.urlparse(f"http://{url}")
|
366 |
+
|
367 |
+
scheme = parsed_url.scheme.lower()
|
368 |
+
if scheme not in ["grpc", "http"]:
|
369 |
+
raise ValueError(f"Invalid scheme {scheme}. Only grpc and http are supported.")
|
370 |
+
|
371 |
+
port = parsed_url.port or {"grpc": DEFAULT_GRPC_PORT, "http": DEFAULT_HTTP_PORT}[scheme]
|
372 |
+
except ValueError as e:
|
373 |
+
raise PyTritonClientInvalidUrlError(f"Invalid url {url}") from e
|
374 |
+
return cls(scheme, parsed_url.hostname, port)
|
375 |
+
|
376 |
+
@property
|
377 |
+
def with_scheme(self):
|
378 |
+
"""Get Triton Inference Server url with scheme."""
|
379 |
+
return f"{self.scheme}://{self.hostname}:{self.port}"
|
380 |
+
|
381 |
+
@property
|
382 |
+
def without_scheme(self):
|
383 |
+
"""Get Triton Inference Server url without scheme."""
|
384 |
+
return f"{self.hostname}:{self.port}"
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/client/warnings.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Warnings for pytriton module."""
|
15 |
+
|
16 |
+
|
17 |
+
class PyTritonWarning(UserWarning):
|
18 |
+
"""Base warning for pytriton module."""
|
19 |
+
|
20 |
+
pass
|
21 |
+
|
22 |
+
|
23 |
+
class NotSupportedTimeoutWarning(PyTritonWarning):
|
24 |
+
"""A warning for client, which doesn't support timeout configuration."""
|
25 |
+
|
26 |
+
pass
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/constants.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# noqa: D104
|
15 |
+
"""Constants for pytriton."""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import pathlib
|
19 |
+
|
20 |
+
DEFAULT_HTTP_PORT = 8000
|
21 |
+
DEFAULT_GRPC_PORT = 8001
|
22 |
+
DEFAULT_METRICS_PORT = 8002
|
23 |
+
TRITON_LOCAL_IP = "127.0.0.1"
|
24 |
+
TRITON_CONTEXT_FIELD_NAME = "triton_context"
|
25 |
+
TRITON_PYTHON_BACKEND_INTERPRETER_DIRNAME = "python_backend_interpreter"
|
26 |
+
DEFAULT_TRITON_STARTUP_TIMEOUT_S = 120
|
27 |
+
CREATE_TRITON_CLIENT_TIMEOUT_S = 10
|
28 |
+
|
29 |
+
__DEFAULT_PYTRITON_HOME = os.path.join(os.getenv("XDG_CACHE_HOME", "$HOME/.cache"), "pytriton")
|
30 |
+
__PYTRITON_HOME = os.path.expanduser(os.path.expandvars(os.getenv("PYTRITON_HOME", __DEFAULT_PYTRITON_HOME)))
|
31 |
+
PYTRITON_HOME = pathlib.Path(__PYTRITON_HOME).resolve().absolute()
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/decorators.py
ADDED
@@ -0,0 +1,678 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Inference callable decorators."""
|
15 |
+
|
16 |
+
import collections
|
17 |
+
import dataclasses
|
18 |
+
import inspect
|
19 |
+
import itertools
|
20 |
+
import operator
|
21 |
+
import typing
|
22 |
+
from bisect import bisect_left
|
23 |
+
from collections.abc import MutableMapping
|
24 |
+
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
import wrapt
|
28 |
+
|
29 |
+
from pytriton.constants import TRITON_CONTEXT_FIELD_NAME
|
30 |
+
from pytriton.exceptions import PyTritonBadParameterError, PyTritonRuntimeError, PyTritonValidationError
|
31 |
+
from pytriton.model_config.triton_model_config import TritonModelConfig
|
32 |
+
from pytriton.proxy.data import _serialize_byte_tensor
|
33 |
+
from pytriton.proxy.telemetry import start_span_from_span
|
34 |
+
|
35 |
+
|
36 |
+
class _WrappedWithWrapper(NamedTuple):
|
37 |
+
wrapped: Optional[Callable]
|
38 |
+
wrapper: Optional[Callable]
|
39 |
+
|
40 |
+
|
41 |
+
InputNames = typing.List[str]
|
42 |
+
InferenceRequest = typing.Dict[str, np.ndarray]
|
43 |
+
InferenceRequests = typing.Union[typing.List[InferenceRequest], typing.Tuple[InferenceRequest, ...]]
|
44 |
+
InferenceResult = typing.Dict[str, np.ndarray]
|
45 |
+
InferenceResults = typing.Union[typing.List[InferenceResult], typing.Tuple[InferenceResult, ...]]
|
46 |
+
|
47 |
+
|
48 |
+
def get_inference_request_batch_size(inference_request: InferenceRequest) -> int:
|
49 |
+
"""Get batch size from triton request.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
inference_request (InferenceRequest): Triton request.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
int: Batch size.
|
56 |
+
"""
|
57 |
+
first_input_value = next(iter(inference_request.values()))
|
58 |
+
batch_size, *_dims = first_input_value.shape
|
59 |
+
return batch_size
|
60 |
+
|
61 |
+
|
62 |
+
def _get_wrapt_stack(wrapped) -> List[_WrappedWithWrapper]:
|
63 |
+
"""Returns stack of wrapped functions with wrappers applied to inference callable."""
|
64 |
+
stack = []
|
65 |
+
infer_callable = wrapped
|
66 |
+
while infer_callable is not None:
|
67 |
+
stack.append(_WrappedWithWrapper(infer_callable, getattr(infer_callable, "_self_wrapper", None)))
|
68 |
+
infer_callable = getattr(infer_callable, "__wrapped__", None)
|
69 |
+
|
70 |
+
return stack
|
71 |
+
|
72 |
+
|
73 |
+
class ModelConfigDict(MutableMapping):
|
74 |
+
"""Dictionary for storing model configs for inference callable."""
|
75 |
+
|
76 |
+
def __init__(self):
|
77 |
+
"""Create ModelConfigDict object."""
|
78 |
+
self._data: Dict[str, TritonModelConfig] = {}
|
79 |
+
self._keys: List[Callable] = []
|
80 |
+
|
81 |
+
def __getitem__(self, infer_callable: Callable) -> TritonModelConfig:
|
82 |
+
"""Get model config for inference callable."""
|
83 |
+
key = self._get_model_config_key(infer_callable)
|
84 |
+
return self._data[key]
|
85 |
+
|
86 |
+
def __setitem__(self, infer_callable: Callable, item: TritonModelConfig):
|
87 |
+
"""Set model config for inference callable."""
|
88 |
+
self._keys.append(infer_callable)
|
89 |
+
key = self._get_model_config_key(infer_callable)
|
90 |
+
self._data[key] = item
|
91 |
+
|
92 |
+
def __delitem__(self, infer_callable: Callable):
|
93 |
+
"""Delete model config for inference callable."""
|
94 |
+
key = self._get_model_config_key(infer_callable)
|
95 |
+
del self._data[key]
|
96 |
+
|
97 |
+
def __len__(self):
|
98 |
+
"""Get number of inference callable keys."""
|
99 |
+
return len(self._data)
|
100 |
+
|
101 |
+
def __iter__(self):
|
102 |
+
"""Iterate over inference callable keys."""
|
103 |
+
return iter(self._keys)
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def _get_model_config_key(infer_callable: Callable) -> str:
|
107 |
+
"""Prepares TritonModelConfig dictionary key for function/callable."""
|
108 |
+
dict_key = infer_callable
|
109 |
+
if inspect.ismethod(dict_key) and dict_key.__name__ == "__call__":
|
110 |
+
dict_key = dict_key.__self__
|
111 |
+
return str(dict_key)
|
112 |
+
|
113 |
+
|
114 |
+
@dataclasses.dataclass
|
115 |
+
class TritonContext:
|
116 |
+
"""Triton context definition class."""
|
117 |
+
|
118 |
+
model_configs: ModelConfigDict = dataclasses.field(default_factory=ModelConfigDict)
|
119 |
+
|
120 |
+
|
121 |
+
def get_triton_context(wrapped, instance) -> TritonContext:
|
122 |
+
"""Retrieves triton context from callable.
|
123 |
+
|
124 |
+
It is used in @triton_context to get triton context registered by triton binding in inference callable.
|
125 |
+
If you use @triton_context decorator you do not need this function.
|
126 |
+
"""
|
127 |
+
caller = instance or wrapped
|
128 |
+
if not hasattr(caller, "__triton_context__"):
|
129 |
+
raise PyTritonValidationError("Wrapped function or object must bound with triton to get __triton_context__")
|
130 |
+
return caller.__triton_context__
|
131 |
+
|
132 |
+
|
133 |
+
def get_model_config(wrapped, instance) -> TritonModelConfig:
|
134 |
+
"""Retrieves instance of TritonModelConfig from callable.
|
135 |
+
|
136 |
+
It is internally used in convert_output function to get output list from model.
|
137 |
+
You can use this in custom decorators if you need access to model_config information.
|
138 |
+
If you use @triton_context decorator you do not need this function (you can get model_config directly
|
139 |
+
from triton_context passing function/callable to dictionary getter).
|
140 |
+
"""
|
141 |
+
return get_triton_context(wrapped, instance).model_configs[wrapped]
|
142 |
+
|
143 |
+
|
144 |
+
def convert_output(
|
145 |
+
outputs: Union[Dict, List, Tuple], wrapped=None, instance=None, model_config: Optional[TritonModelConfig] = None
|
146 |
+
):
|
147 |
+
"""Converts output from tuple ot list to dictionary.
|
148 |
+
|
149 |
+
It is utility function useful for mapping output list into dictionary of outputs.
|
150 |
+
Currently, it is used in @sample and @batch decorators (we assume that user can return list or tuple of outputs
|
151 |
+
instead of dictionary if this list matches output list in model config (size and order).
|
152 |
+
"""
|
153 |
+
if isinstance(outputs, dict):
|
154 |
+
return outputs
|
155 |
+
elif isinstance(outputs, (list, tuple)):
|
156 |
+
if model_config is None:
|
157 |
+
model_config = get_model_config(wrapped, instance)
|
158 |
+
if len(outputs) != len(model_config.outputs):
|
159 |
+
raise PyTritonValidationError("Outputs length different than config outputs length")
|
160 |
+
outputs = {config_output.name: output for config_output, output in zip(model_config.outputs, outputs)}
|
161 |
+
return outputs
|
162 |
+
else:
|
163 |
+
raise PyTritonValidationError(f"Unsupported output type {type(outputs)}.")
|
164 |
+
|
165 |
+
|
166 |
+
@wrapt.decorator
|
167 |
+
def sample(wrapped, instance, args, kwargs):
|
168 |
+
"""Decorator is used for non-batched inputs to convert from one element list of requests to request kwargs.
|
169 |
+
|
170 |
+
Decorator takes first request and convert it into named inputs.
|
171 |
+
Useful with non-batching models - instead of one element list of request, we will get named inputs - `kwargs`.
|
172 |
+
"""
|
173 |
+
kwargs.update(args[0][0])
|
174 |
+
outputs = wrapped(*args[1:], **kwargs)
|
175 |
+
outputs = convert_output(outputs, wrapped, instance)
|
176 |
+
return [outputs]
|
177 |
+
|
178 |
+
|
179 |
+
@wrapt.decorator
|
180 |
+
def batch(wrapped, instance, args, kwargs):
|
181 |
+
"""Decorator for converting list of request dicts to dict of input batches.
|
182 |
+
|
183 |
+
Converts list of request dicts to dict of input batches.
|
184 |
+
It passes **kwargs to inference callable where each named input contains numpy array with batch of requests
|
185 |
+
received by Triton server.
|
186 |
+
We assume that each request has the same set of keys (you can use group_by_keys decorator before
|
187 |
+
using @batch decorator if your requests may have different set of keys).
|
188 |
+
|
189 |
+
Raises:
|
190 |
+
PyTritonValidationError: If the requests have different set of keys.
|
191 |
+
ValueError: If the output tensors have different than expected batch sizes. Expected batch size is
|
192 |
+
calculated as a sum of batch sizes of all requests.
|
193 |
+
"""
|
194 |
+
telemetry_name = "pytriton-batch-decorator-span"
|
195 |
+
|
196 |
+
req_list = args[0]
|
197 |
+
input_names = req_list[0].keys()
|
198 |
+
|
199 |
+
for req_dict2 in req_list[1:]:
|
200 |
+
if input_names != req_dict2.keys():
|
201 |
+
raise PyTritonValidationError("Cannot batch requests with different set of inputs keys")
|
202 |
+
|
203 |
+
inputs = {}
|
204 |
+
for model_input in input_names:
|
205 |
+
concatenated_input_data = np.concatenate([req[model_input] for req in req_list])
|
206 |
+
inputs[model_input] = concatenated_input_data
|
207 |
+
|
208 |
+
args = args[1:]
|
209 |
+
new_kwargs = dict(kwargs)
|
210 |
+
new_kwargs.update(inputs)
|
211 |
+
spans = [start_span_from_span(request.span, telemetry_name) for request in req_list if request.span is not None]
|
212 |
+
try:
|
213 |
+
outputs = wrapped(*args, **new_kwargs)
|
214 |
+
finally:
|
215 |
+
for span in spans:
|
216 |
+
span.end()
|
217 |
+
|
218 |
+
def _split_result(_result):
|
219 |
+
outputs = convert_output(_result, wrapped, instance)
|
220 |
+
output_names = outputs.keys()
|
221 |
+
|
222 |
+
requests_total_batch_size = sum(get_inference_request_batch_size(req) for req in req_list)
|
223 |
+
not_matching_tensors_shapes = {
|
224 |
+
output_name: output_tensor.shape
|
225 |
+
for output_name, output_tensor in outputs.items()
|
226 |
+
if output_tensor.shape[0] != requests_total_batch_size
|
227 |
+
}
|
228 |
+
if not_matching_tensors_shapes:
|
229 |
+
raise ValueError(
|
230 |
+
f"Received output tensors with different batch sizes: {', '.join(': '.join(map(str, item)) for item in not_matching_tensors_shapes.items())}. "
|
231 |
+
f"Expected batch size: {requests_total_batch_size}. "
|
232 |
+
)
|
233 |
+
|
234 |
+
out_list = []
|
235 |
+
start_idx = 0
|
236 |
+
for request in req_list:
|
237 |
+
# get batch_size of first input for each request - assume that all inputs have same batch_size
|
238 |
+
request_batch_size = get_inference_request_batch_size(request)
|
239 |
+
req_output_dict = {}
|
240 |
+
for _output_ind, output_name in enumerate(output_names):
|
241 |
+
req_output = outputs[output_name][start_idx : start_idx + request_batch_size, ...]
|
242 |
+
req_output_dict[output_name] = req_output
|
243 |
+
out_list.append(req_output_dict)
|
244 |
+
start_idx += request_batch_size
|
245 |
+
return out_list
|
246 |
+
|
247 |
+
if inspect.isgenerator(outputs):
|
248 |
+
return (_split_result(_result) for _result in outputs)
|
249 |
+
else:
|
250 |
+
return _split_result(outputs)
|
251 |
+
|
252 |
+
|
253 |
+
def group_by_values(*keys, pad_fn: typing.Optional[typing.Callable[[InferenceRequests], InferenceRequests]] = None):
|
254 |
+
"""Decorator for grouping requests by values of selected keys.
|
255 |
+
|
256 |
+
This function splits a batch into multiple sub-batches based on the specified keys values and
|
257 |
+
calls the decorated function with each sub-batch. This is particularly useful when working with models
|
258 |
+
that require dynamic parameters sent by the user.
|
259 |
+
|
260 |
+
For example, given an input of the form:
|
261 |
+
|
262 |
+
```python
|
263 |
+
{"sentences": [b"Sentence1", b"Sentence2", b"Sentence3"], "param1": [1, 1, 2], "param2": [1, 1, 1]}
|
264 |
+
```
|
265 |
+
|
266 |
+
Using @group_by_values("param1", "param2") will split the batch into two sub-batches:
|
267 |
+
|
268 |
+
```python
|
269 |
+
[
|
270 |
+
{"sentences": [b"Sentence1", b"Sentence2"], "param1": [1, 1], "param2": [1, 1]},
|
271 |
+
{"sentences": [b"Sentence3"], "param1": [2], "param2": [1]}
|
272 |
+
]
|
273 |
+
```
|
274 |
+
|
275 |
+
This decorator should be used after the @batch decorator.
|
276 |
+
|
277 |
+
Example usage:
|
278 |
+
```python
|
279 |
+
@batch
|
280 |
+
@group_by_values("param1", "param2")
|
281 |
+
def infer_fun(**inputs):
|
282 |
+
...
|
283 |
+
return outputs
|
284 |
+
```
|
285 |
+
|
286 |
+
Args:
|
287 |
+
*keys: List of keys to group by.
|
288 |
+
pad_fn: Optional function to pad the batch to the same size before merging again to a single batch.
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
The decorator function.
|
292 |
+
"""
|
293 |
+
|
294 |
+
def value_to_key(value):
|
295 |
+
if isinstance(value, np.ndarray):
|
296 |
+
if value.dtype == np.object_ or value.dtype.type == np.bytes_:
|
297 |
+
return _serialize_byte_tensor(value)
|
298 |
+
else:
|
299 |
+
return value.tobytes()
|
300 |
+
return value
|
301 |
+
|
302 |
+
def _get_sort_key_for_sample(_request, _sample_idx: int):
|
303 |
+
return tuple(value_to_key(_request[_key][_sample_idx]) for _key in keys)
|
304 |
+
|
305 |
+
def _group_request(_request: InferenceRequest, _batch_size: int):
|
306 |
+
idx_inputs = [(sample_idx, _get_sort_key_for_sample(_request, sample_idx)) for sample_idx in range(_batch_size)]
|
307 |
+
idx_inputs.sort(key=operator.itemgetter(1))
|
308 |
+
for _, group in itertools.groupby(idx_inputs, key=operator.itemgetter(1)):
|
309 |
+
_samples_idxes, _ = zip(*group)
|
310 |
+
grouped_request = {input_name: value[_samples_idxes, ...] for input_name, value in _request.items()}
|
311 |
+
yield _samples_idxes, grouped_request
|
312 |
+
|
313 |
+
@wrapt.decorator
|
314 |
+
def _wrapper(wrapped, instance, args, kwargs):
|
315 |
+
wrappers_stack = [
|
316 |
+
callable_with_wrapper.wrapper
|
317 |
+
for callable_with_wrapper in _get_wrapt_stack(wrapped)
|
318 |
+
if callable_with_wrapper.wrapper is not None
|
319 |
+
]
|
320 |
+
if batch in wrappers_stack:
|
321 |
+
raise PyTritonRuntimeError("The @group_by_values decorator must be used after the @batch decorator.")
|
322 |
+
|
323 |
+
request = {k: v for k, v in kwargs.items() if k not in _SPECIAL_KEYS}
|
324 |
+
other_kwargs = {k: v for k, v in kwargs.items() if k in _SPECIAL_KEYS}
|
325 |
+
|
326 |
+
batch_size = get_inference_request_batch_size(request)
|
327 |
+
sample_indices_with_interim_result = []
|
328 |
+
for sample_indices, _grouped_sub_request in _group_request(request, batch_size):
|
329 |
+
interim_result = wrapped(*args, **_grouped_sub_request, **other_kwargs)
|
330 |
+
sample_indices_with_interim_result.append((sample_indices, interim_result))
|
331 |
+
|
332 |
+
if pad_fn is not None:
|
333 |
+
indices, results = tuple(map(tuple, zip(*sample_indices_with_interim_result)))
|
334 |
+
results = pad_fn(results)
|
335 |
+
sample_indices_with_interim_result = tuple(zip(indices, results))
|
336 |
+
|
337 |
+
_, first_result_data = sample_indices_with_interim_result[0]
|
338 |
+
result = {
|
339 |
+
output_name: np.zeros((batch_size,) + data.shape[1:], dtype=data.dtype)
|
340 |
+
for output_name, data in first_result_data.items()
|
341 |
+
}
|
342 |
+
for indices, results in sample_indices_with_interim_result:
|
343 |
+
for output_name, data in results.items():
|
344 |
+
result[output_name][indices, ...] = data
|
345 |
+
|
346 |
+
return result
|
347 |
+
|
348 |
+
return _wrapper
|
349 |
+
|
350 |
+
|
351 |
+
class ConstantPadder:
|
352 |
+
"""Padder that pads the given batches with a constant value."""
|
353 |
+
|
354 |
+
def __init__(self, pad_value=0):
|
355 |
+
"""Initialize the padder.
|
356 |
+
|
357 |
+
Args:
|
358 |
+
pad_value (int, optional): Padding value. Defaults to 0.
|
359 |
+
"""
|
360 |
+
self.pad_value = pad_value
|
361 |
+
|
362 |
+
def __call__(self, batches_list: InferenceResults) -> InferenceResults:
|
363 |
+
"""Pad the given batches with the specified value to pad size enabling further batching to single arrays.
|
364 |
+
|
365 |
+
Args:
|
366 |
+
batches_list (List[Dict[str, np.ndarray]]): List of batches to pad.
|
367 |
+
|
368 |
+
Returns:
|
369 |
+
List[Dict[str, np.ndarray]]: List of padded batches.
|
370 |
+
|
371 |
+
Raises:
|
372 |
+
PyTritonRuntimeError: If the input arrays for a given input name have different dtypes.
|
373 |
+
"""
|
374 |
+
|
375 |
+
def _get_padded_shape(_batches: List[np.ndarray]) -> Tuple[int, ...]:
|
376 |
+
"""Get the shape of the padded array without batch axis."""
|
377 |
+
return tuple(np.max([batch.shape[1:] for batch in _batches if batch is not None], axis=0))
|
378 |
+
|
379 |
+
def _get_padded_dtype(_batches: List[np.ndarray]) -> np.dtype:
|
380 |
+
dtypes = [batch.dtype for batch in _batches if batch is not None]
|
381 |
+
result_dtype = dtypes[0]
|
382 |
+
|
383 |
+
if not all(dtype.kind == result_dtype.kind for dtype in dtypes):
|
384 |
+
raise PyTritonRuntimeError("All input arrays for given input name must have the same dtype.")
|
385 |
+
|
386 |
+
# for bytes (encoded string) or unicode string need to obtain the max length
|
387 |
+
if result_dtype.kind in "SU":
|
388 |
+
order_and_kind = result_dtype.str[:2]
|
389 |
+
max_len = max([int(dtype.str[2:]) for dtype in dtypes])
|
390 |
+
result_dtype = f"{order_and_kind}{max_len}"
|
391 |
+
else:
|
392 |
+
if not all(dtype == result_dtype for dtype in dtypes):
|
393 |
+
raise PyTritonRuntimeError("All input arrays for given input name must have the same dtype.")
|
394 |
+
|
395 |
+
return np.dtype(result_dtype)
|
396 |
+
|
397 |
+
input_names = list(
|
398 |
+
collections.OrderedDict.fromkeys(input_name for batch in batches_list for input_name in batch.keys())
|
399 |
+
)
|
400 |
+
batches_by_name = {input_name: [batch.get(input_name) for batch in batches_list] for input_name in input_names}
|
401 |
+
for input_batches in batches_by_name.values():
|
402 |
+
result_shape, result_dtype = _get_padded_shape(input_batches), _get_padded_dtype(input_batches)
|
403 |
+
for batch_idx, batch in enumerate(input_batches):
|
404 |
+
if batch is not None:
|
405 |
+
input_batches[batch_idx] = np.pad(
|
406 |
+
batch,
|
407 |
+
[(0, 0)] + [(0, b - a) for a, b in zip(batch.shape[1:], result_shape)],
|
408 |
+
mode="constant",
|
409 |
+
constant_values=self.pad_value if result_dtype.kind not in ["S", "U", "O"] else b"",
|
410 |
+
).astype(result_dtype)
|
411 |
+
|
412 |
+
return [
|
413 |
+
{name: batches[batch_idx] for name, batches in batches_by_name.items() if batches[batch_idx] is not None}
|
414 |
+
for batch_idx in range(len(batches_list))
|
415 |
+
]
|
416 |
+
|
417 |
+
|
418 |
+
@wrapt.decorator
|
419 |
+
def group_by_keys(wrapped, instance, args, kwargs):
|
420 |
+
"""Group by keys.
|
421 |
+
|
422 |
+
Decorator prepares groups of requests with the same set of keys and calls wrapped function
|
423 |
+
for each group separately (it is convenient to use this decorator before batching, because the batching decorator
|
424 |
+
requires consistent set of inputs as it stacks them into batches).
|
425 |
+
"""
|
426 |
+
inputs = args[0]
|
427 |
+
idx_inputs = [(idx, tuple(sorted(input.keys())), input) for idx, input in enumerate(inputs)]
|
428 |
+
idx_inputs.sort(key=operator.itemgetter(1))
|
429 |
+
idx_groups_res = []
|
430 |
+
for _, group in itertools.groupby(idx_inputs, key=operator.itemgetter(1)):
|
431 |
+
idx, _key, sample_list = zip(*group)
|
432 |
+
args = (list(sample_list),) + args[1:]
|
433 |
+
out = wrapped(*args, **kwargs)
|
434 |
+
idx_groups_res.extend(zip(idx, out))
|
435 |
+
|
436 |
+
idx_groups_res.sort(key=operator.itemgetter(0))
|
437 |
+
res_flat = [r[1] for r in idx_groups_res]
|
438 |
+
return res_flat
|
439 |
+
|
440 |
+
|
441 |
+
def fill_optionals(**defaults):
|
442 |
+
"""This decorator ensures that any missing inputs in requests are filled with default values specified by the user.
|
443 |
+
|
444 |
+
Default values should be NumPy arrays without batch axis.
|
445 |
+
|
446 |
+
If you plan to group requests ex. with
|
447 |
+
[@group_by_keys][pytriton.decorators.group_by_keys] or
|
448 |
+
[@group_by_vales][pytriton.decorators.group_by_values] decorators
|
449 |
+
provide default values for optional parameters at the beginning of decorators stack.
|
450 |
+
The other decorators can then group requests into bigger batches resulting in a better model performance.
|
451 |
+
|
452 |
+
Typical use:
|
453 |
+
```python
|
454 |
+
@fill_optionals()
|
455 |
+
@group_by_keys()
|
456 |
+
@batch
|
457 |
+
def infer_fun(**inputs):
|
458 |
+
...
|
459 |
+
return outputs
|
460 |
+
```
|
461 |
+
|
462 |
+
Args:
|
463 |
+
defaults: keyword arguments containing default values for missing inputs
|
464 |
+
|
465 |
+
|
466 |
+
If you have default values for some optional parameter it is good idea to provide them at the very beginning,
|
467 |
+
so the other decorators (e.g. @group_by_keys) can make bigger consistent groups.
|
468 |
+
"""
|
469 |
+
|
470 |
+
def _verify_defaults(model_config: TritonModelConfig):
|
471 |
+
inputs = {spec.name: spec for spec in model_config.inputs}
|
472 |
+
not_matching_default_names = sorted(set(defaults) - set(inputs))
|
473 |
+
if not_matching_default_names:
|
474 |
+
raise PyTritonBadParameterError(f"Could not found {', '.join(not_matching_default_names)} inputs")
|
475 |
+
|
476 |
+
non_numpy_items = {k: v for k, v in defaults.items() if not isinstance(v, np.ndarray)}
|
477 |
+
if non_numpy_items:
|
478 |
+
raise PyTritonBadParameterError(
|
479 |
+
f"Could not use {', '.join([f'{k}={v}' for k, v in non_numpy_items.items()])} defaults "
|
480 |
+
"as they are not NumPy arrays"
|
481 |
+
)
|
482 |
+
|
483 |
+
not_matching_dtypes = {k: (v.dtype, inputs[k].dtype) for k, v in defaults.items() if v.dtype != inputs[k].dtype}
|
484 |
+
if not_matching_dtypes:
|
485 |
+
non_matching_dtypes_str_list = [
|
486 |
+
f"{name}: dtype={have_dtype} expected_dtype={expected_dtype}"
|
487 |
+
for name, (have_dtype, expected_dtype) in not_matching_dtypes.items()
|
488 |
+
]
|
489 |
+
raise PyTritonBadParameterError(
|
490 |
+
f"Could not use {', '.join(non_matching_dtypes_str_list)} "
|
491 |
+
f"defaults as they have different than input signature dtypes"
|
492 |
+
)
|
493 |
+
|
494 |
+
def _shape_match(_have_shape, _expected_shape):
|
495 |
+
return len(_have_shape) == len(_expected_shape) and all(
|
496 |
+
e == -1 or h == e for h, e in zip(_have_shape, _expected_shape)
|
497 |
+
)
|
498 |
+
|
499 |
+
not_matching_shapes = {
|
500 |
+
k: (v.shape, inputs[k].shape) for k, v in defaults.items() if not _shape_match(v.shape, inputs[k].shape)
|
501 |
+
}
|
502 |
+
if not_matching_shapes:
|
503 |
+
non_matching_shapes_str_list = [
|
504 |
+
f"{name}: shape={have_shape} expected_shape={expected_shape}"
|
505 |
+
for name, (have_shape, expected_shape) in not_matching_shapes.items()
|
506 |
+
]
|
507 |
+
raise PyTritonBadParameterError(
|
508 |
+
f"Could not use {', '.join(non_matching_shapes_str_list)} "
|
509 |
+
f"defaults as they have different than input signature shapes"
|
510 |
+
)
|
511 |
+
|
512 |
+
@wrapt.decorator
|
513 |
+
def _wrapper(wrapped, instance, args, kwargs):
|
514 |
+
model_config = get_model_config(wrapped, instance)
|
515 |
+
_verify_defaults(model_config)
|
516 |
+
# verification if not after group wrappers is in group wrappers
|
517 |
+
|
518 |
+
(requests,) = args
|
519 |
+
|
520 |
+
model_supports_batching = model_config.batching
|
521 |
+
for request in requests:
|
522 |
+
batch_size = get_inference_request_batch_size(request) if model_supports_batching else None
|
523 |
+
for default_key, default_value in defaults.items():
|
524 |
+
if default_key in request:
|
525 |
+
continue
|
526 |
+
|
527 |
+
if model_supports_batching:
|
528 |
+
ones_reps = (1,) * default_value.ndim # repeat once default_value on each axis
|
529 |
+
axis_reps = (batch_size,) + ones_reps # ... except on batch axis. we repeat it batch_size times
|
530 |
+
default_value = np.tile(default_value, axis_reps)
|
531 |
+
|
532 |
+
request[default_key] = default_value
|
533 |
+
return wrapped(*args, **kwargs)
|
534 |
+
|
535 |
+
return _wrapper
|
536 |
+
|
537 |
+
|
538 |
+
@wrapt.decorator
|
539 |
+
def triton_context(wrapped, instance, args, kwargs):
|
540 |
+
"""Adds triton context.
|
541 |
+
|
542 |
+
It gives you additional argument passed to the function in **kwargs called 'triton_context'.
|
543 |
+
You can read model config from it and in the future possibly have some interaction with triton.
|
544 |
+
"""
|
545 |
+
kwargs[TRITON_CONTEXT_FIELD_NAME] = get_triton_context(wrapped, instance)
|
546 |
+
return wrapped(*args, **kwargs)
|
547 |
+
|
548 |
+
|
549 |
+
@wrapt.decorator
|
550 |
+
def pad_batch(wrapped, instance, args, kwargs):
|
551 |
+
"""Add padding to the inputs batches.
|
552 |
+
|
553 |
+
Decorator appends last rows to the inputs multiple times to get desired batch size (preferred batch size or
|
554 |
+
max batch size from model config whatever is closer to current input size).
|
555 |
+
"""
|
556 |
+
inputs = {k: v for k, v in kwargs.items() if k != "__triton_context__"}
|
557 |
+
first_input = next(iter(inputs.values()))
|
558 |
+
config = get_model_config(wrapped, instance)
|
559 |
+
batch_sizes = (
|
560 |
+
[]
|
561 |
+
if (config.batcher is None or config.batcher.preferred_batch_size is None)
|
562 |
+
else sorted(config.batcher.preferred_batch_size)
|
563 |
+
)
|
564 |
+
batch_sizes.append(config.max_batch_size)
|
565 |
+
batch_size = batch_sizes[bisect_left(batch_sizes, first_input.shape[0])]
|
566 |
+
|
567 |
+
new_inputs = {
|
568 |
+
input_name: np.repeat(
|
569 |
+
input_array,
|
570 |
+
np.concatenate([
|
571 |
+
np.ones(input_array.shape[0] - 1),
|
572 |
+
np.array([batch_size - input_array.shape[0] + 1]),
|
573 |
+
]).astype(np.int64),
|
574 |
+
axis=0,
|
575 |
+
)
|
576 |
+
for input_name, input_array in inputs.items()
|
577 |
+
}
|
578 |
+
|
579 |
+
kwargs.update(new_inputs)
|
580 |
+
return wrapped(*args, **kwargs)
|
581 |
+
|
582 |
+
|
583 |
+
_SPECIAL_KEYS = ["__triton_context__"]
|
584 |
+
|
585 |
+
|
586 |
+
def first_value(*keys: str, squeeze_single_values=True, strict: bool = True):
|
587 |
+
"""This decorator overwrites selected inputs with first element of the given input.
|
588 |
+
|
589 |
+
It can be used in two ways:
|
590 |
+
|
591 |
+
1. Wrapping a single request inference callable by chaining with @batch decorator:
|
592 |
+
```python
|
593 |
+
@batch
|
594 |
+
@first_value("temperature")
|
595 |
+
def infer_fn(**inputs):
|
596 |
+
...
|
597 |
+
return result
|
598 |
+
```
|
599 |
+
|
600 |
+
2. Wrapping a multiple requests inference callable:
|
601 |
+
```python
|
602 |
+
@first_value("temperature")
|
603 |
+
def infer_fn(requests):
|
604 |
+
...
|
605 |
+
return results
|
606 |
+
```
|
607 |
+
|
608 |
+
By default, the decorator squeezes single value arrays to scalars.
|
609 |
+
This behavior can be disabled by setting the `squeeze_single_values` flag to False.
|
610 |
+
|
611 |
+
By default, the decorator checks the equality of the values on selected values.
|
612 |
+
This behavior can be disabled by setting the `strict` flag to False.
|
613 |
+
|
614 |
+
Wrapper can only be used with models that support batching.
|
615 |
+
|
616 |
+
Args:
|
617 |
+
keys: The input keys selected for conversion.
|
618 |
+
squeeze_single_values: squeeze single value ND array to scalar values. Defaults to True.
|
619 |
+
strict: enable checking if all values on single selected input of request are equal. Defaults to True.
|
620 |
+
|
621 |
+
Raises:
|
622 |
+
PyTritonRuntimeError: if not all values on a single selected input of the request are equal
|
623 |
+
and the strict flag is set to True. Additionally, if the decorator is used with a model that doesn't support batching,
|
624 |
+
PyTritonBadParameterError: if any of the keys passed to the decorator are not allowed.
|
625 |
+
"""
|
626 |
+
if any(k in _SPECIAL_KEYS for k in keys):
|
627 |
+
not_allowed_keys = [key for key in keys if key in _SPECIAL_KEYS]
|
628 |
+
raise PyTritonBadParameterError(
|
629 |
+
f"The keys {', '.join(not_allowed_keys)} are not allowed as keys for @first_value wrapper. "
|
630 |
+
f"The set of not allowed keys are {', '.join(_SPECIAL_KEYS)}"
|
631 |
+
)
|
632 |
+
|
633 |
+
@wrapt.decorator
|
634 |
+
def wrapper(wrapped, instance, args, kwargs):
|
635 |
+
model_config = get_model_config(wrapped, instance)
|
636 |
+
if not model_config.batching:
|
637 |
+
raise PyTritonRuntimeError("The @first_value decorator can only be used with models that support batching.")
|
638 |
+
|
639 |
+
def _replace_inputs_with_first_value(_request):
|
640 |
+
for input_name in keys:
|
641 |
+
if input_name not in _request:
|
642 |
+
continue
|
643 |
+
|
644 |
+
values = _request[input_name]
|
645 |
+
if strict:
|
646 |
+
# do not set axis for arrays with strings (object) or models not supporting batching
|
647 |
+
axis_of_uniqueness = None if values.dtype == object else 0
|
648 |
+
unique_values = np.unique(values, axis=axis_of_uniqueness)
|
649 |
+
if len(unique_values) > 1:
|
650 |
+
raise PyTritonRuntimeError(
|
651 |
+
f"The values on the {input_name!r} input are not equal. "
|
652 |
+
"To proceed, either disable strict mode in @first_value wrapper "
|
653 |
+
"or ensure that the values always are consistent. "
|
654 |
+
f"The current values of {input_name!r} are {_request[input_name]!r}."
|
655 |
+
)
|
656 |
+
|
657 |
+
_first_value = values[0]
|
658 |
+
if (
|
659 |
+
squeeze_single_values
|
660 |
+
and not np.isscalar(_first_value)
|
661 |
+
and all(dim == 1 for dim in _first_value.shape)
|
662 |
+
):
|
663 |
+
_dim_0_array = np.squeeze(_first_value)
|
664 |
+
_first_value = _dim_0_array[()] # obtain scalar from 0-dim array with numpy type
|
665 |
+
|
666 |
+
_request[input_name] = _first_value
|
667 |
+
return _request
|
668 |
+
|
669 |
+
inputs_names = set(kwargs) - set(_SPECIAL_KEYS)
|
670 |
+
if inputs_names:
|
671 |
+
kwargs = _replace_inputs_with_first_value(kwargs)
|
672 |
+
return wrapped(*args, **kwargs)
|
673 |
+
else:
|
674 |
+
requests, *other_args = args
|
675 |
+
requests = [_replace_inputs_with_first_value(request) for request in requests]
|
676 |
+
return wrapped(requests, *other_args, **kwargs)
|
677 |
+
|
678 |
+
return wrapper
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/exceptions.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""PyTriton exceptions definition."""
|
15 |
+
|
16 |
+
|
17 |
+
class PyTritonError(Exception):
|
18 |
+
"""Generic PyTriton exception."""
|
19 |
+
|
20 |
+
def __init__(self, message: str):
|
21 |
+
"""Initialize exception with message.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
message: Error message
|
25 |
+
"""
|
26 |
+
self._message = message
|
27 |
+
|
28 |
+
def __str__(self) -> str:
|
29 |
+
"""Return exception as a string.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
Message content
|
33 |
+
"""
|
34 |
+
return self._message
|
35 |
+
|
36 |
+
@property
|
37 |
+
def message(self):
|
38 |
+
"""Get the exception message.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
The message associated with this exception, or None if no message.
|
42 |
+
|
43 |
+
"""
|
44 |
+
return self._message
|
45 |
+
|
46 |
+
|
47 |
+
class PyTritonValidationError(PyTritonError):
|
48 |
+
"""PyTriton configuration validation exception."""
|
49 |
+
|
50 |
+
pass
|
51 |
+
|
52 |
+
|
53 |
+
class PyTritonInvalidOperationError(PyTritonError):
|
54 |
+
"""PyTriton invalid operation exception."""
|
55 |
+
|
56 |
+
pass
|
57 |
+
|
58 |
+
|
59 |
+
class PyTritonBadParameterError(PyTritonError):
|
60 |
+
"""PyTriton invalid parameter exception."""
|
61 |
+
|
62 |
+
pass
|
63 |
+
|
64 |
+
|
65 |
+
class PyTritonModelConfigError(PyTritonError):
|
66 |
+
"""PyTriton invalid model config exception."""
|
67 |
+
|
68 |
+
pass
|
69 |
+
|
70 |
+
|
71 |
+
class PyTritonUnrecoverableError(PyTritonError):
|
72 |
+
"""Unrecoverable error occurred in inference callable, thus no further inferences possible."""
|
73 |
+
|
74 |
+
pass
|
75 |
+
|
76 |
+
|
77 |
+
class PyTritonRuntimeError(PyTritonError):
|
78 |
+
"""Raised when an error is detected that doesn’t fall in any of the other categories."""
|
79 |
+
|
80 |
+
pass
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# noqa: D104
|
15 |
+
from .common import DeviceKind, DynamicBatcher, QueuePolicy, TimeoutAction # noqa: F401
|
16 |
+
from .model_config import ModelConfig # noqa: F401
|
17 |
+
from .tensor import Tensor # noqa: F401
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/common.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Common structures for internal and external ModelConfig."""
|
15 |
+
|
16 |
+
import dataclasses
|
17 |
+
import enum
|
18 |
+
from typing import Dict, Optional
|
19 |
+
|
20 |
+
|
21 |
+
class DeviceKind(enum.Enum):
|
22 |
+
"""Device kind for model deployment.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
KIND_AUTO: Automatically select the device for model deployment.
|
26 |
+
KIND_CPU: Model is deployed on CPU.
|
27 |
+
KIND_GPU: Model is deployed on GPU.
|
28 |
+
"""
|
29 |
+
|
30 |
+
KIND_AUTO = "KIND_AUTO"
|
31 |
+
KIND_CPU = "KIND_CPU"
|
32 |
+
KIND_GPU = "KIND_GPU"
|
33 |
+
|
34 |
+
|
35 |
+
class TimeoutAction(enum.Enum):
|
36 |
+
"""Timeout action definition for timeout_action QueuePolicy field.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
REJECT: Reject the request and return error message accordingly.
|
40 |
+
DELAY: Delay the request until all other requests at the same (or higher) priority levels
|
41 |
+
that have not reached their timeouts are processed.
|
42 |
+
"""
|
43 |
+
|
44 |
+
REJECT = "REJECT"
|
45 |
+
DELAY = "DELAY"
|
46 |
+
|
47 |
+
|
48 |
+
@dataclasses.dataclass
|
49 |
+
class QueuePolicy:
|
50 |
+
"""Model queue policy configuration.
|
51 |
+
|
52 |
+
More in Triton Inference Server [documentation]
|
53 |
+
[documentation]: https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto#L1037
|
54 |
+
|
55 |
+
Args:
|
56 |
+
timeout_action: The action applied to timed-out request.
|
57 |
+
default_timeout_microseconds: The default timeout for every request, in microseconds.
|
58 |
+
allow_timeout_override: Whether individual request can override the default timeout value.
|
59 |
+
max_queue_size: The maximum queue size for holding requests.
|
60 |
+
"""
|
61 |
+
|
62 |
+
timeout_action: TimeoutAction = TimeoutAction.REJECT
|
63 |
+
default_timeout_microseconds: int = 0
|
64 |
+
allow_timeout_override: bool = False
|
65 |
+
max_queue_size: int = 0
|
66 |
+
|
67 |
+
|
68 |
+
@dataclasses.dataclass
|
69 |
+
class DynamicBatcher:
|
70 |
+
"""Dynamic batcher configuration.
|
71 |
+
|
72 |
+
More in Triton Inference Server [documentation]
|
73 |
+
[documentation]: https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto#L1104
|
74 |
+
|
75 |
+
Args:
|
76 |
+
max_queue_delay_microseconds: The maximum time, in microseconds, a request will be delayed in
|
77 |
+
the scheduling queue to wait for additional requests for batching.
|
78 |
+
preferred_batch_size: Preferred batch sizes for dynamic batching.
|
79 |
+
preserve_ordering : Should the dynamic batcher preserve the ordering of responses to
|
80 |
+
match the order of requests received by the scheduler.
|
81 |
+
priority_levels: The number of priority levels to be enabled for the model.
|
82 |
+
default_priority_level: The priority level used for requests that don't specify their priority.
|
83 |
+
default_queue_policy: The default queue policy used for requests.
|
84 |
+
priority_queue_policy: Specify the queue policy for the priority level.
|
85 |
+
"""
|
86 |
+
|
87 |
+
max_queue_delay_microseconds: int = 0
|
88 |
+
preferred_batch_size: Optional[list] = None
|
89 |
+
preserve_ordering: bool = False
|
90 |
+
priority_levels: int = 0
|
91 |
+
default_priority_level: int = 0
|
92 |
+
default_queue_policy: Optional[QueuePolicy] = None
|
93 |
+
priority_queue_policy: Optional[Dict[int, QueuePolicy]] = None
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/generator.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Generator class for creating Triton model config.
|
15 |
+
|
16 |
+
The class consume the TritonModelConfig object as a constructor argument and produce the Triton model config in form of
|
17 |
+
dict or file.
|
18 |
+
|
19 |
+
Typical usage example:
|
20 |
+
|
21 |
+
model_config = TritonModelConfig(model_name="simple")
|
22 |
+
generator = ModelConfigGenerator(model_config)
|
23 |
+
generator.to_file("/path/to/config.pbtxt")
|
24 |
+
"""
|
25 |
+
|
26 |
+
import json
|
27 |
+
import logging
|
28 |
+
import pathlib
|
29 |
+
from typing import Dict, Union
|
30 |
+
|
31 |
+
import numpy as np
|
32 |
+
from google.protobuf import json_format, text_format # pytype: disable=pyi-error
|
33 |
+
|
34 |
+
from pytriton.exceptions import PyTritonBadParameterError
|
35 |
+
|
36 |
+
from .triton_model_config import DynamicBatcher, TensorSpec, TritonModelConfig
|
37 |
+
|
38 |
+
try:
|
39 |
+
import tritonclient.grpc as grpc_client
|
40 |
+
from tritonclient import utils as client_utils # noqa: F401
|
41 |
+
except ImportError:
|
42 |
+
try:
|
43 |
+
import tritonclientutils as client_utils # noqa: F401
|
44 |
+
import tritongrpcclient as grpc_client
|
45 |
+
except ImportError:
|
46 |
+
client_utils = None
|
47 |
+
grpc_client = None
|
48 |
+
|
49 |
+
LOGGER = logging.getLogger(__name__)
|
50 |
+
|
51 |
+
|
52 |
+
class ModelConfigGenerator:
|
53 |
+
"""Generate the protobuf config from ModelConfig object."""
|
54 |
+
|
55 |
+
def __init__(self, config: TritonModelConfig):
|
56 |
+
"""Initialize generator.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
config: model config object
|
60 |
+
"""
|
61 |
+
self._config = config
|
62 |
+
|
63 |
+
def to_file(self, config_path: Union[str, pathlib.Path]) -> str:
|
64 |
+
"""Serialize ModelConfig to prototxt and save to config_path directory.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
config_path: path to configuration file
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
A string with generated model configuration
|
71 |
+
"""
|
72 |
+
from tritonclient.grpc import model_config_pb2 # pytype: disable=import-error
|
73 |
+
|
74 |
+
# https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto
|
75 |
+
model_config = self.get_config()
|
76 |
+
LOGGER.debug(f"Generated Triton config:\n{json.dumps(model_config, indent=4)}")
|
77 |
+
|
78 |
+
config_payload = json_format.ParseDict(model_config, model_config_pb2.ModelConfig())
|
79 |
+
LOGGER.debug(f"Generated Triton config payload:\n{config_payload}")
|
80 |
+
|
81 |
+
config_path = pathlib.Path(config_path)
|
82 |
+
config_path.parent.mkdir(parents=True, exist_ok=True)
|
83 |
+
|
84 |
+
model_config_bytes = text_format.MessageToBytes(config_payload)
|
85 |
+
|
86 |
+
# WAR: triton requires max_batch_size = 0 to be explicit written
|
87 |
+
# while this is not stored in payload during MessageToBytes
|
88 |
+
if model_config["max_batch_size"] == 0:
|
89 |
+
model_config_bytes += b"max_batch_size: 0\n"
|
90 |
+
|
91 |
+
with config_path.open("wb") as cfg:
|
92 |
+
cfg.write(model_config_bytes)
|
93 |
+
|
94 |
+
LOGGER.debug(f"Generated config stored in {config_path}")
|
95 |
+
|
96 |
+
return config_payload
|
97 |
+
|
98 |
+
def get_config(self) -> Dict:
|
99 |
+
"""Create a Triton model config from ModelConfig object.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
Dict with model configuration data
|
103 |
+
"""
|
104 |
+
model_config = {"name": self._config.model_name, "backend": self._config.backend}
|
105 |
+
self._set_batching(model_config)
|
106 |
+
self._set_model_signature(model_config)
|
107 |
+
self._set_instance_group(model_config)
|
108 |
+
self._set_model_transaction_policy(model_config)
|
109 |
+
self._set_backend_parameters(model_config)
|
110 |
+
self._set_response_cache(model_config)
|
111 |
+
return model_config
|
112 |
+
|
113 |
+
def _set_batching(self, model_config: Dict) -> None:
|
114 |
+
"""Configure batching for model deployment on Triton Inference Server.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
model_config: Dict with model config for Triton Inference Server
|
118 |
+
"""
|
119 |
+
if not self._config.batching:
|
120 |
+
model_config["max_batch_size"] = 0
|
121 |
+
LOGGER.debug("Batching for model is disabled. The `max_batch_size` field value set to 0.")
|
122 |
+
return
|
123 |
+
elif self._config.max_batch_size < 1:
|
124 |
+
raise PyTritonBadParameterError("The `max_batch_size` must be greater or equal to 1.")
|
125 |
+
|
126 |
+
model_config["max_batch_size"] = self._config.max_batch_size
|
127 |
+
if isinstance(self._config.batcher, DynamicBatcher):
|
128 |
+
dynamic_batching_config = {}
|
129 |
+
if self._config.batcher.max_queue_delay_microseconds > 0:
|
130 |
+
dynamic_batching_config["maxQueueDelayMicroseconds"] = int(
|
131 |
+
self._config.batcher.max_queue_delay_microseconds
|
132 |
+
)
|
133 |
+
|
134 |
+
if self._config.batcher.preferred_batch_size:
|
135 |
+
dynamic_batching_config["preferredBatchSize"] = [
|
136 |
+
int(bs) for bs in self._config.batcher.preferred_batch_size
|
137 |
+
]
|
138 |
+
|
139 |
+
if self._config.batcher.preserve_ordering:
|
140 |
+
dynamic_batching_config["preserveOrdering"] = self._config.batcher.preserve_ordering
|
141 |
+
|
142 |
+
if self._config.batcher.priority_levels:
|
143 |
+
dynamic_batching_config["priorityLevels"] = self._config.batcher.priority_levels
|
144 |
+
|
145 |
+
if self._config.batcher.default_priority_level:
|
146 |
+
if self._config.batcher.default_priority_level > self._config.batcher.priority_levels:
|
147 |
+
raise PyTritonBadParameterError(
|
148 |
+
"The `default_priority_level` must be between 1 and " f"{self._config.batcher.priority_levels}."
|
149 |
+
)
|
150 |
+
dynamic_batching_config["defaultPriorityLevel"] = self._config.batcher.default_priority_level
|
151 |
+
|
152 |
+
if self._config.batcher.default_queue_policy:
|
153 |
+
priority_queue_policy_config = {
|
154 |
+
"timeoutAction": self._config.batcher.default_queue_policy.timeout_action.value,
|
155 |
+
"defaultTimeoutMicroseconds": int(
|
156 |
+
self._config.batcher.default_queue_policy.default_timeout_microseconds
|
157 |
+
),
|
158 |
+
"allowTimeoutOverride": self._config.batcher.default_queue_policy.allow_timeout_override,
|
159 |
+
"maxQueueSize": int(self._config.batcher.default_queue_policy.max_queue_size),
|
160 |
+
}
|
161 |
+
dynamic_batching_config["defaultQueuePolicy"] = priority_queue_policy_config
|
162 |
+
|
163 |
+
if self._config.batcher.priority_queue_policy:
|
164 |
+
if not self._config.batcher.priority_levels:
|
165 |
+
raise PyTritonBadParameterError(
|
166 |
+
"Provide the `priority_levels` if you want to define `priority_queue_policy` "
|
167 |
+
"for Dynamic Batching."
|
168 |
+
)
|
169 |
+
|
170 |
+
priority_queue_policy_config = {}
|
171 |
+
for priority, queue_policy in self._config.batcher.priority_queue_policy.items():
|
172 |
+
if priority < 0 or priority > self._config.batcher.priority_levels:
|
173 |
+
raise PyTritonBadParameterError(
|
174 |
+
f"Invalid `priority`={priority} provided. The value must be between "
|
175 |
+
f"1 and {self._config.batcher.priority_levels}."
|
176 |
+
)
|
177 |
+
|
178 |
+
priority_queue_policy_config[priority] = {
|
179 |
+
"timeoutAction": queue_policy.timeout_action.value,
|
180 |
+
"defaultTimeoutMicroseconds": int(queue_policy.default_timeout_microseconds),
|
181 |
+
"allowTimeoutOverride": queue_policy.allow_timeout_override,
|
182 |
+
"maxQueueSize": int(queue_policy.max_queue_size),
|
183 |
+
}
|
184 |
+
|
185 |
+
dynamic_batching_config["priorityQueuePolicy"] = priority_queue_policy_config
|
186 |
+
|
187 |
+
model_config["dynamic_batching"] = dynamic_batching_config
|
188 |
+
else:
|
189 |
+
LOGGER.debug("Default batching used")
|
190 |
+
|
191 |
+
def _set_instance_group(self, model_config: Dict) -> None:
|
192 |
+
"""Configure instance group for model deployment on Triton Inference Server.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
model_config: Dict with model config for Triton Inference Server
|
196 |
+
"""
|
197 |
+
instance_groups = []
|
198 |
+
for device_kind, count in self._config.instance_group.items():
|
199 |
+
instance_groups.append({
|
200 |
+
"count": count,
|
201 |
+
"kind": device_kind.value,
|
202 |
+
})
|
203 |
+
|
204 |
+
if instance_groups:
|
205 |
+
model_config["instance_group"] = instance_groups
|
206 |
+
|
207 |
+
def _set_model_transaction_policy(self, model_config: Dict) -> None:
|
208 |
+
"""Configure model transaction policy for model deployment on Triton Inference Server.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
model_config: Dict with model config for Triton Inference Server
|
212 |
+
"""
|
213 |
+
if self._config.decoupled:
|
214 |
+
model_config["model_transaction_policy"] = {"decoupled": True}
|
215 |
+
|
216 |
+
def _set_backend_parameters(self, model_config: Dict) -> None:
|
217 |
+
"""Configure backend parameters for model deployment on Triton Inference Server.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
model_config: Dict with model config for Triton Inference Server
|
221 |
+
"""
|
222 |
+
parameters = {}
|
223 |
+
for key, value in self._config.backend_parameters.items():
|
224 |
+
parameters[key] = {
|
225 |
+
"string_value": str(value),
|
226 |
+
}
|
227 |
+
|
228 |
+
if parameters:
|
229 |
+
model_config["parameters"] = parameters
|
230 |
+
|
231 |
+
def _set_model_signature(self, model_config: Dict) -> None:
|
232 |
+
"""Configure model signature for model deployment on Triton Inference Server.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
model_config: Dict with model config for Triton Inference Server
|
236 |
+
|
237 |
+
"""
|
238 |
+
|
239 |
+
def _rewrite_io_spec(spec_: TensorSpec) -> Dict:
|
240 |
+
if spec_.dtype in [np.object_, object, bytes, np.bytes_]:
|
241 |
+
dtype = "TYPE_STRING"
|
242 |
+
else:
|
243 |
+
# pytype: disable=attribute-error
|
244 |
+
dtype = spec_.dtype().dtype
|
245 |
+
# pytype: enable=attribute-error
|
246 |
+
dtype = f"TYPE_{client_utils.np_to_triton_dtype(dtype)}"
|
247 |
+
|
248 |
+
dims = spec_.shape
|
249 |
+
|
250 |
+
item = {
|
251 |
+
"name": spec_.name,
|
252 |
+
"dims": list(dims),
|
253 |
+
"data_type": dtype,
|
254 |
+
}
|
255 |
+
|
256 |
+
if spec_.optional:
|
257 |
+
item["optional"] = True
|
258 |
+
|
259 |
+
return item
|
260 |
+
|
261 |
+
if self._config.inputs:
|
262 |
+
model_config["input"] = [_rewrite_io_spec(spec) for spec in self._config.inputs]
|
263 |
+
|
264 |
+
if self._config.outputs:
|
265 |
+
outputs = [_rewrite_io_spec(spec) for spec in self._config.outputs]
|
266 |
+
if outputs:
|
267 |
+
optional_outputs = [o for o in outputs if o.get("optional")]
|
268 |
+
if optional_outputs:
|
269 |
+
raise PyTritonBadParameterError(
|
270 |
+
"Optional flag for outputs is not supported. "
|
271 |
+
f"Outputs marked as optional: {', '.join([o['name'] for o in optional_outputs])}."
|
272 |
+
)
|
273 |
+
model_config["output"] = outputs
|
274 |
+
|
275 |
+
def _set_response_cache(self, model_config: Dict):
|
276 |
+
"""Configure response cache for model.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
model_config: Dictionary where configuration is attached.
|
280 |
+
"""
|
281 |
+
if self._config.response_cache:
|
282 |
+
model_config["response_cache"] = {
|
283 |
+
"enable": self._config.response_cache.enable,
|
284 |
+
}
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/model_config.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Model configurations.
|
15 |
+
|
16 |
+
Dataclasses with specialized deployment paths for models on Triton. The purpose of this module is to provide clear options
|
17 |
+
to configure models of given types.
|
18 |
+
|
19 |
+
The dataclasses are exposed in the user API.
|
20 |
+
"""
|
21 |
+
|
22 |
+
import dataclasses
|
23 |
+
|
24 |
+
from pytriton.model_config import DynamicBatcher
|
25 |
+
|
26 |
+
|
27 |
+
@dataclasses.dataclass
|
28 |
+
class ModelConfig:
|
29 |
+
"""Additional model configuration for running model through Triton Inference Server.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
batching: Flag to enable/disable batching for model.
|
33 |
+
max_batch_size: The maximal batch size that would be handled by model.
|
34 |
+
batcher: Configuration of Dynamic Batching for the model.
|
35 |
+
response_cache: Flag to enable/disable response cache for the model
|
36 |
+
decoupled: Flag to enable/disable decoupled from requests execution
|
37 |
+
"""
|
38 |
+
|
39 |
+
batching: bool = True
|
40 |
+
max_batch_size: int = 4
|
41 |
+
batcher: DynamicBatcher = dataclasses.field(default_factory=DynamicBatcher)
|
42 |
+
response_cache: bool = False
|
43 |
+
decoupled: bool = False
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/parser.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""ModelConfigParser class definition.
|
15 |
+
|
16 |
+
Provide functionality to parse the Triton model configuration stored in file or form of dictionary into the object of
|
17 |
+
class ModelConfig.
|
18 |
+
|
19 |
+
Examples of use:
|
20 |
+
|
21 |
+
# Parse from dict
|
22 |
+
model_config = ModelConfigParser.from_dict(model_config_dict)
|
23 |
+
|
24 |
+
# Parse from file
|
25 |
+
model_config = ModelConfigParser.from_file("/path/to/config.pbtxt")
|
26 |
+
|
27 |
+
"""
|
28 |
+
|
29 |
+
import json
|
30 |
+
import logging
|
31 |
+
import pathlib
|
32 |
+
from typing import Dict
|
33 |
+
|
34 |
+
import numpy as np
|
35 |
+
from google.protobuf import json_format, text_format # pytype: disable=pyi-error
|
36 |
+
|
37 |
+
from pytriton.exceptions import PyTritonModelConfigError
|
38 |
+
|
39 |
+
from .common import QueuePolicy, TimeoutAction
|
40 |
+
from .triton_model_config import DeviceKind, DynamicBatcher, ResponseCache, TensorSpec, TritonModelConfig
|
41 |
+
|
42 |
+
try:
|
43 |
+
import tritonclient.grpc as grpc_client
|
44 |
+
from tritonclient import utils as client_utils # noqa: F401
|
45 |
+
except ImportError:
|
46 |
+
try:
|
47 |
+
import tritonclientutils as client_utils # noqa: F401
|
48 |
+
import tritongrpcclient as grpc_client
|
49 |
+
except ImportError:
|
50 |
+
client_utils = None
|
51 |
+
grpc_client = None
|
52 |
+
|
53 |
+
LOGGER = logging.getLogger(__name__)
|
54 |
+
|
55 |
+
|
56 |
+
class ModelConfigParser:
|
57 |
+
"""Provide functionality to parse dictionary or file to ModelConfig object."""
|
58 |
+
|
59 |
+
@classmethod
|
60 |
+
def from_dict(cls, model_config_dict: Dict) -> TritonModelConfig:
|
61 |
+
"""Create ModelConfig from configuration stored in dictionary.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
model_config_dict: Dictionary with model config
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
A ModelConfig object with data parsed from the dictionary
|
68 |
+
"""
|
69 |
+
LOGGER.debug(f"Parsing Triton config model from dict: \n{json.dumps(model_config_dict, indent=4)}")
|
70 |
+
|
71 |
+
if model_config_dict.get("max_batch_size", 0) > 0:
|
72 |
+
batching = True
|
73 |
+
else:
|
74 |
+
batching = False
|
75 |
+
|
76 |
+
dynamic_batcher_config = model_config_dict.get("dynamic_batching")
|
77 |
+
if dynamic_batcher_config is not None:
|
78 |
+
batcher = cls._parse_dynamic_batching(dynamic_batcher_config)
|
79 |
+
else:
|
80 |
+
batcher = None
|
81 |
+
|
82 |
+
instance_group = {
|
83 |
+
DeviceKind(entry["kind"]): entry.get("count") for entry in model_config_dict.get("instance_group", [])
|
84 |
+
}
|
85 |
+
|
86 |
+
decoupled = model_config_dict.get("model_transaction_policy", {}).get("decoupled", False)
|
87 |
+
|
88 |
+
backend_parameters_config = model_config_dict.get("parameters", [])
|
89 |
+
if isinstance(backend_parameters_config, list):
|
90 |
+
# If the backend_parameters_config is a list of strings, use them as keys with empty values
|
91 |
+
LOGGER.debug(f"backend_parameters_config is a list of strings: {backend_parameters_config}")
|
92 |
+
backend_parameters = {name: "" for name in backend_parameters_config}
|
93 |
+
elif isinstance(backend_parameters_config, dict):
|
94 |
+
# If the backend_parameters_config is a dictionary, use the key and "string_value" fields as key-value pairs
|
95 |
+
LOGGER.debug(f"backend_parameters_config is a dictionary: {backend_parameters_config}")
|
96 |
+
backend_parameters = {
|
97 |
+
name: backend_parameters_config[name]["string_value"] for name in backend_parameters_config
|
98 |
+
}
|
99 |
+
else:
|
100 |
+
# Otherwise, raise an error
|
101 |
+
LOGGER.error(
|
102 |
+
f"Invalid type {type(backend_parameters_config)} for backend_parameters_config: {backend_parameters_config}"
|
103 |
+
)
|
104 |
+
raise TypeError(f"Invalid type for backend_parameters_config: {type(backend_parameters_config)}")
|
105 |
+
|
106 |
+
inputs = [
|
107 |
+
cls.rewrite_io_spec(item, "input", idx) for idx, item in enumerate(model_config_dict.get("input", []))
|
108 |
+
] or None
|
109 |
+
outputs = [
|
110 |
+
cls.rewrite_io_spec(item, "output", idx) for idx, item in enumerate(model_config_dict.get("output", []))
|
111 |
+
] or None
|
112 |
+
|
113 |
+
response_cache_config = model_config_dict.get("response_cache")
|
114 |
+
if response_cache_config:
|
115 |
+
response_cache = cls._parse_response_cache(response_cache_config)
|
116 |
+
else:
|
117 |
+
response_cache = None
|
118 |
+
|
119 |
+
return TritonModelConfig(
|
120 |
+
model_name=model_config_dict["name"],
|
121 |
+
batching=batching,
|
122 |
+
max_batch_size=model_config_dict.get("max_batch_size", 0),
|
123 |
+
batcher=batcher,
|
124 |
+
inputs=inputs,
|
125 |
+
outputs=outputs,
|
126 |
+
instance_group=instance_group,
|
127 |
+
decoupled=decoupled,
|
128 |
+
backend_parameters=backend_parameters,
|
129 |
+
response_cache=response_cache,
|
130 |
+
)
|
131 |
+
|
132 |
+
@classmethod
|
133 |
+
def from_file(cls, *, config_path: pathlib.Path) -> TritonModelConfig:
|
134 |
+
"""Create ModelConfig from configuration stored in file.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
config_path: location of file with model config
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
A ModelConfig object with data parsed from the file
|
141 |
+
"""
|
142 |
+
from tritonclient.grpc import model_config_pb2 # pytype: disable=import-error
|
143 |
+
|
144 |
+
LOGGER.debug(f"Parsing Triton config model config_path={config_path}")
|
145 |
+
|
146 |
+
with config_path.open("r") as config_file:
|
147 |
+
payload = config_file.read()
|
148 |
+
model_config_proto = text_format.Parse(payload, model_config_pb2.ModelConfig())
|
149 |
+
|
150 |
+
model_config_dict = json_format.MessageToDict(model_config_proto, preserving_proto_field_name=True)
|
151 |
+
return ModelConfigParser.from_dict(model_config_dict=model_config_dict)
|
152 |
+
|
153 |
+
@classmethod
|
154 |
+
def rewrite_io_spec(cls, item: Dict, io_type: str, idx: int) -> TensorSpec:
|
155 |
+
"""Rewrite the IO Spec provided in form of dictionary to TensorSpec.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
item: IO data for input
|
159 |
+
io_type: Type of the IO (input or output)
|
160 |
+
idx: Index of IO
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
TensorSpec with input or output data
|
164 |
+
"""
|
165 |
+
name = item.get("name")
|
166 |
+
if not name:
|
167 |
+
raise PyTritonModelConfigError(f"Name for {io_type} at index {idx} not provided.")
|
168 |
+
|
169 |
+
data_type = item.get("data_type")
|
170 |
+
if not data_type:
|
171 |
+
raise PyTritonModelConfigError(f"Data type for {io_type} with name `{name}` not defined.")
|
172 |
+
|
173 |
+
data_type_val = data_type.split("_")
|
174 |
+
if len(data_type_val) != 2:
|
175 |
+
raise PyTritonModelConfigError(
|
176 |
+
f"Invalid data type `{data_type}` for {io_type} with name `{name}` not defined. "
|
177 |
+
"The expected name is TYPE_{type}."
|
178 |
+
)
|
179 |
+
|
180 |
+
data_type = data_type_val[1]
|
181 |
+
if data_type == "STRING":
|
182 |
+
dtype = np.bytes_
|
183 |
+
else:
|
184 |
+
dtype = client_utils.triton_to_np_dtype(data_type)
|
185 |
+
if dtype is None:
|
186 |
+
raise PyTritonModelConfigError(f"Unsupported data type `{data_type}` for {io_type} with name `{name}`")
|
187 |
+
|
188 |
+
dtype = np.dtype("bool") if dtype is bool else dtype
|
189 |
+
|
190 |
+
dims = item.get("dims", [])
|
191 |
+
if not dims:
|
192 |
+
raise PyTritonModelConfigError(f"Dimension for {io_type} with name `{name}` not defined.")
|
193 |
+
|
194 |
+
shape = tuple(int(s) for s in dims)
|
195 |
+
|
196 |
+
optional = item.get("optional", False)
|
197 |
+
return TensorSpec(name=item["name"], shape=shape, dtype=dtype, optional=optional)
|
198 |
+
|
199 |
+
@classmethod
|
200 |
+
def _parse_dynamic_batching(cls, dynamic_batching_config: Dict) -> DynamicBatcher:
|
201 |
+
"""Parse config to create DynamicBatcher object.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
dynamic_batching_config: Configuration of dynamic batcher from config
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
DynamicBatcher object with configuration
|
208 |
+
"""
|
209 |
+
default_queue_policy = None
|
210 |
+
default_queue_policy_config = dynamic_batching_config.get("default_queue_policy")
|
211 |
+
if default_queue_policy_config:
|
212 |
+
default_queue_policy = QueuePolicy(
|
213 |
+
timeout_action=TimeoutAction(
|
214 |
+
default_queue_policy_config.get("timeout_action", TimeoutAction.REJECT.value)
|
215 |
+
),
|
216 |
+
default_timeout_microseconds=int(default_queue_policy_config.get("default_timeout_microseconds", 0)),
|
217 |
+
allow_timeout_override=bool(default_queue_policy_config.get("allow_timeout_override", False)),
|
218 |
+
max_queue_size=int(default_queue_policy_config.get("max_queue_size", 0)),
|
219 |
+
)
|
220 |
+
|
221 |
+
priority_queue_policy = None
|
222 |
+
priority_queue_policy_config = dynamic_batching_config.get("priority_queue_policy")
|
223 |
+
if priority_queue_policy_config:
|
224 |
+
priority_queue_policy = {}
|
225 |
+
for priority, queue_policy_config in priority_queue_policy_config.items():
|
226 |
+
queue_policy = QueuePolicy(
|
227 |
+
timeout_action=TimeoutAction(queue_policy_config.get("timeout_action", TimeoutAction.REJECT.value)),
|
228 |
+
default_timeout_microseconds=int(queue_policy_config.get("default_timeout_microseconds", 0)),
|
229 |
+
allow_timeout_override=bool(queue_policy_config.get("allow_timeout_override", False)),
|
230 |
+
max_queue_size=int(queue_policy_config.get("max_queue_size", 0)),
|
231 |
+
)
|
232 |
+
priority_queue_policy[int(priority)] = queue_policy
|
233 |
+
|
234 |
+
batcher = DynamicBatcher(
|
235 |
+
preferred_batch_size=dynamic_batching_config.get("preferred_batch_size"),
|
236 |
+
max_queue_delay_microseconds=int(dynamic_batching_config.get("max_queue_delay_microseconds", 0)),
|
237 |
+
preserve_ordering=bool(dynamic_batching_config.get("preserve_ordering", False)),
|
238 |
+
priority_levels=int(dynamic_batching_config.get("priority_levels", 0)),
|
239 |
+
default_priority_level=int(dynamic_batching_config.get("default_priority_level", 0)),
|
240 |
+
default_queue_policy=default_queue_policy,
|
241 |
+
priority_queue_policy=priority_queue_policy,
|
242 |
+
)
|
243 |
+
return batcher
|
244 |
+
|
245 |
+
@classmethod
|
246 |
+
def _parse_response_cache(cls, response_cache_config: Dict) -> ResponseCache:
|
247 |
+
"""Parse config for response cache.
|
248 |
+
|
249 |
+
Args:
|
250 |
+
response_cache_config: response cache configuration
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
ResponseCache object with configuration
|
254 |
+
"""
|
255 |
+
response_cache = ResponseCache(
|
256 |
+
enable=bool(response_cache_config["enable"]),
|
257 |
+
)
|
258 |
+
return response_cache
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/tensor.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Tensor object definition.
|
15 |
+
|
16 |
+
Describe the model input or output.
|
17 |
+
|
18 |
+
Examples of use:
|
19 |
+
|
20 |
+
# Minimal constructors
|
21 |
+
tensor = Tensor(dtype=np.bytes_, shape=(-1,))
|
22 |
+
tensor = Tensor(dtype=np.float32, shape=(-1,))
|
23 |
+
|
24 |
+
# Type definition from existing object
|
25 |
+
a = np.array([1, 2, 3, 4])
|
26 |
+
tensor = Tensor(dtype=a.dtype, shape=(-1,))
|
27 |
+
|
28 |
+
# Custom name
|
29 |
+
tensor = Tensor(name="data", dtype=np.float32, shape=(16,))
|
30 |
+
"""
|
31 |
+
|
32 |
+
import dataclasses
|
33 |
+
from typing import Optional, Type, Union
|
34 |
+
|
35 |
+
import numpy as np
|
36 |
+
|
37 |
+
|
38 |
+
@dataclasses.dataclass(frozen=True)
|
39 |
+
class Tensor:
|
40 |
+
"""Model input and output definition for Triton deployment.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
shape: Shape of the input/output tensor.
|
44 |
+
dtype: Data type of the input/output tensor.
|
45 |
+
name: Name of the input/output of model.
|
46 |
+
optional: Flag to mark if input is optional.
|
47 |
+
"""
|
48 |
+
|
49 |
+
shape: tuple
|
50 |
+
dtype: Union[np.dtype, Type[np.dtype], Type[object]]
|
51 |
+
name: Optional[str] = None
|
52 |
+
optional: Optional[bool] = False
|
53 |
+
|
54 |
+
def __post_init__(self):
|
55 |
+
"""Override object values on post init or field override."""
|
56 |
+
if isinstance(self.dtype, np.dtype):
|
57 |
+
object.__setattr__(self, "dtype", self.dtype.type) # pytype: disable=attribute-error
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/triton_model_config.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""ModelConfig related objects."""
|
15 |
+
|
16 |
+
import dataclasses
|
17 |
+
from typing import Dict, Optional, Sequence, Type, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
from .common import DeviceKind, DynamicBatcher
|
22 |
+
|
23 |
+
|
24 |
+
@dataclasses.dataclass
|
25 |
+
class ResponseCache:
|
26 |
+
"""Model response cache configuration.
|
27 |
+
|
28 |
+
More in Triton Inference Server [documentation]
|
29 |
+
[documentation]: https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto#L1765
|
30 |
+
"""
|
31 |
+
|
32 |
+
enable: bool
|
33 |
+
|
34 |
+
|
35 |
+
@dataclasses.dataclass
|
36 |
+
class TensorSpec:
|
37 |
+
"""Stores specification of single tensor. This includes name, shape and dtype."""
|
38 |
+
|
39 |
+
name: str
|
40 |
+
shape: tuple
|
41 |
+
dtype: Union[Type[np.dtype], Type[object]]
|
42 |
+
optional: Optional[bool] = False
|
43 |
+
|
44 |
+
|
45 |
+
@dataclasses.dataclass
|
46 |
+
class TritonModelConfig:
|
47 |
+
"""Triton Model Config dataclass for simplification and specialization of protobuf config generation.
|
48 |
+
|
49 |
+
More in Triton Inference Server [documentation]
|
50 |
+
[documentation]: https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto
|
51 |
+
"""
|
52 |
+
|
53 |
+
model_name: str
|
54 |
+
model_version: int = 1
|
55 |
+
max_batch_size: int = 4
|
56 |
+
batching: bool = True
|
57 |
+
batcher: Optional[DynamicBatcher] = None
|
58 |
+
instance_group: Dict[DeviceKind, Optional[int]] = dataclasses.field(default_factory=lambda: {})
|
59 |
+
decoupled: bool = False
|
60 |
+
backend_parameters: Dict[str, str] = dataclasses.field(default_factory=lambda: {})
|
61 |
+
inputs: Optional[Sequence[TensorSpec]] = None
|
62 |
+
outputs: Optional[Sequence[TensorSpec]] = None
|
63 |
+
response_cache: Optional[ResponseCache] = None
|
64 |
+
|
65 |
+
@property
|
66 |
+
def backend(self) -> str:
|
67 |
+
"""Return backend parameter."""
|
68 |
+
return "python"
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/models/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# noqa: D104
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/models/manager.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""ModelManager class.
|
15 |
+
|
16 |
+
The ModelManager is responsible for maintaining the models that has to be server on Triton Inference Server.
|
17 |
+
|
18 |
+
Examples of use:
|
19 |
+
manager = ModelManager(model_repository)
|
20 |
+
manager.add_model(model)
|
21 |
+
|
22 |
+
manager.create_models()
|
23 |
+
"""
|
24 |
+
|
25 |
+
import contextlib
|
26 |
+
import json
|
27 |
+
import logging
|
28 |
+
import pathlib
|
29 |
+
import socket
|
30 |
+
from typing import Dict, Iterable, Optional, Tuple
|
31 |
+
|
32 |
+
from tritonclient.grpc import InferenceServerException
|
33 |
+
|
34 |
+
from pytriton.client import ModelClient
|
35 |
+
from pytriton.client.utils import create_client_from_url, wait_for_server_ready
|
36 |
+
from pytriton.constants import CREATE_TRITON_CLIENT_TIMEOUT_S, DEFAULT_TRITON_STARTUP_TIMEOUT_S
|
37 |
+
from pytriton.exceptions import PyTritonInvalidOperationError
|
38 |
+
from pytriton.models.model import Model
|
39 |
+
|
40 |
+
LOGGER = logging.getLogger(__name__)
|
41 |
+
|
42 |
+
|
43 |
+
class ModelManager:
|
44 |
+
"""ModelManager class for maintaining Triton models."""
|
45 |
+
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
triton_url: str,
|
49 |
+
model_store_path: Optional[pathlib.Path] = None,
|
50 |
+
):
|
51 |
+
"""Create ModelManager object.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
triton_url: Triton server URL
|
55 |
+
model_store_path: Path to local model store
|
56 |
+
"""
|
57 |
+
self._triton_url = triton_url
|
58 |
+
self._models: Dict[Tuple[str, int], Model] = {}
|
59 |
+
self._model_store_path = model_store_path
|
60 |
+
|
61 |
+
@property
|
62 |
+
def models(self) -> Iterable[Model]:
|
63 |
+
"""List models added to manage.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
List with models added to ModelManager.
|
67 |
+
"""
|
68 |
+
return self._models.values()
|
69 |
+
|
70 |
+
def add_model(self, model: Model, load_model: bool = False) -> None:
|
71 |
+
"""Add model to manage.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
model: Model instance
|
75 |
+
load_model: If True, model will be loaded to Triton server.
|
76 |
+
"""
|
77 |
+
key = self._format_key(model)
|
78 |
+
if key in self._models:
|
79 |
+
raise PyTritonInvalidOperationError("Cannot add model with the same name twice.")
|
80 |
+
|
81 |
+
LOGGER.debug(f"Adding {model.model_name} ({model.model_version}) to registry under {key}.")
|
82 |
+
self._models[key] = model
|
83 |
+
|
84 |
+
_is_model_store_local = self._model_store_path is not None
|
85 |
+
if _is_model_store_local:
|
86 |
+
model.generate_model(self._model_store_path)
|
87 |
+
|
88 |
+
if load_model:
|
89 |
+
self._load_model(model, _is_model_store_local)
|
90 |
+
model.setup()
|
91 |
+
|
92 |
+
def load_models(self) -> None:
|
93 |
+
"""Load bound models to Triton server and setup loaded models."""
|
94 |
+
for model in self._models.values():
|
95 |
+
if not model.is_alive():
|
96 |
+
self._load_model(model)
|
97 |
+
model.setup()
|
98 |
+
|
99 |
+
def setup_models(self) -> None:
|
100 |
+
"""Setup loaded models."""
|
101 |
+
for model in self._models.values():
|
102 |
+
if not model.is_alive():
|
103 |
+
model.setup()
|
104 |
+
|
105 |
+
def clean(self) -> None:
|
106 |
+
"""Clean the model and internal registry."""
|
107 |
+
with contextlib.closing(
|
108 |
+
create_client_from_url(self._triton_url, network_timeout_s=CREATE_TRITON_CLIENT_TIMEOUT_S)
|
109 |
+
) as client:
|
110 |
+
server_live = False
|
111 |
+
try:
|
112 |
+
server_live = client.is_server_live()
|
113 |
+
# TimeoutError and ConnectionRefusedError are derived from OSError so they are redundant here
|
114 |
+
# OSError is raised from gevent/_socketcommon.py:590 sometimes, when server is not ready
|
115 |
+
except (socket.timeout, OSError, InferenceServerException):
|
116 |
+
pass
|
117 |
+
except Exception as ex:
|
118 |
+
LOGGER.error(f"Unexpected exception during server live check: {ex}")
|
119 |
+
raise ex
|
120 |
+
|
121 |
+
for name, model in self._models.items():
|
122 |
+
LOGGER.debug(f"Clean model {name}.")
|
123 |
+
model.clean()
|
124 |
+
if server_live:
|
125 |
+
client.unload_model(model.model_name)
|
126 |
+
|
127 |
+
if server_live:
|
128 |
+
# after unload there is a short period of time when server is not ready
|
129 |
+
wait_for_server_ready(client, timeout_s=DEFAULT_TRITON_STARTUP_TIMEOUT_S)
|
130 |
+
|
131 |
+
self._models.clear()
|
132 |
+
|
133 |
+
def _format_key(self, model: Model) -> Tuple[str, int]:
|
134 |
+
key = (model.model_name.lower(), model.model_version)
|
135 |
+
return key
|
136 |
+
|
137 |
+
def _load_model(self, model: Model, local_model_store=False):
|
138 |
+
"""Prepare model config and required files dict and load model to Triton server."""
|
139 |
+
LOGGER.debug(f"Creating model {model.model_name} with version {model.model_version}.")
|
140 |
+
config = None if local_model_store else json.dumps(model.get_model_config())
|
141 |
+
files = None if local_model_store else model.get_proxy_model_files()
|
142 |
+
with ModelClient(
|
143 |
+
url=self._triton_url, model_name=model.model_name, model_version=str(model.model_version)
|
144 |
+
) as client:
|
145 |
+
client.wait_for_server(timeout_s=DEFAULT_TRITON_STARTUP_TIMEOUT_S)
|
146 |
+
client.load_model(config=config, files=files)
|
147 |
+
LOGGER.debug("Done.")
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/models/model.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Model base class."""
|
15 |
+
|
16 |
+
import base64
|
17 |
+
import copy
|
18 |
+
import enum
|
19 |
+
import json
|
20 |
+
import logging
|
21 |
+
import os
|
22 |
+
import pathlib
|
23 |
+
import shutil
|
24 |
+
import threading
|
25 |
+
import typing
|
26 |
+
from typing import Callable, List, Optional, Sequence, Union
|
27 |
+
|
28 |
+
from pytriton.decorators import TritonContext
|
29 |
+
from pytriton.exceptions import PyTritonValidationError
|
30 |
+
from pytriton.model_config.generator import ModelConfigGenerator
|
31 |
+
from pytriton.model_config.model_config import ModelConfig
|
32 |
+
from pytriton.model_config.tensor import Tensor
|
33 |
+
from pytriton.model_config.triton_model_config import DeviceKind, ResponseCache, TensorSpec, TritonModelConfig
|
34 |
+
from pytriton.proxy.communication import get_config_from_handshake_server
|
35 |
+
from pytriton.proxy.data import Base64SerializerDeserializer, TensorStoreSerializerDeserializer
|
36 |
+
from pytriton.proxy.inference import InferenceHandler, InferenceHandlerEvent, RequestsResponsesConnector
|
37 |
+
from pytriton.proxy.validators import TritonResultsValidator
|
38 |
+
from pytriton.utils.workspace import Workspace
|
39 |
+
|
40 |
+
LOGGER = logging.getLogger(__name__)
|
41 |
+
|
42 |
+
|
43 |
+
class ModelEvent(enum.Enum):
|
44 |
+
"""Represents model event."""
|
45 |
+
|
46 |
+
RUNTIME_TERMINATING = "runtime-terminating"
|
47 |
+
RUNTIME_TERMINATED = "runtime-terminated"
|
48 |
+
|
49 |
+
|
50 |
+
ModelEventsHandler = typing.Callable[["Model", ModelEvent, typing.Optional[typing.Any]], None]
|
51 |
+
|
52 |
+
|
53 |
+
def _inject_triton_context(triton_context: TritonContext, model_callable: Callable) -> Callable:
|
54 |
+
"""Inject triton context into callable.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
triton_context: Triton context
|
58 |
+
model_callable: Callable to inject triton context
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
Callable with injected triton context
|
62 |
+
"""
|
63 |
+
if hasattr(model_callable, "__self__"):
|
64 |
+
model_callable.__self__.__triton_context__ = triton_context
|
65 |
+
else:
|
66 |
+
model_callable.__triton_context__ = triton_context
|
67 |
+
return model_callable
|
68 |
+
|
69 |
+
|
70 |
+
class Model:
|
71 |
+
"""Model definition."""
|
72 |
+
|
73 |
+
SCRIPT_FILES_TO_COPY = ["communication.py", "data.py", "model.py", "types.py", "telemetry.py"]
|
74 |
+
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
model_name: str,
|
78 |
+
model_version: int,
|
79 |
+
inference_fn: Union[Callable, Sequence[Callable]],
|
80 |
+
inputs: Sequence[Tensor],
|
81 |
+
outputs: Sequence[Tensor],
|
82 |
+
config: ModelConfig,
|
83 |
+
workspace: Workspace,
|
84 |
+
triton_context: TritonContext,
|
85 |
+
strict: bool,
|
86 |
+
trace_config: Optional[List[str]] = None,
|
87 |
+
):
|
88 |
+
"""Create Python model with required data.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
model_name: Model name
|
92 |
+
model_version: Model version
|
93 |
+
inference_fn: Inference handler (function or lambda)
|
94 |
+
inputs: Model inputs definition
|
95 |
+
outputs: Model outputs definition
|
96 |
+
config: model configuration parameters
|
97 |
+
workspace: workspace for storing artifacts
|
98 |
+
triton_context: Triton context
|
99 |
+
strict: Enable strict validation of model outputs
|
100 |
+
trace_config: List of trace config parameters
|
101 |
+
|
102 |
+
Raises:
|
103 |
+
PyTritonValidationError if one or more of provided values are incorrect.
|
104 |
+
"""
|
105 |
+
self.triton_context = triton_context
|
106 |
+
self.model_name = model_name
|
107 |
+
self.model_version = model_version
|
108 |
+
self._inference_handlers_lock = threading.Lock()
|
109 |
+
self._inference_handlers = []
|
110 |
+
self._requests_respones_connectors = []
|
111 |
+
self._observers_lock = threading.Lock()
|
112 |
+
self._strict = strict
|
113 |
+
self._trace_config = trace_config
|
114 |
+
|
115 |
+
self.infer_functions = [inference_fn] if isinstance(inference_fn, Callable) else inference_fn
|
116 |
+
if not isinstance(self.infer_functions, (Sequence, Callable)):
|
117 |
+
raise PyTritonValidationError("inference_fn has to be either callable or sequence of callables")
|
118 |
+
|
119 |
+
self.inputs = inputs
|
120 |
+
self.outputs = outputs
|
121 |
+
|
122 |
+
if any(output.optional for output in self.outputs):
|
123 |
+
raise PyTritonValidationError("Output tensors cannot be optional.")
|
124 |
+
|
125 |
+
self.config = config
|
126 |
+
self._workspace = workspace
|
127 |
+
if os.environ.get("PYTRITON_NO_TENSORSTORE"):
|
128 |
+
self._serializer_deserializer = Base64SerializerDeserializer()
|
129 |
+
else:
|
130 |
+
self._serializer_deserializer = TensorStoreSerializerDeserializer()
|
131 |
+
self._triton_model_config: Optional[TritonModelConfig] = None
|
132 |
+
self._model_events_observers: typing.List[ModelEventsHandler] = []
|
133 |
+
|
134 |
+
def get_model_config(self) -> dict:
|
135 |
+
"""Get model config.
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
Dictionary with model config
|
139 |
+
"""
|
140 |
+
triton_model_config = self._get_triton_model_config()
|
141 |
+
generator = ModelConfigGenerator(config=triton_model_config)
|
142 |
+
return generator.get_config()
|
143 |
+
|
144 |
+
def get_proxy_model_files(self) -> typing.Dict[str, bytes]:
|
145 |
+
"""Get proxy model files.
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
Dictionary with model files to be copied to Triton model store on server side:
|
149 |
+
key: file path in following format - 'file:{model_version}/{file_name}'
|
150 |
+
value: file content as bytes
|
151 |
+
"""
|
152 |
+
proxy_model_files_dict = {}
|
153 |
+
proxy_path = pathlib.Path(__file__).parent.parent / "proxy"
|
154 |
+
for file_to_copy in self.SCRIPT_FILES_TO_COPY:
|
155 |
+
src_file_path = proxy_path / file_to_copy
|
156 |
+
with open(src_file_path, "rb") as f:
|
157 |
+
src_file = f.read()
|
158 |
+
proxy_model_files_dict[f"file:{self.model_version}/{file_to_copy}"] = src_file
|
159 |
+
|
160 |
+
return proxy_model_files_dict
|
161 |
+
|
162 |
+
def generate_model(self, model_repository: pathlib.Path) -> None:
|
163 |
+
"""Generate model and its config in the model repository.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
model_repository: Path to Triton model repository
|
167 |
+
|
168 |
+
Raises:
|
169 |
+
OSError: when model repository not exists
|
170 |
+
"""
|
171 |
+
LOGGER.debug(
|
172 |
+
f"Generating model and config for {self.model_name} and {self.model_version} to {model_repository}"
|
173 |
+
)
|
174 |
+
|
175 |
+
model_catalog = model_repository / self.model_name
|
176 |
+
|
177 |
+
config_file_path = model_catalog / "config.pbtxt"
|
178 |
+
if config_file_path.exists():
|
179 |
+
LOGGER.warning(f"The config file {config_file_path} is going to be overridden.")
|
180 |
+
|
181 |
+
triton_model_config = self._get_triton_model_config()
|
182 |
+
generator = ModelConfigGenerator(config=triton_model_config)
|
183 |
+
generator.to_file(config_file_path)
|
184 |
+
|
185 |
+
model_version_catalog = model_catalog / str(self.model_version)
|
186 |
+
model_version_catalog.mkdir(exist_ok=True, parents=True)
|
187 |
+
|
188 |
+
proxy_path = pathlib.Path(__file__).parent.parent / "proxy"
|
189 |
+
|
190 |
+
for script_file in self.SCRIPT_FILES_TO_COPY:
|
191 |
+
src_file_path = proxy_path / script_file
|
192 |
+
dst_file_path = model_version_catalog / script_file
|
193 |
+
shutil.copy(src_file_path, dst_file_path)
|
194 |
+
|
195 |
+
def setup(self) -> None:
|
196 |
+
"""Create deployments and bindings to Triton Inference Server."""
|
197 |
+
with self._inference_handlers_lock:
|
198 |
+
if not self._inference_handlers:
|
199 |
+
triton_model_config = self._get_triton_model_config()
|
200 |
+
workspace_path = pathlib.Path(triton_model_config.backend_parameters["workspace-path"])
|
201 |
+
validator = TritonResultsValidator(triton_model_config, self._strict)
|
202 |
+
|
203 |
+
inference_handler_config_path = workspace_path / f"{self.model_name}-config.sock"
|
204 |
+
inference_handler_config = get_config_from_handshake_server(inference_handler_config_path)
|
205 |
+
|
206 |
+
data_socket = pathlib.Path(inference_handler_config["data_socket"])
|
207 |
+
authkey = base64.decodebytes(inference_handler_config["authkey"].encode("ascii"))
|
208 |
+
self._serializer_deserializer.connect(data_socket.as_posix(), authkey)
|
209 |
+
|
210 |
+
for i, infer_function in enumerate(self.infer_functions):
|
211 |
+
self.triton_context.model_configs[infer_function] = copy.deepcopy(triton_model_config)
|
212 |
+
_inject_triton_context(self.triton_context, infer_function)
|
213 |
+
|
214 |
+
request_server_socket = workspace_path / f"{self.model_name}_0_{i}-server.sock"
|
215 |
+
request_server_socket = f"ipc://{request_server_socket.as_posix()}"
|
216 |
+
|
217 |
+
requests_respones_connector = RequestsResponsesConnector(
|
218 |
+
url=request_server_socket,
|
219 |
+
serializer_deserializer=self._serializer_deserializer,
|
220 |
+
)
|
221 |
+
requests_respones_connector.start()
|
222 |
+
self._requests_respones_connectors.append(requests_respones_connector)
|
223 |
+
inference_handler = InferenceHandler(
|
224 |
+
model_callable=infer_function,
|
225 |
+
requests_responses_connector=requests_respones_connector,
|
226 |
+
validator=validator,
|
227 |
+
name=f"inference_handler-{i}",
|
228 |
+
)
|
229 |
+
inference_handler.on_inference_handler_event(self._on_inference_handler_event)
|
230 |
+
inference_handler.start()
|
231 |
+
self._inference_handlers.append(inference_handler)
|
232 |
+
|
233 |
+
def clean(self) -> None:
|
234 |
+
"""Post unload actions to perform on model."""
|
235 |
+
with self._observers_lock:
|
236 |
+
LOGGER.debug("Clearing model events observers")
|
237 |
+
self._model_events_observers.clear()
|
238 |
+
LOGGER.debug("Socket closed. Waiting for inference handler and communication threads to shut down")
|
239 |
+
with self._inference_handlers_lock:
|
240 |
+
for inference_handler in self._inference_handlers:
|
241 |
+
inference_handler.stop()
|
242 |
+
for inference_handler in self._inference_handlers:
|
243 |
+
inference_handler.join()
|
244 |
+
self._inference_handlers.clear()
|
245 |
+
for requests_responses_connector in self._requests_respones_connectors:
|
246 |
+
requests_responses_connector.close()
|
247 |
+
for requests_responses_connector in self._requests_respones_connectors:
|
248 |
+
requests_responses_connector.join()
|
249 |
+
self._requests_respones_connectors.clear()
|
250 |
+
self._serializer_deserializer.close()
|
251 |
+
|
252 |
+
def is_alive(self) -> bool:
|
253 |
+
"""Validate if model is working on Triton.
|
254 |
+
|
255 |
+
If model is fully loaded by Triton, return True. Otherwise, perform a custom verification.
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
True if model is working, False otherwise
|
259 |
+
"""
|
260 |
+
with self._inference_handlers_lock:
|
261 |
+
return (
|
262 |
+
bool(self._inference_handlers)
|
263 |
+
and bool(self._requests_respones_connectors)
|
264 |
+
and all(inference_handler.is_alive() for inference_handler in self._inference_handlers)
|
265 |
+
and all(
|
266 |
+
requests_responses_connector.is_alive()
|
267 |
+
for requests_responses_connector in self._requests_respones_connectors
|
268 |
+
)
|
269 |
+
)
|
270 |
+
|
271 |
+
def _get_triton_model_config(self) -> TritonModelConfig:
|
272 |
+
"""Generate ModelConfig from descriptor and custom arguments for Python model.
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
ModelConfig object with configuration for Python model deployment
|
276 |
+
"""
|
277 |
+
if not self._triton_model_config:
|
278 |
+
backend_parameters = {"workspace-path": self._workspace.path.as_posix()}
|
279 |
+
if self._trace_config:
|
280 |
+
backend_parameters["trace-config"] = base64.b64encode(json.dumps(self._trace_config).encode()).decode()
|
281 |
+
triton_model_config = TritonModelConfig(
|
282 |
+
model_name=self.model_name,
|
283 |
+
model_version=self.model_version,
|
284 |
+
batching=self.config.batching,
|
285 |
+
batcher=self.config.batcher,
|
286 |
+
max_batch_size=self.config.max_batch_size,
|
287 |
+
decoupled=self.config.decoupled,
|
288 |
+
backend_parameters=backend_parameters,
|
289 |
+
instance_group={DeviceKind.KIND_CPU: len(self.infer_functions)},
|
290 |
+
)
|
291 |
+
inputs = []
|
292 |
+
for idx, input_spec in enumerate(self.inputs, start=1):
|
293 |
+
input_name = input_spec.name if input_spec.name else f"INPUT_{idx}"
|
294 |
+
tensor = TensorSpec(
|
295 |
+
name=input_name, dtype=input_spec.dtype, shape=input_spec.shape, optional=input_spec.optional
|
296 |
+
)
|
297 |
+
inputs.append(tensor)
|
298 |
+
|
299 |
+
outputs = []
|
300 |
+
for idx, output_spec in enumerate(self.outputs, start=1):
|
301 |
+
output_name = output_spec.name if output_spec.name else f"OUTPUT_{idx}"
|
302 |
+
tensor = TensorSpec(name=output_name, dtype=output_spec.dtype, shape=output_spec.shape)
|
303 |
+
outputs.append(tensor)
|
304 |
+
|
305 |
+
triton_model_config.inputs = inputs
|
306 |
+
triton_model_config.outputs = outputs
|
307 |
+
|
308 |
+
if self.config.response_cache:
|
309 |
+
triton_model_config.response_cache = ResponseCache(enable=True)
|
310 |
+
|
311 |
+
self._triton_model_config = triton_model_config
|
312 |
+
|
313 |
+
return self._triton_model_config
|
314 |
+
|
315 |
+
def on_model_event(self, model_event_handle_fn: ModelEventsHandler):
|
316 |
+
"""Register ModelEventsHandler callable.
|
317 |
+
|
318 |
+
Args:
|
319 |
+
model_event_handle_fn: function to be called when model events arises
|
320 |
+
"""
|
321 |
+
with self._observers_lock:
|
322 |
+
self._model_events_observers.append(model_event_handle_fn)
|
323 |
+
|
324 |
+
def _notify_model_events_observers(self, event: ModelEvent, context: typing.Any):
|
325 |
+
with self._observers_lock:
|
326 |
+
for model_event_handle_fn in self._model_events_observers:
|
327 |
+
model_event_handle_fn(self, event, context)
|
328 |
+
|
329 |
+
def _on_inference_handler_event(
|
330 |
+
self, proxy_backend: InferenceHandler, event: InferenceHandlerEvent, context: typing.Optional[typing.Any] = None
|
331 |
+
):
|
332 |
+
if event in [InferenceHandlerEvent.CLOSING, InferenceHandlerEvent.UNRECOVERABLE_ERROR]:
|
333 |
+
self._notify_model_events_observers(ModelEvent.RUNTIME_TERMINATING, context)
|
334 |
+
elif event == InferenceHandlerEvent.CLOSED:
|
335 |
+
self._notify_model_events_observers(ModelEvent.RUNTIME_TERMINATED, context)
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/proxy/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# noqa: D104
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/proxy/communication.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Module handling communication between RequestsServer and RequestsServerClients."""
|
15 |
+
|
16 |
+
import asyncio
|
17 |
+
import enum
|
18 |
+
import functools
|
19 |
+
import json
|
20 |
+
import logging
|
21 |
+
import pathlib
|
22 |
+
import socket
|
23 |
+
import threading
|
24 |
+
import time
|
25 |
+
import traceback
|
26 |
+
import typing
|
27 |
+
import uuid
|
28 |
+
from concurrent.futures import Future as ConcurrentFuture
|
29 |
+
|
30 |
+
import zmq # pytype: disable=import-error
|
31 |
+
import zmq.asyncio # pytype: disable=import-error
|
32 |
+
|
33 |
+
LOGGER = logging.getLogger(__name__)
|
34 |
+
SERVER_LOGGER = LOGGER.getChild("server")
|
35 |
+
CLIENT_LOGGER = LOGGER.getChild("client")
|
36 |
+
|
37 |
+
_STARTUP_TIMEOUT_S = 1.0
|
38 |
+
|
39 |
+
|
40 |
+
class PyTritonResponseFlags(enum.IntFlag):
|
41 |
+
"""Response flags for PyTritonInferenceHandler."""
|
42 |
+
|
43 |
+
EOS = enum.auto() # End Of Stream
|
44 |
+
ERROR = enum.auto()
|
45 |
+
|
46 |
+
|
47 |
+
class _RequestsServerState(enum.Enum):
|
48 |
+
STOPPED = enum.auto()
|
49 |
+
STARTING = enum.auto()
|
50 |
+
STARTED = enum.auto()
|
51 |
+
STOPPING = enum.auto()
|
52 |
+
|
53 |
+
|
54 |
+
def _set_current_task_name(name: str):
|
55 |
+
current_task = asyncio.current_task()
|
56 |
+
if current_task is not None:
|
57 |
+
current_task.set_name(name)
|
58 |
+
|
59 |
+
|
60 |
+
_RequestScope = typing.Dict[str, typing.Any]
|
61 |
+
_HandleRequestsCoro = typing.Callable[[_RequestScope, bytes, zmq.asyncio.Socket], typing.Awaitable[typing.Any]]
|
62 |
+
HandleResponsesCoro = typing.Callable[[_RequestScope, asyncio.Queue, ConcurrentFuture], typing.Awaitable[typing.Any]]
|
63 |
+
|
64 |
+
|
65 |
+
class RequestsServer:
|
66 |
+
"""Class for serving available inference requests and passing inference responses."""
|
67 |
+
|
68 |
+
def __init__(self, url: str, handle_responses_fn: HandleResponsesCoro):
|
69 |
+
"""Initialize RequestsServer.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
url: url to bind socket
|
73 |
+
handle_responses_fn: couroutine handling responses from InferenceHandler
|
74 |
+
"""
|
75 |
+
self._url = url
|
76 |
+
self._handle_responses_fn = handle_responses_fn
|
77 |
+
self._state = _RequestsServerState.STOPPED
|
78 |
+
self._state_condition = threading.Condition()
|
79 |
+
self._shutdown_event = asyncio.Event() # TODO: is it still required having condition?
|
80 |
+
self._server_loop = None
|
81 |
+
|
82 |
+
# requests_id -> results asyncio.Queue map
|
83 |
+
self._responses_queues: typing.Dict[bytes, asyncio.Queue] = {}
|
84 |
+
self._handle_responses_tasks: typing.Dict[bytes, asyncio.Task] = {}
|
85 |
+
|
86 |
+
def run(self):
|
87 |
+
"""Run RequestsServer.
|
88 |
+
|
89 |
+
It stops when handle_messages coroutine finishes.
|
90 |
+
|
91 |
+
Raises:
|
92 |
+
RuntimeError: if RequestsServer is already running
|
93 |
+
"""
|
94 |
+
with self._state_condition:
|
95 |
+
if self._state != _RequestsServerState.STOPPED:
|
96 |
+
raise RuntimeError(f"Cannot run {type(self).__name__} as it is already running")
|
97 |
+
|
98 |
+
self._state = _RequestsServerState.STARTING
|
99 |
+
self._state_condition.notify_all()
|
100 |
+
|
101 |
+
assert len(self._responses_queues) == 0
|
102 |
+
assert len(self._handle_responses_tasks) == 0
|
103 |
+
|
104 |
+
asyncio.run(self.handle_messages())
|
105 |
+
|
106 |
+
@property
|
107 |
+
def server_loop(self) -> typing.Optional[asyncio.AbstractEventLoop]:
|
108 |
+
"""Get asyncio loop for RequestsServer.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
asyncio.AbstractEventLoop: asyncio loop for RequestsServer or None if server is not started yet
|
112 |
+
"""
|
113 |
+
return self._server_loop
|
114 |
+
|
115 |
+
def wait_till_running(self):
|
116 |
+
"""Wait till RequestsServer is running.
|
117 |
+
|
118 |
+
Raises:
|
119 |
+
RuntimeError: if RequestsServer is shutting down or not launched yet
|
120 |
+
"""
|
121 |
+
with self._state_condition:
|
122 |
+
if self._state == _RequestsServerState.STARTING:
|
123 |
+
self._state_condition.wait_for(
|
124 |
+
lambda: self._state == _RequestsServerState.STARTED, timeout=_STARTUP_TIMEOUT_S
|
125 |
+
)
|
126 |
+
elif self._state == _RequestsServerState.STOPPED:
|
127 |
+
raise RuntimeError("Cannot push requests before RequestsServer is started")
|
128 |
+
elif self._state == _RequestsServerState.STOPPING:
|
129 |
+
raise RuntimeError(f"Cannot push requests while {type(self).__name__} is shutting down")
|
130 |
+
|
131 |
+
async def handle_messages(self):
|
132 |
+
"""Coroutine for handling messages from InferenceHandler."""
|
133 |
+
self._server_loop = asyncio.get_running_loop()
|
134 |
+
try:
|
135 |
+
SERVER_LOGGER.debug(f"Binding socket to url='{self._url}'")
|
136 |
+
self._zmq_context = zmq.asyncio.Context()
|
137 |
+
self._socket = self._zmq_context.socket(zmq.DEALER)
|
138 |
+
self._socket.bind(self._url)
|
139 |
+
except (TypeError, zmq.error.ZMQError) as e:
|
140 |
+
raise ValueError(
|
141 |
+
f"Error occurred during binding socket to url='{self._url}' (e: {e})." "RequestsServer will be closed."
|
142 |
+
) from e
|
143 |
+
|
144 |
+
_set_current_task_name("handle_messages")
|
145 |
+
|
146 |
+
with self._state_condition:
|
147 |
+
if self._state != _RequestsServerState.STARTING:
|
148 |
+
self._state = _RequestsServerState.STOPPED
|
149 |
+
self._state_condition.notify_all()
|
150 |
+
raise RuntimeError(f"Cannot start {type(self).__name__} as it is not in STARTING state")
|
151 |
+
|
152 |
+
self._state = _RequestsServerState.STARTED
|
153 |
+
self._state_condition.notify_all()
|
154 |
+
|
155 |
+
def _all_responses_processed():
|
156 |
+
return not any([self._handle_responses_tasks, self._responses_queues])
|
157 |
+
|
158 |
+
try:
|
159 |
+
flag_check_interval_s = 1.0
|
160 |
+
# have to receive mssages untill all requestss to be processed, despite shutdown event is set
|
161 |
+
while not self._shutdown_event.is_set() or not _all_responses_processed():
|
162 |
+
requests_id = b"<unknown>"
|
163 |
+
try:
|
164 |
+
requests_id, flags, responses_payload = await asyncio.wait_for(
|
165 |
+
self._socket.recv_multipart(), flag_check_interval_s
|
166 |
+
)
|
167 |
+
flags = int.from_bytes(flags, byteorder="big")
|
168 |
+
responses_queue = self._responses_queues[requests_id]
|
169 |
+
responses_queue.put_nowait((flags, responses_payload)) # queue have no max_size
|
170 |
+
except asyncio.TimeoutError:
|
171 |
+
continue
|
172 |
+
except KeyError:
|
173 |
+
SERVER_LOGGER.warning(f"Received response for unknown requests {requests_id.hex()}. Ignoring it.")
|
174 |
+
except asyncio.CancelledError:
|
175 |
+
SERVER_LOGGER.info("Received CancelledError")
|
176 |
+
self._shutdown_event.set()
|
177 |
+
finally:
|
178 |
+
# Received all responses, close socket
|
179 |
+
SERVER_LOGGER.debug("Closing socket")
|
180 |
+
try:
|
181 |
+
if self._socket is not None:
|
182 |
+
self._socket.close(linger=0)
|
183 |
+
self._socket = None
|
184 |
+
except zmq.error.ZMQError as e:
|
185 |
+
SERVER_LOGGER.error(f"Error occurred during closing socket (e: {e}).")
|
186 |
+
|
187 |
+
try:
|
188 |
+
if self._zmq_context is not None:
|
189 |
+
self._zmq_context.term()
|
190 |
+
self._zmq_context = None
|
191 |
+
except zmq.error.ZMQError as e:
|
192 |
+
SERVER_LOGGER.error(f"Error occurred during closing zmq context (e: {e}).")
|
193 |
+
|
194 |
+
self._server_loop = None
|
195 |
+
|
196 |
+
with self._state_condition:
|
197 |
+
self._state = _RequestsServerState.STOPPED
|
198 |
+
self._state_condition.notify_all()
|
199 |
+
|
200 |
+
SERVER_LOGGER.debug("Socket for handle_messages task closed")
|
201 |
+
self._shutdown_event.clear()
|
202 |
+
SERVER_LOGGER.debug(f"Leaving handle_messages task from {type(self).__name__}")
|
203 |
+
|
204 |
+
def shutdown(self):
|
205 |
+
"""Close RequestsServer.
|
206 |
+
|
207 |
+
Don't wait for handle_messages coroutine to finish.
|
208 |
+
"""
|
209 |
+
SERVER_LOGGER.debug("Closing RequestsServer")
|
210 |
+
with self._state_condition:
|
211 |
+
self._state = _RequestsServerState.STOPPING
|
212 |
+
self._state_condition.notify_all()
|
213 |
+
self._shutdown_event.set()
|
214 |
+
|
215 |
+
async def send_requests(
|
216 |
+
self, requests_id: bytes, requests_payload: bytes, responses_future: ConcurrentFuture
|
217 |
+
) -> asyncio.Task:
|
218 |
+
"""Send requests to InferenceHandler.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
requests_id: id of requests
|
222 |
+
requests_payload: payload of requests
|
223 |
+
responses_future: future for waiting in another thread
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
asyncio.Task: task handling responses from InferenceHandler
|
227 |
+
|
228 |
+
Raises:
|
229 |
+
RuntimeError: if RequestsServer is shutting down or requests_id is already pending
|
230 |
+
"""
|
231 |
+
if self._shutdown_event.is_set():
|
232 |
+
SERVER_LOGGER.debug(f"Cannot send requests while {type(self).__name__} is {self._state.name}")
|
233 |
+
raise RuntimeError(f"Cannot send requests while {type(self).__name__} is {self._state.name}")
|
234 |
+
|
235 |
+
if requests_id in self._responses_queues or requests_id in self._handle_responses_tasks:
|
236 |
+
SERVER_LOGGER.debug(f"Cannot send requests with id {requests_id.hex()} as such id is already pending")
|
237 |
+
raise RuntimeError(f"Cannot send requests with id {requests_id.hex()} as such id is already pending")
|
238 |
+
|
239 |
+
_set_current_task_name(f"send_requests-{requests_id.hex()}")
|
240 |
+
|
241 |
+
self._responses_queues[requests_id] = asyncio.Queue()
|
242 |
+
scope = {"requests_id": requests_id}
|
243 |
+
handle_responses_task = self._server_loop.create_task(
|
244 |
+
self._handle_responses(scope, self._responses_queues[requests_id], responses_future),
|
245 |
+
name=f"handle_responses-{requests_id.hex()}",
|
246 |
+
)
|
247 |
+
self._handle_responses_tasks[requests_id] = handle_responses_task
|
248 |
+
|
249 |
+
# FIXME: check if can not copy buffers; in case copy=False send_multipart returns MessageTracker
|
250 |
+
# https://pyzmq.readthedocs.io/en/latest/api/zmq.html#zmq.Socket.send_multipart
|
251 |
+
# consider send_pyobject|send_serialized (but it is not multipart)
|
252 |
+
|
253 |
+
# sending in same loop, thus thread as handle_messages
|
254 |
+
# send_multipart doesn't return anything, as it copies requests_payload
|
255 |
+
await self._socket.send_multipart([requests_id, requests_payload])
|
256 |
+
|
257 |
+
return handle_responses_task
|
258 |
+
|
259 |
+
async def _handle_responses(self, scope, responses_queue: asyncio.Queue, responses_future: ConcurrentFuture):
|
260 |
+
"""Handle responses from InferenceHandler.
|
261 |
+
|
262 |
+
Args:
|
263 |
+
scope: scope for handling responses
|
264 |
+
responses_queue: queue with responses payload from InferenceHandler
|
265 |
+
responses_future: future for waiting in another thread
|
266 |
+
"""
|
267 |
+
requests_id = scope["requests_id"]
|
268 |
+
try:
|
269 |
+
return await self._handle_responses_fn(scope, responses_queue, responses_future)
|
270 |
+
finally:
|
271 |
+
self._responses_queues.pop(requests_id)
|
272 |
+
self._handle_responses_tasks.pop(requests_id)
|
273 |
+
|
274 |
+
|
275 |
+
class RequestsServerClient:
|
276 |
+
"""RequestsServer client for handling requests from RequestsServer and sending back responses."""
|
277 |
+
|
278 |
+
def __init__(self, url: str, handle_requests_fn: _HandleRequestsCoro, name: typing.Optional[str] = None):
|
279 |
+
"""Initialize RequestsServerClient.
|
280 |
+
|
281 |
+
Args:
|
282 |
+
url: url to connect socket
|
283 |
+
handle_requests_fn: couroutine handling requests from InferenceHandler
|
284 |
+
name: name of RequestsServerClient
|
285 |
+
"""
|
286 |
+
self._shutdown_event = asyncio.Event()
|
287 |
+
self._url = url
|
288 |
+
self._handle_requests_fn = handle_requests_fn
|
289 |
+
self._handle_requests_tasks: typing.Dict[bytes, asyncio.Task] = {}
|
290 |
+
self._handle_requests_tasks_condition = asyncio.Condition()
|
291 |
+
self._name = name or f"requests_server_client-{uuid.uuid4().hex[-4:]}"
|
292 |
+
self._loop = None
|
293 |
+
|
294 |
+
def run(self):
|
295 |
+
"""Run RequestsServerClient.
|
296 |
+
|
297 |
+
It stops when handle_requests coroutine finishes.
|
298 |
+
"""
|
299 |
+
asyncio.run(self.handle_requests())
|
300 |
+
|
301 |
+
def shutdown(self) -> None:
|
302 |
+
"""Close RequestsServerClient.
|
303 |
+
|
304 |
+
Don't wait for handle_requests coroutine to finish.
|
305 |
+
"""
|
306 |
+
CLIENT_LOGGER.debug(f"Closing {type(self).__name__} {self._name}")
|
307 |
+
self._shutdown_event.set()
|
308 |
+
|
309 |
+
async def handle_requests(self):
|
310 |
+
"""Coroutine for handling requests from RequestsServer."""
|
311 |
+
name = self._name
|
312 |
+
_set_current_task_name(name)
|
313 |
+
|
314 |
+
zmq_context = None
|
315 |
+
socket = None
|
316 |
+
self._loop = asyncio.get_running_loop()
|
317 |
+
try:
|
318 |
+
CLIENT_LOGGER.debug(f"Connecting {name} to server listening on {self._url}")
|
319 |
+
zmq_context = zmq.asyncio.Context()
|
320 |
+
socket = zmq_context.socket(zmq.DEALER)
|
321 |
+
socket.connect(self._url)
|
322 |
+
|
323 |
+
send = functools.partial(self._send, socket)
|
324 |
+
|
325 |
+
flag_check_interval_s = 1.0
|
326 |
+
while True:
|
327 |
+
try:
|
328 |
+
requests_id, requests_payloads = await asyncio.wait_for(
|
329 |
+
socket.recv_multipart(), flag_check_interval_s
|
330 |
+
)
|
331 |
+
scope = {"requests_id": requests_id}
|
332 |
+
CLIENT_LOGGER.debug(f"{requests_id.hex()} received requests")
|
333 |
+
handle_requests_task = self._loop.create_task(self._handle_requests(scope, requests_payloads, send))
|
334 |
+
self._handle_requests_tasks[requests_id] = handle_requests_task
|
335 |
+
handle_requests_task.set_name(f"handle_requests-{requests_id.hex()}")
|
336 |
+
except asyncio.TimeoutError:
|
337 |
+
if self._shutdown_event.is_set():
|
338 |
+
break
|
339 |
+
continue
|
340 |
+
|
341 |
+
CLIENT_LOGGER.debug("Waiting for handle_requests tasks to finish")
|
342 |
+
async with self._handle_requests_tasks_condition:
|
343 |
+
await self._handle_requests_tasks_condition.wait_for(lambda: len(self._handle_requests_tasks) == 0)
|
344 |
+
CLIENT_LOGGER.debug("All handle_requests tasks finished")
|
345 |
+
|
346 |
+
except zmq.error.ZMQError:
|
347 |
+
CLIENT_LOGGER.exception(
|
348 |
+
"Connection error occurred during reading requests. " f"{type(self).__name__} will be closed."
|
349 |
+
)
|
350 |
+
self._shutdown_event.set()
|
351 |
+
except Exception:
|
352 |
+
CLIENT_LOGGER.exception(f"Internal {type(self).__name__}. " f"{type(self).__name__} will be closed.")
|
353 |
+
self._shutdown_event.set()
|
354 |
+
finally:
|
355 |
+
try:
|
356 |
+
socket_close_timeout_ms = 0 # immediate close (drop not sent messages)
|
357 |
+
if socket is not None:
|
358 |
+
socket.close(linger=socket_close_timeout_ms)
|
359 |
+
except zmq.error.ZMQError as e:
|
360 |
+
CLIENT_LOGGER.error(f"Error occurred during closing socket (e: {e}).")
|
361 |
+
|
362 |
+
try:
|
363 |
+
if zmq_context is not None:
|
364 |
+
zmq_context.term()
|
365 |
+
except zmq.error.ZMQError as e:
|
366 |
+
CLIENT_LOGGER.error(f"Error occurred during closing zmq context (e: {e}).")
|
367 |
+
|
368 |
+
CLIENT_LOGGER.debug(f"Socket for {name} closed")
|
369 |
+
self._shutdown_event.clear()
|
370 |
+
self._loop = None
|
371 |
+
CLIENT_LOGGER.debug(f"Leaving {name}")
|
372 |
+
|
373 |
+
@property
|
374 |
+
def name(self) -> str:
|
375 |
+
"""Get name of RequestsServerClient.
|
376 |
+
|
377 |
+
Returns:
|
378 |
+
name of RequestsServerClient
|
379 |
+
"""
|
380 |
+
return self._name
|
381 |
+
|
382 |
+
@property
|
383 |
+
def loop(self) -> asyncio.AbstractEventLoop:
|
384 |
+
"""Get asyncio loop for RequestsServerClient.
|
385 |
+
|
386 |
+
Returns:
|
387 |
+
asyncio.AbstractEventLoop: asyncio loop for RequestsServerClient
|
388 |
+
"""
|
389 |
+
return self._loop
|
390 |
+
|
391 |
+
async def _handle_requests(self, scope, requests_payload, send):
|
392 |
+
try:
|
393 |
+
await self._handle_requests_fn(scope, requests_payload, send)
|
394 |
+
# except PyTritonUnrecoverableError:
|
395 |
+
# error = traceback.format_exc()
|
396 |
+
# responses = InferenceHandlerResponses(error=error)
|
397 |
+
# CLIENT_LOGGER.error(
|
398 |
+
# "Unrecoverable error thrown during calling model callable. "
|
399 |
+
# "Shutting down Triton Inference Server. "
|
400 |
+
# f"{error}"
|
401 |
+
# )
|
402 |
+
# self.stopped = True
|
403 |
+
# self._notify_proxy_backend_observers(InferenceHandlerEvent.UNRECOVERABLE_ERROR, error)
|
404 |
+
# CLIENT_LOGGER.debug(f"Send response to proxy model for {model_name}.")
|
405 |
+
# send(responses.as_bytes())
|
406 |
+
except Exception:
|
407 |
+
error = traceback.format_exc()
|
408 |
+
flags = PyTritonResponseFlags.ERROR | PyTritonResponseFlags.EOS
|
409 |
+
await send(scope, flags, error.encode())
|
410 |
+
CLIENT_LOGGER.error(f"Error occurred during handling requests {scope['requests_id'].hex()}\n{error}")
|
411 |
+
finally:
|
412 |
+
async with self._handle_requests_tasks_condition:
|
413 |
+
self._handle_requests_tasks.pop(scope["requests_id"], None)
|
414 |
+
self._handle_requests_tasks_condition.notify()
|
415 |
+
CLIENT_LOGGER.debug(f"Finished handling requests {scope['requests_id'].hex()}")
|
416 |
+
|
417 |
+
async def _send(self, socket, scope, flags, requests_payload):
|
418 |
+
"""Send requests to RequestsServer.
|
419 |
+
|
420 |
+
Args:
|
421 |
+
socket: socket for sending requests
|
422 |
+
scope: scope for sending requests
|
423 |
+
flags: flags for sending requests
|
424 |
+
requests_payload: payload of requests
|
425 |
+
"""
|
426 |
+
flags = flags.to_bytes(1, "big")
|
427 |
+
await socket.send_multipart([scope["requests_id"], flags, requests_payload])
|
428 |
+
|
429 |
+
|
430 |
+
class HandshakeServer(threading.Thread):
|
431 |
+
"""Handshake server for passing config."""
|
432 |
+
|
433 |
+
def __init__(self, socket_path: pathlib.Path, inference_handler_config) -> None:
|
434 |
+
"""Initialize HandshakeServer.
|
435 |
+
|
436 |
+
Args:
|
437 |
+
socket_path: path to socket
|
438 |
+
inference_handler_config: config for InferenceHandler
|
439 |
+
"""
|
440 |
+
super().__init__(daemon=True, name="handshake-server")
|
441 |
+
self._socket_path = socket_path
|
442 |
+
try:
|
443 |
+
self._config_payload = json.dumps(inference_handler_config).encode()
|
444 |
+
except TypeError:
|
445 |
+
raise ValueError(f"InferenceHandler config is not serializable: {inference_handler_config}") from None
|
446 |
+
|
447 |
+
self._server = None
|
448 |
+
self._error_from_thread = None
|
449 |
+
|
450 |
+
def start(self):
|
451 |
+
"""Start HandshakeServer.
|
452 |
+
|
453 |
+
Raises:
|
454 |
+
RuntimeError: if HandshakeServer is already running or error occurred during starting
|
455 |
+
"""
|
456 |
+
if self._server:
|
457 |
+
raise RuntimeError("HandshakeServer is already running")
|
458 |
+
|
459 |
+
super().start()
|
460 |
+
while self._server is None and not self._error_from_thread:
|
461 |
+
time.sleep(0.001)
|
462 |
+
if self._error_from_thread is not None:
|
463 |
+
raise self._error_from_thread
|
464 |
+
|
465 |
+
def run(self):
|
466 |
+
"""Run HandshakeServer."""
|
467 |
+
asyncio.run(self._run())
|
468 |
+
|
469 |
+
async def _run(self):
|
470 |
+
try:
|
471 |
+
self._server = await asyncio.start_unix_server(self._handle_request, self._socket_path)
|
472 |
+
async with self._server:
|
473 |
+
try:
|
474 |
+
await self._server.serve_forever()
|
475 |
+
except asyncio.CancelledError:
|
476 |
+
pass
|
477 |
+
except Exception as e:
|
478 |
+
SERVER_LOGGER.error(f"Error occurred during running handshake server (e: {e})")
|
479 |
+
self._error_from_thread = e
|
480 |
+
|
481 |
+
def close(self):
|
482 |
+
"""Close HandshakeServer."""
|
483 |
+
loop = self._server.get_loop()
|
484 |
+
loop_tasks = asyncio.all_tasks(loop=loop)
|
485 |
+
for task in loop_tasks:
|
486 |
+
loop.call_soon_threadsafe(task.cancel)
|
487 |
+
|
488 |
+
self.join()
|
489 |
+
SERVER_LOGGER.debug("Closed handshake server")
|
490 |
+
|
491 |
+
async def _handle_request(self, reader, writer):
|
492 |
+
peername = writer.get_extra_info("peername")
|
493 |
+
try:
|
494 |
+
request_name = await asyncio.wait_for(reader.readuntil(b"\n"), timeout=1.0)
|
495 |
+
|
496 |
+
if request_name == b"get_config\n":
|
497 |
+
writer.write(len(self._config_payload).to_bytes(4, "big"))
|
498 |
+
writer.write(self._config_payload)
|
499 |
+
await writer.drain()
|
500 |
+
else:
|
501 |
+
SERVER_LOGGER.warning(f"Unknown request {request_name} from {peername}")
|
502 |
+
|
503 |
+
except asyncio.TimeoutError:
|
504 |
+
SERVER_LOGGER.debug(f"Timeout occurred during handling request from {peername}")
|
505 |
+
except Exception as e:
|
506 |
+
SERVER_LOGGER.error(f"Error occurred during handling request from {peername} (e: {e})")
|
507 |
+
finally:
|
508 |
+
writer.close()
|
509 |
+
await writer.wait_closed()
|
510 |
+
|
511 |
+
|
512 |
+
def get_config_from_handshake_server(socket_path: pathlib.Path, timeout_s: float = 1.0) -> dict:
|
513 |
+
"""Get config from handshake server.
|
514 |
+
|
515 |
+
Args:
|
516 |
+
socket_path: path to socket
|
517 |
+
timeout_s: timeout for waiting for the response
|
518 |
+
|
519 |
+
Returns:
|
520 |
+
config from handshake server
|
521 |
+
|
522 |
+
Raises:
|
523 |
+
TimeoutError: if timeout occurred while waiting for the response
|
524 |
+
ValueError: if invalid JSON response from the server
|
525 |
+
"""
|
526 |
+
should_stop_before_s = time.time() + timeout_s
|
527 |
+
sock = None
|
528 |
+
try:
|
529 |
+
LOGGER.debug(f"Waiting for config file {socket_path}")
|
530 |
+
while not socket_path.exists() and time.time() < should_stop_before_s:
|
531 |
+
time.sleep(0.001)
|
532 |
+
|
533 |
+
if not socket_path.exists():
|
534 |
+
raise TimeoutError(f"Timeout occurred while waiting for config file {socket_path}")
|
535 |
+
|
536 |
+
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
537 |
+
sock.settimeout(max(0.0, should_stop_before_s - time.time()))
|
538 |
+
sock.connect(socket_path.as_posix())
|
539 |
+
sock.sendall(b"get_config\n")
|
540 |
+
|
541 |
+
sock.settimeout(max(0.0, should_stop_before_s - time.time()))
|
542 |
+
payload_size = sock.recv(4)
|
543 |
+
payload_size = int.from_bytes(payload_size, "big")
|
544 |
+
|
545 |
+
sock.settimeout(max(0.0, should_stop_before_s - time.time()))
|
546 |
+
config_payload = sock.recv(payload_size)
|
547 |
+
config = json.loads(config_payload)
|
548 |
+
return config
|
549 |
+
except socket.timeout as e:
|
550 |
+
raise TimeoutError(f"Timeout occurred while waiting for config file {socket_path}") from e
|
551 |
+
except json.JSONDecodeError as e:
|
552 |
+
raise ValueError("Invalid JSON response from the server.") from e
|
553 |
+
finally:
|
554 |
+
if sock is not None:
|
555 |
+
sock.close()
|
stf/stf-api-alternative/pytriton/build/lib/pytriton/proxy/data.py
ADDED
@@ -0,0 +1,1133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Communication utility module.
|
15 |
+
|
16 |
+
It is used for interaction between model and proxy_backend.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import abc
|
20 |
+
import atexit
|
21 |
+
import base64
|
22 |
+
import ctypes
|
23 |
+
import ctypes.util
|
24 |
+
import dataclasses
|
25 |
+
import fcntl
|
26 |
+
import gc
|
27 |
+
import json
|
28 |
+
import logging
|
29 |
+
import math
|
30 |
+
import multiprocessing.managers
|
31 |
+
import multiprocessing.popen_spawn_posix
|
32 |
+
import multiprocessing.shared_memory
|
33 |
+
import os
|
34 |
+
import pathlib
|
35 |
+
import signal
|
36 |
+
import struct
|
37 |
+
import threading
|
38 |
+
import time
|
39 |
+
import uuid
|
40 |
+
import weakref
|
41 |
+
from typing import Dict, List, Literal, Optional, Sequence, Tuple, Union
|
42 |
+
|
43 |
+
import numpy as np
|
44 |
+
|
45 |
+
from .telemetry import get_span_dict, start_span_from_remote
|
46 |
+
from .types import Request, Requests, Response, Responses
|
47 |
+
|
48 |
+
LOGGER = logging.getLogger(__name__)
|
49 |
+
|
50 |
+
PROTOCOL_VERSION = "3"
|
51 |
+
|
52 |
+
|
53 |
+
# copy from
|
54 |
+
# https://github.com/triton-inference-server/python_backend/blob/main/src/resources/triton_python_backend_utils.py
|
55 |
+
|
56 |
+
|
57 |
+
def _serialize_byte_tensor(tensor) -> bytes:
|
58 |
+
"""Serializes a bytes tensor into a flat numpy array of length prepended bytes.
|
59 |
+
|
60 |
+
The numpy array should use dtype of np.object_. For np.bytes_,
|
61 |
+
numpy will remove trailing zeros at the end of byte sequence and because
|
62 |
+
of this it should be avoided.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
tensor: The bytes tensor to serialize.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
serialized array as bytes buffer.
|
69 |
+
|
70 |
+
Raises:
|
71 |
+
UnicodeEncodeErrors: raised when try to cast to string of non-bytes items fails
|
72 |
+
"""
|
73 |
+
if tensor.size == 0:
|
74 |
+
return b""
|
75 |
+
|
76 |
+
# If the input is a tensor of string/bytes objects, then must flatten those
|
77 |
+
# into a 1-dimensional array containing the 4-byte byte size followed by the
|
78 |
+
# actual element bytes. All elements are concatenated together in "C" order.
|
79 |
+
assert (tensor.dtype == np.object_) or (tensor.dtype.type == np.bytes_)
|
80 |
+
flattened_ls = []
|
81 |
+
total_len = 0
|
82 |
+
for obj in np.nditer(tensor, flags=["refs_ok"], order="C"):
|
83 |
+
# If directly passing bytes to BYTES type,
|
84 |
+
# don't convert it to str as Python will encode the
|
85 |
+
# bytes which may distort the meaning
|
86 |
+
if tensor.dtype == np.object_ and not isinstance(obj.item(), bytes):
|
87 |
+
s = str(obj.item()).encode("utf-8")
|
88 |
+
else:
|
89 |
+
s = obj.item()
|
90 |
+
item_len = len(s)
|
91 |
+
flattened_ls.append(struct.pack("<I", item_len))
|
92 |
+
flattened_ls.append(s)
|
93 |
+
total_len += struct.calcsize("<I") + item_len
|
94 |
+
flattened_ls.insert(0, struct.pack("<I", total_len))
|
95 |
+
flattened = b"".join(flattened_ls)
|
96 |
+
return flattened
|
97 |
+
|
98 |
+
|
99 |
+
# copy from
|
100 |
+
# https://github.com/triton-inference-server/python_backend/blob/main/src/resources/triton_python_backend_utils.py
|
101 |
+
def _deserialize_bytes_tensor(encoded_tensor, dtype, order: Literal["C", "F"] = "C") -> np.ndarray:
|
102 |
+
"""Deserializes an encoded bytes tensor into an numpy array of dtype of python objects.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
encoded_tensor : The encoded bytes tensor where each element has its length in
|
106 |
+
first 4 bytes followed by the content
|
107 |
+
dtype: The dtype of the numpy array to deserialize to.
|
108 |
+
order: The order of the numpy array to deserialize to.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
The 1-D numpy array of type object containing the deserialized bytes in 'C' order.
|
112 |
+
"""
|
113 |
+
strs = []
|
114 |
+
offset = 0
|
115 |
+
val_buf = encoded_tensor
|
116 |
+
val_len = struct.unpack_from("<I", val_buf, offset)[0] + 4
|
117 |
+
offset += 4
|
118 |
+
while offset < val_len:
|
119 |
+
item_length = struct.unpack_from("<I", val_buf, offset)[0]
|
120 |
+
offset += 4
|
121 |
+
item = struct.unpack_from(f"<{item_length}s", val_buf, offset)[0]
|
122 |
+
offset += item_length
|
123 |
+
strs.append(item)
|
124 |
+
return np.array(strs, dtype=dtype, order=order)
|
125 |
+
|
126 |
+
|
127 |
+
_MAX_DTYPE_DESCR = 16 # up to 16 chars in dtype descr; |S2147483647 (2^31-1) with margin
|
128 |
+
_PARTIAL_HEADER_FORMAT = f"<{_MAX_DTYPE_DESCR}scH"
|
129 |
+
|
130 |
+
|
131 |
+
def _pack_header(shape: Tuple[int, ...], dtype: np.dtype, order: Literal["C", "F"] = "C") -> bytes:
|
132 |
+
header_format = _PARTIAL_HEADER_FORMAT + "Q" * len(shape)
|
133 |
+
dtype_descr = np.lib.format.dtype_to_descr(dtype)
|
134 |
+
assert (
|
135 |
+
len(dtype_descr) <= _MAX_DTYPE_DESCR
|
136 |
+
), f"dtype descr is too long; dtype_descr={dtype_descr} max={_MAX_DTYPE_DESCR}"
|
137 |
+
return struct.pack(header_format, dtype_descr.encode("utf-8"), order.encode("ascii"), len(shape), *shape)
|
138 |
+
|
139 |
+
|
140 |
+
def _unpack_header(header: bytes) -> Tuple[Tuple[int, ...], np.dtype, Literal["C", "F"]]:
|
141 |
+
shape_offset = struct.calcsize(_PARTIAL_HEADER_FORMAT)
|
142 |
+
dtype_descr, order, ndim = struct.unpack_from(_PARTIAL_HEADER_FORMAT, header, offset=0)
|
143 |
+
shape = struct.unpack_from("Q" * ndim, header, offset=shape_offset)
|
144 |
+
dtype = np.lib.format.descr_to_dtype(dtype_descr.decode("utf-8").rstrip("\x00"))
|
145 |
+
order = order.decode("ascii")
|
146 |
+
return shape, dtype, order
|
147 |
+
|
148 |
+
|
149 |
+
def serialize_numpy_with_struct_header(tensor: np.ndarray) -> List[Union[bytes, memoryview]]:
|
150 |
+
"""Serialize numpy array to list of bytes and memoryviews.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
tensor: numpy array to serialize
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
List of data frames in form of bytes and memoryviews
|
157 |
+
"""
|
158 |
+
if tensor.dtype.hasobject:
|
159 |
+
data = _serialize_byte_tensor(tensor.ravel())
|
160 |
+
order = "C" # as _serialize_byte_tensor returns C-ordered array
|
161 |
+
else:
|
162 |
+
if not tensor.data.contiguous:
|
163 |
+
tensor = np.ascontiguousarray(tensor)
|
164 |
+
data = tensor.data
|
165 |
+
order = "C" if tensor.flags.c_contiguous else "F"
|
166 |
+
|
167 |
+
header = _pack_header(tensor.shape, tensor.dtype, order)
|
168 |
+
frames = [header, data]
|
169 |
+
return frames
|
170 |
+
|
171 |
+
|
172 |
+
def deserialize_numpy_with_struct_header(frames: List[Union[bytes, memoryview]]) -> np.ndarray:
|
173 |
+
"""Deserialize numpy array from list of bytes and memoryviews.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
frames: List of data frames in form of bytes and memoryviews
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
numpy array
|
180 |
+
"""
|
181 |
+
header, data = frames
|
182 |
+
shape, dtype, order = _unpack_header(header)
|
183 |
+
if dtype.hasobject:
|
184 |
+
tensor = _deserialize_bytes_tensor(data, dtype).reshape(shape)
|
185 |
+
else:
|
186 |
+
tensor = np.ndarray(shape, dtype=dtype, buffer=data, order=order)
|
187 |
+
return tensor
|
188 |
+
|
189 |
+
|
190 |
+
def calc_serialized_size_of_numpy_with_struct_header(tensor: np.ndarray) -> List[int]:
|
191 |
+
"""Calculate size of serialized numpy array.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
tensor: numpy array to serialize
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
List of sizes of data frames
|
198 |
+
"""
|
199 |
+
header_size = struct.calcsize(_PARTIAL_HEADER_FORMAT) + struct.calcsize("Q") * len(tensor.shape)
|
200 |
+
if tensor.dtype.hasobject:
|
201 |
+
items_sizes = []
|
202 |
+
order = "C" if tensor.flags.c_contiguous else "F"
|
203 |
+
for obj in np.nditer(tensor, flags=["refs_ok"], order=order):
|
204 |
+
if tensor.dtype == np.object_ and not isinstance(obj.item(), bytes):
|
205 |
+
s = str(obj.item()).encode("utf-8")
|
206 |
+
else:
|
207 |
+
s = obj.item()
|
208 |
+
items_sizes.append(len(s))
|
209 |
+
|
210 |
+
# total_size + for size of each item + each item
|
211 |
+
data_size = struct.calcsize("<I") + struct.calcsize("<I") * len(items_sizes) + sum(items_sizes)
|
212 |
+
else:
|
213 |
+
data_size = tensor.nbytes
|
214 |
+
|
215 |
+
return [header_size, data_size]
|
216 |
+
|
217 |
+
|
218 |
+
@dataclasses.dataclass
|
219 |
+
class BlockDescriptor:
|
220 |
+
"""Descriptor of block in shared memory."""
|
221 |
+
|
222 |
+
shm_name: str
|
223 |
+
offset: int
|
224 |
+
size: Optional[int] = None
|
225 |
+
|
226 |
+
def __post_init__(self):
|
227 |
+
"""Initialize other attributes."""
|
228 |
+
self.id = f"{self.shm_name}:{self.offset}"
|
229 |
+
|
230 |
+
@classmethod
|
231 |
+
def from_id(cls, tensor_id: str):
|
232 |
+
"""Create BlockDescriptor from dict."""
|
233 |
+
shm_name, offset = tensor_id.split(":")
|
234 |
+
return cls(shm_name, int(offset))
|
235 |
+
|
236 |
+
|
237 |
+
class _SharedMemorySegment:
|
238 |
+
def __init__(self, size):
|
239 |
+
self.shared_memory = multiprocessing.shared_memory.SharedMemory(create=True, size=size)
|
240 |
+
multiprocessing.util.debug(f"Created {self.shared_memory.name} of size {self.shared_memory.size}")
|
241 |
+
self.used_blocks: List[BlockDescriptor] = []
|
242 |
+
self.used_blocks_lock = threading.RLock()
|
243 |
+
self.free_blocks = [BlockDescriptor(self.shared_memory.name, offset=0, size=size)]
|
244 |
+
self.max_free_block_size = size
|
245 |
+
|
246 |
+
def _update_free_blocks(self):
|
247 |
+
total_size = self.shared_memory.size
|
248 |
+
free_blocks = []
|
249 |
+
offset = 0
|
250 |
+
|
251 |
+
with self.used_blocks_lock:
|
252 |
+
# find holes between used blocks
|
253 |
+
for used_block in self.used_blocks:
|
254 |
+
if used_block.offset > offset:
|
255 |
+
free_blocks.append(
|
256 |
+
BlockDescriptor(self.shared_memory.name, offset=offset, size=used_block.offset - offset)
|
257 |
+
)
|
258 |
+
offset = used_block.offset + used_block.size
|
259 |
+
# if tail is free
|
260 |
+
if offset < total_size:
|
261 |
+
free_blocks.append(BlockDescriptor(self.shared_memory.name, offset=offset, size=total_size - offset))
|
262 |
+
|
263 |
+
self.free_blocks = free_blocks
|
264 |
+
self.max_free_block_size = max(block.size for block in self.free_blocks) if self.free_blocks else 0
|
265 |
+
|
266 |
+
def __contains__(self, block_id: str) -> bool:
|
267 |
+
with self.used_blocks_lock:
|
268 |
+
return any(block_id == block.id for block in self.used_blocks) # pytype: disable=attribute-error
|
269 |
+
|
270 |
+
def __getitem__(self, block_id: str) -> BlockDescriptor:
|
271 |
+
with self.used_blocks_lock:
|
272 |
+
for block in self.used_blocks:
|
273 |
+
if block.id == block_id: # pytype: disable=attribute-error
|
274 |
+
return block
|
275 |
+
raise KeyError(f"Block with id {block_id} not found in segment {self.shared_memory.name}")
|
276 |
+
|
277 |
+
def allocate(self, offset, byte_size):
|
278 |
+
block = BlockDescriptor(self.shared_memory.name, offset=offset, size=byte_size)
|
279 |
+
with self.used_blocks_lock:
|
280 |
+
self.used_blocks.append(block)
|
281 |
+
self.used_blocks.sort(key=lambda block: block.offset)
|
282 |
+
self._update_free_blocks()
|
283 |
+
return block
|
284 |
+
|
285 |
+
def release(self, block: BlockDescriptor):
|
286 |
+
with self.used_blocks_lock:
|
287 |
+
self.used_blocks.remove(block)
|
288 |
+
self._update_free_blocks()
|
289 |
+
|
290 |
+
|
291 |
+
class _DataBlocksServer:
|
292 |
+
_instance = None
|
293 |
+
_cnt = 0
|
294 |
+
_minimal_segment_size = 4096 # 4KB
|
295 |
+
|
296 |
+
def __new__(cls):
|
297 |
+
if cls._instance is None:
|
298 |
+
cls._instance = super().__new__(cls)
|
299 |
+
return cls._instance
|
300 |
+
|
301 |
+
def __init__(self):
|
302 |
+
# WAR: for some reason, the __init__ is called on each create of proxy object
|
303 |
+
if self._cnt == 1:
|
304 |
+
return
|
305 |
+
self._cnt += 1
|
306 |
+
self._id = uuid.uuid4() # to verify that it is singleton across processes
|
307 |
+
self._segments = []
|
308 |
+
self._segments_lock = threading.RLock()
|
309 |
+
atexit.register(self.close)
|
310 |
+
|
311 |
+
def get_free_blocks(self, bytes_sizes: Sequence[int]) -> Sequence[str]:
|
312 |
+
tensors_ids = []
|
313 |
+
with self._segments_lock:
|
314 |
+
for byte_size in bytes_sizes:
|
315 |
+
for segment in self._segments:
|
316 |
+
if segment.max_free_block_size >= byte_size:
|
317 |
+
for free_block in segment.free_blocks:
|
318 |
+
if free_block.size >= byte_size:
|
319 |
+
block = self._allocate_block(segment, free_block.offset, byte_size)
|
320 |
+
tensors_ids.append(block.id) # pytype: disable=attribute-error
|
321 |
+
break
|
322 |
+
else:
|
323 |
+
continue # If no suitable block was found, try the next segment
|
324 |
+
break # If a suitable block was found, don't try any more segments
|
325 |
+
else: # If no suitable block was found in any segment
|
326 |
+
new_segment_size = int(
|
327 |
+
max(self._minimal_segment_size, math.pow(2, math.ceil(math.log2(byte_size))))
|
328 |
+
)
|
329 |
+
block = self._allocate_block(
|
330 |
+
self._create_new_segment(new_segment_size), offset=0, byte_size=byte_size
|
331 |
+
)
|
332 |
+
tensors_ids.append(block.id) # pytype: disable=attribute-error
|
333 |
+
return tensors_ids
|
334 |
+
|
335 |
+
def release_block(self, block_id: str):
|
336 |
+
with self._segments_lock:
|
337 |
+
for segment in self._segments:
|
338 |
+
try:
|
339 |
+
block = segment[block_id]
|
340 |
+
segment.release(block)
|
341 |
+
return
|
342 |
+
except KeyError:
|
343 |
+
pass
|
344 |
+
raise KeyError(f"Block with id {block_id} not found in server")
|
345 |
+
|
346 |
+
def _allocate_block(self, segment: _SharedMemorySegment, offset: int, byte_size: int) -> BlockDescriptor:
|
347 |
+
return segment.allocate(offset, byte_size)
|
348 |
+
|
349 |
+
def _create_new_segment(self, segment_size):
|
350 |
+
segment = _SharedMemorySegment(segment_size)
|
351 |
+
self._segments.append(segment)
|
352 |
+
return segment
|
353 |
+
|
354 |
+
def get_debug_status(self):
|
355 |
+
return {
|
356 |
+
"server_id": str(self._id),
|
357 |
+
"host_pid": multiprocessing.current_process().pid,
|
358 |
+
"segments": [
|
359 |
+
{
|
360 |
+
"shared_memory": segment.shared_memory.name,
|
361 |
+
"used_blocks": [str(block) for block in segment.used_blocks],
|
362 |
+
}
|
363 |
+
for segment in self._segments
|
364 |
+
],
|
365 |
+
}
|
366 |
+
|
367 |
+
def close(self):
|
368 |
+
multiprocessing.util.debug(f"Closing server {self._id}")
|
369 |
+
with self._segments_lock:
|
370 |
+
while self._segments:
|
371 |
+
segment = self._segments.pop()
|
372 |
+
multiprocessing.util.debug(f"Closing and delete segment {segment.shared_memory.name}")
|
373 |
+
segment.shared_memory.close()
|
374 |
+
segment.shared_memory.unlink()
|
375 |
+
|
376 |
+
|
377 |
+
class BlocksStoreManager(multiprocessing.managers.BaseManager):
|
378 |
+
"""Remote block store for storing and retrieving numpy arrays in/from shared memory."""
|
379 |
+
|
380 |
+
@classmethod
|
381 |
+
def _run_server(cls, registry, address, authkey, serializer, writer, initializer=None, initargs=()):
|
382 |
+
PR_SET_PDEATHSIG = 1 # noqa
|
383 |
+
libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
|
384 |
+
libc.prctl(PR_SET_PDEATHSIG, signal.SIGTERM) # terminate process when parent **thread** dies
|
385 |
+
|
386 |
+
if bool(os.environ.get("PYTRITON_VIZTRACER")):
|
387 |
+
from viztracer import VizTracer # type: ignore # pytype: disable=import-error
|
388 |
+
|
389 |
+
cls._tracer = VizTracer(log_async=True, log_gc=True, tracer_entries=10000000, pid_suffix=True)
|
390 |
+
cls._tracer.register_exit()
|
391 |
+
cls._tracer.start()
|
392 |
+
|
393 |
+
super()._run_server(
|
394 |
+
registry, address, authkey, serializer, writer, initializer, initargs
|
395 |
+
) # pytype: disable=attribute-error
|
396 |
+
|
397 |
+
|
398 |
+
class _DataBlocksServerProxy(multiprocessing.managers.BaseProxy):
|
399 |
+
def release_block(self, /, *args, **kwargs):
|
400 |
+
return self._callmethod("release_block", args, kwargs)
|
401 |
+
|
402 |
+
def get_free_blocks(self, /, *args, **kwargs):
|
403 |
+
return self._callmethod("get_free_blocks", args, kwargs)
|
404 |
+
|
405 |
+
def _get_debug_status(self, /, *args, **kwargs):
|
406 |
+
return self._callmethod("get_debug_status", args, kwargs)
|
407 |
+
|
408 |
+
def close(self, /, *args, **kwargs):
|
409 |
+
return self._callmethod("close", args, kwargs)
|
410 |
+
|
411 |
+
|
412 |
+
BlocksStoreManager.register("blocks", _DataBlocksServer, proxytype=_DataBlocksServerProxy)
|
413 |
+
|
414 |
+
|
415 |
+
class _FileLock:
|
416 |
+
_locks = {}
|
417 |
+
|
418 |
+
def __new__(cls, file_path):
|
419 |
+
if file_path not in cls._locks:
|
420 |
+
cls._locks[file_path] = super().__new__(cls)
|
421 |
+
return cls._locks[file_path]
|
422 |
+
|
423 |
+
def __init__(self, file_path):
|
424 |
+
if hasattr(self, "_file_path"):
|
425 |
+
return
|
426 |
+
self._file_path = pathlib.Path(file_path)
|
427 |
+
self._file_lock = None
|
428 |
+
self._lock = threading.RLock()
|
429 |
+
atexit.register(self._clean)
|
430 |
+
|
431 |
+
def __enter__(self):
|
432 |
+
self._file_lock = self._file_path.open("a")
|
433 |
+
fcntl.flock(self._file_lock.fileno(), fcntl.LOCK_EX)
|
434 |
+
self._lock.acquire()
|
435 |
+
|
436 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
437 |
+
fcntl.flock(self._file_lock.fileno(), fcntl.LOCK_UN)
|
438 |
+
self._lock.release()
|
439 |
+
|
440 |
+
def _clean(self):
|
441 |
+
if self._file_lock is not None:
|
442 |
+
self._file_lock.close()
|
443 |
+
try:
|
444 |
+
self._file_path.unlink(missing_ok=True)
|
445 |
+
except OSError as e:
|
446 |
+
LOGGER.warning(f"Could not remove lock file {self._file_path}; {e}")
|
447 |
+
|
448 |
+
|
449 |
+
class _Popen(multiprocessing.popen_spawn_posix.Popen):
|
450 |
+
def _launch(self, process_obj):
|
451 |
+
# Modified version of multiprocessing.popen_spawn_posix.Popen._launch
|
452 |
+
import io
|
453 |
+
import os
|
454 |
+
from multiprocessing import context, resource_tracker, spawn, util
|
455 |
+
|
456 |
+
tracker_fd = resource_tracker.getfd()
|
457 |
+
self._fds.append(tracker_fd) # pytype: disable=attribute-error
|
458 |
+
|
459 |
+
# get prep_data + remove init_main_from* as they are not required for TensorStore process
|
460 |
+
prep_data = spawn.get_preparation_data(process_obj._name)
|
461 |
+
prep_data.pop("init_main_from_module", None)
|
462 |
+
prep_data.pop("init_main_from_path", None)
|
463 |
+
|
464 |
+
fp = io.BytesIO()
|
465 |
+
context.set_spawning_popen(self)
|
466 |
+
try:
|
467 |
+
context.reduction.dump(prep_data, fp) # pytype: disable=module-attr
|
468 |
+
context.reduction.dump(process_obj, fp) # pytype: disable=module-attr
|
469 |
+
finally:
|
470 |
+
context.set_spawning_popen(None)
|
471 |
+
|
472 |
+
parent_r = child_w = child_r = parent_w = None
|
473 |
+
try:
|
474 |
+
parent_r, child_w = os.pipe()
|
475 |
+
child_r, parent_w = os.pipe()
|
476 |
+
cmd = spawn.get_command_line(tracker_fd=tracker_fd, pipe_handle=child_r)
|
477 |
+
self._fds.extend([child_r, child_w]) # pytype: disable=attribute-error
|
478 |
+
self.pid = util.spawnv_passfds(
|
479 |
+
spawn.get_executable(),
|
480 |
+
cmd,
|
481 |
+
self._fds, # pytype: disable=attribute-error,wrong-arg-types
|
482 |
+
)
|
483 |
+
self.sentinel = parent_r
|
484 |
+
with open(parent_w, "wb", closefd=False) as f:
|
485 |
+
f.write(fp.getbuffer())
|
486 |
+
finally:
|
487 |
+
fds_to_close = []
|
488 |
+
for fd in (parent_r, parent_w):
|
489 |
+
if fd is not None:
|
490 |
+
fds_to_close.append(fd)
|
491 |
+
self.finalizer = util.Finalize(self, util.close_fds, fds_to_close) # pytype: disable=module-attr
|
492 |
+
|
493 |
+
for fd in (child_r, child_w):
|
494 |
+
if fd is not None:
|
495 |
+
os.close(fd)
|
496 |
+
|
497 |
+
|
498 |
+
class _SpawnProcess(multiprocessing.process.BaseProcess):
|
499 |
+
_start_method = "spawn"
|
500 |
+
|
501 |
+
@staticmethod
|
502 |
+
def _Popen(process_obj): # noqa N802
|
503 |
+
return _Popen(process_obj)
|
504 |
+
|
505 |
+
|
506 |
+
class _SpawnContext(multiprocessing.context.BaseContext):
|
507 |
+
_name = "spawn"
|
508 |
+
Process = _SpawnProcess
|
509 |
+
|
510 |
+
|
511 |
+
class TensorStore:
|
512 |
+
"""Tensor store for storing and retrieving numpy arrays in/from shared memory."""
|
513 |
+
|
514 |
+
_SOCKET_EXISTANCE_CHECK_INTERVAL_S = 0.1
|
515 |
+
_instances = {}
|
516 |
+
|
517 |
+
def __new__(cls, *args, **kwargs):
|
518 |
+
"""Create TensorStore object. If object with given address already exists, return it."""
|
519 |
+
if args:
|
520 |
+
address = args[0]
|
521 |
+
elif "address" in kwargs:
|
522 |
+
address = kwargs["address"]
|
523 |
+
else:
|
524 |
+
raise TypeError("TensorStore() missing 1 required positional argument: 'address'")
|
525 |
+
|
526 |
+
address = address.as_posix() if isinstance(address, pathlib.Path) else address
|
527 |
+
|
528 |
+
if address not in cls._instances:
|
529 |
+
cls._instances[address] = super().__new__(cls)
|
530 |
+
|
531 |
+
return cls._instances[address]
|
532 |
+
|
533 |
+
def __init__(self, address: Union[str, pathlib.Path], auth_key: Optional[bytes] = None):
|
534 |
+
"""Initialize TensorStore object.
|
535 |
+
|
536 |
+
Args:
|
537 |
+
address: address of data store
|
538 |
+
auth_key: authentication key required to setup connection. If not provided, current process authkey will be used
|
539 |
+
"""
|
540 |
+
if not hasattr(self, "_remote_blocks_store_manager"):
|
541 |
+
address = address.as_posix() if isinstance(address, pathlib.Path) else address
|
542 |
+
self._remote_blocks_store_manager = BlocksStoreManager(address, authkey=auth_key, ctx=_SpawnContext())
|
543 |
+
self._remote_blocks_store = None
|
544 |
+
self._manager_start_stop_filelock = _FileLock(f"{address}.lock")
|
545 |
+
|
546 |
+
# container for keeping map between tensor_id and numpy array weak ref
|
547 |
+
self._handled_blocks: Dict[str, weakref.ReferenceType] = {}
|
548 |
+
self._handled_blocks_lock = threading.RLock()
|
549 |
+
|
550 |
+
self._shm_segments: Dict[str, multiprocessing.shared_memory.SharedMemory] = {}
|
551 |
+
self._shm_segments_lock = threading.RLock()
|
552 |
+
|
553 |
+
self.serialize = serialize_numpy_with_struct_header
|
554 |
+
self.deserialize = deserialize_numpy_with_struct_header
|
555 |
+
self._calc_serialized_tensor_size = calc_serialized_size_of_numpy_with_struct_header
|
556 |
+
|
557 |
+
@property
|
558 |
+
def address(self) -> str:
|
559 |
+
"""Return address of remote block store."""
|
560 |
+
return self._remote_blocks_store_manager.address
|
561 |
+
|
562 |
+
def start(self):
|
563 |
+
"""Start remote block store."""
|
564 |
+
with self._manager_start_stop_filelock:
|
565 |
+
if self._remote_blocks_store is not None:
|
566 |
+
raise RuntimeError("Remote block store is already started/connected")
|
567 |
+
|
568 |
+
self._remote_blocks_store_manager.start()
|
569 |
+
self._remote_blocks_store = self._remote_blocks_store_manager.blocks() # pytype: disable=attribute-error
|
570 |
+
|
571 |
+
address = pathlib.Path(self._remote_blocks_store_manager.address)
|
572 |
+
self._wait_for_address(address)
|
573 |
+
LOGGER.debug(
|
574 |
+
f"Started remote block store at {address} (pid={self._remote_blocks_store_manager._process.pid})" # pytype: disable=attribute-error
|
575 |
+
)
|
576 |
+
|
577 |
+
def connect(self, timeout_s: Optional[float] = None):
|
578 |
+
"""Connect to remote block store."""
|
579 |
+
if self._remote_blocks_store is None:
|
580 |
+
address = pathlib.Path(self._remote_blocks_store_manager.address)
|
581 |
+
|
582 |
+
self._wait_for_address(address, timeout_s)
|
583 |
+
self._remote_blocks_store_manager.connect()
|
584 |
+
self._remote_blocks_store = self._remote_blocks_store_manager.blocks() # pytype: disable=attribute-error
|
585 |
+
LOGGER.debug(f"Connected to remote block store at {address})")
|
586 |
+
else:
|
587 |
+
LOGGER.debug(f"Already connectd to remote block store at {self.address}")
|
588 |
+
|
589 |
+
def _wait_for_address(self, address, timeout_s: Optional[float] = None):
|
590 |
+
should_stop_at = time.time() + timeout_s if timeout_s is not None else None
|
591 |
+
if timeout_s is not None and self._SOCKET_EXISTANCE_CHECK_INTERVAL_S > timeout_s:
|
592 |
+
socket_existance_check_interval = timeout_s
|
593 |
+
else:
|
594 |
+
socket_existance_check_interval = self._SOCKET_EXISTANCE_CHECK_INTERVAL_S
|
595 |
+
|
596 |
+
while not address.exists():
|
597 |
+
if should_stop_at is not None and time.time() >= should_stop_at:
|
598 |
+
raise TimeoutError(f"Timeout while waiting for {address} to be created")
|
599 |
+
time.sleep(socket_existance_check_interval)
|
600 |
+
|
601 |
+
def _calc_serialized_size(self, tensor: np.ndarray) -> int:
|
602 |
+
# frames payload sum + total size + frames sizes
|
603 |
+
# assume 2 frames: header with tensor description + data
|
604 |
+
return sum(self._calc_serialized_tensor_size(tensor)) + struct.calcsize("<I") + 2 * struct.calcsize("<I")
|
605 |
+
|
606 |
+
def put(self, tensors: Sequence[np.ndarray]) -> Sequence[str]:
|
607 |
+
"""Append tensor to shared memory buffer.
|
608 |
+
|
609 |
+
Args:
|
610 |
+
tensors: numpy arrays to store
|
611 |
+
|
612 |
+
Returns:
|
613 |
+
List of ids of stored tensors
|
614 |
+
"""
|
615 |
+
byte_size_of_frames_containers = [self._calc_serialized_size(tensor) for tensor in tensors]
|
616 |
+
tensors_ids = self._remote_blocks_store.get_free_blocks(byte_size_of_frames_containers)
|
617 |
+
blocks = [BlockDescriptor.from_id(tensor_id) for tensor_id in tensors_ids]
|
618 |
+
|
619 |
+
for tensor, block in zip(tensors, blocks):
|
620 |
+
with self._shm_segments_lock:
|
621 |
+
shm = self._shm_segments.get(block.shm_name)
|
622 |
+
if shm is None:
|
623 |
+
shm = multiprocessing.shared_memory.SharedMemory(block.shm_name, create=False)
|
624 |
+
self._shm_segments[block.shm_name] = shm
|
625 |
+
|
626 |
+
frames = self.serialize(tensor)
|
627 |
+
self._copy_frames(frames, shm, block.offset)
|
628 |
+
|
629 |
+
return tensors_ids
|
630 |
+
|
631 |
+
def get(self, tensor_id: str) -> np.ndarray:
|
632 |
+
"""Get numpy array from tensor store.
|
633 |
+
|
634 |
+
Args:
|
635 |
+
tensor_id: id of of tenosr to get
|
636 |
+
|
637 |
+
Returns:
|
638 |
+
numpy array
|
639 |
+
"""
|
640 |
+
tensor = None
|
641 |
+
# try to handle already handled tensor from weakref
|
642 |
+
with self._handled_blocks_lock:
|
643 |
+
tensor_ref = self._handled_blocks.get(tensor_id)
|
644 |
+
if tensor_ref is not None:
|
645 |
+
tensor = tensor_ref()
|
646 |
+
|
647 |
+
if tensor is None: # if tensor was not handled yet or weakref is already empty
|
648 |
+
block = BlockDescriptor.from_id(tensor_id)
|
649 |
+
|
650 |
+
# check if shm segment is already opened
|
651 |
+
with self._shm_segments_lock:
|
652 |
+
shm = self._shm_segments.get(block.shm_name)
|
653 |
+
|
654 |
+
# if not open it and put into cache
|
655 |
+
if shm is None:
|
656 |
+
shm = multiprocessing.shared_memory.SharedMemory(block.shm_name, create=False)
|
657 |
+
with self._shm_segments_lock:
|
658 |
+
shm = self._shm_segments.setdefault(block.shm_name, shm) # in meantime other thread could create it
|
659 |
+
|
660 |
+
frames = self._handle_frames(shm, block.offset)
|
661 |
+
tensor = self.deserialize(frames)
|
662 |
+
|
663 |
+
# store tensor in weakref to be able to release shared memory when tensor will be garbage collected
|
664 |
+
with self._handled_blocks_lock:
|
665 |
+
tensor_ref = self._handled_blocks.setdefault(tensor_id, weakref.ref(tensor))
|
666 |
+
tensor = tensor_ref()
|
667 |
+
|
668 |
+
return tensor # pytype: disable=bad-return-type
|
669 |
+
|
670 |
+
def release_block(self, tensor_id: str):
|
671 |
+
"""Release shared memory block.
|
672 |
+
|
673 |
+
Args:
|
674 |
+
tensor_id: id of tensor to release
|
675 |
+
"""
|
676 |
+
tensor_ref = None
|
677 |
+
with self._handled_blocks_lock:
|
678 |
+
tensor_ref = self._handled_blocks.pop(tensor_id, None)
|
679 |
+
|
680 |
+
try:
|
681 |
+
if tensor_ref is not None:
|
682 |
+
self._remote_blocks_store.release_block(tensor_id)
|
683 |
+
except OSError: # thrown when remote process is already closed
|
684 |
+
LOGGER.warning(
|
685 |
+
f"Failed to release block {tensor_id} on remote process at {self.address}. Probably remote process is already closed"
|
686 |
+
)
|
687 |
+
|
688 |
+
def _copy_frames(
|
689 |
+
self,
|
690 |
+
frames: List[Union[bytes, memoryview]],
|
691 |
+
shm: multiprocessing.shared_memory.SharedMemory,
|
692 |
+
offset: int,
|
693 |
+
) -> int:
|
694 |
+
total_size = struct.calcsize("<I") # start after total_size; max 4GB for all frames
|
695 |
+
for frame in frames:
|
696 |
+
if isinstance(frame, bytes):
|
697 |
+
frame = memoryview(frame)
|
698 |
+
|
699 |
+
assert frame.contiguous, "Only contiguous arrays are supported"
|
700 |
+
struct.pack_into("<I", shm.buf, offset + total_size, frame.nbytes) # pytype: disable=wrong-arg-types
|
701 |
+
total_size += struct.calcsize("<I")
|
702 |
+
shm.buf[offset + total_size : offset + total_size + frame.nbytes] = frame.cast("B")
|
703 |
+
|
704 |
+
total_size += frame.nbytes
|
705 |
+
|
706 |
+
struct.pack_into("<I", shm.buf, offset, total_size) # pytype: disable=wrong-arg-types
|
707 |
+
return total_size
|
708 |
+
|
709 |
+
def _handle_frames(self, shm: multiprocessing.shared_memory.SharedMemory, block_offset: int) -> List[memoryview]:
|
710 |
+
frames = []
|
711 |
+
(total_size,) = struct.unpack_from("<I", shm.buf, block_offset) # pytype: disable=wrong-arg-types
|
712 |
+
offset = struct.calcsize("<I")
|
713 |
+
while offset < total_size:
|
714 |
+
(frame_size,) = struct.unpack_from("<I", shm.buf, block_offset + offset) # pytype: disable=wrong-arg-types
|
715 |
+
offset += struct.calcsize("<I")
|
716 |
+
frame = shm.buf[block_offset + offset : block_offset + offset + frame_size]
|
717 |
+
offset += frame_size
|
718 |
+
frames.append(frame)
|
719 |
+
return frames
|
720 |
+
|
721 |
+
def close(self):
|
722 |
+
"""Free resources used by TensorStore object."""
|
723 |
+
from multiprocessing.resource_tracker import register, unregister
|
724 |
+
|
725 |
+
LOGGER.debug(f"TensorStore is being closed (is_started={self.is_started()})")
|
726 |
+
|
727 |
+
gc.collect()
|
728 |
+
with self._handled_blocks_lock:
|
729 |
+
tensors_ids = list(self._handled_blocks)
|
730 |
+
for tensor_id in tensors_ids:
|
731 |
+
self.release_block(tensor_id)
|
732 |
+
|
733 |
+
with self._shm_segments_lock:
|
734 |
+
while self._shm_segments:
|
735 |
+
_, shm = self._shm_segments.popitem()
|
736 |
+
LOGGER.debug(f"Closing shared memory {shm.name}")
|
737 |
+
try:
|
738 |
+
shm.close()
|
739 |
+
except Exception as e:
|
740 |
+
LOGGER.warning(f"Failed to close shared memory {shm.name}: {e}")
|
741 |
+
finally:
|
742 |
+
if not self.is_started():
|
743 |
+
register(shm._name, "shared_memory") # pytype: disable=attribute-error
|
744 |
+
unregister(shm._name, "shared_memory") # pytype: disable=attribute-error
|
745 |
+
|
746 |
+
if self.is_started():
|
747 |
+
if self._remote_blocks_store is not None:
|
748 |
+
LOGGER.debug(f"Releasing all resources on remote process at {self.address}")
|
749 |
+
try:
|
750 |
+
self._remote_blocks_store.close()
|
751 |
+
except FileNotFoundError: # thrown when remote process is already closed
|
752 |
+
pass
|
753 |
+
self._remote_blocks_store = None
|
754 |
+
LOGGER.debug(f"Shutting down side process of data store at {self.address}")
|
755 |
+
self._remote_blocks_store_manager.shutdown()
|
756 |
+
LOGGER.debug(f"TensorStore at {self.address} closed")
|
757 |
+
|
758 |
+
def is_started(self) -> bool:
|
759 |
+
"""Check if remote block store was started by this instance.
|
760 |
+
|
761 |
+
Returns:
|
762 |
+
True if remote block store was started by this instance, False otherwise
|
763 |
+
"""
|
764 |
+
return hasattr(self._remote_blocks_store_manager, "shutdown")
|
765 |
+
|
766 |
+
|
767 |
+
def get_debug_status(tensor_store: TensorStore) -> dict:
|
768 |
+
"""Get debug status of remote block store.
|
769 |
+
|
770 |
+
Args:
|
771 |
+
tensor_store: TensorStore object
|
772 |
+
|
773 |
+
Returns:
|
774 |
+
Debug status of remote block store
|
775 |
+
"""
|
776 |
+
if tensor_store._remote_blocks_store is None:
|
777 |
+
raise RuntimeError("Remote block store is not initialized")
|
778 |
+
|
779 |
+
return tensor_store._remote_blocks_store._get_debug_status()
|
780 |
+
|
781 |
+
|
782 |
+
class BaseRequestsResponsesSerializerDeserializer(abc.ABC):
|
783 |
+
"""Base class for requests/responses serializer/deserializer."""
|
784 |
+
|
785 |
+
@abc.abstractmethod
|
786 |
+
def serialize_requests(self, requests: Requests) -> bytes:
|
787 |
+
"""Serialize requests.
|
788 |
+
|
789 |
+
Args:
|
790 |
+
requests: list of requests to serialize
|
791 |
+
|
792 |
+
Returns:
|
793 |
+
Serialized requests
|
794 |
+
"""
|
795 |
+
pass
|
796 |
+
|
797 |
+
@abc.abstractmethod
|
798 |
+
def deserialize_requests(self, requests_payload: bytes) -> Requests:
|
799 |
+
"""Deserialize requests.
|
800 |
+
|
801 |
+
Args:
|
802 |
+
requests_payload: serialized requests
|
803 |
+
|
804 |
+
Returns:
|
805 |
+
List of deserialized requests
|
806 |
+
"""
|
807 |
+
pass
|
808 |
+
|
809 |
+
@abc.abstractmethod
|
810 |
+
def free_requests_resources(self, requests_payload: bytes):
|
811 |
+
"""Free resources used by requests."""
|
812 |
+
pass
|
813 |
+
|
814 |
+
@abc.abstractmethod
|
815 |
+
def serialize_responses(self, responses: Responses) -> bytes:
|
816 |
+
"""Serialize responses.
|
817 |
+
|
818 |
+
Args:
|
819 |
+
responses: list of responses to serialize
|
820 |
+
|
821 |
+
Returns:
|
822 |
+
Serialized responses
|
823 |
+
"""
|
824 |
+
pass
|
825 |
+
|
826 |
+
@abc.abstractmethod
|
827 |
+
def deserialize_responses(self, responses_payload: bytes) -> Responses:
|
828 |
+
"""Deserialize responses.
|
829 |
+
|
830 |
+
Args:
|
831 |
+
responses_payload: serialized responses
|
832 |
+
|
833 |
+
Returns:
|
834 |
+
List of deserialized responses
|
835 |
+
"""
|
836 |
+
pass
|
837 |
+
|
838 |
+
@abc.abstractmethod
|
839 |
+
def free_responses_resources(self, responses_payload: bytes):
|
840 |
+
"""Free resources used by responses."""
|
841 |
+
pass
|
842 |
+
|
843 |
+
|
844 |
+
class Base64SerializerDeserializer(BaseRequestsResponsesSerializerDeserializer):
|
845 |
+
"""Serializer/deserializer for requests/responses using base64 implementation."""
|
846 |
+
|
847 |
+
def serialize_requests(self, requests: Requests) -> bytes:
|
848 |
+
"""Serialize requests.
|
849 |
+
|
850 |
+
Args:
|
851 |
+
requests: list of requests to serialize
|
852 |
+
|
853 |
+
Returns:
|
854 |
+
Serialized requests
|
855 |
+
"""
|
856 |
+
serialized_requests = self._serialize_named_tensors_lists(requests)
|
857 |
+
requests_list = []
|
858 |
+
for request, serialized_request in zip(requests, serialized_requests):
|
859 |
+
serialized_request = {"data": serialized_request, "parameters": request.parameters}
|
860 |
+
if request.span is not None:
|
861 |
+
serialized_request["span"] = get_span_dict(request.span)
|
862 |
+
requests_list.append(serialized_request)
|
863 |
+
|
864 |
+
requests = {"requests": requests_list}
|
865 |
+
requests = json.dumps(requests).encode("utf-8")
|
866 |
+
return requests
|
867 |
+
|
868 |
+
def deserialize_requests(self, requests_payload: bytes) -> Requests:
|
869 |
+
"""Deserialize requests.
|
870 |
+
|
871 |
+
Args:
|
872 |
+
requests_payload: serialized requests
|
873 |
+
|
874 |
+
Returns:
|
875 |
+
List of deserialized requests
|
876 |
+
"""
|
877 |
+
requests = json.loads(requests_payload)
|
878 |
+
requests_data = [request["data"] for request in requests["requests"]]
|
879 |
+
requests_data = self._deserialized_named_tensors_lists(requests_data)
|
880 |
+
|
881 |
+
deserialized_requests = []
|
882 |
+
for request, request_data in zip(requests["requests"], requests_data):
|
883 |
+
kwargs = {"data": request_data, "parameters": request.get("parameters")}
|
884 |
+
# FIXME: move span creation above just after json.loads
|
885 |
+
if "span" in request:
|
886 |
+
span_dict = request["span"]
|
887 |
+
span = start_span_from_remote(span_dict, "proxy_inference_callable")
|
888 |
+
kwargs["span"] = span
|
889 |
+
request_wrapped = Request(**kwargs)
|
890 |
+
deserialized_requests.append(request_wrapped)
|
891 |
+
|
892 |
+
return deserialized_requests
|
893 |
+
|
894 |
+
def free_requests_resources(self, requests_payload: bytes):
|
895 |
+
"""Free resources used by requests."""
|
896 |
+
pass
|
897 |
+
|
898 |
+
def serialize_responses(self, responses: Responses) -> bytes:
|
899 |
+
"""Serialize responses.
|
900 |
+
|
901 |
+
Args:
|
902 |
+
responses: list of responses to serialize
|
903 |
+
|
904 |
+
Returns:
|
905 |
+
Serialized responses
|
906 |
+
"""
|
907 |
+
responses = self._serialize_named_tensors_lists(responses)
|
908 |
+
responses = {"responses": [{"data": response} for response in responses]}
|
909 |
+
return json.dumps(responses).encode("utf-8")
|
910 |
+
|
911 |
+
def deserialize_responses(self, responses_payload: bytes) -> Responses:
|
912 |
+
"""Deserialize responses.
|
913 |
+
|
914 |
+
Args:
|
915 |
+
responses_payload: serialized responses
|
916 |
+
|
917 |
+
Returns:
|
918 |
+
List of deserialized responses
|
919 |
+
"""
|
920 |
+
if responses_payload:
|
921 |
+
responses = json.loads(responses_payload)
|
922 |
+
responses = [response["data"] for response in responses["responses"]]
|
923 |
+
responses = self._deserialized_named_tensors_lists(responses)
|
924 |
+
return [Response(data=response) for response in responses]
|
925 |
+
else:
|
926 |
+
return []
|
927 |
+
|
928 |
+
def free_responses_resources(self, responses_payload: bytes):
|
929 |
+
"""Free resources used by responses."""
|
930 |
+
pass
|
931 |
+
|
932 |
+
def _serialize_named_tensors_lists(self, named_tensors_lists):
|
933 |
+
def _encode(_tensor):
|
934 |
+
frames = serialize_numpy_with_struct_header(_tensor)
|
935 |
+
return [base64.b64encode(frame).decode("utf-8") for frame in frames]
|
936 |
+
|
937 |
+
return [
|
938 |
+
{tensor_name: _encode(tensor) for tensor_name, tensor in tensors.items()} for tensors in named_tensors_lists
|
939 |
+
]
|
940 |
+
|
941 |
+
def _deserialized_named_tensors_lists(self, named_tensors_lists):
|
942 |
+
def _decode(decoded_tensor):
|
943 |
+
frames = [base64.b64decode(frame.encode("utf-8")) for frame in decoded_tensor]
|
944 |
+
return deserialize_numpy_with_struct_header(frames)
|
945 |
+
|
946 |
+
return [
|
947 |
+
{tensor_name: _decode(encoded_tensor) for tensor_name, encoded_tensor in tensors.items()}
|
948 |
+
for tensors in named_tensors_lists
|
949 |
+
]
|
950 |
+
|
951 |
+
def start(self, url: Union[str, pathlib.Path], authkey: Optional[bytes] = None):
|
952 |
+
"""Start Dummy implementation.
|
953 |
+
|
954 |
+
Args:
|
955 |
+
url: address of data store
|
956 |
+
authkey: authentication key required to setup connection. If not provided, current process authkey will be used
|
957 |
+
"""
|
958 |
+
pass
|
959 |
+
|
960 |
+
def connect(self, url: Union[str, pathlib.Path], authkey: Optional[bytes] = None):
|
961 |
+
"""Connect to Dummy implementation.
|
962 |
+
|
963 |
+
Args:
|
964 |
+
url: address of data store
|
965 |
+
authkey: authentication key required to setup connection. If not provided, current process authkey will be used
|
966 |
+
"""
|
967 |
+
pass
|
968 |
+
|
969 |
+
def close(self):
|
970 |
+
"""Close Dummy implementation."""
|
971 |
+
pass
|
972 |
+
|
973 |
+
|
974 |
+
class TensorStoreSerializerDeserializer(BaseRequestsResponsesSerializerDeserializer):
|
975 |
+
"""Serializer/deserializer for requests/responses using TensorStore."""
|
976 |
+
|
977 |
+
def __init__(self):
|
978 |
+
"""Initialize TensorStoreSerializerDeserializer object."""
|
979 |
+
self._tensor_store = None
|
980 |
+
|
981 |
+
def serialize_requests(self, requests: Requests) -> bytes:
|
982 |
+
"""Serialize requests.
|
983 |
+
|
984 |
+
Args:
|
985 |
+
requests: list of requests to serialize
|
986 |
+
|
987 |
+
Returns:
|
988 |
+
Serialized requests
|
989 |
+
"""
|
990 |
+
serialized_requests = self._serialize_named_tensors_lists(requests)
|
991 |
+
requests_list = []
|
992 |
+
for request, serialized_request in zip(requests, serialized_requests):
|
993 |
+
serialized_request = {"data": serialized_request, "parameters": request.parameters}
|
994 |
+
if request.span is not None:
|
995 |
+
serialized_request["span"] = get_span_dict(request.span)
|
996 |
+
requests_list.append(serialized_request)
|
997 |
+
|
998 |
+
requests = {"requests": requests_list}
|
999 |
+
return json.dumps(requests).encode("utf-8")
|
1000 |
+
|
1001 |
+
def deserialize_requests(self, requests_payload: bytes) -> Requests:
|
1002 |
+
"""Deserialize requests.
|
1003 |
+
|
1004 |
+
Args:
|
1005 |
+
requests_payload: serialized requests
|
1006 |
+
|
1007 |
+
Returns:
|
1008 |
+
List of deserialized requests
|
1009 |
+
"""
|
1010 |
+
requests = json.loads(requests_payload)
|
1011 |
+
deserialized_requests = []
|
1012 |
+
for request in requests["requests"]:
|
1013 |
+
kwargs = {}
|
1014 |
+
if "span" in request:
|
1015 |
+
span_dict = request["span"]
|
1016 |
+
span = start_span_from_remote(span_dict, "proxy_inference_callable")
|
1017 |
+
kwargs["span"] = span
|
1018 |
+
request_data = {
|
1019 |
+
input_name: self._tensor_store.get(tensor_id)
|
1020 |
+
for input_name, tensor_id in request.get("data", {}).items()
|
1021 |
+
}
|
1022 |
+
kwargs["data"] = request_data
|
1023 |
+
kwargs["parameters"] = request.get("parameters")
|
1024 |
+
request_wrapped = Request(**kwargs)
|
1025 |
+
deserialized_requests.append(request_wrapped)
|
1026 |
+
|
1027 |
+
return deserialized_requests
|
1028 |
+
|
1029 |
+
def free_requests_resources(self, requests_payload: bytes):
|
1030 |
+
"""Free resources used by requests."""
|
1031 |
+
if requests_payload:
|
1032 |
+
requests = json.loads(requests_payload)
|
1033 |
+
for response in requests["requests"]:
|
1034 |
+
for _, tensor_id in response.get("data", {}).items():
|
1035 |
+
self._tensor_store.release_block(tensor_id)
|
1036 |
+
|
1037 |
+
def serialize_responses(self, responses: Responses) -> bytes:
|
1038 |
+
"""Serialize responses.
|
1039 |
+
|
1040 |
+
Args:
|
1041 |
+
responses: list of responses to serialize
|
1042 |
+
|
1043 |
+
Returns:
|
1044 |
+
Serialized responses
|
1045 |
+
"""
|
1046 |
+
responses = self._serialize_named_tensors_lists(responses)
|
1047 |
+
responses = {"responses": [{"data": response} for response in responses]}
|
1048 |
+
return json.dumps(responses).encode("utf-8")
|
1049 |
+
|
1050 |
+
def deserialize_responses(self, responses_payload: bytes) -> Responses:
|
1051 |
+
"""Deserialize responses.
|
1052 |
+
|
1053 |
+
Args:
|
1054 |
+
responses_payload: serialized responses
|
1055 |
+
|
1056 |
+
Returns:
|
1057 |
+
List of deserialized responses
|
1058 |
+
"""
|
1059 |
+
if responses_payload:
|
1060 |
+
responses = json.loads(responses_payload)
|
1061 |
+
return [
|
1062 |
+
Response(
|
1063 |
+
data={
|
1064 |
+
input_name: self._tensor_store.get(tensor_id)
|
1065 |
+
for input_name, tensor_id in response.get("data", {}).items()
|
1066 |
+
}
|
1067 |
+
)
|
1068 |
+
for response in responses["responses"]
|
1069 |
+
]
|
1070 |
+
else:
|
1071 |
+
return []
|
1072 |
+
|
1073 |
+
def free_responses_resources(self, responses_payload: bytes):
|
1074 |
+
"""Free resources used by responses."""
|
1075 |
+
if responses_payload:
|
1076 |
+
responses = json.loads(responses_payload)
|
1077 |
+
for response in responses["responses"]:
|
1078 |
+
for _, tensor_id in response.get("data", {}).items():
|
1079 |
+
self._tensor_store.release_block(tensor_id)
|
1080 |
+
|
1081 |
+
def _serialize_named_tensors_lists(self, named_tensors_lists):
|
1082 |
+
values_with_coords = [
|
1083 |
+
(idx, tensor_name, tensor)
|
1084 |
+
for idx, tensors in enumerate(named_tensors_lists)
|
1085 |
+
for tensor_name, tensor in tensors.items()
|
1086 |
+
]
|
1087 |
+
tensor_ids = self._tensor_store.put([tensor for _, _, tensor in values_with_coords])
|
1088 |
+
named_tensors_lists = [{} for _ in range(len(named_tensors_lists))]
|
1089 |
+
for (idx, tensor_name, _), tensor_id in zip(values_with_coords, tensor_ids):
|
1090 |
+
named_tensors_lists[idx][tensor_name] = tensor_id
|
1091 |
+
|
1092 |
+
return named_tensors_lists
|
1093 |
+
|
1094 |
+
def start(self, url: Union[str, pathlib.Path], authkey: Optional[bytes] = None):
|
1095 |
+
"""Start TensorStore.
|
1096 |
+
|
1097 |
+
Args:
|
1098 |
+
url: address of data store
|
1099 |
+
authkey: authentication key required to setup connection. If not provided, current process authkey will be used
|
1100 |
+
"""
|
1101 |
+
self._tensor_store = self._create(url, authkey)
|
1102 |
+
self._tensor_store.start()
|
1103 |
+
|
1104 |
+
def connect(self, url: Union[str, pathlib.Path], authkey: Optional[bytes] = None):
|
1105 |
+
"""Connect to TensorStore.
|
1106 |
+
|
1107 |
+
Args:
|
1108 |
+
url: address of data store
|
1109 |
+
authkey: authentication key required to setup connection. If not provided, current process authkey will be used
|
1110 |
+
"""
|
1111 |
+
self._tensor_store = self._create(url, authkey)
|
1112 |
+
self._tensor_store.connect()
|
1113 |
+
|
1114 |
+
def _create(self, url: Union[str, pathlib.Path], authkey: Optional[bytes] = None):
|
1115 |
+
authkey = authkey or multiprocessing.current_process().authkey
|
1116 |
+
return TensorStore(url, authkey)
|
1117 |
+
|
1118 |
+
def close(self):
|
1119 |
+
"""Close TensorStore."""
|
1120 |
+
if self._tensor_store:
|
1121 |
+
# check if run by this serializer/deserializer
|
1122 |
+
if self._tensor_store.is_started():
|
1123 |
+
debug_status = get_debug_status(self._tensor_store)
|
1124 |
+
used_blocks = [block for segment in debug_status["segments"] for block in segment["used_blocks"]]
|
1125 |
+
if used_blocks:
|
1126 |
+
LOGGER.debug(f"TensorStore used blocks while closing: {used_blocks}")
|
1127 |
+
# raise RuntimeError(
|
1128 |
+
# f"TensorStore at {self._tensor_store.address} is still running. Used blocks: {used_blocks}"
|
1129 |
+
# )
|
1130 |
+
LOGGER.debug(f"Closing TensorStore process at {self._tensor_store.address}")
|
1131 |
+
|
1132 |
+
self._tensor_store.close()
|
1133 |
+
self._tensor_store = None
|