ketanmore commited on
Commit
2720487
1 Parent(s): e5f7812

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +53 -0
  2. benchmark.ipynb +0 -0
  3. requirements.txt +299 -0
  4. results/layout-benchmark-results-images-1.jpg +0 -0
  5. results/layout-benchmark-results-images-10.jpg +0 -0
  6. results/layout-benchmark-results-images-2.jpg +0 -0
  7. results/layout-benchmark-results-images-3.jpg +0 -0
  8. results/layout-benchmark-results-images-4.jpg +0 -0
  9. results/layout-benchmark-results-images-5.jpg +0 -0
  10. results/layout-benchmark-results-images-6.jpg +0 -0
  11. results/layout-benchmark-results-images-7.jpg +0 -0
  12. results/layout-benchmark-results-images-8.jpg +0 -0
  13. results/layout-benchmark-results-images-9.jpg +0 -0
  14. surya/__pycache__/detection.cpython-310.pyc +0 -0
  15. surya/__pycache__/layout.cpython-310.pyc +0 -0
  16. surya/__pycache__/ocr.cpython-310.pyc +0 -0
  17. surya/__pycache__/recognition.cpython-310.pyc +0 -0
  18. surya/__pycache__/schema.cpython-310.pyc +0 -0
  19. surya/__pycache__/settings.cpython-310.pyc +0 -0
  20. surya/benchmark/bbox.py +22 -0
  21. surya/benchmark/metrics.py +139 -0
  22. surya/benchmark/tesseract.py +179 -0
  23. surya/benchmark/util.py +31 -0
  24. surya/detection.py +139 -0
  25. surya/input/__pycache__/processing.cpython-310.pyc +0 -0
  26. surya/input/langs.py +19 -0
  27. surya/input/load.py +74 -0
  28. surya/input/processing.py +116 -0
  29. surya/languages.py +101 -0
  30. surya/layout.py +204 -0
  31. surya/model/detection/__pycache__/processor.cpython-310.pyc +0 -0
  32. surya/model/detection/__pycache__/segformer.cpython-310.pyc +0 -0
  33. surya/model/detection/processor.py +284 -0
  34. surya/model/detection/segformer.py +468 -0
  35. surya/model/ordering/config.py +8 -0
  36. surya/model/ordering/decoder.py +557 -0
  37. surya/model/ordering/encoder.py +83 -0
  38. surya/model/ordering/encoderdecoder.py +90 -0
  39. surya/model/ordering/model.py +34 -0
  40. surya/model/ordering/processor.py +156 -0
  41. surya/model/recognition/__pycache__/config.cpython-310.pyc +0 -0
  42. surya/model/recognition/__pycache__/decoder.cpython-310.pyc +0 -0
  43. surya/model/recognition/__pycache__/encoder.cpython-310.pyc +0 -0
  44. surya/model/recognition/__pycache__/model.cpython-310.pyc +0 -0
  45. surya/model/recognition/__pycache__/processor.cpython-310.pyc +0 -0
  46. surya/model/recognition/__pycache__/tokenizer.cpython-310.pyc +0 -0
  47. surya/model/recognition/config.py +111 -0
  48. surya/model/recognition/decoder.py +511 -0
  49. surya/model/recognition/encoder.py +469 -0
  50. surya/model/recognition/model.py +64 -0
README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ # Suryolo : Layout Model For Arabic Documents
6
+
7
+ Suryolo is combination of Surya layout Model form SuryaOCR(based on Segformer) and YoloV10 objection detection.
8
+
9
+ ## Setup Instructions
10
+
11
+ ### Clone the Surya OCR GitHub Repository
12
+
13
+ ```bash
14
+ git clone https://github.com/vikp/surya.git
15
+ cd surya
16
+ ```
17
+
18
+ ### Switch to v0.4.14
19
+
20
+ ```bash
21
+ git checkout f7c6c04
22
+ ```
23
+
24
+ ### Install Dependencies
25
+
26
+ You can install the required dependencies using the following command:
27
+
28
+ ```bash
29
+ pip install -r requirements.txt
30
+ ```
31
+
32
+ ```bash
33
+ pip install ultralytics
34
+ ```
35
+
36
+ ```bash
37
+ pip install supervision
38
+ ```
39
+
40
+ ### Suryolo Pipeline
41
+
42
+ Download `surya_yolo_pipeline.py` file from the Repository.
43
+
44
+ ```python
45
+ from surya_yolo_pipeline import suryolo
46
+ from surya.postprocessing.heatmap import draw_bboxes_on_image
47
+
48
+ image_path = "sample.jpg"
49
+ image = Image.open(image_path)
50
+ bboxes = suryolo(image_path)
51
+ plotted_image = draw_bboxes_on_image(bboxes,image)
52
+ ```
53
+ #### Refer to `benchmark.ipynb` for comparison between Traditional Surya Layout Model and Suryolo Layout Model.
benchmark.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==0.34.2
3
+ addict==2.4.0
4
+ aiofiles==23.2.1
5
+ aiohappyeyeballs==2.4.0
6
+ aiohttp==3.10.5
7
+ aiosignal==1.3.1
8
+ albucore==0.0.17
9
+ albumentations==1.4.18
10
+ altair==5.4.1
11
+ annotated-types==0.7.0
12
+ antlr4-python3-runtime==4.8
13
+ anyio==4.6.0
14
+ appdirs==1.4.4
15
+ astor==0.8.1
16
+ asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work
17
+ async-timeout==4.0.3
18
+ attrs==24.2.0
19
+ av==13.1.0
20
+ babel==2.16.0
21
+ bce-python-sdk==0.9.23
22
+ bcrypt==4.2.0
23
+ beartype==0.19.0
24
+ beautifulsoup4==4.12.3
25
+ bitsandbytes==0.44.1
26
+ blinker==1.8.2
27
+ boto3==1.35.34
28
+ botocore==1.35.34
29
+ braceexpand==0.1.7
30
+ Brotli @ file:///croot/brotli-split_1714483155106/work
31
+ cachetools==5.5.0
32
+ certifi @ file:///croot/certifi_1725551672989/work/certifi
33
+ cffi==1.17.1
34
+ cfgv==3.4.0
35
+ charset-normalizer @ file:///croot/charset-normalizer_1721748349566/work
36
+ click==8.1.7
37
+ colossalai==0.4.0
38
+ comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1710320294760/work
39
+ contexttimer==0.3.3
40
+ contourpy==1.3.0
41
+ cpm-kernels==1.0.11
42
+ cryptography==43.0.1
43
+ cycler==0.12.1
44
+ Cython==3.0.11
45
+ datasets==3.0.0
46
+ debugpy @ file:///croot/debugpy_1690905042057/work
47
+ decorator==4.4.2
48
+ decord==0.6.0
49
+ deepspeed==0.15.1
50
+ defusedxml==0.7.1
51
+ Deprecated==1.2.14
52
+ diffusers==0.30.3
53
+ dill==0.3.8
54
+ distlib==0.3.8
55
+ distro==1.9.0
56
+ docker-pycreds==0.4.0
57
+ doclayout_yolo==0.0.2
58
+ easydict==1.13
59
+ einops==0.7.0
60
+ entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
61
+ eval_type_backport==0.2.0
62
+ exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1720869315914/work
63
+ executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1725214404607/work
64
+ fabric==3.2.2
65
+ faiss-cpu==1.8.0.post1
66
+ fastapi==0.110.0
67
+ ffmpy==0.4.0
68
+ filelock @ file:///croot/filelock_1700591183607/work
69
+ fire==0.6.0
70
+ flash-attn==2.6.3
71
+ Flask==3.0.3
72
+ flask-babel==4.0.0
73
+ fonttools==4.54.1
74
+ frozenlist==1.4.1
75
+ fsspec==2024.6.1
76
+ ftfy==6.2.3
77
+ future==1.0.0
78
+ fvcore==0.1.5.post20221221
79
+ galore-torch==1.0
80
+ gast==0.3.3
81
+ gdown==5.1.0
82
+ gitdb==4.0.11
83
+ GitPython==3.1.43
84
+ gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645455533097/work
85
+ google==3.0.0
86
+ google-auth==2.35.0
87
+ google-auth-oauthlib==1.0.0
88
+ gradio==4.44.1
89
+ gradio_client==1.3.0
90
+ grpcio==1.66.1
91
+ h11==0.14.0
92
+ h5py==3.10.0
93
+ hjson==3.1.0
94
+ httpcore==1.0.5
95
+ httpx==0.27.2
96
+ huggingface-hub==0.25.0
97
+ identify==2.6.1
98
+ idna==3.6
99
+ imageio==2.35.1
100
+ imageio-ffmpeg==0.5.1
101
+ imgaug==0.4.0
102
+ importlib_metadata==8.5.0
103
+ importlib_resources==6.4.5
104
+ invoke==2.2.0
105
+ iopath==0.1.10
106
+ ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
107
+ ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1725050136642/work
108
+ ipywidgets==8.1.5
109
+ itsdangerous==2.2.0
110
+ jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work
111
+ Jinja2 @ file:///croot/jinja2_1716993405101/work
112
+ jiter==0.5.0
113
+ jmespath==1.0.1
114
+ joblib==1.4.2
115
+ jsonschema==4.23.0
116
+ jsonschema-specifications==2023.12.1
117
+ jupyter-client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1654730843242/work
118
+ jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1727163409502/work
119
+ jupyterlab_widgets==3.0.13
120
+ kiwisolver==1.4.7
121
+ lazy_loader==0.4
122
+ lightning-utilities==0.11.7
123
+ lmdb==1.5.1
124
+ lxml==5.3.0
125
+ Markdown==3.7
126
+ markdown-it-py==3.0.0
127
+ MarkupSafe @ file:///croot/markupsafe_1704205993651/work
128
+ matplotlib==3.7.5
129
+ matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1713250518406/work
130
+ mdurl==0.1.2
131
+ mkl-service==2.4.0
132
+ mkl_fft @ file:///croot/mkl_fft_1725370245198/work
133
+ mkl_random @ file:///croot/mkl_random_1725370241878/work
134
+ mmengine==0.10.5
135
+ moviepy==1.0.3
136
+ mpmath @ file:///croot/mpmath_1690848262763/work
137
+ msgpack==1.1.0
138
+ multidict==6.1.0
139
+ multiprocess==0.70.16
140
+ narwhals==1.9.1
141
+ nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work
142
+ networkx @ file:///croot/networkx_1717597493534/work
143
+ ninja==1.11.1.1
144
+ nodeenv==1.9.1
145
+ numpy==1.26.0
146
+ nvidia-cublas-cu12==12.1.3.1
147
+ nvidia-cuda-cupti-cu12==12.1.105
148
+ nvidia-cuda-nvrtc-cu12==12.1.105
149
+ nvidia-cuda-runtime-cu12==12.1.105
150
+ nvidia-cudnn-cu12==9.1.0.70
151
+ nvidia-cufft-cu12==11.0.2.54
152
+ nvidia-curand-cu12==10.3.2.106
153
+ nvidia-cusolver-cu12==11.4.5.107
154
+ nvidia-cusparse-cu12==12.1.0.106
155
+ nvidia-ml-py==12.560.30
156
+ nvidia-nccl-cu12==2.20.5
157
+ nvidia-nvjitlink-cu12==12.6.77
158
+ nvidia-nvtx-cu12==12.1.105
159
+ oauthlib==3.2.2
160
+ omegaconf==2.1.1
161
+ openai==1.51.0
162
+ opencv-contrib-python==4.10.0.84
163
+ opencv-python==4.9.0.80
164
+ opencv-python-headless==4.9.0.80
165
+ opensora @ file:///share/data/drive_3/ketan/t2v/Open-Sora
166
+ opt-einsum==3.3.0
167
+ orjson==3.10.7
168
+ packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1718189413536/work
169
+ paddleclas==2.5.2
170
+ paddleocr==2.8.1
171
+ paddlepaddle==2.6.2
172
+ pandarallel==1.6.5
173
+ pandas==2.0.3
174
+ parameterized==0.9.0
175
+ paramiko==3.5.0
176
+ parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1712320355065/work
177
+ peft==0.13.0
178
+ pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work
179
+ pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
180
+ Pillow==9.5.0
181
+ platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1726613481435/work
182
+ plumbum==1.9.0
183
+ portalocker==2.10.1
184
+ pre_commit==4.0.0
185
+ prettytable==3.11.0
186
+ proglog==0.1.10
187
+ prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1718047967974/work
188
+ protobuf==4.25.5
189
+ psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work
190
+ ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
191
+ pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1721585709575/work
192
+ py-cpuinfo==9.0.0
193
+ pyarrow==17.0.0
194
+ pyasn1==0.6.1
195
+ pyasn1_modules==0.4.1
196
+ pyclipper==1.3.0.post5
197
+ pycparser==2.22
198
+ pycryptodome==3.20.0
199
+ pydantic==2.9.2
200
+ pydantic-settings==2.5.2
201
+ pydantic_core==2.23.4
202
+ pydub==0.25.1
203
+ Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1714846767233/work
204
+ PyNaCl==1.5.0
205
+ pyparsing==3.1.4
206
+ pypdfium2==4.30.0
207
+ PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work
208
+ python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1709299778482/work
209
+ python-docx==1.1.2
210
+ python-dotenv==1.0.1
211
+ python-multipart==0.0.12
212
+ pytorch-lightning==2.2.1
213
+ pytorchvideo==0.1.5
214
+ pytz==2024.2
215
+ PyYAML @ file:///croot/pyyaml_1698096049011/work
216
+ pyzmq @ file:///croot/pyzmq_1705605076900/work
217
+ qudida==0.0.4
218
+ RapidFuzz==3.10.0
219
+ rarfile==4.2
220
+ ray==2.37.0
221
+ referencing==0.35.1
222
+ regex==2023.12.25
223
+ requests==2.32.3
224
+ requests-oauthlib==2.0.0
225
+ rich==13.9.2
226
+ rotary-embedding-torch==0.5.3
227
+ rpds-py==0.20.0
228
+ rpyc==6.0.0
229
+ rsa==4.9
230
+ ruff==0.6.9
231
+ s3transfer==0.10.2
232
+ safetensors==0.4.5
233
+ scikit-image==0.24.0
234
+ scikit-learn==1.3.2
235
+ scikit-video==1.1.11
236
+ scipy==1.10.1
237
+ seaborn==0.13.2
238
+ semantic-version==2.10.0
239
+ sentencepiece==0.2.0
240
+ sentry-sdk==2.15.0
241
+ setproctitle==1.3.3
242
+ shapely==2.0.6
243
+ shellingham==1.5.4
244
+ six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
245
+ smmap==5.0.1
246
+ sniffio==1.3.1
247
+ soupsieve==2.6
248
+ spaces==0.30.3
249
+ stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
250
+ starlette==0.36.3
251
+ supervision==0.23.0
252
+ SwissArmyTransformer==0.4.12
253
+ sympy @ file:///croot/sympy_1724938189289/work
254
+ tabulate==0.9.0
255
+ tensorboard==2.14.0
256
+ tensorboard-data-server==0.7.2
257
+ tensorboardX==2.6.2.2
258
+ termcolor==2.4.0
259
+ test_tube==0.7.5
260
+ thop==0.1.1.post2209072238
261
+ threadpoolctl==3.5.0
262
+ tifffile==2024.9.20
263
+ timm==0.9.16
264
+ tokenizers==0.20.0
265
+ tomli==2.0.2
266
+ tomlkit==0.12.0
267
+ torch==2.4.1
268
+ torch-lr-finder==0.2.2
269
+ torchaudio==2.4.1
270
+ torchdiffeq==0.2.3
271
+ torchmetrics==1.3.2
272
+ torchvision==0.19.1
273
+ tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1648827254365/work
274
+ tqdm==4.66.5
275
+ traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1713535121073/work
276
+ transformers==4.45.1
277
+ triton==3.0.0
278
+ typer==0.12.5
279
+ typing_extensions @ file:///croot/typing_extensions_1715268824938/work
280
+ tzdata==2024.1
281
+ ujson==5.10.0
282
+ ultralytics==8.3.1
283
+ ultralytics-thop==2.0.8
284
+ urllib3==2.2.1
285
+ uvicorn==0.29.0
286
+ virtualenv==20.26.6
287
+ visualdl==2.5.3
288
+ wandb==0.18.3
289
+ wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work
290
+ webdataset==0.2.100
291
+ websockets==11.0.3
292
+ Werkzeug==3.0.4
293
+ widgetsnbextension==4.0.13
294
+ wrapt==1.16.0
295
+ xxhash==3.5.0
296
+ yacs==0.1.8
297
+ yapf==0.40.2
298
+ yarl==1.11.1
299
+ zipp==3.20.2
results/layout-benchmark-results-images-1.jpg ADDED
results/layout-benchmark-results-images-10.jpg ADDED
results/layout-benchmark-results-images-2.jpg ADDED
results/layout-benchmark-results-images-3.jpg ADDED
results/layout-benchmark-results-images-4.jpg ADDED
results/layout-benchmark-results-images-5.jpg ADDED
results/layout-benchmark-results-images-6.jpg ADDED
results/layout-benchmark-results-images-7.jpg ADDED
results/layout-benchmark-results-images-8.jpg ADDED
results/layout-benchmark-results-images-9.jpg ADDED
surya/__pycache__/detection.cpython-310.pyc ADDED
Binary file (5.06 kB). View file
 
surya/__pycache__/layout.cpython-310.pyc ADDED
Binary file (6.35 kB). View file
 
surya/__pycache__/ocr.cpython-310.pyc ADDED
Binary file (2.79 kB). View file
 
surya/__pycache__/recognition.cpython-310.pyc ADDED
Binary file (5.86 kB). View file
 
surya/__pycache__/schema.cpython-310.pyc ADDED
Binary file (6.41 kB). View file
 
surya/__pycache__/settings.cpython-310.pyc ADDED
Binary file (3.77 kB). View file
 
surya/benchmark/bbox.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fitz as pymupdf
2
+ from surya.postprocessing.util import rescale_bbox
3
+
4
+
5
+ def get_pdf_lines(pdf_path, img_sizes):
6
+ doc = pymupdf.open(pdf_path)
7
+ page_lines = []
8
+ for idx, img_size in enumerate(img_sizes):
9
+ page = doc[idx]
10
+ blocks = page.get_text("dict", sort=True, flags=pymupdf.TEXTFLAGS_DICT & ~pymupdf.TEXT_PRESERVE_LIGATURES & ~pymupdf.TEXT_PRESERVE_IMAGES)["blocks"]
11
+
12
+ line_boxes = []
13
+ for block_idx, block in enumerate(blocks):
14
+ for l in block["lines"]:
15
+ line_boxes.append(list(l["bbox"]))
16
+
17
+ page_box = page.bound()
18
+ pwidth, pheight = page_box[2] - page_box[0], page_box[3] - page_box[1]
19
+ line_boxes = [rescale_bbox(bbox, (pwidth, pheight), img_size) for bbox in line_boxes]
20
+ page_lines.append(line_boxes)
21
+
22
+ return page_lines
surya/benchmark/metrics.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from itertools import repeat
3
+
4
+ import numpy as np
5
+ from concurrent.futures import ProcessPoolExecutor
6
+
7
+ def intersection_area(box1, box2):
8
+ x_left = max(box1[0], box2[0])
9
+ y_top = max(box1[1], box2[1])
10
+ x_right = min(box1[2], box2[2])
11
+ y_bottom = min(box1[3], box2[3])
12
+
13
+ if x_right < x_left or y_bottom < y_top:
14
+ return 0.0
15
+
16
+ return (x_right - x_left) * (y_bottom - y_top)
17
+
18
+
19
+ def intersection_pixels(box1, box2):
20
+ x_left = max(box1[0], box2[0])
21
+ y_top = max(box1[1], box2[1])
22
+ x_right = min(box1[2], box2[2])
23
+ y_bottom = min(box1[3], box2[3])
24
+
25
+ if x_right < x_left or y_bottom < y_top:
26
+ return set()
27
+
28
+ x_left, x_right = int(x_left), int(x_right)
29
+ y_top, y_bottom = int(y_top), int(y_bottom)
30
+
31
+ coords = np.meshgrid(np.arange(x_left, x_right), np.arange(y_top, y_bottom))
32
+ pixels = set(zip(coords[0].flat, coords[1].flat))
33
+
34
+ return pixels
35
+
36
+
37
+ def calculate_coverage(box, other_boxes, penalize_double=False):
38
+ box_area = (box[2] - box[0]) * (box[3] - box[1])
39
+ if box_area == 0:
40
+ return 0
41
+
42
+ # find total coverage of the box
43
+ covered_pixels = set()
44
+ double_coverage = list()
45
+ for other_box in other_boxes:
46
+ ia = intersection_pixels(box, other_box)
47
+ double_coverage.append(list(covered_pixels.intersection(ia)))
48
+ covered_pixels = covered_pixels.union(ia)
49
+
50
+ # Penalize double coverage - having multiple bboxes overlapping the same pixels
51
+ double_coverage_penalty = len(double_coverage)
52
+ if not penalize_double:
53
+ double_coverage_penalty = 0
54
+ covered_pixels_count = max(0, len(covered_pixels) - double_coverage_penalty)
55
+ return covered_pixels_count / box_area
56
+
57
+
58
+ def calculate_coverage_fast(box, other_boxes, penalize_double=False):
59
+ box_area = (box[2] - box[0]) * (box[3] - box[1])
60
+ if box_area == 0:
61
+ return 0
62
+
63
+ total_intersect = 0
64
+ for other_box in other_boxes:
65
+ total_intersect += intersection_area(box, other_box)
66
+
67
+ return min(1, total_intersect / box_area)
68
+
69
+
70
+ def precision_recall(preds, references, threshold=.5, workers=8, penalize_double=True):
71
+ if len(references) == 0:
72
+ return {
73
+ "precision": 1,
74
+ "recall": 1,
75
+ }
76
+
77
+ if len(preds) == 0:
78
+ return {
79
+ "precision": 0,
80
+ "recall": 0,
81
+ }
82
+
83
+ # If we're not penalizing double coverage, we can use a faster calculation
84
+ coverage_func = calculate_coverage_fast
85
+ if penalize_double:
86
+ coverage_func = calculate_coverage
87
+
88
+ with ProcessPoolExecutor(max_workers=workers) as executor:
89
+ precision_func = partial(coverage_func, penalize_double=penalize_double)
90
+ precision_iou = executor.map(precision_func, preds, repeat(references))
91
+ reference_iou = executor.map(coverage_func, references, repeat(preds))
92
+
93
+ precision_classes = [1 if i > threshold else 0 for i in precision_iou]
94
+ precision = sum(precision_classes) / len(precision_classes)
95
+
96
+ recall_classes = [1 if i > threshold else 0 for i in reference_iou]
97
+ recall = sum(recall_classes) / len(recall_classes)
98
+
99
+ return {
100
+ "precision": precision,
101
+ "recall": recall,
102
+ }
103
+
104
+
105
+ def mean_coverage(preds, references):
106
+ coverages = []
107
+
108
+ for box1 in references:
109
+ coverage = calculate_coverage(box1, preds)
110
+ coverages.append(coverage)
111
+
112
+ for box2 in preds:
113
+ coverage = calculate_coverage(box2, references)
114
+ coverages.append(coverage)
115
+
116
+ # Calculate the average coverage over all comparisons
117
+ if len(coverages) == 0:
118
+ return 0
119
+ coverage = sum(coverages) / len(coverages)
120
+ return {"coverage": coverage}
121
+
122
+
123
+ def rank_accuracy(preds, references):
124
+ # Preds and references need to be aligned so each position refers to the same bbox
125
+ pairs = []
126
+ for i, pred in enumerate(preds):
127
+ for j, pred2 in enumerate(preds):
128
+ if i == j:
129
+ continue
130
+ pairs.append((i, j, pred > pred2))
131
+
132
+ # Find how many of the prediction rankings are correct
133
+ correct = 0
134
+ for i, ref in enumerate(references):
135
+ for j, ref2 in enumerate(references):
136
+ if (i, j, ref > ref2) in pairs:
137
+ correct += 1
138
+
139
+ return correct / len(pairs)
surya/benchmark/tesseract.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import numpy as np
4
+ import pytesseract
5
+ from pytesseract import Output
6
+ from tqdm import tqdm
7
+
8
+ from surya.input.processing import slice_bboxes_from_image
9
+ from surya.settings import settings
10
+ import os
11
+ from concurrent.futures import ProcessPoolExecutor
12
+ from surya.detection import get_batch_size as get_det_batch_size
13
+ from surya.recognition import get_batch_size as get_rec_batch_size
14
+ from surya.languages import CODE_TO_LANGUAGE
15
+
16
+
17
+ def surya_lang_to_tesseract(code: str) -> Optional[str]:
18
+ lang_str = CODE_TO_LANGUAGE[code]
19
+ try:
20
+ tess_lang = TESS_LANGUAGE_TO_CODE[lang_str]
21
+ except KeyError:
22
+ return None
23
+ return tess_lang
24
+
25
+
26
+ def tesseract_ocr(img, bboxes, lang: str):
27
+ line_imgs = slice_bboxes_from_image(img, bboxes)
28
+ config = f'--tessdata-dir "{settings.TESSDATA_PREFIX}"'
29
+ lines = []
30
+ for line_img in line_imgs:
31
+ line = pytesseract.image_to_string(line_img, lang=lang, config=config)
32
+ lines.append(line)
33
+ return lines
34
+
35
+
36
+ def tesseract_ocr_parallel(imgs, bboxes, langs: List[str], cpus=None):
37
+ tess_parallel_cores = min(len(imgs), get_rec_batch_size())
38
+ if not cpus:
39
+ cpus = os.cpu_count()
40
+ tess_parallel_cores = min(tess_parallel_cores, cpus)
41
+
42
+ # Tesseract uses up to 4 processes per instance
43
+ # Divide by 2 because tesseract doesn't seem to saturate all 4 cores with these small images
44
+ tess_parallel = max(tess_parallel_cores // 2, 1)
45
+
46
+ with ProcessPoolExecutor(max_workers=tess_parallel) as executor:
47
+ tess_text = tqdm(executor.map(tesseract_ocr, imgs, bboxes, langs), total=len(imgs), desc="Running tesseract OCR")
48
+ tess_text = list(tess_text)
49
+ return tess_text
50
+
51
+
52
+ def tesseract_bboxes(img):
53
+ arr_img = np.asarray(img, dtype=np.uint8)
54
+ ocr = pytesseract.image_to_data(arr_img, output_type=Output.DICT)
55
+
56
+ bboxes = []
57
+ n_boxes = len(ocr['level'])
58
+ for i in range(n_boxes):
59
+ # It is possible to merge by line here with line number, but it gives bad results.
60
+ _, x, y, w, h = ocr['text'][i], ocr['left'][i], ocr['top'][i], ocr['width'][i], ocr['height'][i]
61
+ bbox = (x, y, x + w, y + h)
62
+ bboxes.append(bbox)
63
+
64
+ return bboxes
65
+
66
+
67
+ def tesseract_parallel(imgs):
68
+ # Tesseract uses 4 threads per instance
69
+ tess_parallel_cores = min(len(imgs), get_det_batch_size())
70
+ cpus = os.cpu_count()
71
+ tess_parallel_cores = min(tess_parallel_cores, cpus)
72
+
73
+ # Tesseract uses 4 threads per instance
74
+ tess_parallel = max(tess_parallel_cores // 4, 1)
75
+
76
+ with ProcessPoolExecutor(max_workers=tess_parallel) as executor:
77
+ tess_bboxes = tqdm(executor.map(tesseract_bboxes, imgs), total=len(imgs), desc="Running tesseract bbox detection")
78
+ tess_bboxes = list(tess_bboxes)
79
+ return tess_bboxes
80
+
81
+
82
+ TESS_CODE_TO_LANGUAGE = {
83
+ "afr": "Afrikaans",
84
+ "amh": "Amharic",
85
+ "ara": "Arabic",
86
+ "asm": "Assamese",
87
+ "aze": "Azerbaijani",
88
+ "bel": "Belarusian",
89
+ "ben": "Bengali",
90
+ "bod": "Tibetan",
91
+ "bos": "Bosnian",
92
+ "bre": "Breton",
93
+ "bul": "Bulgarian",
94
+ "cat": "Catalan",
95
+ "ceb": "Cebuano",
96
+ "ces": "Czech",
97
+ "chi_sim": "Chinese",
98
+ "chr": "Cherokee",
99
+ "cym": "Welsh",
100
+ "dan": "Danish",
101
+ "deu": "German",
102
+ "dzo": "Dzongkha",
103
+ "ell": "Greek",
104
+ "eng": "English",
105
+ "epo": "Esperanto",
106
+ "est": "Estonian",
107
+ "eus": "Basque",
108
+ "fas": "Persian",
109
+ "fin": "Finnish",
110
+ "fra": "French",
111
+ "fry": "Western Frisian",
112
+ "guj": "Gujarati",
113
+ "gla": "Scottish Gaelic",
114
+ "gle": "Irish",
115
+ "glg": "Galician",
116
+ "heb": "Hebrew",
117
+ "hin": "Hindi",
118
+ "hrv": "Croatian",
119
+ "hun": "Hungarian",
120
+ "hye": "Armenian",
121
+ "iku": "Inuktitut",
122
+ "ind": "Indonesian",
123
+ "isl": "Icelandic",
124
+ "ita": "Italian",
125
+ "jav": "Javanese",
126
+ "jpn": "Japanese",
127
+ "kan": "Kannada",
128
+ "kat": "Georgian",
129
+ "kaz": "Kazakh",
130
+ "khm": "Khmer",
131
+ "kir": "Kyrgyz",
132
+ "kor": "Korean",
133
+ "lao": "Lao",
134
+ "lat": "Latin",
135
+ "lav": "Latvian",
136
+ "lit": "Lithuanian",
137
+ "mal": "Malayalam",
138
+ "mar": "Marathi",
139
+ "mkd": "Macedonian",
140
+ "mlt": "Maltese",
141
+ "mon": "Mongolian",
142
+ "msa": "Malay",
143
+ "mya": "Burmese",
144
+ "nep": "Nepali",
145
+ "nld": "Dutch",
146
+ "nor": "Norwegian",
147
+ "ori": "Oriya",
148
+ "pan": "Punjabi",
149
+ "pol": "Polish",
150
+ "por": "Portuguese",
151
+ "pus": "Pashto",
152
+ "ron": "Romanian",
153
+ "rus": "Russian",
154
+ "san": "Sanskrit",
155
+ "sin": "Sinhala",
156
+ "slk": "Slovak",
157
+ "slv": "Slovenian",
158
+ "snd": "Sindhi",
159
+ "spa": "Spanish",
160
+ "sqi": "Albanian",
161
+ "srp": "Serbian",
162
+ "swa": "Swahili",
163
+ "swe": "Swedish",
164
+ "syr": "Syriac",
165
+ "tam": "Tamil",
166
+ "tel": "Telugu",
167
+ "tgk": "Tajik",
168
+ "tha": "Thai",
169
+ "tir": "Tigrinya",
170
+ "tur": "Turkish",
171
+ "uig": "Uyghur",
172
+ "ukr": "Ukrainian",
173
+ "urd": "Urdu",
174
+ "uzb": "Uzbek",
175
+ "vie": "Vietnamese",
176
+ "yid": "Yiddish"
177
+ }
178
+
179
+ TESS_LANGUAGE_TO_CODE = {v:k for k,v in TESS_CODE_TO_LANGUAGE.items()}
surya/benchmark/util.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def merge_boxes(box1, box2):
2
+ return (min(box1[0], box2[0]), min(box1[1], box2[1]), max(box1[2], box2[2]), max(box1[3], box2[3]))
3
+
4
+
5
+ def join_lines(bboxes, max_gap=5):
6
+ to_merge = {}
7
+ for i, box1 in bboxes:
8
+ for z, box2 in bboxes[i + 1:]:
9
+ j = i + z + 1
10
+ if box1 == box2:
11
+ continue
12
+
13
+ if box1[0] <= box2[0] and box1[2] >= box2[2]:
14
+ if abs(box1[1] - box2[3]) <= max_gap:
15
+ if i not in to_merge:
16
+ to_merge[i] = []
17
+ to_merge[i].append(j)
18
+
19
+ merged_boxes = set()
20
+ merged = []
21
+ for i, box in bboxes:
22
+ if i in merged_boxes:
23
+ continue
24
+
25
+ if i in to_merge:
26
+ for j in to_merge[i]:
27
+ box = merge_boxes(box, bboxes[j][1])
28
+ merged_boxes.add(j)
29
+
30
+ merged.append(box)
31
+ return merged
surya/detection.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ from surya.model.detection.segformer import SegformerForRegressionMask
8
+ from surya.postprocessing.heatmap import get_and_clean_boxes
9
+ from surya.postprocessing.affinity import get_vertical_lines
10
+ from surya.input.processing import prepare_image_detection, split_image, get_total_splits, convert_if_not_rgb
11
+ from surya.schema import TextDetectionResult
12
+ from surya.settings import settings
13
+ from tqdm import tqdm
14
+ from concurrent.futures import ProcessPoolExecutor
15
+ import torch.nn.functional as F
16
+
17
+
18
+ def get_batch_size():
19
+ batch_size = settings.DETECTOR_BATCH_SIZE
20
+ if batch_size is None:
21
+ batch_size = 6
22
+ if settings.TORCH_DEVICE_MODEL == "cuda":
23
+ batch_size = 24
24
+ return batch_size
25
+
26
+
27
+ def batch_detection(images: List, model: SegformerForRegressionMask, processor, batch_size=None) -> Tuple[List[List[np.ndarray]], List[Tuple[int, int]]]:
28
+ assert all([isinstance(image, Image.Image) for image in images])
29
+ if batch_size is None:
30
+ batch_size = get_batch_size()
31
+ heatmap_count = model.config.num_labels
32
+
33
+ images = [image.convert("RGB") for image in images] # also copies the images
34
+
35
+ orig_sizes = [image.size for image in images]
36
+ splits_per_image = [get_total_splits(size, processor) for size in orig_sizes]
37
+
38
+ batches = []
39
+ current_batch_size = 0
40
+ current_batch = []
41
+ for i in range(len(images)):
42
+ if current_batch_size + splits_per_image[i] > batch_size:
43
+ if len(current_batch) > 0:
44
+ batches.append(current_batch)
45
+ current_batch = []
46
+ current_batch_size = 0
47
+ current_batch.append(i)
48
+ current_batch_size += splits_per_image[i]
49
+
50
+ if len(current_batch) > 0:
51
+ batches.append(current_batch)
52
+
53
+ all_preds = []
54
+ for batch_idx in tqdm(range(len(batches)), desc="Detecting bboxes"):
55
+ batch_image_idxs = batches[batch_idx]
56
+ batch_images = convert_if_not_rgb([images[j] for j in batch_image_idxs])
57
+
58
+ split_index = []
59
+ split_heights = []
60
+ image_splits = []
61
+ for image_idx, image in enumerate(batch_images):
62
+ image_parts, split_height = split_image(image, processor)
63
+ image_splits.extend(image_parts)
64
+ split_index.extend([image_idx] * len(image_parts))
65
+ split_heights.extend(split_height)
66
+
67
+ image_splits = [prepare_image_detection(image, processor) for image in image_splits]
68
+ # Batch images in dim 0
69
+ batch = torch.stack(image_splits, dim=0).to(model.dtype).to(model.device)
70
+
71
+ with torch.inference_mode():
72
+ pred = model(pixel_values=batch)
73
+
74
+ logits = pred.logits
75
+ correct_shape = [processor.size["height"], processor.size["width"]]
76
+ current_shape = list(logits.shape[2:])
77
+ if current_shape != correct_shape:
78
+ logits = F.interpolate(logits, size=correct_shape, mode='bilinear', align_corners=False)
79
+
80
+ logits = logits.cpu().detach().numpy().astype(np.float32)
81
+ preds = []
82
+ for i, (idx, height) in enumerate(zip(split_index, split_heights)):
83
+ # If our current prediction length is below the image idx, that means we have a new image
84
+ # Otherwise, we need to add to the current image
85
+ if len(preds) <= idx:
86
+ preds.append([logits[i][k] for k in range(heatmap_count)])
87
+ else:
88
+ heatmaps = preds[idx]
89
+ pred_heatmaps = [logits[i][k] for k in range(heatmap_count)]
90
+
91
+ if height < processor.size["height"]:
92
+ # Cut off padding to get original height
93
+ pred_heatmaps = [pred_heatmap[:height, :] for pred_heatmap in pred_heatmaps]
94
+
95
+ for k in range(heatmap_count):
96
+ heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]])
97
+ preds[idx] = heatmaps
98
+
99
+ all_preds.extend(preds)
100
+
101
+ assert len(all_preds) == len(images)
102
+ assert all([len(pred) == heatmap_count for pred in all_preds])
103
+ return all_preds, orig_sizes
104
+
105
+
106
+ def parallel_get_lines(preds, orig_sizes):
107
+ heatmap, affinity_map = preds
108
+ heat_img = Image.fromarray((heatmap * 255).astype(np.uint8))
109
+ aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8))
110
+ affinity_size = list(reversed(affinity_map.shape))
111
+ heatmap_size = list(reversed(heatmap.shape))
112
+ bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes)
113
+ vertical_lines = get_vertical_lines(affinity_map, affinity_size, orig_sizes)
114
+
115
+ result = TextDetectionResult(
116
+ bboxes=bboxes,
117
+ vertical_lines=vertical_lines,
118
+ heatmap=heat_img,
119
+ affinity_map=aff_img,
120
+ image_bbox=[0, 0, orig_sizes[0], orig_sizes[1]]
121
+ )
122
+ return result
123
+
124
+
125
+ def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]:
126
+ preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size)
127
+ results = []
128
+ if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit, or with very few images
129
+ for i in range(len(images)):
130
+ result = parallel_get_lines(preds[i], orig_sizes[i])
131
+ results.append(result)
132
+ else:
133
+ max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
134
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
135
+ results = list(executor.map(parallel_get_lines, preds, orig_sizes))
136
+
137
+ return results
138
+
139
+
surya/input/__pycache__/processing.cpython-310.pyc ADDED
Binary file (4.05 kB). View file
 
surya/input/langs.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from surya.languages import LANGUAGE_TO_CODE, CODE_TO_LANGUAGE
3
+
4
+
5
+ def replace_lang_with_code(langs: List[str]):
6
+ for i in range(len(langs)):
7
+ if langs[i].title() in LANGUAGE_TO_CODE:
8
+ langs[i] = LANGUAGE_TO_CODE[langs[i].title()]
9
+ if langs[i] not in CODE_TO_LANGUAGE:
10
+ raise ValueError(f"Language code {langs[i]} not found.")
11
+
12
+
13
+ def get_unique_langs(langs: List[List[str]]):
14
+ uniques = []
15
+ for lang_list in langs:
16
+ for lang in lang_list:
17
+ if lang not in uniques:
18
+ uniques.append(lang)
19
+ return uniques
surya/input/load.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+
3
+ from surya.input.processing import open_pdf, get_page_images
4
+ import os
5
+ import filetype
6
+ from PIL import Image
7
+ import json
8
+
9
+
10
+ def get_name_from_path(path):
11
+ return os.path.basename(path).split(".")[0]
12
+
13
+
14
+ def load_pdf(pdf_path, max_pages=None, start_page=None):
15
+ doc = open_pdf(pdf_path)
16
+ last_page = len(doc)
17
+
18
+ if start_page:
19
+ assert start_page < last_page and start_page >= 0, f"Start page must be between 0 and {last_page}"
20
+ else:
21
+ start_page = 0
22
+
23
+ if max_pages:
24
+ assert max_pages >= 0, f"Max pages must be greater than 0"
25
+ last_page = min(start_page + max_pages, last_page)
26
+
27
+ page_indices = list(range(start_page, last_page))
28
+ images = get_page_images(doc, page_indices)
29
+ doc.close()
30
+ names = [get_name_from_path(pdf_path) for _ in page_indices]
31
+ return images, names
32
+
33
+
34
+ def load_image(image_path):
35
+ image = Image.open(image_path).convert("RGB")
36
+ name = get_name_from_path(image_path)
37
+ return [image], [name]
38
+
39
+
40
+ def load_from_file(input_path, max_pages=None, start_page=None):
41
+ input_type = filetype.guess(input_path)
42
+ if input_type.extension == "pdf":
43
+ return load_pdf(input_path, max_pages, start_page)
44
+ else:
45
+ return load_image(input_path)
46
+
47
+
48
+ def load_from_folder(folder_path, max_pages=None, start_page=None):
49
+ image_paths = [os.path.join(folder_path, image_name) for image_name in os.listdir(folder_path) if not image_name.startswith(".")]
50
+ image_paths = [ip for ip in image_paths if not os.path.isdir(ip)]
51
+
52
+ images = []
53
+ names = []
54
+ for path in image_paths:
55
+ extension = filetype.guess(path)
56
+ if extension and extension.extension == "pdf":
57
+ image, name = load_pdf(path, max_pages, start_page)
58
+ images.extend(image)
59
+ names.extend(name)
60
+ else:
61
+ try:
62
+ image, name = load_image(path)
63
+ images.extend(image)
64
+ names.extend(name)
65
+ except PIL.UnidentifiedImageError:
66
+ print(f"Could not load image {path}")
67
+ continue
68
+ return images, names
69
+
70
+
71
+ def load_lang_file(lang_path, names):
72
+ with open(lang_path, "r") as f:
73
+ lang_dict = json.load(f)
74
+ return [lang_dict[name].copy() for name in names]
surya/input/processing.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import math
6
+ import pypdfium2
7
+ from PIL import Image, ImageOps, ImageDraw
8
+ import torch
9
+ from surya.settings import settings
10
+
11
+
12
+ def convert_if_not_rgb(images: List[Image.Image]) -> List[Image.Image]:
13
+ new_images = []
14
+ for image in images:
15
+ if image.mode != "RGB":
16
+ image = image.convert("RGB")
17
+ new_images.append(image)
18
+ return new_images
19
+
20
+
21
+ def get_total_splits(image_size, processor):
22
+ img_height = list(image_size)[1]
23
+ max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT
24
+ processor_height = processor.size["height"]
25
+ if img_height > max_height:
26
+ num_splits = math.ceil(img_height / processor_height)
27
+ return num_splits
28
+ return 1
29
+
30
+
31
+ def split_image(img, processor):
32
+ # This will not modify/return the original image - it will either crop, or copy the image
33
+ img_height = list(img.size)[1]
34
+ max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT
35
+ processor_height = processor.size["height"]
36
+ if img_height > max_height:
37
+ num_splits = math.ceil(img_height / processor_height)
38
+ splits = []
39
+ split_heights = []
40
+ for i in range(num_splits):
41
+ top = i * processor_height
42
+ bottom = (i + 1) * processor_height
43
+ if bottom > img_height:
44
+ bottom = img_height
45
+ cropped = img.crop((0, top, img.size[0], bottom))
46
+ height = bottom - top
47
+ if height < processor_height:
48
+ cropped = ImageOps.pad(cropped, (img.size[0], processor_height), color=255, centering=(0, 0))
49
+ splits.append(cropped)
50
+ split_heights.append(height)
51
+ return splits, split_heights
52
+ return [img.copy()], [img_height]
53
+
54
+
55
+ def prepare_image_detection(img, processor):
56
+ new_size = (processor.size["width"], processor.size["height"])
57
+
58
+ # This double resize actually necessary for downstream accuracy
59
+ img.thumbnail(new_size, Image.Resampling.LANCZOS)
60
+ img = img.resize(new_size, Image.Resampling.LANCZOS) # Stretch smaller dimension to fit new size
61
+
62
+ img = np.asarray(img, dtype=np.uint8)
63
+ img = processor(img)["pixel_values"][0]
64
+ img = torch.from_numpy(img)
65
+ return img
66
+
67
+
68
+ def open_pdf(pdf_filepath):
69
+ return pypdfium2.PdfDocument(pdf_filepath)
70
+
71
+
72
+ def get_page_images(doc, indices: List, dpi=settings.IMAGE_DPI):
73
+ renderer = doc.render(
74
+ pypdfium2.PdfBitmap.to_pil,
75
+ page_indices=indices,
76
+ scale=dpi / 72,
77
+ )
78
+ images = list(renderer)
79
+ images = [image.convert("RGB") for image in images]
80
+ return images
81
+
82
+
83
+ def slice_bboxes_from_image(image: Image.Image, bboxes):
84
+ lines = []
85
+ for bbox in bboxes:
86
+ line = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
87
+ lines.append(line)
88
+ return lines
89
+
90
+
91
+ def slice_polys_from_image(image: Image.Image, polys):
92
+ image_array = np.array(image, dtype=np.uint8)
93
+ lines = []
94
+ for idx, poly in enumerate(polys):
95
+ lines.append(slice_and_pad_poly(image_array, poly))
96
+ return lines
97
+
98
+
99
+ def slice_and_pad_poly(image_array: np.array, coordinates):
100
+ # Draw polygon onto mask
101
+ coordinates = [(corner[0], corner[1]) for corner in coordinates]
102
+ bbox = [min([x[0] for x in coordinates]), min([x[1] for x in coordinates]), max([x[0] for x in coordinates]), max([x[1] for x in coordinates])]
103
+
104
+ # We mask out anything not in the polygon
105
+ cropped_polygon = image_array[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy()
106
+ coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates]
107
+
108
+ # Pad the area outside the polygon with the pad value
109
+ mask = np.zeros(cropped_polygon.shape[:2], dtype=np.uint8)
110
+ cv2.fillPoly(mask, [np.int32(coordinates)], 1)
111
+ mask = np.stack([mask] * 3, axis=-1)
112
+
113
+ cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE
114
+ rectangle_image = Image.fromarray(cropped_polygon)
115
+
116
+ return rectangle_image
surya/languages.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CODE_TO_LANGUAGE = {
2
+ 'af': 'Afrikaans',
3
+ 'am': 'Amharic',
4
+ 'ar': 'Arabic',
5
+ 'as': 'Assamese',
6
+ 'az': 'Azerbaijani',
7
+ 'be': 'Belarusian',
8
+ 'bg': 'Bulgarian',
9
+ 'bn': 'Bengali',
10
+ 'br': 'Breton',
11
+ 'bs': 'Bosnian',
12
+ 'ca': 'Catalan',
13
+ 'cs': 'Czech',
14
+ 'cy': 'Welsh',
15
+ 'da': 'Danish',
16
+ 'de': 'German',
17
+ 'el': 'Greek',
18
+ 'en': 'English',
19
+ 'eo': 'Esperanto',
20
+ 'es': 'Spanish',
21
+ 'et': 'Estonian',
22
+ 'eu': 'Basque',
23
+ 'fa': 'Persian',
24
+ 'fi': 'Finnish',
25
+ 'fr': 'French',
26
+ 'fy': 'Western Frisian',
27
+ 'ga': 'Irish',
28
+ 'gd': 'Scottish Gaelic',
29
+ 'gl': 'Galician',
30
+ 'gu': 'Gujarati',
31
+ 'ha': 'Hausa',
32
+ 'he': 'Hebrew',
33
+ 'hi': 'Hindi',
34
+ 'hr': 'Croatian',
35
+ 'hu': 'Hungarian',
36
+ 'hy': 'Armenian',
37
+ 'id': 'Indonesian',
38
+ 'is': 'Icelandic',
39
+ 'it': 'Italian',
40
+ 'ja': 'Japanese',
41
+ 'jv': 'Javanese',
42
+ 'ka': 'Georgian',
43
+ 'kk': 'Kazakh',
44
+ 'km': 'Khmer',
45
+ 'kn': 'Kannada',
46
+ 'ko': 'Korean',
47
+ 'ku': 'Kurdish',
48
+ 'ky': 'Kyrgyz',
49
+ 'la': 'Latin',
50
+ 'lo': 'Lao',
51
+ 'lt': 'Lithuanian',
52
+ 'lv': 'Latvian',
53
+ 'mg': 'Malagasy',
54
+ 'mk': 'Macedonian',
55
+ 'ml': 'Malayalam',
56
+ 'mn': 'Mongolian',
57
+ 'mr': 'Marathi',
58
+ 'ms': 'Malay',
59
+ 'my': 'Burmese',
60
+ 'ne': 'Nepali',
61
+ 'nl': 'Dutch',
62
+ 'no': 'Norwegian',
63
+ 'om': 'Oromo',
64
+ 'or': 'Oriya',
65
+ 'pa': 'Punjabi',
66
+ 'pl': 'Polish',
67
+ 'ps': 'Pashto',
68
+ 'pt': 'Portuguese',
69
+ 'ro': 'Romanian',
70
+ 'ru': 'Russian',
71
+ 'sa': 'Sanskrit',
72
+ 'sd': 'Sindhi',
73
+ 'si': 'Sinhala',
74
+ 'sk': 'Slovak',
75
+ 'sl': 'Slovenian',
76
+ 'so': 'Somali',
77
+ 'sq': 'Albanian',
78
+ 'sr': 'Serbian',
79
+ 'su': 'Sundanese',
80
+ 'sv': 'Swedish',
81
+ 'sw': 'Swahili',
82
+ 'ta': 'Tamil',
83
+ 'te': 'Telugu',
84
+ 'th': 'Thai',
85
+ 'tl': 'Tagalog',
86
+ 'tr': 'Turkish',
87
+ 'ug': 'Uyghur',
88
+ 'uk': 'Ukrainian',
89
+ 'ur': 'Urdu',
90
+ 'uz': 'Uzbek',
91
+ 'vi': 'Vietnamese',
92
+ 'xh': 'Xhosa',
93
+ 'yi': 'Yiddish',
94
+ 'zh': 'Chinese',
95
+ }
96
+
97
+ LANGUAGE_TO_CODE = {v: k for k, v in CODE_TO_LANGUAGE.items()}
98
+
99
+
100
+ def is_arabic(lang_code):
101
+ return lang_code in ["ar", "fa", "ps", "ug", "ur"]
surya/layout.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from concurrent.futures import ProcessPoolExecutor
3
+ from typing import List, Optional
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ from surya.detection import batch_detection
8
+ from surya.postprocessing.heatmap import keep_largest_boxes, get_and_clean_boxes, get_detected_boxes
9
+ from surya.schema import LayoutResult, LayoutBox, TextDetectionResult
10
+ from surya.settings import settings
11
+
12
+
13
+ def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]:
14
+ logits = np.stack(heatmaps, axis=0)
15
+ vertical_line_bboxes = [line for line in detection_result.vertical_lines]
16
+ line_bboxes = detection_result.bboxes
17
+
18
+ # Scale back to processor size
19
+ for line in vertical_line_bboxes:
20
+ line.rescale_bbox(orig_size, list(reversed(heatmaps[0].shape)))
21
+
22
+ for line in line_bboxes:
23
+ line.rescale(orig_size, list(reversed(heatmaps[0].shape)))
24
+
25
+ for bbox in vertical_line_bboxes:
26
+ # Give some width to the vertical lines
27
+ vert_bbox = list(bbox.bbox)
28
+ vert_bbox[2] = min(heatmaps[0].shape[0], vert_bbox[2] + vertical_line_width)
29
+
30
+ logits[:, vert_bbox[1]:vert_bbox[3], vert_bbox[0]:vert_bbox[2]] = 0 # zero out where the column lines are
31
+
32
+ logits[:, logits[0] >= .5] = 0 # zero out where blanks are
33
+
34
+ # Zero out where other segments are
35
+ for i in range(logits.shape[0]):
36
+ logits[i, segment_assignment != i] = 0
37
+
38
+ detected_boxes = []
39
+ for heatmap_idx in range(1, len(id2label)): # Skip the blank class
40
+ heatmap = logits[heatmap_idx]
41
+ bboxes = get_detected_boxes(heatmap)
42
+ bboxes = [bbox for bbox in bboxes if bbox.area > 25]
43
+ for bb in bboxes:
44
+ bb.fit_to_bounds([0, 0, heatmap.shape[1] - 1, heatmap.shape[0] - 1])
45
+
46
+ for bbox in bboxes:
47
+ detected_boxes.append(LayoutBox(polygon=bbox.polygon, label=id2label[heatmap_idx], confidence=1))
48
+
49
+ detected_boxes = sorted(detected_boxes, key=lambda x: x.confidence, reverse=True)
50
+ # Expand bbox to cover intersecting lines
51
+ box_lines = defaultdict(list)
52
+ used_lines = set()
53
+
54
+ # We try 2 rounds of identifying the correct lines to snap to
55
+ # First round is majority intersection, second lowers the threshold
56
+ for thresh in [.5, .4]:
57
+ for bbox_idx, bbox in enumerate(detected_boxes):
58
+ for line_idx, line_bbox in enumerate(line_bboxes):
59
+ if line_bbox.intersection_pct(bbox) > thresh and line_idx not in used_lines:
60
+ box_lines[bbox_idx].append(line_bbox.bbox)
61
+ used_lines.add(line_idx)
62
+
63
+ new_boxes = []
64
+ for bbox_idx, bbox in enumerate(detected_boxes):
65
+ if bbox.label == "Picture" and bbox.area < 200: # Remove very small figures
66
+ continue
67
+
68
+ # Skip if we didn't find any lines to snap to, except for Pictures and Formulas
69
+ if bbox_idx not in box_lines and bbox.label not in ["Picture", "Formula"]:
70
+ continue
71
+
72
+ covered_lines = box_lines[bbox_idx]
73
+ # Snap non-picture layout boxes to correct text boundaries
74
+ if len(covered_lines) > 0 and bbox.label not in ["Picture"]:
75
+ min_x = min([line[0] for line in covered_lines])
76
+ min_y = min([line[1] for line in covered_lines])
77
+ max_x = max([line[2] for line in covered_lines])
78
+ max_y = max([line[3] for line in covered_lines])
79
+
80
+ # Tables and formulas can contain text, but text isn't the whole area
81
+ if bbox.label in ["Table", "Formula"]:
82
+ min_x_box = min([b[0] for b in bbox.polygon])
83
+ min_y_box = min([b[1] for b in bbox.polygon])
84
+ max_x_box = max([b[0] for b in bbox.polygon])
85
+ max_y_box = max([b[1] for b in bbox.polygon])
86
+
87
+ min_x = min(min_x, min_x_box)
88
+ min_y = min(min_y, min_y_box)
89
+ max_x = max(max_x, max_x_box)
90
+ max_y = max(max_y, max_y_box)
91
+
92
+ bbox.polygon[0][0] = min_x
93
+ bbox.polygon[0][1] = min_y
94
+ bbox.polygon[1][0] = max_x
95
+ bbox.polygon[1][1] = min_y
96
+ bbox.polygon[2][0] = max_x
97
+ bbox.polygon[2][1] = max_y
98
+ bbox.polygon[3][0] = min_x
99
+ bbox.polygon[3][1] = max_y
100
+
101
+ if bbox_idx in box_lines and bbox.label in ["Picture"]:
102
+ bbox.label = "Figure"
103
+
104
+ new_boxes.append(bbox)
105
+
106
+ # Merge tables together (sometimes one column is detected as a separate table)
107
+ for i in range(5): # Up to 5 rounds of merging
108
+ to_remove = set()
109
+ for bbox_idx, bbox in enumerate(new_boxes):
110
+ if bbox.label != "Table" or bbox_idx in to_remove:
111
+ continue
112
+
113
+ for bbox_idx2, bbox2 in enumerate(new_boxes):
114
+ if bbox2.label != "Table" or bbox_idx2 in to_remove or bbox_idx == bbox_idx2:
115
+ continue
116
+
117
+ if bbox.intersection_pct(bbox2) > 0:
118
+ bbox.merge(bbox2)
119
+ to_remove.add(bbox_idx2)
120
+
121
+ new_boxes = [bbox for idx, bbox in enumerate(new_boxes) if idx not in to_remove]
122
+
123
+ # Ensure we account for all text lines in the layout
124
+ unused_lines = [line for idx, line in enumerate(line_bboxes) if idx not in used_lines]
125
+ for bbox in unused_lines:
126
+ new_boxes.append(LayoutBox(polygon=bbox.polygon, label="Text", confidence=.5))
127
+
128
+ for bbox in new_boxes:
129
+ bbox.rescale(list(reversed(heatmaps[0].shape)), orig_size)
130
+
131
+ detected_boxes = [bbox for bbox in new_boxes if bbox.area > 16]
132
+
133
+ # Remove bboxes contained inside others, unless they're captions
134
+ contained_bbox = []
135
+ for i, bbox in enumerate(detected_boxes):
136
+ for j, bbox2 in enumerate(detected_boxes):
137
+ if i == j:
138
+ continue
139
+
140
+ if bbox2.intersection_pct(bbox) >= .95 and bbox2.label not in ["Caption"]:
141
+ contained_bbox.append(j)
142
+
143
+ detected_boxes = [bbox for idx, bbox in enumerate(detected_boxes) if idx not in contained_bbox]
144
+
145
+ return detected_boxes
146
+
147
+
148
+ def get_regions(heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment) -> List[LayoutBox]:
149
+ bboxes = []
150
+ for i in range(1, len(id2label)): # Skip the blank class
151
+ heatmap = heatmaps[i]
152
+ assert heatmap.shape == segment_assignment.shape
153
+ heatmap[segment_assignment != i] = 0 # zero out where another segment is
154
+ bbox = get_and_clean_boxes(heatmap, list(reversed(heatmap.shape)), orig_size)
155
+ for bb in bbox:
156
+ bboxes.append(LayoutBox(polygon=bb.polygon, label=id2label[i]))
157
+ heatmaps.append(heatmap)
158
+
159
+ bboxes = keep_largest_boxes(bboxes)
160
+ return bboxes
161
+
162
+
163
+ def parallel_get_regions(heatmaps: List[np.ndarray], orig_size, id2label, detection_results=None) -> LayoutResult:
164
+ logits = np.stack(heatmaps, axis=0)
165
+ segment_assignment = logits.argmax(axis=0)
166
+ if detection_results is not None:
167
+ bboxes = get_regions_from_detection_result(detection_results, heatmaps, orig_size, id2label,
168
+ segment_assignment)
169
+ else:
170
+ bboxes = get_regions(heatmaps, orig_size, id2label, segment_assignment)
171
+
172
+ segmentation_img = Image.fromarray(segment_assignment.astype(np.uint8))
173
+
174
+ result = LayoutResult(
175
+ bboxes=bboxes,
176
+ segmentation_map=segmentation_img,
177
+ heatmaps=heatmaps,
178
+ image_bbox=[0, 0, orig_size[0], orig_size[1]]
179
+ )
180
+
181
+ return result
182
+
183
+
184
+ def batch_layout_detection(images: List, model, processor, detection_results: Optional[List[TextDetectionResult]] = None, batch_size=None) -> List[LayoutResult]:
185
+ preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size)
186
+ id2label = model.config.id2label
187
+
188
+ results = []
189
+ if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit or too few images
190
+ for i in range(len(images)):
191
+ result = parallel_get_regions(preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None)
192
+ results.append(result)
193
+ else:
194
+ futures = []
195
+ max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
196
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
197
+ for i in range(len(images)):
198
+ future = executor.submit(parallel_get_regions, preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None)
199
+ futures.append(future)
200
+
201
+ for future in futures:
202
+ results.append(future.result())
203
+
204
+ return results
surya/model/detection/__pycache__/processor.cpython-310.pyc ADDED
Binary file (11.6 kB). View file
 
surya/model/detection/__pycache__/segformer.cpython-310.pyc ADDED
Binary file (14.5 kB). View file
 
surya/model/detection/processor.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Any, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+
6
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
7
+ from transformers.image_transforms import to_channel_dimension_format
8
+ from transformers.image_utils import (
9
+ IMAGENET_DEFAULT_MEAN,
10
+ IMAGENET_DEFAULT_STD,
11
+ ChannelDimension,
12
+ ImageInput,
13
+ PILImageResampling,
14
+ infer_channel_dimension_format,
15
+ make_list_of_images,
16
+ )
17
+ from transformers.utils import TensorType
18
+
19
+
20
+ import PIL.Image
21
+ import torch
22
+
23
+
24
+ class SegformerImageProcessor(BaseImageProcessor):
25
+ r"""
26
+ Constructs a Segformer image processor.
27
+
28
+ Args:
29
+ do_resize (`bool`, *optional*, defaults to `True`):
30
+ Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
31
+ size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
32
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`):
33
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
34
+ method.
35
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
36
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
37
+ `preprocess` method.
38
+ do_rescale (`bool`, *optional*, defaults to `True`):
39
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
40
+ parameter in the `preprocess` method.
41
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
42
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
43
+ method.
44
+ do_normalize (`bool`, *optional*, defaults to `True`):
45
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
46
+ method.
47
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
48
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
49
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
50
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
51
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
52
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
53
+ do_reduce_labels (`bool`, *optional*, defaults to `False`):
54
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
55
+ used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
56
+ background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
57
+ `preprocess` method.
58
+ """
59
+
60
+ model_input_names = ["pixel_values"]
61
+
62
+ def __init__(
63
+ self,
64
+ do_resize: bool = True,
65
+ size: Dict[str, int] = None,
66
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
67
+ do_rescale: bool = True,
68
+ rescale_factor: Union[int, float] = 1 / 255,
69
+ do_normalize: bool = True,
70
+ image_mean: Optional[Union[float, List[float]]] = None,
71
+ image_std: Optional[Union[float, List[float]]] = None,
72
+ do_reduce_labels: bool = False,
73
+ **kwargs,
74
+ ) -> None:
75
+ if "reduce_labels" in kwargs:
76
+ warnings.warn(
77
+ "The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use "
78
+ "`do_reduce_labels` instead.",
79
+ FutureWarning,
80
+ )
81
+ do_reduce_labels = kwargs.pop("reduce_labels")
82
+
83
+ super().__init__(**kwargs)
84
+ size = size if size is not None else {"height": 512, "width": 512}
85
+ size = get_size_dict(size)
86
+ self.do_resize = do_resize
87
+ self.size = size
88
+ self.resample = resample
89
+ self.do_rescale = do_rescale
90
+ self.rescale_factor = rescale_factor
91
+ self.do_normalize = do_normalize
92
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
93
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
94
+ self.do_reduce_labels = do_reduce_labels
95
+ self._valid_processor_keys = [
96
+ "images",
97
+ "segmentation_maps",
98
+ "do_resize",
99
+ "size",
100
+ "resample",
101
+ "do_rescale",
102
+ "rescale_factor",
103
+ "do_normalize",
104
+ "image_mean",
105
+ "image_std",
106
+ "do_reduce_labels",
107
+ "return_tensors",
108
+ "data_format",
109
+ "input_data_format",
110
+ ]
111
+
112
+ @classmethod
113
+ def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
114
+ """
115
+ Overrides the `from_dict` method from the base class to make sure `do_reduce_labels` is updated if image
116
+ processor is created using from_dict and kwargs e.g. `SegformerImageProcessor.from_pretrained(checkpoint,
117
+ reduce_labels=True)`
118
+ """
119
+ image_processor_dict = image_processor_dict.copy()
120
+ if "reduce_labels" in kwargs:
121
+ image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels")
122
+ return super().from_dict(image_processor_dict, **kwargs)
123
+
124
+ def _preprocess(
125
+ self,
126
+ image: ImageInput,
127
+ do_resize: bool,
128
+ do_rescale: bool,
129
+ do_normalize: bool,
130
+ size: Optional[Dict[str, int]] = None,
131
+ resample: PILImageResampling = None,
132
+ rescale_factor: Optional[float] = None,
133
+ image_mean: Optional[Union[float, List[float]]] = None,
134
+ image_std: Optional[Union[float, List[float]]] = None,
135
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
136
+ ):
137
+
138
+ if do_rescale:
139
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
140
+
141
+ if do_normalize:
142
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
143
+
144
+ return image
145
+
146
+ def _preprocess_image(
147
+ self,
148
+ image: ImageInput,
149
+ do_resize: bool = None,
150
+ size: Dict[str, int] = None,
151
+ resample: PILImageResampling = None,
152
+ do_rescale: bool = None,
153
+ rescale_factor: float = None,
154
+ do_normalize: bool = None,
155
+ image_mean: Optional[Union[float, List[float]]] = None,
156
+ image_std: Optional[Union[float, List[float]]] = None,
157
+ data_format: Optional[Union[str, ChannelDimension]] = None,
158
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
159
+ ) -> np.ndarray:
160
+ """Preprocesses a single image."""
161
+ # All transformations expect numpy arrays.
162
+ if input_data_format is None:
163
+ input_data_format = infer_channel_dimension_format(image)
164
+
165
+ image = self._preprocess(
166
+ image=image,
167
+ do_resize=do_resize,
168
+ size=size,
169
+ resample=resample,
170
+ do_rescale=do_rescale,
171
+ rescale_factor=rescale_factor,
172
+ do_normalize=do_normalize,
173
+ image_mean=image_mean,
174
+ image_std=image_std,
175
+ input_data_format=input_data_format,
176
+ )
177
+ if data_format is not None:
178
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
179
+ return image
180
+
181
+ def __call__(self, images, segmentation_maps=None, **kwargs):
182
+ """
183
+ Preprocesses a batch of images and optionally segmentation maps.
184
+
185
+ Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
186
+ passed in as positional arguments.
187
+ """
188
+ return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
189
+
190
+ def preprocess(
191
+ self,
192
+ images: ImageInput,
193
+ segmentation_maps: Optional[ImageInput] = None,
194
+ do_resize: Optional[bool] = None,
195
+ size: Optional[Dict[str, int]] = None,
196
+ resample: PILImageResampling = None,
197
+ do_rescale: Optional[bool] = None,
198
+ rescale_factor: Optional[float] = None,
199
+ do_normalize: Optional[bool] = None,
200
+ image_mean: Optional[Union[float, List[float]]] = None,
201
+ image_std: Optional[Union[float, List[float]]] = None,
202
+ do_reduce_labels: Optional[bool] = None,
203
+ return_tensors: Optional[Union[str, TensorType]] = None,
204
+ data_format: ChannelDimension = ChannelDimension.FIRST,
205
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
206
+ **kwargs,
207
+ ) -> PIL.Image.Image:
208
+ """
209
+ Preprocess an image or batch of images.
210
+
211
+ Args:
212
+ images (`ImageInput`):
213
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
214
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
215
+ segmentation_maps (`ImageInput`, *optional*):
216
+ Segmentation map to preprocess.
217
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
218
+ Whether to resize the image.
219
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
220
+ Size of the image after `resize` is applied.
221
+ resample (`int`, *optional*, defaults to `self.resample`):
222
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
223
+ has an effect if `do_resize` is set to `True`.
224
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
225
+ Whether to rescale the image values between [0 - 1].
226
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
227
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
228
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
229
+ Whether to normalize the image.
230
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
231
+ Image mean.
232
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
233
+ Image standard deviation.
234
+ do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
235
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
236
+ is used for background, and background itself is not included in all classes of a dataset (e.g.
237
+ ADE20k). The background label will be replaced by 255.
238
+ return_tensors (`str` or `TensorType`, *optional*):
239
+ The type of tensors to return. Can be one of:
240
+ - Unset: Return a list of `np.ndarray`.
241
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
242
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
243
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
244
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
245
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
246
+ The channel dimension format for the output image. Can be one of:
247
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
248
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
249
+ input_data_format (`ChannelDimension` or `str`, *optional*):
250
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
251
+ from the input image. Can be one of:
252
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
253
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
254
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
255
+ """
256
+ do_resize = do_resize if do_resize is not None else self.do_resize
257
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
258
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
259
+ resample = resample if resample is not None else self.resample
260
+ size = size if size is not None else self.size
261
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
262
+ image_mean = image_mean if image_mean is not None else self.image_mean
263
+ image_std = image_std if image_std is not None else self.image_std
264
+
265
+ images = make_list_of_images(images)
266
+ images = [
267
+ self._preprocess_image(
268
+ image=img,
269
+ do_resize=do_resize,
270
+ resample=resample,
271
+ size=size,
272
+ do_rescale=do_rescale,
273
+ rescale_factor=rescale_factor,
274
+ do_normalize=do_normalize,
275
+ image_mean=image_mean,
276
+ image_std=image_std,
277
+ data_format=data_format,
278
+ input_data_format=input_data_format,
279
+ )
280
+ for img in images
281
+ ]
282
+
283
+ data = {"pixel_values": images}
284
+ return BatchFeature(data=data, tensor_type=return_tensors)
surya/model/detection/segformer.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import warnings
3
+
4
+ from transformers.activations import ACT2FN
5
+ from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
6
+
7
+ warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated")
8
+
9
+ import math
10
+ from typing import Optional, Tuple, Union
11
+
12
+ from transformers import SegformerConfig, SegformerForSemanticSegmentation, SegformerDecodeHead, \
13
+ SegformerPreTrainedModel
14
+ from surya.model.detection.processor import SegformerImageProcessor
15
+ import torch
16
+ from torch import nn
17
+
18
+ from transformers.modeling_outputs import SemanticSegmenterOutput, BaseModelOutput
19
+ from surya.settings import settings
20
+
21
+
22
+ def load_model(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_DETECTION, dtype=settings.MODEL_DTYPE_DETECTION):
23
+ config = SegformerConfig.from_pretrained(checkpoint)
24
+ model = SegformerForRegressionMask.from_pretrained(checkpoint, torch_dtype=dtype, config=config)
25
+ if "mps" in device:
26
+ print("Warning: MPS may have poor results. This is a bug with MPS, see here - https://github.com/pytorch/pytorch/issues/84936")
27
+ model = model.to(device)
28
+ model = model.eval()
29
+ print(f"Loaded detection model {checkpoint} on device {device} with dtype {dtype}")
30
+ return model
31
+
32
+
33
+ def load_processor(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT):
34
+ processor = SegformerImageProcessor.from_pretrained(checkpoint)
35
+ return processor
36
+
37
+
38
+ class SegformerForMaskMLP(nn.Module):
39
+ def __init__(self, config: SegformerConfig, input_dim, output_dim):
40
+ super().__init__()
41
+ self.proj = nn.Linear(input_dim, output_dim)
42
+
43
+ def forward(self, hidden_states: torch.Tensor):
44
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
45
+ hidden_states = self.proj(hidden_states)
46
+ return hidden_states
47
+
48
+
49
+ class SegformerForMaskDecodeHead(SegformerDecodeHead):
50
+ def __init__(self, config):
51
+ super().__init__(config)
52
+ decoder_layer_hidden_size = getattr(config, "decoder_layer_hidden_size", config.decoder_hidden_size)
53
+
54
+ # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size
55
+ mlps = []
56
+ for i in range(config.num_encoder_blocks):
57
+ mlp = SegformerForMaskMLP(config, input_dim=config.hidden_sizes[i], output_dim=decoder_layer_hidden_size)
58
+ mlps.append(mlp)
59
+ self.linear_c = nn.ModuleList(mlps)
60
+
61
+ # the following 3 layers implement the ConvModule of the original implementation
62
+ self.linear_fuse = nn.Conv2d(
63
+ in_channels=decoder_layer_hidden_size * config.num_encoder_blocks,
64
+ out_channels=config.decoder_hidden_size,
65
+ kernel_size=1,
66
+ bias=False,
67
+ )
68
+ self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size)
69
+ self.activation = nn.ReLU()
70
+
71
+ self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1)
72
+
73
+ self.config = config
74
+
75
+ def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor:
76
+ batch_size = encoder_hidden_states[-1].shape[0]
77
+
78
+ all_hidden_states = ()
79
+ for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c):
80
+ if self.config.reshape_last_stage is False and encoder_hidden_state.ndim == 3:
81
+ height = width = int(math.sqrt(encoder_hidden_state.shape[-1]))
82
+ encoder_hidden_state = (
83
+ encoder_hidden_state.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous()
84
+ )
85
+
86
+ # unify channel dimension
87
+ height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
88
+ encoder_hidden_state = mlp(encoder_hidden_state)
89
+ encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
90
+ encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width)
91
+ # upsample
92
+ encoder_hidden_state = encoder_hidden_state.contiguous()
93
+ encoder_hidden_state = nn.functional.interpolate(
94
+ encoder_hidden_state, size=encoder_hidden_states[0].size()[2:], mode="bilinear", align_corners=False
95
+ )
96
+ all_hidden_states += (encoder_hidden_state,)
97
+
98
+ hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1))
99
+ hidden_states = self.batch_norm(hidden_states)
100
+ hidden_states = self.activation(hidden_states)
101
+
102
+ # logits are of shape (batch_size, num_labels, height/4, width/4)
103
+ logits = self.classifier(hidden_states)
104
+
105
+ return logits
106
+
107
+
108
+ class SegformerOverlapPatchEmbeddings(nn.Module):
109
+ """Construct the overlapping patch embeddings."""
110
+
111
+ def __init__(self, patch_size, stride, num_channels, hidden_size):
112
+ super().__init__()
113
+ self.proj = nn.Conv2d(
114
+ num_channels,
115
+ hidden_size,
116
+ kernel_size=patch_size,
117
+ stride=stride,
118
+ padding=patch_size // 2,
119
+ )
120
+
121
+ self.layer_norm = nn.LayerNorm(hidden_size)
122
+
123
+ def forward(self, pixel_values):
124
+ embeddings = self.proj(pixel_values)
125
+ _, _, height, width = embeddings.shape
126
+ # (batch_size, num_channels, height, width) -> (batch_size, num_channels, height*width) -> (batch_size, height*width, num_channels)
127
+ # this can be fed to a Transformer layer
128
+ embeddings = embeddings.flatten(2).transpose(1, 2)
129
+ embeddings = self.layer_norm(embeddings)
130
+ return embeddings, height, width
131
+
132
+
133
+ class SegformerEfficientSelfAttention(nn.Module):
134
+ """SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT
135
+ paper](https://arxiv.org/abs/2102.12122)."""
136
+
137
+ def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
138
+ super().__init__()
139
+ self.hidden_size = hidden_size
140
+ self.num_attention_heads = num_attention_heads
141
+
142
+ if self.hidden_size % self.num_attention_heads != 0:
143
+ raise ValueError(
144
+ f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
145
+ f"heads ({self.num_attention_heads})"
146
+ )
147
+
148
+ self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
149
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
150
+
151
+ self.query = nn.Linear(self.hidden_size, self.all_head_size)
152
+ self.key = nn.Linear(self.hidden_size, self.all_head_size)
153
+ self.value = nn.Linear(self.hidden_size, self.all_head_size)
154
+
155
+ self.sr_ratio = sequence_reduction_ratio
156
+ if sequence_reduction_ratio > 1:
157
+ self.sr = nn.Conv2d(
158
+ hidden_size, hidden_size, kernel_size=sequence_reduction_ratio, stride=sequence_reduction_ratio
159
+ )
160
+ self.layer_norm = nn.LayerNorm(hidden_size)
161
+
162
+ def transpose_for_scores(self, hidden_states):
163
+ new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
164
+ hidden_states = hidden_states.view(new_shape)
165
+ return hidden_states.permute(0, 2, 1, 3)
166
+
167
+ def forward(
168
+ self,
169
+ hidden_states,
170
+ height,
171
+ width,
172
+ output_attentions=False,
173
+ ):
174
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
175
+
176
+ if self.sr_ratio > 1:
177
+ batch_size, seq_len, num_channels = hidden_states.shape
178
+ # Reshape to (batch_size, num_channels, height, width)
179
+ hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
180
+ # Apply sequence reduction
181
+ hidden_states = self.sr(hidden_states)
182
+ # Reshape back to (batch_size, seq_len, num_channels)
183
+ hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1)
184
+ hidden_states = self.layer_norm(hidden_states)
185
+
186
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
187
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
188
+
189
+ # Take the dot product between "query" and "key" to get the raw attention scores.
190
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
191
+
192
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
193
+
194
+ # Normalize the attention scores to probabilities.
195
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
196
+
197
+ context_layer = torch.matmul(attention_probs, value_layer)
198
+
199
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
200
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
201
+ context_layer = context_layer.view(new_context_layer_shape)
202
+
203
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
204
+
205
+ return outputs
206
+
207
+ class SegformerEncoder(nn.Module):
208
+ def __init__(self, config):
209
+ super().__init__()
210
+ self.config = config
211
+
212
+ # patch embeddings
213
+ embeddings = []
214
+ for i in range(config.num_encoder_blocks):
215
+ embeddings.append(
216
+ SegformerOverlapPatchEmbeddings(
217
+ patch_size=config.patch_sizes[i],
218
+ stride=config.strides[i],
219
+ num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
220
+ hidden_size=config.hidden_sizes[i],
221
+ )
222
+ )
223
+ self.patch_embeddings = nn.ModuleList(embeddings)
224
+
225
+ # Transformer blocks
226
+ blocks = []
227
+ cur = 0
228
+ for i in range(config.num_encoder_blocks):
229
+ # each block consists of layers
230
+ layers = []
231
+ if i != 0:
232
+ cur += config.depths[i - 1]
233
+ for j in range(config.depths[i]):
234
+ layers.append(
235
+ SegformerLayer(
236
+ config,
237
+ hidden_size=config.hidden_sizes[i],
238
+ num_attention_heads=config.num_attention_heads[i],
239
+ sequence_reduction_ratio=config.sr_ratios[i],
240
+ mlp_ratio=config.mlp_ratios[i],
241
+ )
242
+ )
243
+ blocks.append(nn.ModuleList(layers))
244
+
245
+ self.block = nn.ModuleList(blocks)
246
+
247
+ # Layer norms
248
+ self.layer_norm = nn.ModuleList(
249
+ [nn.LayerNorm(config.hidden_sizes[i]) for i in range(config.num_encoder_blocks)]
250
+ )
251
+
252
+ def forward(
253
+ self,
254
+ pixel_values: torch.FloatTensor,
255
+ output_attentions: Optional[bool] = False,
256
+ output_hidden_states: Optional[bool] = False,
257
+ return_dict: Optional[bool] = True,
258
+ ) -> Union[Tuple, BaseModelOutput]:
259
+ all_hidden_states = () if output_hidden_states else None
260
+
261
+ batch_size = pixel_values.shape[0]
262
+
263
+ hidden_states = pixel_values
264
+ for idx, x in enumerate(zip(self.patch_embeddings, self.block, self.layer_norm)):
265
+ embedding_layer, block_layer, norm_layer = x
266
+ # first, obtain patch embeddings
267
+ hidden_states, height, width = embedding_layer(hidden_states)
268
+ # second, send embeddings through blocks
269
+ for i, blk in enumerate(block_layer):
270
+ layer_outputs = blk(hidden_states, height, width, output_attentions)
271
+ hidden_states = layer_outputs[0]
272
+ # third, apply layer norm
273
+ hidden_states = norm_layer(hidden_states)
274
+ # fourth, optionally reshape back to (batch_size, num_channels, height, width)
275
+ if idx != len(self.patch_embeddings) - 1 or (
276
+ idx == len(self.patch_embeddings) - 1 and self.config.reshape_last_stage
277
+ ):
278
+ hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous()
279
+ all_hidden_states = all_hidden_states + (hidden_states,)
280
+
281
+ return all_hidden_states
282
+
283
+ class SegformerSelfOutput(nn.Module):
284
+ def __init__(self, config, hidden_size):
285
+ super().__init__()
286
+ self.dense = nn.Linear(hidden_size, hidden_size)
287
+
288
+ def forward(self, hidden_states, input_tensor):
289
+ hidden_states = self.dense(hidden_states)
290
+ return hidden_states
291
+
292
+
293
+ class SegformerAttention(nn.Module):
294
+ def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
295
+ super().__init__()
296
+ self.self = SegformerEfficientSelfAttention(
297
+ config=config,
298
+ hidden_size=hidden_size,
299
+ num_attention_heads=num_attention_heads,
300
+ sequence_reduction_ratio=sequence_reduction_ratio,
301
+ )
302
+ self.output = SegformerSelfOutput(config, hidden_size=hidden_size)
303
+ self.pruned_heads = set()
304
+
305
+ def prune_heads(self, heads):
306
+ if len(heads) == 0:
307
+ return
308
+ heads, index = find_pruneable_heads_and_indices(
309
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
310
+ )
311
+
312
+ # Prune linear layers
313
+ self.self.query = prune_linear_layer(self.self.query, index)
314
+ self.self.key = prune_linear_layer(self.self.key, index)
315
+ self.self.value = prune_linear_layer(self.self.value, index)
316
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
317
+
318
+ # Update hyper params and store pruned heads
319
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
320
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
321
+ self.pruned_heads = self.pruned_heads.union(heads)
322
+
323
+ def forward(self, hidden_states, height, width, output_attentions=False):
324
+ self_outputs = self.self(hidden_states, height, width, output_attentions)
325
+
326
+ attention_output = self.output(self_outputs[0], hidden_states)
327
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
328
+ return outputs
329
+
330
+ class SegformerDWConv(nn.Module):
331
+ def __init__(self, dim=768):
332
+ super().__init__()
333
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
334
+
335
+ def forward(self, hidden_states, height, width):
336
+ batch_size, seq_len, num_channels = hidden_states.shape
337
+ hidden_states = hidden_states.transpose(1, 2).view(batch_size, num_channels, height, width)
338
+ hidden_states = self.dwconv(hidden_states)
339
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
340
+
341
+ return hidden_states
342
+
343
+
344
+ class SegformerMixFFN(nn.Module):
345
+ def __init__(self, config, in_features, hidden_features=None, out_features=None):
346
+ super().__init__()
347
+ out_features = out_features or in_features
348
+ self.dense1 = nn.Linear(in_features, hidden_features)
349
+ self.dwconv = SegformerDWConv(hidden_features)
350
+ if isinstance(config.hidden_act, str):
351
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
352
+ else:
353
+ self.intermediate_act_fn = config.hidden_act
354
+ self.dense2 = nn.Linear(hidden_features, out_features)
355
+
356
+ def forward(self, hidden_states, height, width):
357
+ hidden_states = self.dense1(hidden_states)
358
+ hidden_states = self.dwconv(hidden_states, height, width)
359
+ hidden_states = self.intermediate_act_fn(hidden_states)
360
+ hidden_states = self.dense2(hidden_states)
361
+ return hidden_states
362
+
363
+
364
+ class SegformerLayer(nn.Module):
365
+ """This corresponds to the Block class in the original implementation."""
366
+
367
+ def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio, mlp_ratio):
368
+ super().__init__()
369
+ self.layer_norm_1 = nn.LayerNorm(hidden_size)
370
+ self.attention = SegformerAttention(
371
+ config,
372
+ hidden_size=hidden_size,
373
+ num_attention_heads=num_attention_heads,
374
+ sequence_reduction_ratio=sequence_reduction_ratio,
375
+ )
376
+ self.layer_norm_2 = nn.LayerNorm(hidden_size)
377
+ mlp_hidden_size = int(hidden_size * mlp_ratio)
378
+ self.mlp = SegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size)
379
+
380
+ def forward(self, hidden_states, height, width, output_attentions=False):
381
+ self_attention_outputs = self.attention(
382
+ self.layer_norm_1(hidden_states), # in Segformer, layernorm is applied before self-attention
383
+ height,
384
+ width,
385
+ output_attentions=output_attentions,
386
+ )
387
+
388
+ attention_output = self_attention_outputs[0]
389
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
390
+
391
+ # first residual connection (with stochastic depth)
392
+ hidden_states = attention_output + hidden_states
393
+
394
+ mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width)
395
+
396
+ # second residual connection (with stochastic depth)
397
+ layer_output = mlp_output + hidden_states
398
+
399
+ outputs = (layer_output,) + outputs
400
+
401
+ return outputs
402
+
403
+ class SegformerModel(SegformerPreTrainedModel):
404
+ def __init__(self, config):
405
+ super().__init__(config)
406
+ self.config = config
407
+
408
+ # hierarchical Transformer encoder
409
+ self.encoder = SegformerEncoder(config)
410
+
411
+ # Initialize weights and apply final processing
412
+ self.post_init()
413
+
414
+ def _prune_heads(self, heads_to_prune):
415
+ """
416
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
417
+ class PreTrainedModel
418
+ """
419
+ for layer, heads in heads_to_prune.items():
420
+ self.encoder.layer[layer].attention.prune_heads(heads)
421
+
422
+ def forward(
423
+ self,
424
+ pixel_values: torch.FloatTensor,
425
+ output_attentions: Optional[bool] = None,
426
+ output_hidden_states: Optional[bool] = None,
427
+ return_dict: Optional[bool] = None,
428
+ ) -> Union[Tuple, BaseModelOutput]:
429
+ encoder_outputs = self.encoder(
430
+ pixel_values,
431
+ output_attentions=output_attentions,
432
+ output_hidden_states=output_hidden_states,
433
+ return_dict=return_dict,
434
+ )
435
+ return encoder_outputs
436
+
437
+ class SegformerForRegressionMask(SegformerForSemanticSegmentation):
438
+ def __init__(self, config, **kwargs):
439
+ super().__init__(config)
440
+ self.segformer = SegformerModel(config)
441
+ self.decode_head = SegformerForMaskDecodeHead(config)
442
+
443
+ # Initialize weights and apply final processing
444
+ self.post_init()
445
+
446
+ def forward(
447
+ self,
448
+ pixel_values: torch.FloatTensor,
449
+ **kwargs
450
+ ) -> Union[Tuple, SemanticSegmenterOutput]:
451
+
452
+ encoder_hidden_states = self.segformer(
453
+ pixel_values,
454
+ output_attentions=False,
455
+ output_hidden_states=True, # we need the intermediate hidden states
456
+ return_dict=False,
457
+ )
458
+
459
+ logits = self.decode_head(encoder_hidden_states)
460
+ # Apply sigmoid to get 0-1 output
461
+ sigmoid_logits = torch.special.expit(logits)
462
+
463
+ return SemanticSegmenterOutput(
464
+ loss=None,
465
+ logits=sigmoid_logits,
466
+ hidden_states=None,
467
+ attentions=None,
468
+ )
surya/model/ordering/config.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from transformers import MBartConfig, DonutSwinConfig
2
+
3
+
4
+ class MBartOrderConfig(MBartConfig):
5
+ pass
6
+
7
+ class VariableDonutSwinConfig(DonutSwinConfig):
8
+ pass
surya/model/ordering/decoder.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Optional, List, Union, Tuple
3
+
4
+ from transformers import MBartForCausalLM, MBartConfig
5
+ from torch import nn
6
+ from transformers.activations import ACT2FN
7
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_attention_mask
8
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions
9
+ from transformers.models.mbart.modeling_mbart import MBartPreTrainedModel, MBartDecoder, MBartLearnedPositionalEmbedding, MBartDecoderLayer
10
+ from surya.model.ordering.config import MBartOrderConfig
11
+ import torch
12
+ import math
13
+
14
+
15
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
16
+ """
17
+ From llama
18
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
19
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
20
+ """
21
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
22
+ if n_rep == 1:
23
+ return hidden_states
24
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
25
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
26
+
27
+
28
+ class MBartGQAttention(nn.Module):
29
+ def __init__(
30
+ self,
31
+ embed_dim: int,
32
+ num_heads: int,
33
+ num_kv_heads: int,
34
+ dropout: float = 0.0,
35
+ is_decoder: bool = False,
36
+ bias: bool = True,
37
+ is_causal: bool = False,
38
+ config: Optional[MBartConfig] = None,
39
+ ):
40
+ super().__init__()
41
+ self.embed_dim = embed_dim
42
+ self.num_heads = num_heads
43
+ self.num_kv_heads = num_kv_heads
44
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
45
+
46
+ assert self.num_heads % self.num_kv_heads == 0, f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
47
+ assert embed_dim % self.num_kv_heads == 0, f"embed_dim ({self.embed_dim}) must be divisible by num_kv_heads ({self.num_kv_heads})"
48
+
49
+ self.dropout = dropout
50
+ self.head_dim = embed_dim // num_heads
51
+ self.config = config
52
+
53
+ if (self.head_dim * num_heads) != self.embed_dim:
54
+ raise ValueError(
55
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
56
+ f" and `num_heads`: {num_heads})."
57
+ )
58
+ self.scaling = self.head_dim**-0.5
59
+ self.is_decoder = is_decoder
60
+ self.is_causal = is_causal
61
+
62
+ self.k_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias)
63
+ self.v_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias)
64
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
65
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
66
+
67
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
68
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
69
+
70
+ def _shape_key_value(self, tensor: torch.Tensor, seq_len: int, bsz: int):
71
+ return tensor.view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2).contiguous()
72
+
73
+ def forward(
74
+ self,
75
+ hidden_states: torch.Tensor,
76
+ key_value_states: Optional[torch.Tensor] = None,
77
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
78
+ attention_mask: Optional[torch.Tensor] = None,
79
+ layer_head_mask: Optional[torch.Tensor] = None,
80
+ output_attentions: bool = False,
81
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
82
+ """Input shape: Batch x Time x Channel"""
83
+
84
+ # if key_value_states are provided this layer is used as a cross-attention layer
85
+ # for the decoder
86
+ is_cross_attention = key_value_states is not None
87
+
88
+ bsz, tgt_len, _ = hidden_states.size()
89
+
90
+ # get query proj
91
+ query_states = self.q_proj(hidden_states) * self.scaling
92
+ # get key, value proj
93
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
94
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
95
+ # the provided `key_value_states` to support prefix tuning
96
+ if (
97
+ is_cross_attention
98
+ and past_key_value is not None
99
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
100
+ ):
101
+ # reuse k,v, cross_attentions
102
+ key_states = past_key_value[0]
103
+ value_states = past_key_value[1]
104
+ elif is_cross_attention:
105
+ # cross_attentions
106
+ key_states = self._shape_key_value(self.k_proj(key_value_states), -1, bsz)
107
+ value_states = self._shape_key_value(self.v_proj(key_value_states), -1, bsz)
108
+ elif past_key_value is not None:
109
+ # reuse k, v, self_attention
110
+ key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz)
111
+ value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz)
112
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
113
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
114
+ else:
115
+ # self_attention
116
+ key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz)
117
+ value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz)
118
+
119
+ if self.is_decoder:
120
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
121
+ # Further calls to cross_attention layer can then reuse all cross-attention
122
+ # key/value_states (first "if" case)
123
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
124
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
125
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
126
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
127
+ past_key_value = (key_states, value_states)
128
+
129
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
130
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
131
+
132
+ # Expand kv heads, then match query shape
133
+ key_states = repeat_kv(key_states, self.num_kv_groups)
134
+ value_states = repeat_kv(value_states, self.num_kv_groups)
135
+ key_states = key_states.reshape(*proj_shape)
136
+ value_states = value_states.reshape(*proj_shape)
137
+
138
+ src_len = key_states.size(1)
139
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
140
+
141
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
142
+ raise ValueError(
143
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
144
+ f" {attn_weights.size()}"
145
+ )
146
+
147
+ if attention_mask is not None:
148
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
149
+ raise ValueError(
150
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
151
+ )
152
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
153
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
154
+
155
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
156
+
157
+ if layer_head_mask is not None:
158
+ if layer_head_mask.size() != (self.num_heads,):
159
+ raise ValueError(
160
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
161
+ f" {layer_head_mask.size()}"
162
+ )
163
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
164
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
165
+
166
+ if output_attentions:
167
+ # this operation is a bit awkward, but it's required to
168
+ # make sure that attn_weights keeps its gradient.
169
+ # In order to do so, attn_weights have to be reshaped
170
+ # twice and have to be reused in the following
171
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
172
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
173
+ else:
174
+ attn_weights_reshaped = None
175
+
176
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
177
+
178
+ attn_output = torch.bmm(attn_probs, value_states)
179
+
180
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
181
+ raise ValueError(
182
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
183
+ f" {attn_output.size()}"
184
+ )
185
+
186
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
187
+ attn_output = attn_output.transpose(1, 2)
188
+
189
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
190
+ # partitioned across GPUs when using tensor-parallelism.
191
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
192
+
193
+ attn_output = self.out_proj(attn_output)
194
+
195
+ return attn_output, attn_weights_reshaped, past_key_value
196
+
197
+
198
+ MBART_ATTENTION_CLASSES = {
199
+ "eager": MBartGQAttention,
200
+ "flash_attention_2": None
201
+ }
202
+
203
+
204
+ class MBartOrderDecoderLayer(MBartDecoderLayer):
205
+ def __init__(self, config: MBartConfig):
206
+ nn.Module.__init__(self)
207
+ self.embed_dim = config.d_model
208
+
209
+ self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
210
+ embed_dim=self.embed_dim,
211
+ num_heads=config.decoder_attention_heads,
212
+ num_kv_heads=config.kv_heads,
213
+ dropout=config.attention_dropout,
214
+ is_decoder=True,
215
+ is_causal=True,
216
+ config=config,
217
+ )
218
+ self.dropout = config.dropout
219
+ self.activation_fn = ACT2FN[config.activation_function]
220
+ self.activation_dropout = config.activation_dropout
221
+
222
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
223
+ self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
224
+ self.embed_dim,
225
+ config.decoder_attention_heads,
226
+ num_kv_heads=config.kv_heads,
227
+ dropout=config.attention_dropout,
228
+ is_decoder=True,
229
+ config=config,
230
+ )
231
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
232
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
233
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
234
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
235
+
236
+
237
+ class BboxEmbedding(nn.Module):
238
+ def __init__(self, config):
239
+ super().__init__()
240
+ self.x1_embed = nn.Embedding(config.max_width, config.d_model)
241
+ self.y1_embed = nn.Embedding(config.max_height, config.d_model)
242
+ self.x2_embed = nn.Embedding(config.max_width, config.d_model)
243
+ self.y2_embed = nn.Embedding(config.max_height, config.d_model)
244
+ self.w_embed = nn.Embedding(config.max_width, config.d_model)
245
+ self.h_embed = nn.Embedding(config.max_height, config.d_model)
246
+ self.cx_embed = nn.Embedding(config.max_width, config.d_model)
247
+ self.cy_embed = nn.Embedding(config.max_height, config.d_model)
248
+ self.box_pos_embed = nn.Embedding(config.max_position_embeddings, config.d_model)
249
+
250
+ def forward(self, boxes: torch.LongTensor, input_box_counts: torch.LongTensor, past_key_values_length: int):
251
+ x1, y1, x2, y2 = boxes.unbind(dim=-1)
252
+ # Shape is (batch_size, num_boxes/seq len, d_model)
253
+ w = x2 - x1
254
+ h = y2 - y1
255
+ # Center x and y in torch long tensors
256
+ cx = (x1 + x2) / 2
257
+ cy = (y1 + y2) / 2
258
+ cx = cx.long()
259
+ cy = cy.long()
260
+
261
+ coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2)
262
+ embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy)
263
+
264
+ # Add in positional embeddings for the boxes
265
+ if past_key_values_length == 0:
266
+ for j in range(embedded.shape[0]):
267
+ box_start = input_box_counts[j, 0]
268
+ box_end = input_box_counts[j, 1] - 1 # Skip the sep token
269
+ box_count = box_end - box_start
270
+ embedded[j, box_start:box_end] = embedded[j, box_start:box_end] + self.box_pos_embed.weight[:box_count]
271
+
272
+ return embedded
273
+
274
+
275
+ class MBartOrderDecoder(MBartDecoder):
276
+ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
277
+ MBartPreTrainedModel.__init__(self, config)
278
+ self.dropout = config.dropout
279
+ self.layerdrop = config.decoder_layerdrop
280
+ self.padding_idx = config.pad_token_id
281
+ self.max_target_positions = config.max_position_embeddings
282
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
283
+
284
+ self.embed_tokens = BboxEmbedding(config) if embed_tokens is None else embed_tokens
285
+
286
+ if embed_tokens is not None:
287
+ self.embed_tokens.weight = embed_tokens.weight
288
+
289
+ self.embed_positions = MBartLearnedPositionalEmbedding(
290
+ config.max_position_embeddings,
291
+ config.d_model,
292
+ )
293
+ # Language-specific MoE goes at second and second-to-last layer
294
+ self.layers = nn.ModuleList([MBartOrderDecoderLayer(config) for _ in range(config.decoder_layers)])
295
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
296
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
297
+ self.layer_norm = nn.LayerNorm(config.d_model)
298
+
299
+ self.gradient_checkpointing = False
300
+ # Initialize weights and apply final processing
301
+ self.post_init()
302
+
303
+ def forward(
304
+ self,
305
+ input_boxes: torch.LongTensor = None,
306
+ input_boxes_mask: Optional[torch.Tensor] = None,
307
+ input_boxes_counts: Optional[torch.Tensor] = None,
308
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
309
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
310
+ head_mask: Optional[torch.Tensor] = None,
311
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
312
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
313
+ inputs_embeds: Optional[torch.FloatTensor] = None,
314
+ use_cache: Optional[bool] = None,
315
+ output_attentions: Optional[bool] = None,
316
+ output_hidden_states: Optional[bool] = None,
317
+ return_dict: Optional[bool] = None,
318
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
319
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
320
+ output_hidden_states = (
321
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
322
+ )
323
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
324
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
325
+
326
+ # retrieve input_ids and inputs_embeds
327
+ if input_boxes is not None and inputs_embeds is not None:
328
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
329
+ elif input_boxes is not None:
330
+ input = input_boxes
331
+ input_shape = input_boxes.size()[:-1] # Shape (batch_size, num_boxes)
332
+ elif inputs_embeds is not None:
333
+ input_shape = inputs_embeds.size()[:-1]
334
+ input = inputs_embeds[:, :, -1]
335
+ else:
336
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
337
+
338
+ # past_key_values_length
339
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
340
+
341
+ if inputs_embeds is None:
342
+ inputs_embeds = self.embed_tokens(input_boxes, input_boxes_counts, past_key_values_length) * self.embed_scale
343
+
344
+ if self._use_flash_attention_2:
345
+ # 2d mask is passed through the layers
346
+ attention_mask = input_boxes_mask if (input_boxes_mask is not None and 0 in input_boxes_mask) else None
347
+ else:
348
+ # 4d mask is passed through the layers
349
+ attention_mask = _prepare_4d_causal_attention_mask(
350
+ input_boxes_mask, input_shape, inputs_embeds, past_key_values_length
351
+ )
352
+
353
+ if past_key_values_length == 0:
354
+ box_ends = input_boxes_counts[:, 1]
355
+ box_starts = input_boxes_counts[:, 0]
356
+ input_shape_arranged = torch.arange(input_shape[1], device=attention_mask.device)[None, :]
357
+ # Enable all boxes to attend to each other (before the sep token)
358
+ # Ensure that the boxes are not attending to the padding tokens
359
+ boxes_end_mask = input_shape_arranged < box_ends[:, None]
360
+ boxes_start_mask = input_shape_arranged >= box_starts[:, None]
361
+ boxes_mask = boxes_end_mask & boxes_start_mask
362
+ boxes_mask = boxes_mask.unsqueeze(1).unsqueeze(1) # Enable proper broadcasting
363
+ attention_mask = attention_mask.masked_fill(boxes_mask, 0)
364
+
365
+ # expand encoder attention mask
366
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
367
+ if self._use_flash_attention_2:
368
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
369
+ else:
370
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
371
+ encoder_attention_mask = _prepare_4d_attention_mask(
372
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
373
+ )
374
+
375
+ # embed positions
376
+ positions = self.embed_positions(input, past_key_values_length)
377
+
378
+ hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
379
+ hidden_states = self.layernorm_embedding(hidden_states)
380
+
381
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
382
+
383
+ if self.gradient_checkpointing and self.training:
384
+ if use_cache:
385
+ use_cache = False
386
+
387
+ # decoder layers
388
+ all_hidden_states = () if output_hidden_states else None
389
+ all_self_attns = () if output_attentions else None
390
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
391
+ next_decoder_cache = () if use_cache else None
392
+
393
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
394
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
395
+ if attn_mask is not None:
396
+ if attn_mask.size()[0] != len(self.layers):
397
+ raise ValueError(
398
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
399
+ f" {attn_mask.size()[0]}."
400
+ )
401
+ for idx, decoder_layer in enumerate(self.layers):
402
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
403
+ if output_hidden_states:
404
+ all_hidden_states += (hidden_states,)
405
+ if self.training:
406
+ dropout_probability = torch.rand([])
407
+ if dropout_probability < self.layerdrop:
408
+ continue
409
+
410
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
411
+
412
+ if self.gradient_checkpointing and self.training:
413
+ layer_outputs = self._gradient_checkpointing_func(
414
+ decoder_layer.__call__,
415
+ hidden_states,
416
+ attention_mask,
417
+ encoder_hidden_states,
418
+ encoder_attention_mask,
419
+ head_mask[idx] if head_mask is not None else None,
420
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
421
+ None,
422
+ output_attentions,
423
+ use_cache,
424
+ )
425
+ else:
426
+ layer_outputs = decoder_layer(
427
+ hidden_states,
428
+ attention_mask=attention_mask,
429
+ encoder_hidden_states=encoder_hidden_states,
430
+ encoder_attention_mask=encoder_attention_mask,
431
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
432
+ cross_attn_layer_head_mask=(
433
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
434
+ ),
435
+ past_key_value=past_key_value,
436
+ output_attentions=output_attentions,
437
+ use_cache=use_cache,
438
+ )
439
+ hidden_states = layer_outputs[0]
440
+
441
+ if use_cache:
442
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
443
+
444
+ if output_attentions:
445
+ all_self_attns += (layer_outputs[1],)
446
+
447
+ if encoder_hidden_states is not None:
448
+ all_cross_attentions += (layer_outputs[2],)
449
+
450
+ hidden_states = self.layer_norm(hidden_states)
451
+
452
+ # add hidden states from the last decoder layer
453
+ if output_hidden_states:
454
+ all_hidden_states += (hidden_states,)
455
+
456
+ next_cache = next_decoder_cache if use_cache else None
457
+ if not return_dict:
458
+ return tuple(
459
+ v
460
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
461
+ if v is not None
462
+ )
463
+ return BaseModelOutputWithPastAndCrossAttentions(
464
+ last_hidden_state=hidden_states,
465
+ past_key_values=next_cache,
466
+ hidden_states=all_hidden_states,
467
+ attentions=all_self_attns,
468
+ cross_attentions=all_cross_attentions,
469
+ )
470
+
471
+
472
+ class MBartOrderDecoderWrapper(MBartPreTrainedModel):
473
+ """
474
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
475
+ used in combination with the [`EncoderDecoderModel`] framework.
476
+ """
477
+
478
+ def __init__(self, config):
479
+ super().__init__(config)
480
+ self.decoder = MBartOrderDecoder(config)
481
+
482
+ def forward(self, *args, **kwargs):
483
+ return self.decoder(*args, **kwargs)
484
+
485
+
486
+ class MBartOrder(MBartForCausalLM):
487
+ config_class = MBartOrderConfig
488
+ _tied_weights_keys = []
489
+
490
+ def __init__(self, config, **kwargs):
491
+ config = copy.deepcopy(config)
492
+ config.is_decoder = True
493
+ config.is_encoder_decoder = False
494
+ MBartPreTrainedModel.__init__(self, config)
495
+ self.model = MBartOrderDecoderWrapper(config)
496
+
497
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
498
+
499
+ # Initialize weights and apply final processing
500
+ self.post_init()
501
+
502
+ def forward(
503
+ self,
504
+ input_boxes: torch.LongTensor = None,
505
+ input_boxes_mask: Optional[torch.Tensor] = None,
506
+ input_boxes_counts: Optional[torch.Tensor] = None,
507
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
508
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
509
+ head_mask: Optional[torch.Tensor] = None,
510
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
511
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
512
+ inputs_embeds: Optional[torch.FloatTensor] = None,
513
+ labels: Optional[torch.LongTensor] = None,
514
+ use_cache: Optional[bool] = None,
515
+ output_attentions: Optional[bool] = None,
516
+ output_hidden_states: Optional[bool] = None,
517
+ return_dict: Optional[bool] = None,
518
+ **kwargs
519
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
520
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
521
+ output_hidden_states = (
522
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
523
+ )
524
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
525
+
526
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
527
+ outputs = self.model.decoder(
528
+ input_boxes=input_boxes,
529
+ input_boxes_mask=input_boxes_mask,
530
+ input_boxes_counts=input_boxes_counts,
531
+ encoder_hidden_states=encoder_hidden_states,
532
+ encoder_attention_mask=encoder_attention_mask,
533
+ head_mask=head_mask,
534
+ cross_attn_head_mask=cross_attn_head_mask,
535
+ past_key_values=past_key_values,
536
+ inputs_embeds=inputs_embeds,
537
+ use_cache=use_cache,
538
+ output_attentions=output_attentions,
539
+ output_hidden_states=output_hidden_states,
540
+ return_dict=return_dict,
541
+ )
542
+
543
+ logits = self.lm_head(outputs[0])
544
+
545
+ loss = None
546
+ if not return_dict:
547
+ output = (logits,) + outputs[1:]
548
+ return (loss,) + output if loss is not None else output
549
+
550
+ return CausalLMOutputWithCrossAttentions(
551
+ loss=loss,
552
+ logits=logits,
553
+ past_key_values=outputs.past_key_values,
554
+ hidden_states=outputs.hidden_states,
555
+ attentions=outputs.attentions,
556
+ cross_attentions=outputs.cross_attentions,
557
+ )
surya/model/ordering/encoder.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ from typing import Optional, Tuple, Union
4
+ import collections
5
+ import math
6
+
7
+ from transformers import DonutSwinPreTrainedModel
8
+ from transformers.models.donut.modeling_donut_swin import DonutSwinPatchEmbeddings, DonutSwinEmbeddings, DonutSwinModel, \
9
+ DonutSwinEncoder
10
+
11
+ from surya.model.ordering.config import VariableDonutSwinConfig
12
+
13
+ class VariableDonutSwinEmbeddings(DonutSwinEmbeddings):
14
+ """
15
+ Construct the patch and position embeddings. Optionally, also the mask token.
16
+ """
17
+
18
+ def __init__(self, config, use_mask_token=False, **kwargs):
19
+ super().__init__(config, use_mask_token)
20
+
21
+ self.patch_embeddings = DonutSwinPatchEmbeddings(config)
22
+ num_patches = self.patch_embeddings.num_patches
23
+ self.patch_grid = self.patch_embeddings.grid_size
24
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
25
+ self.position_embeddings = None
26
+
27
+ if config.use_absolute_embeddings:
28
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
29
+
30
+ self.row_embeddings = None
31
+ self.column_embeddings = None
32
+ if config.use_2d_embeddings:
33
+ self.row_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim))
34
+ self.column_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim))
35
+
36
+ self.norm = nn.LayerNorm(config.embed_dim)
37
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
38
+
39
+ def forward(
40
+ self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None, **kwargs
41
+ ) -> Tuple[torch.Tensor]:
42
+
43
+ embeddings, output_dimensions = self.patch_embeddings(pixel_values)
44
+ # Layernorm across the last dimension (each patch is a single row)
45
+ embeddings = self.norm(embeddings)
46
+ batch_size, seq_len, embed_dim = embeddings.size()
47
+
48
+ if bool_masked_pos is not None:
49
+ mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
50
+ # replace the masked visual tokens by mask_tokens
51
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
52
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
53
+
54
+ if self.position_embeddings is not None:
55
+ embeddings = embeddings + self.position_embeddings[:, :seq_len, :]
56
+
57
+ if self.row_embeddings is not None and self.column_embeddings is not None:
58
+ # Repeat the x position embeddings across the y axis like 0, 1, 2, 3, 0, 1, 2, 3, ...
59
+ row_embeddings = self.row_embeddings[:, :output_dimensions[0], :].repeat_interleave(output_dimensions[1], dim=1)
60
+ column_embeddings = self.column_embeddings[:, :output_dimensions[1], :].repeat(1, output_dimensions[0], 1)
61
+
62
+ embeddings = embeddings + row_embeddings + column_embeddings
63
+
64
+ embeddings = self.dropout(embeddings)
65
+
66
+ return embeddings, output_dimensions
67
+
68
+
69
+ class VariableDonutSwinModel(DonutSwinModel):
70
+ config_class = VariableDonutSwinConfig
71
+ def __init__(self, config, add_pooling_layer=True, use_mask_token=False, **kwargs):
72
+ super().__init__(config)
73
+ self.config = config
74
+ self.num_layers = len(config.depths)
75
+ self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
76
+
77
+ self.embeddings = VariableDonutSwinEmbeddings(config, use_mask_token=use_mask_token)
78
+ self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid)
79
+
80
+ self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
81
+
82
+ # Initialize weights and apply final processing
83
+ self.post_init()
surya/model/ordering/encoderdecoder.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Tuple, List
2
+
3
+ import torch
4
+ from transformers import VisionEncoderDecoderModel
5
+ from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
6
+
7
+
8
+ class OrderVisionEncoderDecoderModel(VisionEncoderDecoderModel):
9
+ def forward(
10
+ self,
11
+ pixel_values: Optional[torch.FloatTensor] = None,
12
+ decoder_input_boxes: torch.LongTensor = None,
13
+ # Shape (batch_size, num_boxes, 4), all coords scaled 0 - 1000, with 1001 as padding
14
+ decoder_input_boxes_mask: torch.LongTensor = None, # Shape (batch_size, num_boxes), 0 if padding, 1 otherwise
15
+ decoder_input_boxes_counts: torch.LongTensor = None, # Shape (batch_size), number of boxes in each image
16
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
17
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
18
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
19
+ labels: Optional[List[List[int]]] = None,
20
+ use_cache: Optional[bool] = None,
21
+ output_attentions: Optional[bool] = None,
22
+ output_hidden_states: Optional[bool] = None,
23
+ return_dict: Optional[bool] = None,
24
+ **kwargs,
25
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
26
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
27
+
28
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
29
+
30
+ kwargs_decoder = {
31
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
32
+ }
33
+
34
+ if encoder_outputs is None:
35
+ if pixel_values is None:
36
+ raise ValueError("You have to specify pixel_values")
37
+
38
+ encoder_outputs = self.encoder(
39
+ pixel_values=pixel_values,
40
+ output_attentions=output_attentions,
41
+ output_hidden_states=output_hidden_states,
42
+ return_dict=return_dict,
43
+ **kwargs_encoder,
44
+ )
45
+ elif isinstance(encoder_outputs, tuple):
46
+ encoder_outputs = BaseModelOutput(*encoder_outputs)
47
+
48
+ encoder_hidden_states = encoder_outputs[0]
49
+
50
+ # optionally project encoder_hidden_states
51
+ if (
52
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
53
+ and self.decoder.config.cross_attention_hidden_size is None
54
+ ):
55
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
56
+
57
+ # else:
58
+ encoder_attention_mask = None
59
+
60
+ # Decode
61
+ decoder_outputs = self.decoder(
62
+ input_boxes=decoder_input_boxes,
63
+ input_boxes_mask=decoder_input_boxes_mask,
64
+ input_boxes_counts=decoder_input_boxes_counts,
65
+ encoder_hidden_states=encoder_hidden_states,
66
+ encoder_attention_mask=encoder_attention_mask,
67
+ inputs_embeds=decoder_inputs_embeds,
68
+ output_attentions=output_attentions,
69
+ output_hidden_states=output_hidden_states,
70
+ use_cache=use_cache,
71
+ past_key_values=past_key_values,
72
+ return_dict=return_dict,
73
+ labels=labels,
74
+ **kwargs_decoder,
75
+ )
76
+
77
+ if not return_dict:
78
+ return decoder_outputs + encoder_outputs
79
+
80
+ return Seq2SeqLMOutput(
81
+ loss=decoder_outputs.loss,
82
+ logits=decoder_outputs.logits,
83
+ past_key_values=decoder_outputs.past_key_values,
84
+ decoder_hidden_states=decoder_outputs.hidden_states,
85
+ decoder_attentions=decoder_outputs.attentions,
86
+ cross_attentions=decoder_outputs.cross_attentions,
87
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
88
+ encoder_hidden_states=encoder_outputs.hidden_states,
89
+ encoder_attentions=encoder_outputs.attentions,
90
+ )
surya/model/ordering/model.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DetrConfig, BeitConfig, DetrImageProcessor, VisionEncoderDecoderConfig, AutoModelForCausalLM, \
2
+ AutoModel
3
+ from surya.model.ordering.config import MBartOrderConfig, VariableDonutSwinConfig
4
+ from surya.model.ordering.decoder import MBartOrder
5
+ from surya.model.ordering.encoder import VariableDonutSwinModel
6
+ from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel
7
+ from surya.model.ordering.processor import OrderImageProcessor
8
+ from surya.settings import settings
9
+
10
+
11
+ def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
12
+ config = VisionEncoderDecoderConfig.from_pretrained(checkpoint)
13
+
14
+ decoder_config = vars(config.decoder)
15
+ decoder = MBartOrderConfig(**decoder_config)
16
+ config.decoder = decoder
17
+
18
+ encoder_config = vars(config.encoder)
19
+ encoder = VariableDonutSwinConfig(**encoder_config)
20
+ config.encoder = encoder
21
+
22
+ # Get transformers to load custom model
23
+ AutoModel.register(MBartOrderConfig, MBartOrder)
24
+ AutoModelForCausalLM.register(MBartOrderConfig, MBartOrder)
25
+ AutoModel.register(VariableDonutSwinConfig, VariableDonutSwinModel)
26
+
27
+ model = OrderVisionEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype)
28
+ assert isinstance(model.decoder, MBartOrder)
29
+ assert isinstance(model.encoder, VariableDonutSwinModel)
30
+
31
+ model = model.to(device)
32
+ model = model.eval()
33
+ print(f"Loaded reading order model {checkpoint} on device {device} with dtype {dtype}")
34
+ return model
surya/model/ordering/processor.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from typing import Dict, Union, Optional, List, Tuple
3
+
4
+ import torch
5
+ from torch import TensorType
6
+ from transformers import DonutImageProcessor, DonutProcessor
7
+ from transformers.image_processing_utils import BatchFeature
8
+ from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, \
9
+ valid_images, to_numpy_array
10
+ import numpy as np
11
+ from PIL import Image
12
+ import PIL
13
+ from surya.settings import settings
14
+
15
+
16
+ def load_processor(checkpoint=settings.ORDER_MODEL_CHECKPOINT):
17
+ processor = OrderImageProcessor.from_pretrained(checkpoint)
18
+ processor.size = settings.ORDER_IMAGE_SIZE
19
+ box_size = 1024
20
+ max_tokens = 256
21
+ processor.token_sep_id = max_tokens + box_size + 1
22
+ processor.token_pad_id = max_tokens + box_size + 2
23
+ processor.max_boxes = settings.ORDER_MAX_BOXES - 1
24
+ processor.box_size = {"height": box_size, "width": box_size}
25
+ return processor
26
+
27
+
28
+ class OrderImageProcessor(DonutImageProcessor):
29
+ def __init__(self, *args, **kwargs):
30
+ super().__init__(*args, **kwargs)
31
+
32
+ self.patch_size = kwargs.get("patch_size", (4, 4))
33
+
34
+ def process_inner(self, images: List[np.ndarray]):
35
+ images = [img.transpose(2, 0, 1) for img in images] # convert to CHW format
36
+
37
+ assert images[0].shape[0] == 3 # RGB input images, channel dim last
38
+
39
+ # Convert to float32 for rescale/normalize
40
+ images = [img.astype(np.float32) for img in images]
41
+
42
+ # Rescale and normalize
43
+ images = [
44
+ self.rescale(img, scale=self.rescale_factor, input_data_format=ChannelDimension.FIRST)
45
+ for img in images
46
+ ]
47
+ images = [
48
+ self.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST)
49
+ for img in images
50
+ ]
51
+
52
+ return images
53
+
54
+ def process_boxes(self, boxes):
55
+ padded_boxes = []
56
+ box_masks = []
57
+ box_counts = []
58
+ for b in boxes:
59
+ # Left pad for generation
60
+ padded_b = deepcopy(b)
61
+ padded_b.append([self.token_sep_id] * 4) # Sep token to indicate start of label predictions
62
+ padded_boxes.append(padded_b)
63
+
64
+ max_boxes = max(len(b) for b in padded_boxes)
65
+ for i in range(len(padded_boxes)):
66
+ pad_len = max_boxes - len(padded_boxes[i])
67
+ box_len = len(padded_boxes[i])
68
+ box_mask = [0] * pad_len + [1] * box_len
69
+ padded_box = [[self.token_pad_id] * 4] * pad_len + padded_boxes[i]
70
+ padded_boxes[i] = padded_box
71
+ box_masks.append(box_mask)
72
+ box_counts.append([pad_len, max_boxes])
73
+
74
+ return padded_boxes, box_masks, box_counts
75
+
76
+ def resize_img_and_boxes(self, img, boxes):
77
+ orig_dim = img.size
78
+ new_size = (self.size["width"], self.size["height"])
79
+ img.thumbnail(new_size, Image.Resampling.LANCZOS) # Shrink largest dimension to fit new size
80
+ img = img.resize(new_size, Image.Resampling.LANCZOS) # Stretch smaller dimension to fit new size
81
+
82
+ img = np.asarray(img, dtype=np.uint8)
83
+
84
+ width, height = orig_dim
85
+ box_width, box_height = self.box_size["width"], self.box_size["height"]
86
+ for box in boxes:
87
+ # Rescale to 0-1024
88
+ box[0] = box[0] / width * box_width
89
+ box[1] = box[1] / height * box_height
90
+ box[2] = box[2] / width * box_width
91
+ box[3] = box[3] / height * box_height
92
+
93
+ if box[0] < 0:
94
+ box[0] = 0
95
+ if box[1] < 0:
96
+ box[1] = 0
97
+ if box[2] > box_width:
98
+ box[2] = box_width
99
+ if box[3] > box_height:
100
+ box[3] = box_height
101
+
102
+ return img, boxes
103
+
104
+ def preprocess(
105
+ self,
106
+ images: ImageInput,
107
+ boxes: List[List[int]],
108
+ do_resize: bool = None,
109
+ size: Dict[str, int] = None,
110
+ resample: PILImageResampling = None,
111
+ do_thumbnail: bool = None,
112
+ do_align_long_axis: bool = None,
113
+ do_pad: bool = None,
114
+ random_padding: bool = False,
115
+ do_rescale: bool = None,
116
+ rescale_factor: float = None,
117
+ do_normalize: bool = None,
118
+ image_mean: Optional[Union[float, List[float]]] = None,
119
+ image_std: Optional[Union[float, List[float]]] = None,
120
+ return_tensors: Optional[Union[str, TensorType]] = None,
121
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
122
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
123
+ **kwargs,
124
+ ) -> PIL.Image.Image:
125
+ images = make_list_of_images(images)
126
+
127
+ if not valid_images(images):
128
+ raise ValueError(
129
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
130
+ "torch.Tensor, tf.Tensor or jax.ndarray."
131
+ )
132
+
133
+ new_images = []
134
+ new_boxes = []
135
+ for img, box in zip(images, boxes):
136
+ if len(box) > self.max_boxes:
137
+ raise ValueError(f"Too many boxes, max is {self.max_boxes}")
138
+ img, box = self.resize_img_and_boxes(img, box)
139
+ new_images.append(img)
140
+ new_boxes.append(box)
141
+
142
+ images = new_images
143
+ boxes = new_boxes
144
+
145
+ # Convert to numpy for later processing steps
146
+ images = [np.array(image) for image in images]
147
+
148
+ images = self.process_inner(images)
149
+ boxes, box_mask, box_counts = self.process_boxes(boxes)
150
+ data = {
151
+ "pixel_values": images,
152
+ "input_boxes": boxes,
153
+ "input_boxes_mask": box_mask,
154
+ "input_boxes_counts": box_counts,
155
+ }
156
+ return BatchFeature(data=data, tensor_type=return_tensors)
surya/model/recognition/__pycache__/config.cpython-310.pyc ADDED
Binary file (2.41 kB). View file
 
surya/model/recognition/__pycache__/decoder.cpython-310.pyc ADDED
Binary file (14.2 kB). View file
 
surya/model/recognition/__pycache__/encoder.cpython-310.pyc ADDED
Binary file (13.2 kB). View file
 
surya/model/recognition/__pycache__/model.cpython-310.pyc ADDED
Binary file (2.56 kB). View file
 
surya/model/recognition/__pycache__/processor.cpython-310.pyc ADDED
Binary file (7.14 kB). View file
 
surya/model/recognition/__pycache__/tokenizer.cpython-310.pyc ADDED
Binary file (3.5 kB). View file
 
surya/model/recognition/config.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import T5Config, MBartConfig, DonutSwinConfig
2
+
3
+
4
+ class MBartMoEConfig(MBartConfig):
5
+ pass
6
+
7
+
8
+ class VariableDonutSwinConfig(DonutSwinConfig):
9
+ pass
10
+
11
+
12
+ # Config specific to the model, needed for the tokenizer
13
+ TOTAL_TOKENS = 65536
14
+ TOKEN_OFFSET = 3 # Pad, eos, bos
15
+ SPECIAL_TOKENS = 253
16
+ TOTAL_VOCAB_SIZE = TOTAL_TOKENS + TOKEN_OFFSET + SPECIAL_TOKENS
17
+ LANGUAGE_MAP = {
18
+ 'af': 0,
19
+ 'am': 1,
20
+ 'ar': 2,
21
+ 'as': 3,
22
+ 'az': 4,
23
+ 'be': 5,
24
+ 'bg': 6,
25
+ 'bn': 7,
26
+ 'br': 8,
27
+ 'bs': 9,
28
+ 'ca': 10,
29
+ 'cs': 11,
30
+ 'cy': 12,
31
+ 'da': 13,
32
+ 'de': 14,
33
+ 'el': 15,
34
+ 'en': 16,
35
+ 'eo': 17,
36
+ 'es': 18,
37
+ 'et': 19,
38
+ 'eu': 20,
39
+ 'fa': 21,
40
+ 'fi': 22,
41
+ 'fr': 23,
42
+ 'fy': 24,
43
+ 'ga': 25,
44
+ 'gd': 26,
45
+ 'gl': 27,
46
+ 'gu': 28,
47
+ 'ha': 29,
48
+ 'he': 30,
49
+ 'hi': 31,
50
+ 'hr': 32,
51
+ 'hu': 33,
52
+ 'hy': 34,
53
+ 'id': 35,
54
+ 'is': 36,
55
+ 'it': 37,
56
+ 'ja': 38,
57
+ 'jv': 39,
58
+ 'ka': 40,
59
+ 'kk': 41,
60
+ 'km': 42,
61
+ 'kn': 43,
62
+ 'ko': 44,
63
+ 'ku': 45,
64
+ 'ky': 46,
65
+ 'la': 47,
66
+ 'lo': 48,
67
+ 'lt': 49,
68
+ 'lv': 50,
69
+ 'mg': 51,
70
+ 'mk': 52,
71
+ 'ml': 53,
72
+ 'mn': 54,
73
+ 'mr': 55,
74
+ 'ms': 56,
75
+ 'my': 57,
76
+ 'ne': 58,
77
+ 'nl': 59,
78
+ 'no': 60,
79
+ 'om': 61,
80
+ 'or': 62,
81
+ 'pa': 63,
82
+ 'pl': 64,
83
+ 'ps': 65,
84
+ 'pt': 66,
85
+ 'ro': 67,
86
+ 'ru': 68,
87
+ 'sa': 69,
88
+ 'sd': 70,
89
+ 'si': 71,
90
+ 'sk': 72,
91
+ 'sl': 73,
92
+ 'so': 74,
93
+ 'sq': 75,
94
+ 'sr': 76,
95
+ 'su': 77,
96
+ 'sv': 78,
97
+ 'sw': 79,
98
+ 'ta': 80,
99
+ 'te': 81,
100
+ 'th': 82,
101
+ 'tl': 83,
102
+ 'tr': 84,
103
+ 'ug': 85,
104
+ 'uk': 86,
105
+ 'ur': 87,
106
+ 'uz': 88,
107
+ 'vi': 89,
108
+ 'xh': 90,
109
+ 'yi': 91,
110
+ 'zh': 92
111
+ }
surya/model/recognition/decoder.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Optional, List, Union, Tuple
3
+
4
+ from transformers import MBartForCausalLM, MBartConfig
5
+ from torch import nn
6
+ from transformers.activations import ACT2FN
7
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions
8
+ from transformers.models.mbart.modeling_mbart import MBartPreTrainedModel, MBartDecoder
9
+ from .config import MBartMoEConfig
10
+ import torch
11
+ import math
12
+
13
+
14
+ class MBartLearnedPositionalEmbedding(nn.Embedding):
15
+ """
16
+ This module learns positional embeddings up to a fixed maximum size.
17
+ """
18
+
19
+ def __init__(self, num_embeddings: int, embedding_dim: int):
20
+ # MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
21
+ # and adjust num_embeddings appropriately. Other models don't have this hack
22
+ self.offset = 2
23
+ super().__init__(num_embeddings + self.offset, embedding_dim)
24
+
25
+ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
26
+ """`input_ids' shape is expected to be [bsz x seqlen]."""
27
+
28
+ bsz, seq_len = input_ids.shape[:2]
29
+ positions = torch.arange(
30
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
31
+ ).expand(bsz, -1)
32
+
33
+ return super().forward(positions + self.offset)
34
+
35
+
36
+ class MBartExpertMLP(nn.Module):
37
+ def __init__(self, config: MBartConfig, is_lg=False, is_xl=False):
38
+ super().__init__()
39
+ self.ffn_dim = config.d_expert
40
+ if is_lg:
41
+ self.ffn_dim = config.d_expert_lg
42
+ if is_xl:
43
+ self.ffn_dim = config.d_expert_xl
44
+ self.hidden_dim = config.d_model
45
+
46
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
47
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
48
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
49
+ self.dropout = nn.Dropout(config.activation_dropout)
50
+
51
+ self.act_fn = ACT2FN[config.activation_function]
52
+
53
+ def forward(self, hidden_states):
54
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
55
+ current_hidden_states = self.w2(current_hidden_states)
56
+ return current_hidden_states
57
+
58
+
59
+ class MBartExpertLayer(nn.Module):
60
+ # From mixtral, with modifications
61
+ def __init__(self, config):
62
+ super().__init__()
63
+ self.dropout = nn.Dropout(config.activation_dropout)
64
+
65
+ self.hidden_dim = config.d_model
66
+
67
+ self.lg_lang_codes = sorted(config.lg_langs.values()) if hasattr(config, "lg_langs") else []
68
+ self.xl_lang_codes = sorted(config.xl_langs.values()) if hasattr(config, "xl_langs") else []
69
+
70
+ self.lang_codes = sorted(config.langs.values())
71
+ self.num_experts = len(self.lang_codes)
72
+
73
+ self.experts = nn.ModuleDict({str(lang): MBartExpertMLP(config, is_lg=(lang in self.lg_lang_codes), is_xl=(lang in self.xl_lang_codes)) for lang in self.lang_codes})
74
+
75
+ def forward(self, hidden_states: torch.Tensor, langs: torch.LongTensor) -> torch.Tensor:
76
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
77
+
78
+ final_hidden_states = torch.zeros(
79
+ (batch_size, sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
80
+ )
81
+
82
+ # Weight experts based on how many languages in the input
83
+ routing_weights = 1 / ((langs > 3).sum(axis=-1))
84
+ # Set weights to 1 if zero experts activated
85
+ routing_weights[torch.isinf(routing_weights)] = 1
86
+
87
+ unique_langs = langs.unique(dim=None, sorted=True)
88
+ unique_langs = unique_langs[unique_langs > 3] # Remove start token
89
+
90
+ # Loop over all available experts in the model and perform the computation on each expert
91
+ for expert_lang in unique_langs:
92
+ # Check which samples match with this expert
93
+ lang_match = (langs == expert_lang).any(dim=-1)
94
+ idx = torch.nonzero(lang_match, as_tuple=True)[0]
95
+
96
+ if idx.shape[0] == 0:
97
+ continue
98
+
99
+ expert_layer = self.experts[str(expert_lang.item())]
100
+
101
+ current_state = hidden_states[idx]
102
+ current_hidden_states = expert_layer(current_state.view(-1, hidden_dim))
103
+ current_hidden_states = current_hidden_states.view(-1, sequence_length, hidden_dim)
104
+
105
+ # Weight by number of languages in the input
106
+ selected_routing_weights = routing_weights[idx].view(-1, 1, 1)
107
+ current_hidden_states *= selected_routing_weights
108
+
109
+ final_hidden_states.index_add_(0, idx, current_hidden_states)
110
+
111
+ return final_hidden_states
112
+
113
+
114
+ class MBartGQAttention(nn.Module):
115
+ def __init__(
116
+ self,
117
+ embed_dim: int,
118
+ num_heads: int,
119
+ num_kv_heads: int,
120
+ dropout: float = 0.0,
121
+ is_decoder: bool = False,
122
+ bias: bool = True,
123
+ is_causal: bool = False,
124
+ config: Optional[MBartConfig] = None,
125
+ ):
126
+ super().__init__()
127
+ self.embed_dim = embed_dim
128
+ self.num_heads = num_heads
129
+ self.num_kv_heads = num_kv_heads
130
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
131
+
132
+ self.dropout = dropout
133
+ self.head_dim = embed_dim // num_heads
134
+ self.config = config
135
+ self.scaling = self.head_dim**-0.5
136
+ self.is_decoder = is_decoder
137
+ self.is_causal = is_causal
138
+
139
+ self.k_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias)
140
+ self.v_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias)
141
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
142
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
143
+
144
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
145
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
146
+
147
+ def _shape_key_value(self, tensor: torch.Tensor, seq_len: int, bsz: int):
148
+ return tensor.view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2).contiguous()
149
+
150
+ def forward(
151
+ self,
152
+ hidden_states: torch.Tensor,
153
+ key_value_states: Optional[torch.Tensor] = None,
154
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
155
+ is_prefill: Optional[bool] = False,
156
+ attention_mask: Optional[torch.Tensor] = None,
157
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
158
+ """Input shape: Batch x Time x Channel"""
159
+
160
+ # if key_value_states are provided this layer is used as a cross-attention layer
161
+ # for the decoder
162
+ is_cross_attention = key_value_states is not None
163
+
164
+ bsz, tgt_len, _ = hidden_states.size()
165
+
166
+ # get query proj
167
+ query_states = self.q_proj(hidden_states) * self.scaling
168
+ # get key, value proj
169
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
170
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
171
+ # the provided `key_value_states` to support prefix tuning
172
+ if is_cross_attention:
173
+ if is_prefill:
174
+ # cross_attentions
175
+ key_states = self._shape_key_value(self.k_proj(key_value_states), -1, bsz)
176
+ value_states = self._shape_key_value(self.v_proj(key_value_states), -1, bsz)
177
+ past_key_value = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0)
178
+ else:
179
+ # reuse k,v, cross_attentions
180
+ key_states = past_key_value[0]
181
+ value_states = past_key_value[1]
182
+ past_key_value = None
183
+ # Self-attention
184
+ else:
185
+ if is_prefill:
186
+ # initial prompt
187
+ key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz)
188
+ value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz)
189
+ past_key_value = torch.cat([key_states[:, :, -tgt_len:].unsqueeze(0), value_states[:, :, -tgt_len:].unsqueeze(0)], dim=0)
190
+ else:
191
+ # reuse k, v, self_attention
192
+ key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz)
193
+ value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz)
194
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
195
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
196
+ past_key_value = torch.cat([key_states[:, :, -tgt_len:].unsqueeze(0), value_states[:, :, -tgt_len:].unsqueeze(0)], dim=0)
197
+
198
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
199
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
200
+
201
+ # Expand kv heads, then match query shape
202
+ key_states = key_states.repeat_interleave(self.num_kv_groups, dim=1).reshape(*proj_shape)
203
+ value_states = value_states.repeat_interleave(self.num_kv_groups, dim=1).reshape(*proj_shape)
204
+
205
+ src_len = key_states.size(1)
206
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
207
+
208
+ if not is_cross_attention:
209
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
210
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
211
+
212
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
213
+
214
+ attn_output = torch.bmm(attn_weights, value_states).view(bsz, self.num_heads, tgt_len, self.head_dim).transpose(1,2)
215
+
216
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
217
+ # partitioned across GPUs when using tensor-parallelism.
218
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
219
+ attn_output = self.out_proj(attn_output)
220
+
221
+ return attn_output, past_key_value
222
+
223
+
224
+ class MBartMoEDecoderLayer(nn.Module):
225
+ def __init__(self, config: MBartConfig, has_moe=False):
226
+ super().__init__()
227
+ self.embed_dim = config.d_model
228
+
229
+ self.self_attn = MBartGQAttention(
230
+ embed_dim=self.embed_dim,
231
+ num_heads=config.decoder_attention_heads,
232
+ num_kv_heads=config.kv_heads,
233
+ dropout=config.attention_dropout,
234
+ is_decoder=True,
235
+ is_causal=True,
236
+ config=config,
237
+ )
238
+ self.dropout = config.dropout
239
+ self.activation_fn = ACT2FN[config.activation_function]
240
+ self.activation_dropout = config.activation_dropout
241
+
242
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
243
+ self.encoder_attn = MBartGQAttention(
244
+ self.embed_dim,
245
+ config.decoder_attention_heads,
246
+ num_kv_heads=config.kv_heads,
247
+ dropout=config.attention_dropout,
248
+ is_decoder=True,
249
+ config=config,
250
+ )
251
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
252
+ self.has_moe = has_moe
253
+ if has_moe:
254
+ self.moe = MBartExpertLayer(config)
255
+ else:
256
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
257
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
258
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
259
+
260
+ def forward(
261
+ self,
262
+ hidden_states: torch.Tensor,
263
+ attention_mask: Optional[torch.Tensor] = None,
264
+ langs: Optional[torch.LongTensor] = None,
265
+ self_kv_cache: Optional[torch.Tensor] = None,
266
+ cross_kv_cache: Optional[torch.Tensor] = None,
267
+ is_prefill: Optional[bool] = False,
268
+ encoder_hidden_states: Optional[torch.Tensor] = None,
269
+ encoder_attention_mask: Optional[torch.Tensor] = None,
270
+ use_cache: Optional[bool] = True,
271
+ ) -> torch.Tensor:
272
+ residual = hidden_states
273
+ hidden_states = self.self_attn_layer_norm(hidden_states)
274
+
275
+ # Self Attention
276
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
277
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
278
+ hidden_states, present_key_value = self.self_attn(
279
+ hidden_states=hidden_states,
280
+ past_key_value=self_kv_cache,
281
+ is_prefill=is_prefill,
282
+ attention_mask=attention_mask,
283
+ )
284
+ hidden_states = residual + hidden_states
285
+
286
+ # Cross-Attention Block
287
+ if encoder_hidden_states is not None:
288
+ residual = hidden_states
289
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
290
+
291
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
292
+ hidden_states, cross_attn_present_key_value = self.encoder_attn(
293
+ hidden_states=hidden_states,
294
+ key_value_states=encoder_hidden_states,
295
+ is_prefill=is_prefill,
296
+ attention_mask=encoder_attention_mask,
297
+ past_key_value=cross_kv_cache,
298
+ )
299
+ hidden_states = residual + hidden_states
300
+
301
+ # add cross-attn to positions 3,4 of present_key_value tuple
302
+ present_key_value = (present_key_value, cross_attn_present_key_value)
303
+
304
+ # Fully Connected
305
+ residual = hidden_states
306
+ hidden_states = self.final_layer_norm(hidden_states)
307
+ if self.has_moe:
308
+ hidden_states = self.moe(hidden_states, langs)
309
+ else:
310
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
311
+ hidden_states = self.fc2(hidden_states)
312
+
313
+ hidden_states = residual + hidden_states
314
+
315
+ outputs = (hidden_states,)
316
+
317
+ if use_cache:
318
+ outputs += (present_key_value,)
319
+
320
+ return outputs
321
+
322
+
323
+ class MBartMoEDecoder(MBartDecoder):
324
+ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
325
+ MBartPreTrainedModel.__init__(self, config)
326
+ self.dropout = config.dropout
327
+ self.layerdrop = config.decoder_layerdrop
328
+ self.padding_idx = config.pad_token_id
329
+ self.max_target_positions = config.max_position_embeddings
330
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
331
+
332
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
333
+
334
+ if embed_tokens is not None:
335
+ self.embed_tokens.weight = embed_tokens.weight
336
+
337
+ self.embed_positions = MBartLearnedPositionalEmbedding(
338
+ config.max_position_embeddings,
339
+ config.d_model,
340
+ )
341
+ # Language-specific MoE goes at second and second-to-last layer
342
+ self.layers = nn.ModuleList([MBartMoEDecoderLayer(config, has_moe=(i in config.moe_layers) and config.use_moe) for i in range(config.decoder_layers)])
343
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
344
+ self.layer_norm = nn.LayerNorm(config.d_model)
345
+
346
+ self.gradient_checkpointing = False
347
+ # Initialize weights and apply final processing
348
+ self.post_init()
349
+
350
+ def forward(
351
+ self,
352
+ input_ids: torch.LongTensor = None,
353
+ attention_mask: Optional[torch.Tensor] = None,
354
+ self_kv_cache: Optional[torch.Tensor] = None,
355
+ cross_kv_cache: Optional[torch.Tensor] = None,
356
+ past_token_count: Optional[int] = None,
357
+ langs: Optional[torch.LongTensor] = None,
358
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
359
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
360
+ use_cache = True
361
+ return_dict = True
362
+
363
+ input = input_ids
364
+ input_shape = input.size()
365
+ input_ids = input_ids.view(-1, input_shape[-1])
366
+
367
+ # past_key_values_length
368
+ past_key_values_length = past_token_count if self_kv_cache is not None else 0
369
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
370
+
371
+ # embed positions
372
+ positions = self.embed_positions(input, past_key_values_length)
373
+
374
+ hidden_states = inputs_embeds + positions
375
+ hidden_states = self.layernorm_embedding(hidden_states)
376
+
377
+ # decoder layers
378
+ all_hidden_states = None
379
+ all_self_attns = None
380
+ all_cross_attentions = None
381
+ next_decoder_cache = () if use_cache else None
382
+
383
+ for idx, decoder_layer in enumerate(self.layers):
384
+ is_prefill = past_token_count == 0
385
+ layer_self_kv_cache = self_kv_cache[idx] if self_kv_cache is not None else None
386
+ layer_cross_kv_cache = cross_kv_cache[idx] if cross_kv_cache is not None else None
387
+ layer_outputs = decoder_layer(
388
+ hidden_states,
389
+ attention_mask=attention_mask,
390
+ langs=langs,
391
+ self_kv_cache=layer_self_kv_cache,
392
+ cross_kv_cache=layer_cross_kv_cache,
393
+ is_prefill=is_prefill,
394
+ encoder_hidden_states=encoder_hidden_states,
395
+ encoder_attention_mask=None,
396
+ use_cache=use_cache,
397
+ )
398
+ hidden_states = layer_outputs[0]
399
+
400
+ if use_cache:
401
+ next_decoder_cache += (layer_outputs[1],)
402
+
403
+ hidden_states = self.layer_norm(hidden_states)
404
+
405
+ next_cache = next_decoder_cache if use_cache else None
406
+ if not return_dict:
407
+ return tuple(
408
+ v
409
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
410
+ if v is not None
411
+ )
412
+ return BaseModelOutputWithPastAndCrossAttentions(
413
+ last_hidden_state=hidden_states,
414
+ past_key_values=next_cache,
415
+ hidden_states=all_hidden_states,
416
+ attentions=all_self_attns,
417
+ cross_attentions=all_cross_attentions,
418
+ )
419
+
420
+
421
+ class MBartMoEDecoderWrapper(MBartPreTrainedModel):
422
+ """
423
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
424
+ used in combination with the [`EncoderDecoderModel`] framework.
425
+ """
426
+
427
+ def __init__(self, config):
428
+ super().__init__(config)
429
+ self.decoder = MBartMoEDecoder(config)
430
+
431
+ def forward(self, *args, **kwargs):
432
+ return self.decoder(*args, **kwargs)
433
+
434
+
435
+ class MBartMoE(MBartForCausalLM):
436
+ config_class = MBartMoEConfig
437
+ _tied_weights_keys = ["lm_head.weight"]
438
+
439
+ def __init__(self, config, **kwargs):
440
+ config = copy.deepcopy(config)
441
+ config.is_decoder = True
442
+ config.is_encoder_decoder = False
443
+ MBartPreTrainedModel.__init__(self, config)
444
+ self.model = MBartMoEDecoderWrapper(config)
445
+
446
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
447
+
448
+ # Initialize weights and apply final processing
449
+ self.post_init()
450
+
451
+ def forward(
452
+ self,
453
+ input_ids: torch.LongTensor = None,
454
+ attention_mask: Optional[torch.Tensor] = None,
455
+ self_kv_cache: Optional[torch.FloatTensor] = None,
456
+ cross_kv_cache: Optional[torch.FloatTensor] = None,
457
+ past_token_count: Optional[int] = None,
458
+ langs: Optional[torch.LongTensor] = None,
459
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
460
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
461
+ head_mask: Optional[torch.Tensor] = None,
462
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
463
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
464
+ inputs_embeds: Optional[torch.FloatTensor] = None,
465
+ labels: Optional[torch.LongTensor] = None,
466
+ use_cache: Optional[bool] = None,
467
+ output_attentions: Optional[bool] = None,
468
+ output_hidden_states: Optional[bool] = None,
469
+ return_dict: Optional[bool] = None,
470
+ **kwargs
471
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
472
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
473
+
474
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
475
+ outputs = self.model.decoder(
476
+ input_ids=input_ids,
477
+ attention_mask=attention_mask,
478
+ self_kv_cache=self_kv_cache,
479
+ cross_kv_cache=cross_kv_cache,
480
+ past_token_count=past_token_count,
481
+ langs=langs,
482
+ encoder_hidden_states=encoder_hidden_states,
483
+ )
484
+
485
+ logits = self.lm_head(outputs[0])
486
+
487
+ if not return_dict:
488
+ output = (logits,) + outputs[1:]
489
+ return output
490
+
491
+ return CausalLMOutputWithCrossAttentions(
492
+ loss=None,
493
+ logits=logits,
494
+ past_key_values=outputs.past_key_values,
495
+ hidden_states=outputs.hidden_states,
496
+ attentions=outputs.attentions,
497
+ cross_attentions=outputs.cross_attentions,
498
+ )
499
+
500
+ def prune_moe_experts(self, keep_keys: List[int]):
501
+ # Remove experts not specified in keep_keys
502
+ str_keep_keys = [str(key) for key in keep_keys]
503
+ for layer in self.model.decoder.layers:
504
+ if not layer.has_moe:
505
+ continue
506
+
507
+ lang_keys = list(layer.moe.experts.keys())
508
+ for lang in lang_keys:
509
+ if lang not in str_keep_keys:
510
+ layer.moe.experts.pop(lang)
511
+ layer.lang_codes = keep_keys
surya/model/recognition/encoder.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ from typing import Optional, Tuple, Union
4
+
5
+ from transformers.models.donut.modeling_donut_swin import DonutSwinPatchEmbeddings, DonutSwinEmbeddings, DonutSwinModel, \
6
+ DonutSwinEncoder, DonutSwinModelOutput, DonutSwinEncoderOutput, DonutSwinAttention, DonutSwinDropPath, \
7
+ DonutSwinIntermediate, DonutSwinOutput, window_partition, window_reverse
8
+
9
+ # from config import VariableDonutSwinConfig
10
+
11
+ from .config import VariableDonutSwinConfig
12
+
13
+
14
+ class VariableDonutSwinEmbeddings(DonutSwinEmbeddings):
15
+ """
16
+ Construct the patch and position embeddings. Optionally, also the mask token.
17
+ """
18
+
19
+ def __init__(self, config, use_mask_token=False):
20
+ super().__init__(config, use_mask_token)
21
+
22
+ self.patch_embeddings = DonutSwinPatchEmbeddings(config)
23
+ num_patches = self.patch_embeddings.num_patches
24
+ self.patch_grid = self.patch_embeddings.grid_size
25
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
26
+ self.position_embeddings = None
27
+
28
+ if config.use_absolute_embeddings:
29
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
30
+
31
+ self.norm = nn.LayerNorm(config.embed_dim)
32
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
33
+
34
+ def forward(
35
+ self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
36
+ ) -> Tuple[torch.Tensor]:
37
+
38
+ embeddings, output_dimensions = self.patch_embeddings(pixel_values)
39
+ # Layernorm across the last dimension (each patch is a single row)
40
+ embeddings = self.norm(embeddings)
41
+ batch_size, seq_len, embed_dim = embeddings.size()
42
+
43
+ if bool_masked_pos is not None:
44
+ mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
45
+ # replace the masked visual tokens by mask_tokens
46
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
47
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
48
+
49
+ if self.position_embeddings is not None:
50
+ embeddings = embeddings + self.position_embeddings[:, :seq_len, :]
51
+
52
+ embeddings = self.dropout(embeddings)
53
+
54
+ return embeddings, output_dimensions
55
+
56
+
57
+ class VariableDonutSwinPatchMerging(nn.Module):
58
+ """
59
+ Patch Merging Layer.
60
+
61
+ Args:
62
+ input_resolution (`Tuple[int]`):
63
+ Resolution of input feature.
64
+ dim (`int`):
65
+ Number of input channels.
66
+ norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
67
+ Normalization layer class.
68
+ """
69
+
70
+ def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
71
+ super().__init__()
72
+ self.input_resolution = input_resolution
73
+ self.dim = dim
74
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
75
+ self.norm = norm_layer(4 * dim)
76
+
77
+ def maybe_pad(self, input_feature, height, width):
78
+ should_pad = (height % 2 == 1) or (width % 2 == 1)
79
+ if should_pad:
80
+ pad_values = (0, 0, 0, width % 2, 0, height % 2)
81
+ input_feature = nn.functional.pad(input_feature, pad_values)
82
+
83
+ return input_feature
84
+
85
+ def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
86
+ height, width = input_dimensions
87
+ # `dim` is height * width
88
+ batch_size, dim, num_channels = input_feature.shape
89
+
90
+ input_feature = input_feature.view(batch_size, height, width, num_channels)
91
+ # pad input to be disible by width and height, if needed
92
+ input_feature = self.maybe_pad(input_feature, height, width)
93
+ # [batch_size, height/2, width/2, num_channels]
94
+ input_feature_0 = input_feature[:, 0::2, 0::2, :]
95
+ # [batch_size, height/2, width/2, num_channels]
96
+ input_feature_1 = input_feature[:, 1::2, 0::2, :]
97
+ # [batch_size, height/2, width/2, num_channels]
98
+ input_feature_2 = input_feature[:, 0::2, 1::2, :]
99
+ # [batch_size, height/2, width/2, num_channels]
100
+ input_feature_3 = input_feature[:, 1::2, 1::2, :]
101
+ # batch_size height/2 width/2 4*num_channels
102
+ input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
103
+ input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
104
+
105
+ input_feature = self.norm(input_feature)
106
+ input_feature = self.reduction(input_feature)
107
+
108
+ return input_feature
109
+
110
+
111
+ class VariableDonutSwinLayer(nn.Module):
112
+ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
113
+ super().__init__()
114
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
115
+ self.shift_size = shift_size
116
+ self.window_size = config.window_size
117
+ self.input_resolution = input_resolution
118
+ self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
119
+ self.attention = DonutSwinAttention(config, dim, num_heads, window_size=self.window_size)
120
+ self.drop_path = DonutSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
121
+ self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
122
+ self.intermediate = DonutSwinIntermediate(config, dim)
123
+ self.output = DonutSwinOutput(config, dim)
124
+
125
+ def set_shift_and_window_size(self, input_resolution):
126
+ if min(input_resolution) <= self.window_size:
127
+ # if window size is larger than input resolution, we don't partition windows
128
+ self.shift_size = 0
129
+ self.window_size = min(input_resolution)
130
+
131
+ def get_attn_mask(self, height, width, dtype):
132
+ if self.shift_size > 0:
133
+ # calculate attention mask for SW-MSA
134
+ img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
135
+ height_slices = (
136
+ slice(0, -self.window_size),
137
+ slice(-self.window_size, -self.shift_size),
138
+ slice(-self.shift_size, None),
139
+ )
140
+ width_slices = (
141
+ slice(0, -self.window_size),
142
+ slice(-self.window_size, -self.shift_size),
143
+ slice(-self.shift_size, None),
144
+ )
145
+ count = 0
146
+ for height_slice in height_slices:
147
+ for width_slice in width_slices:
148
+ img_mask[:, height_slice, width_slice, :] = count
149
+ count += 1
150
+
151
+ mask_windows = window_partition(img_mask, self.window_size)
152
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
153
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
154
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
155
+ else:
156
+ attn_mask = None
157
+ return attn_mask
158
+
159
+ def maybe_pad(self, hidden_states, height, width):
160
+ pad_right = (self.window_size - width % self.window_size) % self.window_size
161
+ pad_bottom = (self.window_size - height % self.window_size) % self.window_size
162
+ pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
163
+ hidden_states = nn.functional.pad(hidden_states, pad_values)
164
+ return hidden_states, pad_values
165
+
166
+ def forward(
167
+ self,
168
+ hidden_states: torch.Tensor,
169
+ input_dimensions: Tuple[int, int],
170
+ head_mask: Optional[torch.FloatTensor] = None,
171
+ output_attentions: Optional[bool] = False,
172
+ always_partition: Optional[bool] = False,
173
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
174
+ if not always_partition:
175
+ self.set_shift_and_window_size(input_dimensions)
176
+ else:
177
+ pass
178
+ height, width = input_dimensions
179
+ batch_size, _, channels = hidden_states.size()
180
+ shortcut = hidden_states
181
+
182
+ hidden_states = self.layernorm_before(hidden_states)
183
+
184
+ hidden_states = hidden_states.view(batch_size, height, width, channels)
185
+
186
+ # pad hidden_states to multiples of window size
187
+ hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
188
+
189
+ _, height_pad, width_pad, _ = hidden_states.shape
190
+ # cyclic shift
191
+ if self.shift_size > 0:
192
+ shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
193
+ else:
194
+ shifted_hidden_states = hidden_states
195
+
196
+ # partition windows
197
+ hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
198
+ hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
199
+ attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
200
+ if attn_mask is not None:
201
+ attn_mask = attn_mask.to(hidden_states_windows.device)
202
+
203
+ attention_outputs = self.attention(
204
+ hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
205
+ )
206
+
207
+ attention_output = attention_outputs[0]
208
+
209
+ attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
210
+ shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
211
+
212
+ # reverse cyclic shift
213
+ if self.shift_size > 0:
214
+ attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
215
+ else:
216
+ attention_windows = shifted_windows
217
+
218
+ was_padded = pad_values[3] > 0 or pad_values[5] > 0
219
+ if was_padded:
220
+ attention_windows = attention_windows[:, :height, :width, :].contiguous()
221
+
222
+ attention_windows = attention_windows.view(batch_size, height * width, channels)
223
+
224
+ hidden_states = shortcut + self.drop_path(attention_windows)
225
+
226
+ layer_output = self.layernorm_after(hidden_states)
227
+ layer_output = self.intermediate(layer_output)
228
+ layer_output = hidden_states + self.output(layer_output)
229
+
230
+ layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
231
+ return layer_outputs
232
+
233
+
234
+ class VariableDonutSwinStage(nn.Module):
235
+ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
236
+ super().__init__()
237
+ self.config = config
238
+ self.dim = dim
239
+ self.blocks = nn.ModuleList(
240
+ [
241
+ VariableDonutSwinLayer(
242
+ config=config,
243
+ dim=dim,
244
+ input_resolution=input_resolution,
245
+ num_heads=num_heads,
246
+ shift_size=0 if (i % 2 == 0) else int(config.window_size // 2),
247
+ )
248
+ for i in range(depth)
249
+ ]
250
+ )
251
+
252
+ # patch merging layer
253
+ if downsample is not None:
254
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
255
+ else:
256
+ self.downsample = None
257
+
258
+ self.pointing = False
259
+
260
+ def forward(
261
+ self,
262
+ hidden_states: torch.Tensor,
263
+ input_dimensions: Tuple[int, int],
264
+ head_mask: Optional[torch.FloatTensor] = None,
265
+ output_attentions: Optional[bool] = False,
266
+ always_partition: Optional[bool] = False,
267
+ ) -> Tuple[torch.Tensor]:
268
+ height, width = input_dimensions
269
+ for i, layer_module in enumerate(self.blocks):
270
+ layer_head_mask = head_mask[i] if head_mask is not None else None
271
+
272
+ layer_outputs = layer_module(
273
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
274
+ )
275
+
276
+ hidden_states = layer_outputs[0]
277
+
278
+ hidden_states_before_downsampling = hidden_states
279
+ if self.downsample is not None:
280
+ height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
281
+ output_dimensions = (height, width, height_downsampled, width_downsampled)
282
+ hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
283
+ else:
284
+ output_dimensions = (height, width, height, width)
285
+
286
+ stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
287
+
288
+ if output_attentions:
289
+ stage_outputs += layer_outputs[1:]
290
+ return stage_outputs
291
+
292
+
293
+ class VariableDonutSwinEncoder(nn.Module):
294
+ def __init__(self, config, grid_size):
295
+ super().__init__()
296
+ self.num_layers = len(config.depths)
297
+ self.config = config
298
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
299
+ self.layers = nn.ModuleList(
300
+ [
301
+ VariableDonutSwinStage(
302
+ config=config,
303
+ dim=int(config.embed_dim * 2**i_layer),
304
+ input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
305
+ depth=config.depths[i_layer],
306
+ num_heads=config.num_heads[i_layer],
307
+ drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
308
+ downsample=VariableDonutSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
309
+ )
310
+ for i_layer in range(self.num_layers)
311
+ ]
312
+ )
313
+
314
+ self.gradient_checkpointing = False
315
+
316
+ def forward(
317
+ self,
318
+ hidden_states: torch.Tensor,
319
+ input_dimensions: Tuple[int, int],
320
+ head_mask: Optional[torch.FloatTensor] = None,
321
+ output_attentions: Optional[bool] = False,
322
+ output_hidden_states: Optional[bool] = False,
323
+ output_hidden_states_before_downsampling: Optional[bool] = False,
324
+ always_partition: Optional[bool] = False,
325
+ return_dict: Optional[bool] = True,
326
+ ) -> Union[Tuple, DonutSwinEncoderOutput]:
327
+ all_hidden_states = () if output_hidden_states else None
328
+ all_reshaped_hidden_states = () if output_hidden_states else None
329
+ all_self_attentions = () if output_attentions else None
330
+
331
+ if output_hidden_states:
332
+ batch_size, _, hidden_size = hidden_states.shape
333
+ # rearrange b (h w) c -> b c h w
334
+ reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
335
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
336
+ all_hidden_states += (hidden_states,)
337
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
338
+
339
+ for i, layer_module in enumerate(self.layers):
340
+ layer_head_mask = head_mask[i] if head_mask is not None else None
341
+
342
+ if self.gradient_checkpointing and self.training:
343
+ layer_outputs = self._gradient_checkpointing_func(
344
+ layer_module.__call__,
345
+ hidden_states,
346
+ input_dimensions,
347
+ layer_head_mask,
348
+ output_attentions,
349
+ always_partition,
350
+ )
351
+ else:
352
+ layer_outputs = layer_module(
353
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
354
+ )
355
+
356
+ hidden_states = layer_outputs[0]
357
+ hidden_states_before_downsampling = layer_outputs[1]
358
+ output_dimensions = layer_outputs[2]
359
+
360
+ input_dimensions = (output_dimensions[-2], output_dimensions[-1])
361
+
362
+ if output_hidden_states and output_hidden_states_before_downsampling:
363
+ batch_size, _, hidden_size = hidden_states_before_downsampling.shape
364
+ # rearrange b (h w) c -> b c h w
365
+ # here we use the original (not downsampled) height and width
366
+ reshaped_hidden_state = hidden_states_before_downsampling.view(
367
+ batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
368
+ )
369
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
370
+ all_hidden_states += (hidden_states_before_downsampling,)
371
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
372
+ elif output_hidden_states and not output_hidden_states_before_downsampling:
373
+ batch_size, _, hidden_size = hidden_states.shape
374
+ # rearrange b (h w) c -> b c h w
375
+ reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
376
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
377
+ all_hidden_states += (hidden_states,)
378
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
379
+
380
+ if output_attentions:
381
+ all_self_attentions += layer_outputs[3:]
382
+
383
+ if not return_dict:
384
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
385
+
386
+ return DonutSwinEncoderOutput(
387
+ last_hidden_state=hidden_states,
388
+ hidden_states=all_hidden_states,
389
+ attentions=all_self_attentions,
390
+ reshaped_hidden_states=all_reshaped_hidden_states,
391
+ )
392
+
393
+
394
+ class VariableDonutSwinModel(DonutSwinModel):
395
+ config_class = VariableDonutSwinConfig
396
+ def __init__(self, config, add_pooling_layer=True, use_mask_token=False, **kwargs):
397
+ super().__init__(config)
398
+ self.config = config
399
+ self.num_layers = len(config.depths)
400
+ self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
401
+
402
+ self.embeddings = VariableDonutSwinEmbeddings(config, use_mask_token=use_mask_token)
403
+ self.encoder = VariableDonutSwinEncoder(config, self.embeddings.patch_grid)
404
+
405
+ self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
406
+
407
+ # Initialize weights and apply final processing
408
+ self.post_init()
409
+
410
+ def forward(
411
+ self,
412
+ pixel_values: Optional[torch.FloatTensor] = None,
413
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
414
+ head_mask: Optional[torch.FloatTensor] = None,
415
+ output_attentions: Optional[bool] = None,
416
+ output_hidden_states: Optional[bool] = None,
417
+ return_dict: Optional[bool] = None,
418
+ **kwargs
419
+ ) -> Union[Tuple, DonutSwinModelOutput]:
420
+ r"""
421
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
422
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
423
+ """
424
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
425
+ output_hidden_states = (
426
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
427
+ )
428
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
429
+
430
+ if pixel_values is None:
431
+ raise ValueError("You have to specify pixel_values")
432
+
433
+ # Prepare head mask if needed
434
+ # 1.0 in head_mask indicate we keep the head
435
+ # attention_probs has shape bsz x n_heads x N x N
436
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
437
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
438
+ head_mask = self.get_head_mask(head_mask, len(self.config.depths))
439
+
440
+ embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
441
+
442
+ encoder_outputs = self.encoder(
443
+ embedding_output,
444
+ input_dimensions,
445
+ head_mask=head_mask,
446
+ output_attentions=output_attentions,
447
+ output_hidden_states=output_hidden_states,
448
+ return_dict=return_dict,
449
+ )
450
+
451
+ sequence_output = encoder_outputs[0]
452
+
453
+ pooled_output = None
454
+ if self.pooler is not None:
455
+ pooled_output = self.pooler(sequence_output.transpose(1, 2))
456
+ pooled_output = torch.flatten(pooled_output, 1)
457
+
458
+ if not return_dict:
459
+ output = (sequence_output, pooled_output) + encoder_outputs[1:]
460
+
461
+ return output
462
+
463
+ return DonutSwinModelOutput(
464
+ last_hidden_state=sequence_output,
465
+ pooler_output=pooled_output,
466
+ hidden_states=encoder_outputs.hidden_states,
467
+ attentions=encoder_outputs.attentions,
468
+ reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
469
+ )
surya/model/recognition/model.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+
5
+ warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated")
6
+
7
+ import logging
8
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
9
+
10
+ from typing import List, Optional, Tuple
11
+ from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, AutoModel, AutoModelForCausalLM
12
+ from surya.model.recognition.config import MBartMoEConfig, VariableDonutSwinConfig
13
+ from surya.model.recognition.encoder import VariableDonutSwinModel
14
+ from surya.model.recognition.decoder import MBartMoE
15
+ from surya.settings import settings
16
+
17
+
18
+ def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE, langs: Optional[List[int]] = None):
19
+ config = VisionEncoderDecoderConfig.from_pretrained(checkpoint)
20
+
21
+ # Prune moe experts that are not needed before loading the model
22
+ if langs is not None:
23
+ config.decoder.langs = {lang_iso : lang_int for lang_iso, lang_int in config.decoder.langs.items() if lang_int in langs}
24
+
25
+ decoder_config = vars(config.decoder)
26
+ decoder = MBartMoEConfig(**decoder_config)
27
+ config.decoder = decoder
28
+
29
+ encoder_config = vars(config.encoder)
30
+ encoder = VariableDonutSwinConfig(**encoder_config)
31
+ config.encoder = encoder
32
+
33
+ # Get transformers to load custom encoder/decoder
34
+ AutoModel.register(MBartMoEConfig, MBartMoE)
35
+ AutoModelForCausalLM.register(MBartMoEConfig, MBartMoE)
36
+ AutoModel.register(VariableDonutSwinConfig, VariableDonutSwinModel)
37
+
38
+ model = LangVisionEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype)
39
+ assert isinstance(model.decoder, MBartMoE)
40
+ assert isinstance(model.encoder, VariableDonutSwinModel)
41
+
42
+ model = model.to(device)
43
+ model = model.eval()
44
+ print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}")
45
+ return model
46
+
47
+
48
+ class LangVisionEncoderDecoderModel(VisionEncoderDecoderModel):
49
+ def prepare_inputs_for_generation(
50
+ self, input_ids, decoder_langs=None, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
51
+ ):
52
+ decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, langs=decoder_langs, past_key_values=past_key_values)
53
+ decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
54
+ input_dict = {
55
+ "attention_mask": attention_mask,
56
+ "decoder_attention_mask": decoder_attention_mask,
57
+ "decoder_input_ids": decoder_inputs["input_ids"],
58
+ "encoder_outputs": encoder_outputs,
59
+ "past_key_values": decoder_inputs["past_key_values"],
60
+ "use_cache": use_cache,
61
+ "decoder_langs": decoder_inputs["langs"],
62
+ }
63
+ return input_dict
64
+