diff --git a/yolo-world-with-efficientvit-sam/.DS_Store b/yolo-world-with-efficientvit-sam/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..8243cd1ca3fae5a918b263d12156dbfbcbfbbf96
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/.DS_Store differ
diff --git a/yolo-world-with-efficientvit-sam/.gitignore b/yolo-world-with-efficientvit-sam/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..ede81d51393c8c004388841a8015bca47f367f10
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/.gitignore
@@ -0,0 +1,171 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+
+# Model Weights
+*.pth
+*.pt
+
+# Yolo-World
+work_dirs
+src
+
+# Etc
+.DS_Store
diff --git a/yolo-world-with-efficientvit-sam/LICENSE b/yolo-world-with-efficientvit-sam/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..ba4d8aa9ea77437c367794bdef61e11712b6d28f
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/yolo-world-with-efficientvit-sam/Makefile b/yolo-world-with-efficientvit-sam/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..a3969b9b36b6cb6d6e2e4af86b19f5d6bbf26717
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/Makefile
@@ -0,0 +1,18 @@
+EFFICIENTVIT_SAM_URL := "https://huggingface.co/han-cai/efficientvit-sam/resolve/main"
+EFFICIENTVIT_SAM_MODEL := "xl1.pt"
+
+
+define download
+ @if [ ! -f $(2) ]; then \
+ echo "Download $(2)..."; \
+ wget "$(1)/$(2)"; \
+ fi
+endef
+
+
+setup:
+ pip install -r requirements.txt
+
+
+model:
+ $(call download,$(EFFICIENTVIT_SAM_URL),$(EFFICIENTVIT_SAM_MODEL))
diff --git a/yolo-world-with-efficientvit-sam/README.md b/yolo-world-with-efficientvit-sam/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..df870d071eee19aeb4de82fe34a8ee05474a4b92
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/README.md
@@ -0,0 +1,68 @@
+# YOLO-World + EfficientViT SAM
+
+🤗 [HuggingFace Space](https://huggingface.co/spaces/curt-park/yolo-world-with-efficientvit-sam)
+
+![example_0](https://github.com/Curt-Park/yolo-world-with-efficientvit-sam/assets/14961526/326bde19-d535-4be5-829e-782fce0c1d00)
+
+## Prerequisites
+This project is developed and tested on Python3.10.
+
+```bash
+# Create and activate a python 3.10 environment.
+conda create -n yolo-world-with-efficientvit-sam python=3.10 -y
+conda activate yolo-world-with-efficientvit-sam
+# Setup packages.
+make setup
+```
+
+## How to Run
+```bash
+python app.py
+```
+
+Open http://127.0.0.1:7860/ on your web browser.
+
+![example_1](https://github.com/Curt-Park/yolo-world-with-efficientvit-sam/assets/14961526/9388e4ee-6f71-4428-b17c-d218fd059949)
+
+## Core Components
+
+### YOLO-World
+[YOLO-World](https://github.com/AILab-CVC/YOLO-World) is an open-vocabulary object detection model with high efficiency.
+On the challenging LVIS dataset, YOLO-World achieves 35.4 AP with 52.0 FPS on V100,
+which outperforms many state-of-the-art methods in terms of both accuracy and speed.
+![image](https://github.com/Curt-Park/yolo-world-with-efficientvit-sam/assets/14961526/8a4a17bd-918d-478a-8451-f58e4a2dce79)
+
+
+### EfficientViT SAM
+[EfficientViT SAM](https://github.com/mit-han-lab/efficientvit) is a new family of accelerated segment anything models.
+Thanks to the lightweight and hardware-efficient core building block,
+it delivers 48.9× measured TensorRT speedup on A100 GPU over SAM-ViT-H without sacrificing performance.
+
+
+
+
+## Powered By
+```
+@misc{zhang2024efficientvitsam,
+ title={EfficientViT-SAM: Accelerated Segment Anything Model Without Performance Loss},
+ author={Zhuoyang Zhang and Han Cai and Song Han},
+ year={2024},
+ eprint={2402.05008},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+
+@article{cheng2024yolow,
+ title={YOLO-World: Real-Time Open-Vocabulary Object Detection},
+ author={Cheng, Tianheng and Song, Lin and Ge, Yixiao and Liu, Wenyu and Wang, Xinggang and Shan, Ying},
+ journal={arXiv preprint arXiv:2401.17270},
+ year={2024}
+}
+
+@article{cai2022efficientvit,
+ title={Efficientvit: Enhanced linear attention for high-resolution low-computation visual recognition},
+ author={Cai, Han and Gan, Chuang and Han, Song},
+ journal={arXiv preprint arXiv:2205.14756},
+ year={2022}
+}
+```
diff --git a/yolo-world-with-efficientvit-sam/app.py b/yolo-world-with-efficientvit-sam/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..47769bad44bd701d889184e9707c5cf99d3ed5ca
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/app.py
@@ -0,0 +1,132 @@
+"""Fast text to segmentation with yolo-world and efficient-vit sam."""
+import os
+
+import cv2
+import gradio as gr
+import numpy as np
+import supervision as sv
+import torch
+from inference.models import YOLOWorld
+
+from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
+from efficientvit.sam_model_zoo import create_sam_model
+
+
+# Download model weights.
+os.system("make model")
+
+# Load models.
+yolo_world = YOLOWorld(model_id="yolo_world/l")
+#yolo_world = YOLOWorld("/Users/tounsi/Desktop/DOCTORIA/Doctoria\ Full\ Software/Doctoria\ CXR/Doctoria\ CXR\ Thoracic\ Abnormalities/YOLOv8/CXR\ YOLOv8l.pt")
+device = "cuda" if torch.cuda.is_available() else "cpu"
+sam = EfficientViTSamPredictor(
+ create_sam_model(name="xl1", weight_url="xl1.pt").to(device).eval()
+)
+
+
+# Load annotators.
+BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
+MASK_ANNOTATOR = sv.MaskAnnotator()
+LABEL_ANNOTATOR = sv.LabelAnnotator()
+
+
+def detect(
+ image: np.ndarray,
+ query: str,
+ confidence_threshold: float,
+ nms_threshold: float,
+) -> np.ndarray:
+ # Preparation.
+ categories = [category.strip() for category in query.split(",")]
+ yolo_world.set_classes(categories)
+ print("categories:", categories)
+
+ # Object detection.
+ results = yolo_world.infer(image, confidence=confidence_threshold)
+ detections = sv.Detections.from_inference(results).with_nms(
+ class_agnostic=True, threshold=nms_threshold
+ )
+ print("detected:", detections)
+
+ # Segmentation.
+ sam.set_image(image, image_format="RGB")
+ masks = []
+ for xyxy in detections.xyxy:
+ mask, _, _ = sam.predict(box=xyxy, multimask_output=False)
+ masks.append(mask.squeeze())
+ detections.mask = np.array(masks)
+ print("masks shaped as", detections.mask.shape)
+
+ # Annotation.
+ output_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+ labels = [
+ f"{categories[class_id]}: {confidence:.2f}"
+ for class_id, confidence in zip(detections.class_id, detections.confidence)
+ ]
+ output_image = MASK_ANNOTATOR.annotate(output_image, detections)
+ output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
+ output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
+ return cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
+
+
+app = gr.Interface(
+ fn=detect,
+ inputs=[
+ gr.Image(type="numpy", label="input image"),
+ gr.Text(info="you can input multiple words with comma (,)"),
+ gr.Slider(
+ minimum=0,
+ maximum=1,
+ value=0.3,
+ step=0.01,
+ interactive=True,
+ label="Confidence Threshold",
+ ),
+ gr.Slider(
+ minimum=0,
+ maximum=1,
+ value=0.5,
+ step=0.01,
+ interactive=True,
+ label="NMS Threshold",
+ ),
+ ],
+ outputs=gr.Image(type="numpy", label="output image"),
+ allow_flagging="never",
+ title="Fast Text to Segmentation with YOLO-World + EfficientViT SAM",
+ description="""
+ ## Core components
+ ### YOLO-World
+ [YOLO-World](https://github.com/AILab-CVC/YOLO-World) is an open-vocabulary object detection model with high efficiency.
+ On the challenging LVIS dataset, YOLO-World achieves 35.4 AP with 52.0 FPS on V100,
+ which outperforms many state-of-the-art methods in terms of both accuracy and speed.
+
+ ### EfficientViT SAM
+ [EfficientViT SAM](https://github.com/mit-han-lab/efficientvit) is a new family of accelerated segment anything models.
+ Thanks to the lightweight and hardware-efficient core building block,
+ it delivers 48.9× measured TensorRT speedup on A100 GPU over SAM-ViT-H without sacrificing performance.
+
+ ## Demo especially powered by
+ Roboflow's [inference](https://github.com/roboflow/inference) and [supervision](https://github.com/roboflow/supervision).
+
+ ## Example images came from
+ [Segment Anything Demo](https://segment-anything.com/demo) and [Unsplash](https://unsplash.com/).
+ """,
+ examples=[
+ [
+ os.path.join(os.path.dirname(__file__), "examples/livingroom.jpg"),
+ "table, lamp, dog, sofa, plant, clock, carpet, frame on the wall",
+ 0.05,
+ 0.5
+ ],
+ [
+ os.path.join(os.path.dirname(__file__), "examples/cat_and_dogs.jpg"),
+ "cat, dog",
+ 0.2,
+ 0.5
+ ],
+ ],
+)
+
+
+app.launch(server_name="0.0.0.0")
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/__init__.py b/yolo-world-with-efficientvit-sam/efficientvit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/__pycache__/__init__.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..625272eb0649e118b3f38f38809ffe46be266b92
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/__pycache__/__init__.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/__pycache__/sam_model_zoo.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/__pycache__/sam_model_zoo.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..44b6993f87c2009936f02fa551f0bbe974832c68
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/__pycache__/sam_model_zoo.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/__init__.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/__pycache__/__init__.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..239f7bcc35d99670331784a9e511837f6543b48f
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/__pycache__/__init__.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__init__.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c9a5dfa34097fdf24730a203a9f24c5c4ac0a74
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__init__.py
@@ -0,0 +1,7 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+from .augment import *
+from .base import *
+from .random_resolution import *
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__pycache__/__init__.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e6d2a857724e869cee97991ab4000a925497090
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__pycache__/__init__.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__pycache__/base.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__pycache__/base.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2a58603f3bc7181260b737b67772759e91c56400
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__pycache__/base.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__init__.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9ea4d65f7f5a471cc433fbd68a58d4853b217d2
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__init__.py
@@ -0,0 +1,6 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+from .bbox import *
+from .color_aug import *
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/__init__.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f84965c9b174db250d9443d9b022509c998a2af0
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/__init__.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/bbox.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/bbox.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9f5522b513fcf82fb55fcecfd5d7f22767c370b
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/bbox.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/color_aug.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/color_aug.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..568188dc259c3936d3f9e43c9be0b652e66267d8
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/color_aug.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/bbox.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/bbox.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9f089a3f70881313a5ce4308d1f74fbf1fa0c31
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/bbox.py
@@ -0,0 +1,30 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import numpy as np
+
+__all__ = ["rand_bbox"]
+
+
+def rand_bbox(
+ h: int,
+ w: int,
+ lam: float,
+ rand_func: callable = np.random.uniform,
+) -> tuple[int, int, int, int]:
+ """randomly sample bbox, used in cutmix"""
+ cut_rat = np.sqrt(1.0 - lam)
+ cut_w = w * cut_rat
+ cut_h = h * cut_rat
+
+ # uniform
+ cx = rand_func(0, w)
+ cy = rand_func(0, h)
+
+ bbx1 = int(np.clip(cx - cut_w / 2, 0, w))
+ bby1 = int(np.clip(cy - cut_h / 2, 0, h))
+ bbx2 = int(np.clip(cx + cut_w / 2, 0, w))
+ bby2 = int(np.clip(cy + cut_h / 2, 0, h))
+
+ return bbx1, bby1, bbx2, bby2
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/color_aug.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/color_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5e1dcc6998374738c300414b06e4fdb2ed8af95
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/color_aug.py
@@ -0,0 +1,84 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import numpy as np
+import torchvision.transforms as transforms
+from PIL import Image
+from timm.data.auto_augment import rand_augment_transform
+
+__all__ = ["ColorAug", "RandAug"]
+
+
+class ImageAug:
+ def aug_image(self, image: Image.Image) -> Image.Image:
+ raise NotImplementedError
+
+ def __call__(
+ self, feed_dict: dict or np.ndarray or Image.Image
+ ) -> dict or np.ndarray or Image.Image:
+ if isinstance(feed_dict, dict):
+ output_dict = feed_dict
+ image = feed_dict[self.key]
+ else:
+ output_dict = None
+ image = feed_dict
+ is_ndarray = isinstance(image, np.ndarray)
+ if is_ndarray:
+ image = Image.fromarray(image)
+
+ image = self.aug_image(image)
+
+ if is_ndarray:
+ image = np.array(image)
+
+ if output_dict is None:
+ return image
+ else:
+ output_dict[self.key] = image
+ return output_dict
+
+
+class ColorAug(transforms.ColorJitter, ImageAug):
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, key="data"):
+ super().__init__(
+ brightness=brightness,
+ contrast=contrast,
+ saturation=saturation,
+ hue=hue,
+ )
+ self.key = key
+
+ def aug_image(self, image: Image.Image) -> Image.Image:
+ return transforms.ColorJitter.forward(self, image)
+
+ def forward(
+ self, feed_dict: dict or np.ndarray or Image.Image
+ ) -> dict or np.ndarray or Image.Image:
+ return ImageAug.__call__(self, feed_dict)
+
+
+class RandAug(ImageAug):
+ def __init__(
+ self, config: dict[str, any], mean: tuple[float, float, float], key="data"
+ ):
+ n = config.get("n", 2)
+ m = config.get("m", 9)
+ mstd = config.get("mstd", 1.0)
+ inc = config.get("inc", 1)
+ tpct = config.get("tpct", 0.45)
+ config_str = f"rand-n{n}-m{m}-mstd{mstd}-inc{inc}"
+
+ aa_params = dict(
+ translate_pct=tpct,
+ img_mean=tuple([min(255, round(255 * x)) for x in mean]),
+ interpolation=Image.BICUBIC,
+ )
+ self.aug_op = rand_augment_transform(config_str, aa_params)
+ self.key = key
+
+ def aug_image(self, image: Image.Image) -> Image.Image:
+ return self.aug_op(image)
+
+ def __repr__(self):
+ return self.aug_op.__repr__()
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/base.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..dea9e6f4c2caee2d78c4921422b387fb787d22bd
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/base.py
@@ -0,0 +1,223 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import copy
+import warnings
+
+import torch.utils.data
+from torch.utils.data.distributed import DistributedSampler
+
+from efficientvit.apps.data_provider.random_resolution import RRSController
+from efficientvit.models.utils import val2tuple
+
+__all__ = ["parse_image_size", "random_drop_data", "DataProvider"]
+
+
+def parse_image_size(size: int or str) -> tuple[int, int]:
+ if isinstance(size, str):
+ size = [int(val) for val in size.split("-")]
+ return size[0], size[1]
+ else:
+ return val2tuple(size, 2)
+
+
+def random_drop_data(dataset, drop_size: int, seed: int, keys=("samples",)):
+ g = torch.Generator()
+ g.manual_seed(seed) # set random seed before sampling validation set
+ rand_indexes = torch.randperm(len(dataset), generator=g).tolist()
+
+ dropped_indexes = rand_indexes[:drop_size]
+ remaining_indexes = rand_indexes[drop_size:]
+
+ dropped_dataset = copy.deepcopy(dataset)
+ for key in keys:
+ setattr(
+ dropped_dataset,
+ key,
+ [getattr(dropped_dataset, key)[idx] for idx in dropped_indexes],
+ )
+ setattr(dataset, key, [getattr(dataset, key)[idx] for idx in remaining_indexes])
+ return dataset, dropped_dataset
+
+
+class DataProvider:
+ data_keys = ("samples",)
+ mean_std = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}
+ SUB_SEED = 937162211 # random seed for sampling subset
+ VALID_SEED = 2147483647 # random seed for the validation set
+
+ name: str
+
+ def __init__(
+ self,
+ train_batch_size: int,
+ test_batch_size: int or None,
+ valid_size: int or float or None,
+ n_worker: int,
+ image_size: int or list[int] or str or list[str],
+ num_replicas: int or None = None,
+ rank: int or None = None,
+ train_ratio: float or None = None,
+ drop_last: bool = False,
+ ):
+ warnings.filterwarnings("ignore")
+ super().__init__()
+
+ # batch_size & valid_size
+ self.train_batch_size = train_batch_size
+ self.test_batch_size = test_batch_size or self.train_batch_size
+ self.valid_size = valid_size
+
+ # image size
+ if isinstance(image_size, list):
+ self.image_size = [parse_image_size(size) for size in image_size]
+ self.image_size.sort() # e.g., 160 -> 224
+ RRSController.IMAGE_SIZE_LIST = copy.deepcopy(self.image_size)
+ self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size[-1]
+ else:
+ self.image_size = parse_image_size(image_size)
+ RRSController.IMAGE_SIZE_LIST = [self.image_size]
+ self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size
+
+ # distributed configs
+ self.num_replicas = num_replicas
+ self.rank = rank
+
+ # build datasets
+ train_dataset, val_dataset, test_dataset = self.build_datasets()
+
+ if train_ratio is not None and train_ratio < 1.0:
+ assert 0 < train_ratio < 1
+ _, train_dataset = random_drop_data(
+ train_dataset,
+ int(train_ratio * len(train_dataset)),
+ self.SUB_SEED,
+ self.data_keys,
+ )
+
+ # build data loader
+ self.train = self.build_dataloader(
+ train_dataset, train_batch_size, n_worker, drop_last=drop_last, train=True
+ )
+ self.valid = self.build_dataloader(
+ val_dataset, test_batch_size, n_worker, drop_last=False, train=False
+ )
+ self.test = self.build_dataloader(
+ test_dataset, test_batch_size, n_worker, drop_last=False, train=False
+ )
+ if self.valid is None:
+ self.valid = self.test
+ self.sub_train = None
+
+ @property
+ def data_shape(self) -> tuple[int, ...]:
+ return 3, self.active_image_size[0], self.active_image_size[1]
+
+ def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any:
+ raise NotImplementedError
+
+ def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any:
+ raise NotImplementedError
+
+ def build_datasets(self) -> tuple[any, any, any]:
+ raise NotImplementedError
+
+ def build_dataloader(
+ self,
+ dataset: any or None,
+ batch_size: int,
+ n_worker: int,
+ drop_last: bool,
+ train: bool,
+ ):
+ if dataset is None:
+ return None
+ if isinstance(self.image_size, list) and train:
+ from efficientvit.apps.data_provider.random_resolution._data_loader import \
+ RRSDataLoader
+
+ dataloader_class = RRSDataLoader
+ else:
+ dataloader_class = torch.utils.data.DataLoader
+ if self.num_replicas is None:
+ return dataloader_class(
+ dataset=dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=n_worker,
+ pin_memory=True,
+ drop_last=drop_last,
+ )
+ else:
+ sampler = DistributedSampler(dataset, self.num_replicas, self.rank)
+ return dataloader_class(
+ dataset=dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ num_workers=n_worker,
+ pin_memory=True,
+ drop_last=drop_last,
+ )
+
+ def set_epoch(self, epoch: int) -> None:
+ RRSController.set_epoch(epoch, len(self.train))
+ if isinstance(self.train.sampler, DistributedSampler):
+ self.train.sampler.set_epoch(epoch)
+
+ def assign_active_image_size(self, new_size: int or tuple[int, int]) -> None:
+ self.active_image_size = val2tuple(new_size, 2)
+ new_transform = self.build_valid_transform(self.active_image_size)
+ # change the transform of the valid and test set
+ self.valid.dataset.transform = self.test.dataset.transform = new_transform
+
+ def sample_val_dataset(self, train_dataset, valid_transform) -> tuple[any, any]:
+ if self.valid_size is not None:
+ if 0 < self.valid_size < 1:
+ valid_size = int(self.valid_size * len(train_dataset))
+ else:
+ assert self.valid_size >= 1
+ valid_size = int(self.valid_size)
+ train_dataset, val_dataset = random_drop_data(
+ train_dataset,
+ valid_size,
+ self.VALID_SEED,
+ self.data_keys,
+ )
+ val_dataset.transform = valid_transform
+ else:
+ val_dataset = None
+ return train_dataset, val_dataset
+
+ def build_sub_train_loader(self, n_samples: int, batch_size: int) -> any:
+ # used for resetting BN running statistics
+ if self.sub_train is None:
+ self.sub_train = {}
+ if self.active_image_size in self.sub_train:
+ return self.sub_train[self.active_image_size]
+
+ # construct dataset and dataloader
+ train_dataset = copy.deepcopy(self.train.dataset)
+ if n_samples < len(train_dataset):
+ _, train_dataset = random_drop_data(
+ train_dataset,
+ n_samples,
+ self.SUB_SEED,
+ self.data_keys,
+ )
+ RRSController.ACTIVE_SIZE = self.active_image_size
+ train_dataset.transform = self.build_train_transform(
+ image_size=self.active_image_size
+ )
+ data_loader = self.build_dataloader(
+ train_dataset, batch_size, self.train.num_workers, True, False
+ )
+
+ # pre-fetch data
+ self.sub_train[self.active_image_size] = [
+ data
+ for data in data_loader
+ for _ in range(max(1, n_samples // len(train_dataset)))
+ ]
+
+ return self.sub_train[self.active_image_size]
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__init__.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b831fa9d3e933e76cf78120947143e8a19133ea2
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__init__.py
@@ -0,0 +1,7 @@
+"""Random resolution data loader compatible with multi-processing and distributed training.
+
+Replace Pytorch's DataLoader with RRSDataLoader to support random resolution
+at the training time, resolution sampling is controlled by RRSController
+"""
+
+from .controller import *
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__pycache__/__init__.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ed312e1d7eaf5a4a4e8d68be94c80ddacd7f614
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__pycache__/__init__.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__pycache__/controller.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__pycache__/controller.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..137869dc7b3eb91812d34b1240ac62f225e84836
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__pycache__/controller.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/_data_loader.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/_data_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b92ecfa012f0c1804e60fc126d5d3b8ee404efd
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/_data_loader.py
@@ -0,0 +1,1603 @@
+r"""This file is based on torch/utils/data/data_loader.py
+
+Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
+
+To support these two classes, in `./_utils` we define many utility methods and
+functions to be run in multiprocessing. E.g., the data loading worker loop is
+in `./_utils/worker.py`.
+"""
+
+import functools
+import itertools
+import logging
+import multiprocessing as python_multiprocessing
+import os
+import queue
+import threading
+import warnings
+from typing import (Any, Callable, Generic, Iterable, List, Optional, Sequence,
+ TypeVar, Union)
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as multiprocessing
+import torch.utils.data.graph_settings
+from torch._utils import ExceptionWrapper
+from torch.utils.data import (BatchSampler, Dataset, IterableDataset,
+ IterDataPipe, MapDataPipe, RandomSampler,
+ Sampler, SequentialSampler, _utils)
+from torch.utils.data.datapipes.datapipe import (
+ _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper)
+
+from ._data_worker import _worker_loop
+
+__all__ = ["RRSDataLoader"]
+
+T_co = TypeVar("T_co", covariant=True)
+T = TypeVar("T")
+_worker_init_fn_t = Callable[[int], None]
+
+# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
+# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
+# See https://github.com/python/mypy/issues/3737.
+_collate_fn_t = Callable[[List[T]], Any]
+
+
+# These functions used to be defined in this file. However, it was moved to
+# _utils/collate.py. Although it is rather hard to access this from user land
+# (one has to explicitly directly `import torch.utils.data.dataloader`), there
+# probably is user code out there using it. This aliasing maintains BC in this
+# aspect.
+default_collate: _collate_fn_t = _utils.collate.default_collate
+default_convert = _utils.collate.default_convert
+
+get_worker_info = _utils.worker.get_worker_info
+
+logger = logging.getLogger(__name__)
+
+
+class _DatasetKind:
+ Map = 0
+ Iterable = 1
+
+ @staticmethod
+ def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
+ if kind == _DatasetKind.Map:
+ return _utils.fetch._MapDatasetFetcher(
+ dataset, auto_collation, collate_fn, drop_last
+ )
+ else:
+ return _utils.fetch._IterableDatasetFetcher(
+ dataset, auto_collation, collate_fn, drop_last
+ )
+
+
+class _InfiniteConstantSampler(Sampler):
+ r"""Analogous to ``itertools.repeat(None, None)``.
+ Used as sampler for :class:`~torch.utils.data.IterableDataset`.
+
+ Args:
+ data_source (Dataset): dataset to sample from
+ """
+
+ def __init__(self):
+ super().__init__(None)
+
+ def __iter__(self):
+ while True:
+ yield None
+
+
+def _get_distributed_settings():
+ if dist.is_available() and dist.is_initialized():
+ return dist.get_world_size(), dist.get_rank()
+ else:
+ return 1, 0
+
+
+def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
+ global_worker_id = worker_id
+ info = torch.utils.data.get_worker_info()
+ assert info is not None
+ total_workers = info.num_workers
+ datapipe = info.dataset
+ assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
+ # To distribute elements across distributed process evenly, we should shard data on distributed
+ # processes first then shard on worker processes
+ total_workers *= world_size
+ global_worker_id = global_worker_id * world_size + rank_id
+ # For BC, use default SHARDING_PRIORITIES
+ torch.utils.data.graph_settings.apply_sharding(
+ datapipe, total_workers, global_worker_id
+ )
+ if worker_init_fn is not None:
+ worker_init_fn(worker_id)
+
+
+def _share_dist_seed(generator, pg):
+ _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator)
+ if isinstance(pg, dist.ProcessGroup):
+ dist.broadcast(_shared_seed, src=0, group=pg)
+ return _shared_seed.item()
+
+
+class RRSDataLoader(Generic[T_co]):
+ r"""
+ Data loader. Combines a dataset and a sampler, and provides an iterable over
+ the given dataset.
+
+ The :class:`~torch.utils.data.DataLoader` supports both map-style and
+ iterable-style datasets with single- or multi-process loading, customizing
+ loading order and optional automatic batching (collation) and memory pinning.
+
+ See :py:mod:`torch.utils.data` documentation page for more details.
+
+ Args:
+ dataset (Dataset): dataset from which to load the data.
+ batch_size (int, optional): how many samples per batch to load
+ (default: ``1``).
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
+ at every epoch (default: ``False``).
+ sampler (Sampler or Iterable, optional): defines the strategy to draw
+ samples from the dataset. Can be any ``Iterable`` with ``__len__``
+ implemented. If specified, :attr:`shuffle` must not be specified.
+ batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
+ returns a batch of indices at a time. Mutually exclusive with
+ :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
+ and :attr:`drop_last`.
+ num_workers (int, optional): how many subprocesses to use for data
+ loading. ``0`` means that the data will be loaded in the main process.
+ (default: ``0``)
+ collate_fn (Callable, optional): merges a list of samples to form a
+ mini-batch of Tensor(s). Used when using batched loading from a
+ map-style dataset.
+ pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
+ into device/CUDA pinned memory before returning them. If your data elements
+ are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
+ see the example below.
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
+ if the dataset size is not divisible by the batch size. If ``False`` and
+ the size of dataset is not divisible by the batch size, then the last batch
+ will be smaller. (default: ``False``)
+ timeout (numeric, optional): if positive, the timeout value for collecting a batch
+ from workers. Should always be non-negative. (default: ``0``)
+ worker_init_fn (Callable, optional): If not ``None``, this will be called on each
+ worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
+ input, after seeding and before data loading. (default: ``None``)
+ generator (torch.Generator, optional): If not ``None``, this RNG will be used
+ by RandomSampler to generate random indexes and multiprocessing to generate
+ `base_seed` for workers. (default: ``None``)
+ prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
+ in advance by each worker. ``2`` means there will be a total of
+ 2 * num_workers batches prefetched across all workers. (default value depends
+ on the set value for num_workers. If value of num_workers=0 default is ``None``.
+ Otherwise if value of num_workers>0 default is ``2``).
+ persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
+ the worker processes after a dataset has been consumed once. This allows to
+ maintain the workers `Dataset` instances alive. (default: ``False``)
+ pin_memory_device (str, optional): the data loader will copy Tensors
+ into device pinned memory before returning them if pin_memory is set to true.
+
+
+ .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
+ cannot be an unpicklable object, e.g., a lambda function. See
+ :ref:`multiprocessing-best-practices` on more details related
+ to multiprocessing in PyTorch.
+
+ .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
+ When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
+ it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
+ rounding depending on :attr:`drop_last`, regardless of multi-process loading
+ configurations. This represents the best guess PyTorch can make because PyTorch
+ trusts user :attr:`dataset` code in correctly handling multi-process
+ loading to avoid duplicate data.
+
+ However, if sharding results in multiple workers having incomplete last batches,
+ this estimate can still be inaccurate, because (1) an otherwise complete batch can
+ be broken into multiple ones and (2) more than one batch worth of samples can be
+ dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
+ cases in general.
+
+ See `Dataset Types`_ for more details on these two types of datasets and how
+ :class:`~torch.utils.data.IterableDataset` interacts with
+ `Multi-process data loading`_.
+
+ .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
+ :ref:`data-loading-randomness` notes for random seed related questions.
+ """
+
+ dataset: Dataset[T_co]
+ batch_size: Optional[int]
+ num_workers: int
+ pin_memory: bool
+ drop_last: bool
+ timeout: float
+ sampler: Union[Sampler, Iterable]
+ pin_memory_device: str
+ prefetch_factor: Optional[int]
+ _iterator: Optional["_BaseDataLoaderIter"]
+ __initialized = False
+
+ def __init__(
+ self,
+ dataset: Dataset[T_co],
+ batch_size: Optional[int] = 1,
+ shuffle: Optional[bool] = None,
+ sampler: Union[Sampler, Iterable, None] = None,
+ batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
+ num_workers: int = 0,
+ collate_fn: Optional[_collate_fn_t] = None,
+ pin_memory: bool = False,
+ drop_last: bool = False,
+ timeout: float = 0,
+ worker_init_fn: Optional[_worker_init_fn_t] = None,
+ multiprocessing_context=None,
+ generator=None,
+ *,
+ prefetch_factor: Optional[int] = None,
+ persistent_workers: bool = False,
+ pin_memory_device: str = ""
+ ):
+ torch._C._log_api_usage_once("python.data_loader")
+
+ if num_workers < 0:
+ raise ValueError(
+ "num_workers option should be non-negative; "
+ "use num_workers=0 to disable multiprocessing."
+ )
+
+ if timeout < 0:
+ raise ValueError("timeout option should be non-negative")
+
+ if num_workers == 0 and prefetch_factor is not None:
+ raise ValueError(
+ "prefetch_factor option could only be specified in multiprocessing."
+ "let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None."
+ )
+ elif num_workers > 0 and prefetch_factor is None:
+ prefetch_factor = 2
+ elif prefetch_factor is not None and prefetch_factor < 0:
+ raise ValueError("prefetch_factor option should be non-negative")
+
+ if persistent_workers and num_workers == 0:
+ raise ValueError("persistent_workers option needs num_workers > 0")
+
+ self.dataset = dataset
+ self.num_workers = num_workers
+ self.prefetch_factor = prefetch_factor
+ self.pin_memory = pin_memory
+ self.pin_memory_device = pin_memory_device
+ self.timeout = timeout
+ self.worker_init_fn = worker_init_fn
+ self.multiprocessing_context = multiprocessing_context
+
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
+ # _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
+ if isinstance(self.dataset, IterDataPipe):
+ self.dataset = _IterDataPipeSerializationWrapper(self.dataset)
+ elif isinstance(self.dataset, MapDataPipe):
+ self.dataset = _MapDataPipeSerializationWrapper(self.dataset)
+
+ # Arg-check dataset related before checking samplers because we want to
+ # tell users that iterable-style datasets are incompatible with custom
+ # samplers first, so that they don't learn that this combo doesn't work
+ # after spending time fixing the custom sampler errors.
+ if isinstance(dataset, IterableDataset):
+ self._dataset_kind = _DatasetKind.Iterable
+ # NOTE [ Custom Samplers and IterableDataset ]
+ #
+ # `IterableDataset` does not support custom `batch_sampler` or
+ # `sampler` since the key is irrelevant (unless we support
+ # generator-style dataset one day...).
+ #
+ # For `sampler`, we always create a dummy sampler. This is an
+ # infinite sampler even when the dataset may have an implemented
+ # finite `__len__` because in multi-process data loading, naive
+ # settings will return duplicated data (which may be desired), and
+ # thus using a sampler with length matching that of dataset will
+ # cause data lost (you may have duplicates of the first couple
+ # batches, but never see anything afterwards). Therefore,
+ # `Iterabledataset` always uses an infinite sampler, an instance of
+ # `_InfiniteConstantSampler` defined above.
+ #
+ # A custom `batch_sampler` essentially only controls the batch size.
+ # However, it is unclear how useful it would be since an iterable-style
+ # dataset can handle that within itself. Moreover, it is pointless
+ # in multi-process data loading as the assignment order of batches
+ # to workers is an implementation detail so users can not control
+ # how to batchify each worker's iterable. Thus, we disable this
+ # option. If this turns out to be useful in future, we can re-enable
+ # this, and support custom samplers that specify the assignments to
+ # specific workers.
+ if isinstance(dataset, IterDataPipe):
+ if shuffle is not None:
+ dataset = torch.utils.data.graph_settings.apply_shuffle_settings(
+ dataset, shuffle=shuffle
+ )
+ # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default.
+ elif shuffle not in {False, None}:
+ raise ValueError(
+ "DataLoader with IterableDataset: expected unspecified "
+ "shuffle option, but got shuffle={}".format(shuffle)
+ )
+
+ if sampler is not None:
+ # See NOTE [ Custom Samplers and IterableDataset ]
+ raise ValueError(
+ "DataLoader with IterableDataset: expected unspecified "
+ "sampler option, but got sampler={}".format(sampler)
+ )
+ elif batch_sampler is not None:
+ # See NOTE [ Custom Samplers and IterableDataset ]
+ raise ValueError(
+ "DataLoader with IterableDataset: expected unspecified "
+ "batch_sampler option, but got batch_sampler={}".format(
+ batch_sampler
+ )
+ )
+ else:
+ shuffle = bool(shuffle)
+ self._dataset_kind = _DatasetKind.Map
+
+ if sampler is not None and shuffle:
+ raise ValueError("sampler option is mutually exclusive with " "shuffle")
+
+ if batch_sampler is not None:
+ # auto_collation with custom batch_sampler
+ if batch_size != 1 or shuffle or sampler is not None or drop_last:
+ raise ValueError(
+ "batch_sampler option is mutually exclusive "
+ "with batch_size, shuffle, sampler, and "
+ "drop_last"
+ )
+ batch_size = None
+ drop_last = False
+ elif batch_size is None:
+ # no auto_collation
+ if drop_last:
+ raise ValueError(
+ "batch_size=None option disables auto-batching "
+ "and is mutually exclusive with drop_last"
+ )
+
+ if sampler is None: # give default samplers
+ if self._dataset_kind == _DatasetKind.Iterable:
+ # See NOTE [ Custom Samplers and IterableDataset ]
+ sampler = _InfiniteConstantSampler()
+ else: # map-style
+ if shuffle:
+ sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
+ else:
+ sampler = SequentialSampler(dataset) # type: ignore[arg-type]
+
+ if batch_size is not None and batch_sampler is None:
+ # auto_collation without custom batch_sampler
+ batch_sampler = BatchSampler(sampler, batch_size, drop_last)
+
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+ self.sampler = sampler
+ self.batch_sampler = batch_sampler
+ self.generator = generator
+
+ if collate_fn is None:
+ if self._auto_collation:
+ collate_fn = _utils.collate.default_collate
+ else:
+ collate_fn = _utils.collate.default_convert
+
+ self.collate_fn = collate_fn
+ self.persistent_workers = persistent_workers
+
+ self.__initialized = True
+ self._IterableDataset_len_called = (
+ None # See NOTE [ IterableDataset and __len__ ]
+ )
+
+ self._iterator = None
+
+ self.check_worker_number_rationality()
+
+ torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined]
+
+ def _get_iterator(self) -> "_BaseDataLoaderIter":
+ if self.num_workers == 0:
+ return _SingleProcessDataLoaderIter(self)
+ else:
+ self.check_worker_number_rationality()
+ return _MultiProcessingDataLoaderIter(self)
+
+ @property
+ def multiprocessing_context(self):
+ return self.__multiprocessing_context
+
+ @multiprocessing_context.setter
+ def multiprocessing_context(self, multiprocessing_context):
+ if multiprocessing_context is not None:
+ if self.num_workers > 0:
+ if isinstance(multiprocessing_context, str):
+ valid_start_methods = multiprocessing.get_all_start_methods()
+ if multiprocessing_context not in valid_start_methods:
+ raise ValueError(
+ (
+ "multiprocessing_context option "
+ "should specify a valid start method in {!r}, but got "
+ "multiprocessing_context={!r}"
+ ).format(valid_start_methods, multiprocessing_context)
+ )
+ multiprocessing_context = multiprocessing.get_context(
+ multiprocessing_context
+ )
+
+ if not isinstance(
+ multiprocessing_context, python_multiprocessing.context.BaseContext
+ ):
+ raise TypeError(
+ (
+ "multiprocessing_context option should be a valid context "
+ "object or a string specifying the start method, but got "
+ "multiprocessing_context={}"
+ ).format(multiprocessing_context)
+ )
+ else:
+ raise ValueError(
+ (
+ "multiprocessing_context can only be used with "
+ "multi-process loading (num_workers > 0), but got "
+ "num_workers={}"
+ ).format(self.num_workers)
+ )
+
+ self.__multiprocessing_context = multiprocessing_context
+
+ def __setattr__(self, attr, val):
+ if self.__initialized and attr in (
+ "batch_size",
+ "batch_sampler",
+ "sampler",
+ "drop_last",
+ "dataset",
+ "persistent_workers",
+ ):
+ raise ValueError(
+ "{} attribute should not be set after {} is "
+ "initialized".format(attr, self.__class__.__name__)
+ )
+
+ super().__setattr__(attr, val)
+
+ # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
+ # since '_BaseDataLoaderIter' references 'DataLoader'.
+ def __iter__(self) -> "_BaseDataLoaderIter":
+ # When using a single worker the returned iterator should be
+ # created everytime to avoid reseting its state
+ # However, in the case of a multiple workers iterator
+ # the iterator is only created once in the lifetime of the
+ # DataLoader object so that workers can be reused
+ if self.persistent_workers and self.num_workers > 0:
+ if self._iterator is None:
+ self._iterator = self._get_iterator()
+ else:
+ self._iterator._reset(self)
+ return self._iterator
+ else:
+ return self._get_iterator()
+
+ @property
+ def _auto_collation(self):
+ return self.batch_sampler is not None
+
+ @property
+ def _index_sampler(self):
+ # The actual sampler used for generating indices for `_DatasetFetcher`
+ # (see _utils/fetch.py) to read data at each time. This would be
+ # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
+ # We can't change `.sampler` and `.batch_sampler` attributes for BC
+ # reasons.
+ if self._auto_collation:
+ return self.batch_sampler
+ else:
+ return self.sampler
+
+ def __len__(self) -> int:
+ if self._dataset_kind == _DatasetKind.Iterable:
+ # NOTE [ IterableDataset and __len__ ]
+ #
+ # For `IterableDataset`, `__len__` could be inaccurate when one naively
+ # does multi-processing data loading, since the samples will be duplicated.
+ # However, no real use case should be actually using that behavior, so
+ # it should count as a user error. We should generally trust user
+ # code to do the proper thing (e.g., configure each replica differently
+ # in `__iter__`), and give us the correct `__len__` if they choose to
+ # implement it (this will still throw if the dataset does not implement
+ # a `__len__`).
+ #
+ # To provide a further warning, we track if `__len__` was called on the
+ # `DataLoader`, save the returned value in `self._len_called`, and warn
+ # if the iterator ends up yielding more than this number of samples.
+
+ # Cannot statically verify that dataset is Sized
+ length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type]
+ if (
+ self.batch_size is not None
+ ): # IterableDataset doesn't allow custom sampler or batch_sampler
+ from math import ceil
+
+ if self.drop_last:
+ length = length // self.batch_size
+ else:
+ length = ceil(length / self.batch_size)
+ return length
+ else:
+ return len(self._index_sampler)
+
+ def check_worker_number_rationality(self):
+ # This function check whether the dataloader's worker number is rational based on
+ # current system's resource. Current rule is that if the number of workers this
+ # Dataloader will create is bigger than the number of logical cpus that is allowed to
+ # use, than we will pop up a warning to let user pay attention.
+ #
+ # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
+ # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
+ # DataLoader process can use half of them which is 32, then the rational max number of
+ # worker that initiated from this process is 32.
+ # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
+ # So the warning message is triggered to notify the user to lower the worker number if
+ # necessary.
+ #
+ #
+ # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is
+ # available (available in most of Linux system, but not OSX and Windows).
+ # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
+ # it doesn't repect cpuset.
+ # We don't take threading into account since each worker process is single threaded
+ # at this time.
+ #
+ # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
+ # other than `torch.set_num_threads` to 1 in the worker process, if the passing
+ # in functions use 3rd party modules that rely on those threading flags to determine
+ # how many thread to create (eg. numpy, etc), then it is caller's responsibility to
+ # set those flags correctly.
+ def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
+
+ suggested_max_worker_msg = (
+ (
+ (
+ "Our suggested max number of worker in current system is {}{}, which is smaller "
+ "than what this DataLoader is going to create."
+ ).format(
+ num_worker_suggest,
+ (
+ ""
+ if cpuset_checked
+ else " (`cpuset` is not taken into account)"
+ ),
+ )
+ )
+ if num_worker_suggest is not None
+ else (
+ "DataLoader is not able to compute a suggested max number of worker in current system."
+ )
+ )
+
+ warn_msg = (
+ "This DataLoader will create {} worker processes in total. {} "
+ "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
+ "lower the worker number to avoid potential slowness/freeze if necessary."
+ ).format(num_worker_created, suggested_max_worker_msg)
+ return warn_msg
+
+ if not self.num_workers or self.num_workers == 0:
+ return
+
+ # try to compute a suggested max number of worker based on system's resource
+ max_num_worker_suggest = None
+ cpuset_checked = False
+ if hasattr(os, "sched_getaffinity"):
+ try:
+ max_num_worker_suggest = len(os.sched_getaffinity(0))
+ cpuset_checked = True
+ except Exception:
+ pass
+ if max_num_worker_suggest is None:
+ # os.cpu_count() could return Optional[int]
+ # get cpu count first and check None in order to satify mypy check
+ cpu_count = os.cpu_count()
+ if cpu_count is not None:
+ max_num_worker_suggest = cpu_count
+
+ if max_num_worker_suggest is None:
+ warnings.warn(
+ _create_warning_msg(
+ max_num_worker_suggest, self.num_workers, cpuset_checked
+ )
+ )
+ return
+
+ if self.num_workers > max_num_worker_suggest:
+ warnings.warn(
+ _create_warning_msg(
+ max_num_worker_suggest, self.num_workers, cpuset_checked
+ )
+ )
+
+
+class _BaseDataLoaderIter:
+ def __init__(self, loader: RRSDataLoader) -> None:
+ self._dataset = loader.dataset
+ self._shared_seed = None
+ self._pg = None
+ if isinstance(self._dataset, IterDataPipe):
+ if dist.is_available() and dist.is_initialized():
+ self._pg = dist.new_group(backend="gloo")
+ self._shared_seed = _share_dist_seed(loader.generator, self._pg)
+ shared_rng = torch.Generator()
+ shared_rng.manual_seed(self._shared_seed)
+ self._dataset = torch.utils.data.graph_settings.apply_random_seed(
+ self._dataset, shared_rng
+ )
+ self._dataset_kind = loader._dataset_kind
+ self._IterableDataset_len_called = loader._IterableDataset_len_called
+ self._auto_collation = loader._auto_collation
+ self._drop_last = loader.drop_last
+ self._index_sampler = loader._index_sampler
+ self._num_workers = loader.num_workers
+ ws, rank = _get_distributed_settings()
+ self._world_size = ws
+ self._rank = rank
+ # for other backends, pin_memory_device need to set. if not set
+ # default behaviour is CUDA device. if pin_memory_device is selected
+ # and pin_memory is not set, the default behaviour false.
+ if len(loader.pin_memory_device) == 0:
+ self._pin_memory = loader.pin_memory and torch.cuda.is_available()
+ self._pin_memory_device = None
+ else:
+ if not loader.pin_memory:
+ warn_msg = (
+ "pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
+ "please set pin_memory to true, if you need to use the device pin memory"
+ )
+ warnings.warn(warn_msg)
+
+ self._pin_memory = loader.pin_memory
+ self._pin_memory_device = loader.pin_memory_device
+ self._timeout = loader.timeout
+ self._collate_fn = loader.collate_fn
+ self._sampler_iter = iter(self._index_sampler)
+ self._base_seed = (
+ torch.empty((), dtype=torch.int64)
+ .random_(generator=loader.generator)
+ .item()
+ )
+ self._persistent_workers = loader.persistent_workers
+ self._num_yielded = 0
+ self._profile_name = "enumerate(DataLoader)#{}.__next__".format(
+ self.__class__.__name__
+ )
+
+ def __iter__(self) -> "_BaseDataLoaderIter":
+ return self
+
+ def _reset(self, loader, first_iter=False):
+ self._sampler_iter = iter(self._index_sampler)
+ self._num_yielded = 0
+ self._IterableDataset_len_called = loader._IterableDataset_len_called
+ if isinstance(self._dataset, IterDataPipe):
+ self._shared_seed = _share_dist_seed(loader.generator, self._pg)
+ shared_rng = torch.Generator()
+ shared_rng.manual_seed(self._shared_seed)
+ self._dataset = torch.utils.data.graph_settings.apply_random_seed(
+ self._dataset, shared_rng
+ )
+
+ def _next_index(self):
+ return next(self._sampler_iter) # may raise StopIteration
+
+ def _next_data(self):
+ raise NotImplementedError
+
+ def __next__(self) -> Any:
+ with torch.autograd.profiler.record_function(self._profile_name):
+ if self._sampler_iter is None:
+ # TODO(https://github.com/pytorch/pytorch/issues/76750)
+ self._reset() # type: ignore[call-arg]
+ data = self._next_data()
+ self._num_yielded += 1
+ if (
+ self._dataset_kind == _DatasetKind.Iterable
+ and self._IterableDataset_len_called is not None
+ and self._num_yielded > self._IterableDataset_len_called
+ ):
+ warn_msg = (
+ "Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
+ "samples have been fetched. "
+ ).format(
+ self._dataset, self._IterableDataset_len_called, self._num_yielded
+ )
+ if self._num_workers > 0:
+ warn_msg += (
+ "For multiprocessing data-loading, this could be caused by not properly configuring the "
+ "IterableDataset replica at each worker. Please see "
+ "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples."
+ )
+ warnings.warn(warn_msg)
+ return data
+
+ def __len__(self) -> int:
+ return len(self._index_sampler)
+
+ def __getstate__(self):
+ # TODO: add limited pickling support for sharing an iterator
+ # across multiple threads for HOGWILD.
+ # Probably the best way to do this is by moving the sample pushing
+ # to a separate thread and then just sharing the data queue
+ # but signalling the end is tricky without a non-blocking API
+ raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
+
+
+class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
+ def __init__(self, loader):
+ super().__init__(loader)
+ assert self._timeout == 0
+ assert self._num_workers == 0
+
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
+ # Taking care of distributed sharding
+ if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
+ # For BC, use default SHARDING_PRIORITIES
+ torch.utils.data.graph_settings.apply_sharding(
+ self._dataset, self._world_size, self._rank
+ )
+
+ self._dataset_fetcher = _DatasetKind.create_fetcher(
+ self._dataset_kind,
+ self._dataset,
+ self._auto_collation,
+ self._collate_fn,
+ self._drop_last,
+ )
+
+ def _next_data(self):
+ index = self._next_index() # may raise StopIteration
+ data = self._dataset_fetcher.fetch(index) # may raise StopIteration
+ if self._pin_memory:
+ data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
+ return data
+
+
+class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
+ r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
+
+ # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
+ #
+ # Preliminary:
+ #
+ # Our data model looks like this (queues are indicated with curly brackets):
+ #
+ # main process ||
+ # | ||
+ # {index_queue} ||
+ # | ||
+ # worker processes || DATA
+ # | ||
+ # {worker_result_queue} || FLOW
+ # | ||
+ # pin_memory_thread of main process || DIRECTION
+ # | ||
+ # {data_queue} ||
+ # | ||
+ # data output \/
+ #
+ # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
+ # `pin_memory=False`.
+ #
+ #
+ # Terminating multiprocessing logic requires very careful design. In
+ # particular, we need to make sure that
+ #
+ # 1. The iterator gracefully exits the workers when its last reference is
+ # gone or it is depleted.
+ #
+ # In this case, the workers should be gracefully exited because the
+ # main process may still need to continue to run, and we want cleaning
+ # up code in the workers to be executed (e.g., releasing GPU memory).
+ # Naturally, we implement the shutdown logic in `__del__` of
+ # DataLoaderIterator.
+ #
+ # We delay the discussion on the logic in this case until later.
+ #
+ # 2. The iterator exits the workers when the loader process and/or worker
+ # processes exits normally or with error.
+ #
+ # We set all workers and `pin_memory_thread` to have `daemon=True`.
+ #
+ # You may ask, why can't we make the workers non-daemonic, and
+ # gracefully exit using the same logic as we have in `__del__` when the
+ # iterator gets deleted (see 1 above)?
+ #
+ # First of all, `__del__` is **not** guaranteed to be called when
+ # interpreter exits. Even if it is called, by the time it executes,
+ # many Python core library resources may alreay be freed, and even
+ # simple things like acquiring an internal lock of a queue may hang.
+ # Therefore, in this case, we actually need to prevent `__del__` from
+ # being executed, and rely on the automatic termination of daemonic
+ # children.
+ #
+ # Thus, we register an `atexit` hook that sets a global flag
+ # `_utils.python_exit_status`. Since `atexit` hooks are executed in the
+ # reverse order of registration, we are guaranteed that this flag is
+ # set before library resources we use are freed (which, at least in
+ # CPython, is done via an `atexit` handler defined in
+ # `multiprocessing/util.py`
+ # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362
+ # registered when an object requiring this mechanism is first
+ # created, e.g., `mp.Queue`
+ # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103
+ # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29
+ # )
+ #
+ # So in `__del__`, we check if `_utils.python_exit_status` is set or
+ # `None` (freed), and perform no-op if so.
+ #
+ # However, simply letting library clean-up codes run can also be bad,
+ # because such codes (i.e., `multiprocessing.util._exit_function()`)
+ # include join putting threads for `mp.Queue`, which can be blocking.
+ # Hence, the main process putting threads are called with
+ # `cancel_join_thread` at creation. See later section
+ # [ 3b. A process won't hang when putting into a queue; ]
+ # for more details.
+ #
+ # Here are two example cases where library clean-up codes can run
+ # before `__del__` is called:
+ #
+ # 1. If we hold onto a reference to the iterator, it more often
+ # than not tries to do `multiprocessing` library cleaning before
+ # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666)
+ # and thus prevents our cleaning-up code to run first.
+ #
+ # 2. A similar issue araises when a `DataLoader` is used in a subprocess.
+ # When a process ends, it shuts the all its daemonic children
+ # down with a SIGTERM (instead of joining them without a timeout).
+ # Simiarly for threads, but by a different mechanism. This fact,
+ # together with a few implementation details of multiprocessing, forces
+ # us to make workers daemonic. All of our problems arise when a
+ # DataLoader is used in a subprocess, and are caused by multiprocessing
+ # code which looks more or less like this:
+ #
+ # try:
+ # your_function_using_a_dataloader()
+ # finally:
+ # multiprocessing.util._exit_function()
+ #
+ # The joining/termination mentioned above happens inside
+ # `_exit_function()`. Now, if `your_function_using_a_dataloader()`
+ # throws, the stack trace stored in the exception will prevent the
+ # frame which uses `DataLoaderIter` to be freed. If the frame has any
+ # reference to the `DataLoaderIter` (e.g., in a method of the iter),
+ # its `__del__`, which starts the shutdown procedure, will not be
+ # called. That, in turn, means that workers aren't notified. Attempting
+ # to join in `_exit_function` will then result in a hang.
+ #
+ # For context, `_exit_function` is also registered as an `atexit` call.
+ # So it is unclear to me (@ssnl) why this is needed in a finally block.
+ # The code dates back to 2008 and there is no comment on the original
+ # PEP 371 or patch https://bugs.python.org/issue3050 (containing both
+ # the finally block and the `atexit` registration) that explains this.
+ #
+ #
+ # Finally, another choice is to just shutdown workers with logic in 1
+ # above whenever we see an error in `next`. This isn't ideal because
+ # a. It prevents users from using try-catch to resume data loading.
+ # b. It doesn't prevent hanging if users have references to the
+ # iterator.
+ #
+ # 3. All processes exit if any of them die unexpectedly by fatal signals.
+ #
+ # As shown above, the workers are set as daemonic children of the main
+ # process. However, automatic cleaning-up of such child processes only
+ # happens if the parent process exits gracefully (e.g., not via fatal
+ # signals like SIGKILL). So we must ensure that each process will exit
+ # even the process that should send/receive data to/from it were
+ # killed, i.e.,
+ #
+ # a. A process won't hang when getting from a queue.
+ #
+ # Even with carefully designed data dependencies (i.e., a `put()`
+ # always corresponding to a `get()`), hanging on `get()` can still
+ # happen when data in queue is corrupted (e.g., due to
+ # `cancel_join_thread` or unexpected exit).
+ #
+ # For child exit, we set a timeout whenever we try to get data
+ # from `data_queue`, and check the workers' status on each timeout
+ # and error.
+ # See `_DataLoaderiter._get_batch()` and
+ # `_DataLoaderiter._try_get_data()` for details.
+ #
+ # Additionally, for child exit on non-Windows platforms, we also
+ # register a SIGCHLD handler (which is supported on Windows) on
+ # the main process, which checks if any of the workers fail in the
+ # (Python) handler. This is more efficient and faster in detecting
+ # worker failures, compared to only using the above mechanism.
+ # See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
+ #
+ # For `.get()` calls where the sender(s) is not the workers, we
+ # guard them with timeouts, and check the status of the sender
+ # when timeout happens:
+ # + in the workers, the `_utils.worker.ManagerWatchdog` class
+ # checks the status of the main process.
+ # + if `pin_memory=True`, when getting from `pin_memory_thread`,
+ # check `pin_memory_thread` status periodically until `.get()`
+ # returns or see that `pin_memory_thread` died.
+ #
+ # b. A process won't hang when putting into a queue;
+ #
+ # We use `mp.Queue` which has a separate background thread to put
+ # objects from an unbounded buffer array. The background thread is
+ # daemonic and usually automatically joined when the process
+ # *exits*.
+ #
+ # In case that the receiver has ended abruptly while
+ # reading from the pipe, the join will hang forever. The usual
+ # solution for this in Python is calling `q.cancel_join_thread`,
+ # which prevents automatically joining it when finalizing
+ # (exiting).
+ #
+ # Nonetheless, `cancel_join_thread` must only be called when the
+ # queue is **not** going to be read from or write into by another
+ # process, because it may hold onto a lock or leave corrupted data
+ # in the queue, leading other readers/writers to hang.
+ #
+ # Hence,
+ # + For worker processes, we only do so (for their output
+ # queues, i.e., `worker_result_queue`) before exiting.
+ # + For `pin_memory_thread`, its output queue `data_queue` is a
+ # `queue.Queue` that does blocking `put` if the queue is full.
+ # So there is no above problem, but as a result, in
+ # `_pin_memory_loop`, we do need to wrap the `put` in a loop
+ # that breaks not only upon success, but also when the main
+ # process stops reading, i.e., is shutting down.
+ # + For loader process, we `cancel_join_thread()` for all
+ # `_index_queues` because the whole purpose of workers and
+ # `pin_memory_thread` is to serve the loader process. If
+ # loader process is already exiting, we don't really care if
+ # the queues are corrupted.
+ #
+ #
+ # Now let's get back to 1:
+ # how we gracefully exit the workers when the last reference to the
+ # iterator is gone.
+ #
+ # To achieve this, we implement the following logic along with the design
+ # choices mentioned above:
+ #
+ # `workers_done_event`:
+ # A `multiprocessing.Event` shared among the main process and all worker
+ # processes. This is used to signal the workers that the iterator is
+ # shutting down. After it is set, they will not send processed data to
+ # queues anymore, and only wait for the final `None` before exiting.
+ # `done_event` isn't strictly needed. I.e., we can just check for `None`
+ # from the input queue, but it allows us to skip wasting resources
+ # processing data if we are already shutting down.
+ #
+ # `pin_memory_thread_done_event`:
+ # A `threading.Event` for a similar purpose to that of
+ # `workers_done_event`, but is for the `pin_memory_thread`. The reason
+ # that separate events are needed is that `pin_memory_thread` reads from
+ # the output queue of the workers. But the workers, upon seeing that
+ # `workers_done_event` is set, only wants to see the final `None`, and is
+ # not required to flush all data in the output queue (e.g., it may call
+ # `cancel_join_thread` on that queue if its `IterableDataset` iterator
+ # happens to exhaust coincidentally, which is out of the control of the
+ # main process). Thus, since we will exit `pin_memory_thread` before the
+ # workers (see below), two separete events are used.
+ #
+ # NOTE: In short, the protocol is that the main process will set these
+ # `done_event`s and then the corresponding processes/threads a `None`,
+ # and that they may exit at any time after receiving the `None`.
+ #
+ # NOTE: Using `None` as the final signal is valid, since normal data will
+ # always be a 2-tuple with the 1st element being the index of the data
+ # transferred (different from dataset index/key), and the 2nd being
+ # either the dataset key or the data sample (depending on which part
+ # of the data model the queue is at).
+ #
+ # [ worker processes ]
+ # While loader process is alive:
+ # Get from `index_queue`.
+ # If get anything else,
+ # Check `workers_done_event`.
+ # If set, continue to next iteration
+ # i.e., keep getting until see the `None`, then exit.
+ # Otherwise, process data:
+ # If is fetching from an `IterableDataset` and the iterator
+ # is exhausted, send an `_IterableDatasetStopIteration`
+ # object to signal iteration end. The main process, upon
+ # receiving such an object, will send `None` to this
+ # worker and not use the corresponding `index_queue`
+ # anymore.
+ # If timed out,
+ # No matter `workers_done_event` is set (still need to see `None`)
+ # or not, must continue to next iteration.
+ # (outside loop)
+ # If `workers_done_event` is set, (this can be False with `IterableDataset`)
+ # `data_queue.cancel_join_thread()`. (Everything is ending here:
+ # main process won't read from it;
+ # other workers will also call
+ # `cancel_join_thread`.)
+ #
+ # [ pin_memory_thread ]
+ # # No need to check main thread. If this thread is alive, the main loader
+ # # thread must be alive, because this thread is set as daemonic.
+ # While `pin_memory_thread_done_event` is not set:
+ # Get from `index_queue`.
+ # If timed out, continue to get in the next iteration.
+ # Otherwise, process data.
+ # While `pin_memory_thread_done_event` is not set:
+ # Put processed data to `data_queue` (a `queue.Queue` with blocking put)
+ # If timed out, continue to put in the next iteration.
+ # Otherwise, break, i.e., continuing to the out loop.
+ #
+ # NOTE: we don't check the status of the main thread because
+ # 1. if the process is killed by fatal signal, `pin_memory_thread`
+ # ends.
+ # 2. in other cases, either the cleaning-up in __del__ or the
+ # automatic exit of daemonic thread will take care of it.
+ # This won't busy-wait either because `.get(timeout)` does not
+ # busy-wait.
+ #
+ # [ main process ]
+ # In the DataLoader Iter's `__del__`
+ # b. Exit `pin_memory_thread`
+ # i. Set `pin_memory_thread_done_event`.
+ # ii Put `None` in `worker_result_queue`.
+ # iii. Join the `pin_memory_thread`.
+ # iv. `worker_result_queue.cancel_join_thread()`.
+ #
+ # c. Exit the workers.
+ # i. Set `workers_done_event`.
+ # ii. Put `None` in each worker's `index_queue`.
+ # iii. Join the workers.
+ # iv. Call `.cancel_join_thread()` on each worker's `index_queue`.
+ #
+ # NOTE: (c) is better placed after (b) because it may leave corrupted
+ # data in `worker_result_queue`, which `pin_memory_thread`
+ # reads from, in which case the `pin_memory_thread` can only
+ # happen at timeing out, which is slow. Nonetheless, same thing
+ # happens if a worker is killed by signal at unfortunate times,
+ # but in other cases, we are better off having a non-corrupted
+ # `worker_result_queue` for `pin_memory_thread`.
+ #
+ # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
+ # can be omitted
+ #
+ # NB: `done_event`s isn't strictly needed. E.g., we can just check for
+ # `None` from `index_queue`, but it allows us to skip wasting resources
+ # processing indices already in `index_queue` if we are already shutting
+ # down.
+
+ def __init__(self, loader):
+ super().__init__(loader)
+
+ self._prefetch_factor = loader.prefetch_factor
+
+ assert self._num_workers > 0
+ assert self._prefetch_factor > 0
+
+ if loader.multiprocessing_context is None:
+ multiprocessing_context = multiprocessing
+ else:
+ multiprocessing_context = loader.multiprocessing_context
+
+ self._worker_init_fn = loader.worker_init_fn
+
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
+ # Additional worker init function will take care of sharding in MP and Distributed
+ if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
+ self._worker_init_fn = functools.partial(
+ _sharding_worker_init_fn,
+ self._worker_init_fn,
+ self._world_size,
+ self._rank,
+ )
+
+ # No certainty which module multiprocessing_context is
+ self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
+ self._worker_pids_set = False
+ self._shutdown = False
+ self._workers_done_event = multiprocessing_context.Event()
+
+ self._index_queues = []
+ self._workers = []
+ for i in range(self._num_workers):
+ # No certainty which module multiprocessing_context is
+ index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
+ # Need to `cancel_join_thread` here!
+ # See sections (2) and (3b) above.
+ index_queue.cancel_join_thread()
+ w = multiprocessing_context.Process(
+ target=_worker_loop,
+ args=(
+ self._dataset_kind,
+ self._dataset,
+ index_queue,
+ self._worker_result_queue,
+ self._workers_done_event,
+ self._auto_collation,
+ self._collate_fn,
+ self._drop_last,
+ self._base_seed,
+ self._worker_init_fn,
+ i,
+ self._num_workers,
+ self._persistent_workers,
+ self._shared_seed,
+ ),
+ )
+ w.daemon = True
+ # NB: Process.start() actually take some time as it needs to
+ # start a process and pass the arguments over via a pipe.
+ # Therefore, we only add a worker to self._workers list after
+ # it started, so that we do not call .join() if program dies
+ # before it starts, and __del__ tries to join but will get:
+ # AssertionError: can only join a started process.
+ w.start()
+ self._index_queues.append(index_queue)
+ self._workers.append(w)
+
+ if self._pin_memory:
+ self._pin_memory_thread_done_event = threading.Event()
+
+ # Queue is not type-annotated
+ self._data_queue = queue.Queue() # type: ignore[var-annotated]
+ if self._pin_memory_device == "xpu":
+ current_device = torch.xpu.current_device() # type: ignore[attr-defined]
+ else:
+ current_device = torch.cuda.current_device() # choose cuda for default
+ pin_memory_thread = threading.Thread(
+ target=_utils.pin_memory._pin_memory_loop,
+ args=(
+ self._worker_result_queue,
+ self._data_queue,
+ current_device,
+ self._pin_memory_thread_done_event,
+ self._pin_memory_device,
+ ),
+ )
+ pin_memory_thread.daemon = True
+ pin_memory_thread.start()
+ # Similar to workers (see comment above), we only register
+ # pin_memory_thread once it is started.
+ self._pin_memory_thread = pin_memory_thread
+ else:
+ self._data_queue = self._worker_result_queue
+
+ # In some rare cases, persistent workers (daemonic processes)
+ # would be terminated before `__del__` of iterator is invoked
+ # when main process exits
+ # It would cause failure when pin_memory_thread tries to read
+ # corrupted data from worker_result_queue
+ # atexit is used to shutdown thread and child processes in the
+ # right sequence before main process exits
+ if self._persistent_workers and self._pin_memory:
+ import atexit
+
+ for w in self._workers:
+ atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
+
+ # .pid can be None only before process is spawned (not the case, so ignore)
+ _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
+ _utils.signal_handling._set_SIGCHLD_handler()
+ self._worker_pids_set = True
+ self._reset(loader, first_iter=True)
+
+ def _reset(self, loader, first_iter=False):
+ super()._reset(loader, first_iter)
+ self._send_idx = 0 # idx of the next task to be sent to workers
+ self._rcvd_idx = 0 # idx of the next task to be returned in __next__
+ # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
+ # map: task idx => - (worker_id,) if data isn't fetched (outstanding)
+ # \ (worker_id, data) if data is already fetched (out-of-order)
+ self._task_info = {}
+ self._tasks_outstanding = (
+ 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
+ )
+ # A list of booleans representing whether each worker still has work to
+ # do, i.e., not having exhausted its iterable dataset object. It always
+ # contains all `True`s if not using an iterable-style dataset
+ # (i.e., if kind != Iterable).
+ # Not that this indicates that a worker still has work to do *for this epoch*.
+ # It does not mean that a worker is dead. In case of `_persistent_workers`,
+ # the worker will be reset to available in the next epoch.
+ self._workers_status = [True for i in range(self._num_workers)]
+ # Reset the worker queue cycle so it resumes next epoch at worker 0
+ self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
+ # We resume the prefetching in case it was enabled
+ if not first_iter:
+ for idx in range(self._num_workers):
+ self._index_queues[idx].put(
+ _utils.worker._ResumeIteration(self._shared_seed)
+ )
+ resume_iteration_cnt = self._num_workers
+ while resume_iteration_cnt > 0:
+ return_idx, return_data = self._get_data()
+ if isinstance(return_idx, _utils.worker._ResumeIteration):
+ assert return_data is None
+ resume_iteration_cnt -= 1
+ # prime the prefetch loop
+ for _ in range(self._prefetch_factor * self._num_workers):
+ self._try_put_index()
+
+ def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
+ # Tries to fetch data from `self._data_queue` once for a given timeout.
+ # This can also be used as inner loop of fetching without timeout, with
+ # the sender status as the loop condition.
+ #
+ # This raises a `RuntimeError` if any worker died expectedly. This error
+ # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
+ # (only for non-Windows platforms), or the manual check below on errors
+ # and timeouts.
+ #
+ # Returns a 2-tuple:
+ # (bool: whether successfully get data, any: data if successful else None)
+ try:
+ data = self._data_queue.get(timeout=timeout)
+ return (True, data)
+ except Exception as e:
+ # At timeout and error, we manually check whether any worker has
+ # failed. Note that this is the only mechanism for Windows to detect
+ # worker failures.
+ failed_workers = []
+ for worker_id, w in enumerate(self._workers):
+ if self._workers_status[worker_id] and not w.is_alive():
+ failed_workers.append(w)
+ self._mark_worker_as_unavailable(worker_id)
+ if len(failed_workers) > 0:
+ pids_str = ", ".join(str(w.pid) for w in failed_workers)
+ raise RuntimeError(
+ "DataLoader worker (pid(s) {}) exited unexpectedly".format(pids_str)
+ ) from e
+ if isinstance(e, queue.Empty):
+ return (False, None)
+ import errno
+ import tempfile
+
+ try:
+ # Raise an exception if we are this close to the FDs limit.
+ # Apparently, trying to open only one file is not a sufficient
+ # test.
+ # See NOTE [ DataLoader on Linux and open files limit ]
+ fds_limit_margin = 10
+ fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
+ except OSError as e:
+ if e.errno == errno.EMFILE:
+ raise RuntimeError(
+ "Too many open files. Communication with the"
+ " workers is no longer possible. Please increase the"
+ " limit using `ulimit -n` in the shell or change the"
+ " sharing strategy by calling"
+ " `torch.multiprocessing.set_sharing_strategy('file_system')`"
+ " at the beginning of your code"
+ ) from None
+ raise
+
+ # NOTE [ DataLoader on Linux and open files limit ]
+ #
+ # On Linux when DataLoader is used with multiprocessing we pass the data between
+ # the root process and the workers through SHM files. We remove those files from
+ # the filesystem as soon as they are created and keep them alive by
+ # passing around their file descriptors through AF_UNIX sockets. (See
+ # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in
+ # the wiki (https://github.com/pytorch/pytorch/wiki).)
+ #
+ # This sometimes leads us to exceeding the open files limit. When that happens,
+ # and the offending file descriptor is coming over a socket, the `socket` Python
+ # package silently strips the file descriptor from the message, setting only the
+ # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that
+ # it _indicates that some control data were discarded due to lack of space in
+ # the buffer for ancillary data_). This might reflect the C implementation of
+ # AF_UNIX sockets.
+ #
+ # This behaviour can be reproduced with the script and instructions at the
+ # bottom of this note.
+ #
+ # When that happens, the standard Python `multiprocessing` (and not
+ # `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata`
+ #
+ # Sometimes, instead of the FD being stripped, you may get an `OSError:
+ # Too many open files`, both in the script below and in DataLoader. However,
+ # this is rare and seems to be nondeterministic.
+ #
+ #
+ # #!/usr/bin/env python3
+ # import sys
+ # import socket
+ # import os
+ # import array
+ # import shutil
+ # import socket
+ #
+ #
+ # if len(sys.argv) != 4:
+ # print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)")
+ # sys.exit(1)
+ #
+ # if __name__ == '__main__':
+ # dirname = sys.argv[1]
+ # sock_path = dirname + "/sock"
+ # iterations = int(sys.argv[2])
+ # def dummy_path(i):
+ # return dirname + "/" + str(i) + ".dummy"
+ #
+ #
+ # if sys.argv[3] == 'send':
+ # while not os.path.exists(sock_path):
+ # pass
+ # client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
+ # client.connect(sock_path)
+ # for i in range(iterations):
+ # fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT)
+ # ancdata = array.array('i', [fd])
+ # msg = bytes([i % 256])
+ # print("Sending fd ", fd, " (iteration #", i, ")")
+ # client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)])
+ #
+ #
+ # else:
+ # assert sys.argv[3] == 'recv'
+ #
+ # if os.path.exists(dirname):
+ # raise Exception("Directory exists")
+ #
+ # os.mkdir(dirname)
+ #
+ # print("Opening socket...")
+ # server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
+ # server.bind(sock_path)
+ #
+ # print("Listening...")
+ # for i in range(iterations):
+ # a = array.array('i')
+ # msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize))
+ # assert(len(ancdata) == 1)
+ # cmsg_level, cmsg_type, cmsg_data = ancdata[0]
+ # a.frombytes(cmsg_data)
+ # print("Received fd ", a[0], " (iteration #", i, ")")
+ #
+ # shutil.rmtree(dirname)
+ #
+ # Steps to reproduce:
+ #
+ # 1. Run two shells and set lower file descriptor limit in the receiving one:
+ # (shell1) ulimit -n 1020
+ # (shell2) ulimit -n 1022
+ #
+ # 2. Run the script above with the `recv` option in the first shell
+ # (shell1) ./test_socket.py sock_tmp 1017 recv
+ #
+ # 3. Run the script with the `send` option in the second shell:
+ # (shell2) ./test_socket.py sock_tmp 1017 send
+
+ def _get_data(self):
+ # Fetches data from `self._data_queue`.
+ #
+ # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
+ # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
+ # in a loop. This is the only mechanism to detect worker failures for
+ # Windows. For other platforms, a SIGCHLD handler is also used for
+ # worker failure detection.
+ #
+ # If `pin_memory=True`, we also need check if `pin_memory_thread` had
+ # died at timeouts.
+ if self._timeout > 0:
+ success, data = self._try_get_data(self._timeout)
+ if success:
+ return data
+ else:
+ raise RuntimeError(
+ "DataLoader timed out after {} seconds".format(self._timeout)
+ )
+ elif self._pin_memory:
+ while self._pin_memory_thread.is_alive():
+ success, data = self._try_get_data()
+ if success:
+ return data
+ else:
+ # while condition is false, i.e., pin_memory_thread died.
+ raise RuntimeError("Pin memory thread exited unexpectedly")
+ # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
+ # need to call `.task_done()` because we don't use `.join()`.
+ else:
+ while True:
+ success, data = self._try_get_data()
+ if success:
+ return data
+
+ def _next_data(self):
+ while True:
+ # If the worker responsible for `self._rcvd_idx` has already ended
+ # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
+ # we try to advance `self._rcvd_idx` to find the next valid index.
+ #
+ # This part needs to run in the loop because both the `self._get_data()`
+ # call and `_IterableDatasetStopIteration` check below can mark
+ # extra worker(s) as dead.
+ while self._rcvd_idx < self._send_idx:
+ info = self._task_info[self._rcvd_idx]
+ worker_id = info[0]
+ if (
+ len(info) == 2 or self._workers_status[worker_id]
+ ): # has data or is still active
+ break
+ del self._task_info[self._rcvd_idx]
+ self._rcvd_idx += 1
+ else:
+ # no valid `self._rcvd_idx` is found (i.e., didn't break)
+ if not self._persistent_workers:
+ self._shutdown_workers()
+ raise StopIteration
+
+ # Now `self._rcvd_idx` is the batch index we want to fetch
+
+ # Check if the next sample has already been generated
+ if len(self._task_info[self._rcvd_idx]) == 2:
+ data = self._task_info.pop(self._rcvd_idx)[1]
+ return self._process_data(data)
+
+ assert not self._shutdown and self._tasks_outstanding > 0
+ idx, data = self._get_data()
+ self._tasks_outstanding -= 1
+ if self._dataset_kind == _DatasetKind.Iterable:
+ # Check for _IterableDatasetStopIteration
+ if isinstance(data, _utils.worker._IterableDatasetStopIteration):
+ if self._persistent_workers:
+ self._workers_status[data.worker_id] = False
+ else:
+ self._mark_worker_as_unavailable(data.worker_id)
+ self._try_put_index()
+ continue
+
+ if idx != self._rcvd_idx:
+ # store out-of-order samples
+ self._task_info[idx] += (data,)
+ else:
+ del self._task_info[idx]
+ return self._process_data(data)
+
+ def _try_put_index(self):
+ assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
+
+ try:
+ index = self._next_index()
+ except StopIteration:
+ return
+ for _ in range(self._num_workers): # find the next active worker, if any
+ worker_queue_idx = next(self._worker_queue_idx_cycle)
+ if self._workers_status[worker_queue_idx]:
+ break
+ else:
+ # not found (i.e., didn't break)
+ return
+
+ self._index_queues[worker_queue_idx].put((self._send_idx, index))
+ self._task_info[self._send_idx] = (worker_queue_idx,)
+ self._tasks_outstanding += 1
+ self._send_idx += 1
+
+ def _process_data(self, data):
+ self._rcvd_idx += 1
+ self._try_put_index()
+ if isinstance(data, ExceptionWrapper):
+ data.reraise()
+ return data
+
+ def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
+ # Mark a worker as having finished its work e.g., due to
+ # exhausting an `IterableDataset`. This should be used only when this
+ # `_MultiProcessingDataLoaderIter` is going to continue running.
+
+ assert self._workers_status[worker_id] or (
+ self._persistent_workers and shutdown
+ )
+
+ # Signal termination to that specific worker.
+ q = self._index_queues[worker_id]
+ # Indicate that no more data will be put on this queue by the current
+ # process.
+ q.put(None)
+
+ # Note that we don't actually join the worker here, nor do we remove the
+ # worker's pid from C side struct because (1) joining may be slow, and
+ # (2) since we don't join, the worker may still raise error, and we
+ # prefer capturing those, rather than ignoring them, even though they
+ # are raised after the worker has finished its job.
+ # Joinning is deferred to `_shutdown_workers`, which it is called when
+ # all workers finish their jobs (e.g., `IterableDataset` replicas) or
+ # when this iterator is garbage collected.
+
+ self._workers_status[worker_id] = False
+
+ assert self._workers_done_event.is_set() == shutdown
+
+ def _shutdown_workers(self):
+ # Called when shutting down this `_MultiProcessingDataLoaderIter`.
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
+ # the logic of this function.
+ if (
+ _utils is None
+ or _utils.python_exit_status is True
+ or _utils.python_exit_status is None
+ ):
+ # See (2) of the note. If Python is shutting down, do no-op.
+ return
+ # Normal exit when last reference is gone / iterator is depleted.
+ # See (1) and the second half of the note.
+ if not self._shutdown:
+ self._shutdown = True
+ try:
+ # Normal exit when last reference is gone / iterator is depleted.
+ # See (1) and the second half of the note.
+
+ # Exit `pin_memory_thread` first because exiting workers may leave
+ # corrupted data in `worker_result_queue` which `pin_memory_thread`
+ # reads from.
+ if hasattr(self, "_pin_memory_thread"):
+ # Use hasattr in case error happens before we set the attribute.
+ self._pin_memory_thread_done_event.set()
+ # Send something to pin_memory_thread in case it is waiting
+ # so that it can wake up and check `pin_memory_thread_done_event`
+ self._worker_result_queue.put((None, None))
+ self._pin_memory_thread.join()
+ self._worker_result_queue.cancel_join_thread()
+ self._worker_result_queue.close()
+
+ # Exit workers now.
+ self._workers_done_event.set()
+ for worker_id in range(len(self._workers)):
+ # Get number of workers from `len(self._workers)` instead of
+ # `self._num_workers` in case we error before starting all
+ # workers.
+ # If we are using workers_status with persistent_workers
+ # we have to shut it down because the worker is paused
+ if self._persistent_workers or self._workers_status[worker_id]:
+ self._mark_worker_as_unavailable(worker_id, shutdown=True)
+ for w in self._workers:
+ # We should be able to join here, but in case anything went
+ # wrong, we set a timeout and if the workers fail to join,
+ # they are killed in the `finally` block.
+ w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
+ for q in self._index_queues:
+ q.cancel_join_thread()
+ q.close()
+ finally:
+ # Even though all this function does is putting into queues that
+ # we have called `cancel_join_thread` on, weird things can
+ # happen when a worker is killed by a signal, e.g., hanging in
+ # `Event.set()`. So we need to guard this with SIGCHLD handler,
+ # and remove pids from the C side data structure only at the
+ # end.
+ #
+ # FIXME: Unfortunately, for Windows, we are missing a worker
+ # error detection mechanism here in this function, as it
+ # doesn't provide a SIGCHLD handler.
+ if self._worker_pids_set:
+ _utils.signal_handling._remove_worker_pids(id(self))
+ self._worker_pids_set = False
+ for w in self._workers:
+ if w.is_alive():
+ # Existing mechanisms try to make the workers exit
+ # peacefully, but in case that we unfortunately reach
+ # here, which we shouldn't, (e.g., pytorch/pytorch#39570),
+ # we kill the worker.
+ w.terminate()
+
+ # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter`
+ @staticmethod
+ def _clean_up_worker(w):
+ try:
+ w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
+ finally:
+ if w.is_alive():
+ w.terminate()
+
+ def __del__(self):
+ self._shutdown_workers()
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/_data_worker.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/_data_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..5abb4bca83842259606fc67d06f685a10421057d
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/_data_worker.py
@@ -0,0 +1,378 @@
+r""""This file is based on torch/utils/data/_utils/worker.py
+
+Contains definitions of the methods used by the _BaseDataLoaderIter workers.
+These **needs** to be in global scope since Py2 doesn't support serializing
+static methods.
+"""
+
+import os
+import queue
+import random
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Optional, Union
+
+import torch
+from torch._utils import ExceptionWrapper
+from torch.utils.data._utils import (HAS_NUMPY, IS_WINDOWS,
+ MP_STATUS_CHECK_INTERVAL, signal_handling)
+
+if TYPE_CHECKING:
+ from torch.utils.data import Dataset
+
+from .controller import RRSController
+
+if IS_WINDOWS:
+ import ctypes
+ from ctypes.wintypes import BOOL, DWORD, HANDLE
+
+ # On Windows, the parent ID of the worker process remains unchanged when the manager process
+ # is gone, and the only way to check it through OS is to let the worker have a process handle
+ # of the manager and ask if the process status has changed.
+ class ManagerWatchdog:
+ def __init__(self):
+ self.manager_pid = os.getppid()
+
+ # mypy cannot detect this code is windows only
+ self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined]
+ self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
+ self.kernel32.OpenProcess.restype = HANDLE
+ self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
+ self.kernel32.WaitForSingleObject.restype = DWORD
+
+ # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
+ SYNCHRONIZE = 0x00100000
+ self.manager_handle = self.kernel32.OpenProcess(
+ SYNCHRONIZE, 0, self.manager_pid
+ )
+
+ if not self.manager_handle:
+ raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined]
+
+ self.manager_dead = False
+
+ def is_alive(self):
+ if not self.manager_dead:
+ # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
+ self.manager_dead = (
+ self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
+ )
+ return not self.manager_dead
+
+else:
+
+ class ManagerWatchdog: # type: ignore[no-redef]
+ def __init__(self):
+ self.manager_pid = os.getppid()
+ self.manager_dead = False
+
+ def is_alive(self):
+ if not self.manager_dead:
+ self.manager_dead = os.getppid() != self.manager_pid
+ return not self.manager_dead
+
+
+_worker_info = None
+
+
+class WorkerInfo:
+ id: int
+ num_workers: int
+ seed: int
+ dataset: "Dataset"
+ __initialized = False
+
+ def __init__(self, **kwargs):
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+ self.__keys = tuple(kwargs.keys())
+ self.__initialized = True
+
+ def __setattr__(self, key, val):
+ if self.__initialized:
+ raise RuntimeError(
+ "Cannot assign attributes to {} objects".format(self.__class__.__name__)
+ )
+ return super().__setattr__(key, val)
+
+ def __repr__(self):
+ items = []
+ for k in self.__keys:
+ items.append("{}={}".format(k, getattr(self, k)))
+ return "{}({})".format(self.__class__.__name__, ", ".join(items))
+
+
+def get_worker_info() -> Optional[WorkerInfo]:
+ r"""Returns the information about the current
+ :class:`~torch.utils.data.DataLoader` iterator worker process.
+
+ When called in a worker, this returns an object guaranteed to have the
+ following attributes:
+
+ * :attr:`id`: the current worker id.
+ * :attr:`num_workers`: the total number of workers.
+ * :attr:`seed`: the random seed set for the current worker. This value is
+ determined by main process RNG and the worker id. See
+ :class:`~torch.utils.data.DataLoader`'s documentation for more details.
+ * :attr:`dataset`: the copy of the dataset object in **this** process. Note
+ that this will be a different object in a different process than the one
+ in the main process.
+
+ When called in the main process, this returns ``None``.
+
+ .. note::
+ When used in a :attr:`worker_init_fn` passed over to
+ :class:`~torch.utils.data.DataLoader`, this method can be useful to
+ set up each worker process differently, for instance, using ``worker_id``
+ to configure the ``dataset`` object to only read a specific fraction of a
+ sharded dataset, or use ``seed`` to seed other libraries used in dataset
+ code.
+ """
+ return _worker_info
+
+
+r"""Dummy class used to signal the end of an IterableDataset"""
+
+
+@dataclass(frozen=True)
+class _IterableDatasetStopIteration:
+ worker_id: int
+
+
+r"""Dummy class used to resume the fetching when worker reuse is enabled"""
+
+
+@dataclass(frozen=True)
+class _ResumeIteration:
+ seed: Optional[int] = None
+
+
+# The function `_generate_state` is adapted from `numpy.random.SeedSequence`
+# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx
+# It's MIT licensed, here is the copyright:
+
+# Copyright (c) 2015 Melissa E. O'Neill
+# Copyright (c) 2019 NumPy Developers
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+# This function generates an array of int32 as the seed for
+# `numpy.random`, in order to prevent state collision due to same
+# seed and algorithm for `numpy.random` and `random` modules.
+# TODO: Implement `SeedSequence` like object for `torch.random`
+def _generate_state(base_seed, worker_id):
+ INIT_A = 0x43B0D7E5
+ MULT_A = 0x931E8875
+ INIT_B = 0x8B51F9DD
+ MULT_B = 0x58F38DED
+ MIX_MULT_L = 0xCA01F9DD
+ MIX_MULT_R = 0x4973F715
+ XSHIFT = 4 * 8 // 2
+ MASK32 = 0xFFFFFFFF
+
+ entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
+ pool = [0] * 4
+
+ hash_const_A = INIT_A
+
+ def hash(value):
+ nonlocal hash_const_A
+ value = (value ^ hash_const_A) & MASK32
+ hash_const_A = (hash_const_A * MULT_A) & MASK32
+ value = (value * hash_const_A) & MASK32
+ value = (value ^ (value >> XSHIFT)) & MASK32
+ return value
+
+ def mix(x, y):
+ result_x = (MIX_MULT_L * x) & MASK32
+ result_y = (MIX_MULT_R * y) & MASK32
+ result = (result_x - result_y) & MASK32
+ result = (result ^ (result >> XSHIFT)) & MASK32
+ return result
+
+ # Add in the entropy to the pool.
+ for i in range(len(pool)):
+ pool[i] = hash(entropy[i])
+
+ # Mix all bits together so late bits can affect earlier bits.
+ for i_src in range(len(pool)):
+ for i_dst in range(len(pool)):
+ if i_src != i_dst:
+ pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))
+
+ hash_const_B = INIT_B
+ state = []
+ for i_dst in range(4):
+ data_val = pool[i_dst]
+ data_val = (data_val ^ hash_const_B) & MASK32
+ hash_const_B = (hash_const_B * MULT_B) & MASK32
+ data_val = (data_val * hash_const_B) & MASK32
+ data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32
+ state.append(data_val)
+ return state
+
+
+def _worker_loop(
+ dataset_kind,
+ dataset,
+ index_queue,
+ data_queue,
+ done_event,
+ auto_collation,
+ collate_fn,
+ drop_last,
+ base_seed,
+ init_fn,
+ worker_id,
+ num_workers,
+ persistent_workers,
+ shared_seed,
+):
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
+ # logic of this function.
+
+ try:
+ # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
+ # module's handlers are executed after Python returns from C low-level
+ # handlers, likely when the same fatal signal had already happened
+ # again.
+ # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
+ signal_handling._set_worker_signal_handlers()
+
+ torch.set_num_threads(1)
+ seed = base_seed + worker_id
+ random.seed(seed)
+ torch.manual_seed(seed)
+ if HAS_NUMPY:
+ np_seed = _generate_state(base_seed, worker_id)
+ import numpy as np
+
+ np.random.seed(np_seed)
+
+ from torch.utils.data import IterDataPipe
+ from torch.utils.data.graph_settings import apply_random_seed
+
+ shared_rng = torch.Generator()
+ if isinstance(dataset, IterDataPipe):
+ assert shared_seed is not None
+ shared_rng.manual_seed(shared_seed)
+ dataset = apply_random_seed(dataset, shared_rng)
+
+ global _worker_info
+ _worker_info = WorkerInfo(
+ id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset
+ )
+
+ from torch.utils.data import _DatasetKind
+
+ init_exception = None
+
+ try:
+ if init_fn is not None:
+ init_fn(worker_id)
+
+ fetcher = _DatasetKind.create_fetcher(
+ dataset_kind, dataset, auto_collation, collate_fn, drop_last
+ )
+ except Exception:
+ init_exception = ExceptionWrapper(
+ where="in DataLoader worker process {}".format(worker_id)
+ )
+
+ # When using Iterable mode, some worker can exit earlier than others due
+ # to the IterableDataset behaving differently for different workers.
+ # When such things happen, an `_IterableDatasetStopIteration` object is
+ # sent over to the main process with the ID of this worker, so that the
+ # main process won't send more tasks to this worker, and will send
+ # `None` to this worker to properly exit it.
+ #
+ # Note that we cannot set `done_event` from a worker as it is shared
+ # among all processes. Instead, we set the `iteration_end` flag to
+ # signify that the iterator is exhausted. When either `done_event` or
+ # `iteration_end` is set, we skip all processing step and just wait for
+ # `None`.
+ iteration_end = False
+
+ watchdog = ManagerWatchdog()
+
+ while watchdog.is_alive():
+ try:
+ r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
+ except queue.Empty:
+ continue
+ if isinstance(r, _ResumeIteration):
+ # Acknowledge the main process
+ data_queue.put((r, None))
+ iteration_end = False
+
+ if isinstance(dataset, IterDataPipe):
+ assert r.seed is not None
+ shared_rng.manual_seed(r.seed)
+ dataset = apply_random_seed(dataset, shared_rng)
+
+ # Recreate the fetcher for worker-reuse policy
+ fetcher = _DatasetKind.create_fetcher(
+ dataset_kind, dataset, auto_collation, collate_fn, drop_last
+ )
+ continue
+ elif r is None:
+ # Received the final signal
+ assert done_event.is_set() or iteration_end
+ break
+ elif done_event.is_set() or iteration_end:
+ # `done_event` is set. But I haven't received the final signal
+ # (None) yet. I will keep continuing until get it, and skip the
+ # processing steps.
+ continue
+ idx, index = r
+ """ Added """
+ RRSController.sample_resolution(batch_id=idx)
+ """ Added """
+ data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
+ if init_exception is not None:
+ data = init_exception
+ init_exception = None
+ else:
+ try:
+ data = fetcher.fetch(index)
+ except Exception as e:
+ if (
+ isinstance(e, StopIteration)
+ and dataset_kind == _DatasetKind.Iterable
+ ):
+ data = _IterableDatasetStopIteration(worker_id)
+ # Set `iteration_end`
+ # (1) to save future `next(...)` calls, and
+ # (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
+ iteration_end = True
+ else:
+ # It is important that we don't store exc_info in a variable.
+ # `ExceptionWrapper` does the correct thing.
+ # See NOTE [ Python Traceback Reference Cycle Problem ]
+ data = ExceptionWrapper(
+ where="in DataLoader worker process {}".format(worker_id)
+ )
+ data_queue.put((idx, data))
+ del data, idx, index, r # save memory
+ except KeyboardInterrupt:
+ # Main process will raise KeyboardInterrupt anyways.
+ pass
+ if done_event.is_set():
+ data_queue.cancel_join_thread()
+ data_queue.close()
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/controller.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/controller.py
new file mode 100644
index 0000000000000000000000000000000000000000..05d5c3643d379da055879bfaed5fbd5e1609eba6
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/controller.py
@@ -0,0 +1,94 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import copy
+
+import torch
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as F
+
+from efficientvit.models.utils import torch_random_choices
+
+__all__ = [
+ "RRSController",
+ "get_interpolate",
+ "MyRandomResizedCrop",
+]
+
+
+class RRSController:
+ ACTIVE_SIZE = (224, 224)
+ IMAGE_SIZE_LIST = [(224, 224)]
+
+ CHOICE_LIST = None
+
+ @staticmethod
+ def get_candidates() -> list[tuple[int, int]]:
+ return copy.deepcopy(RRSController.IMAGE_SIZE_LIST)
+
+ @staticmethod
+ def sample_resolution(batch_id: int) -> None:
+ RRSController.ACTIVE_SIZE = RRSController.CHOICE_LIST[batch_id]
+
+ @staticmethod
+ def set_epoch(epoch: int, batch_per_epoch: int) -> None:
+ g = torch.Generator()
+ g.manual_seed(epoch)
+ RRSController.CHOICE_LIST = torch_random_choices(
+ RRSController.get_candidates(),
+ g,
+ batch_per_epoch,
+ )
+
+
+def get_interpolate(name: str) -> F.InterpolationMode:
+ mapping = {
+ "nearest": F.InterpolationMode.NEAREST,
+ "bilinear": F.InterpolationMode.BILINEAR,
+ "bicubic": F.InterpolationMode.BICUBIC,
+ "box": F.InterpolationMode.BOX,
+ "hamming": F.InterpolationMode.HAMMING,
+ "lanczos": F.InterpolationMode.LANCZOS,
+ }
+ if name in mapping:
+ return mapping[name]
+ elif name == "random":
+ return torch_random_choices(
+ [
+ F.InterpolationMode.NEAREST,
+ F.InterpolationMode.BILINEAR,
+ F.InterpolationMode.BICUBIC,
+ F.InterpolationMode.BOX,
+ F.InterpolationMode.HAMMING,
+ F.InterpolationMode.LANCZOS,
+ ],
+ )
+ else:
+ raise NotImplementedError
+
+
+class MyRandomResizedCrop(transforms.RandomResizedCrop):
+ def __init__(
+ self,
+ scale=(0.08, 1.0),
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
+ interpolation: str = "random",
+ ):
+ super(MyRandomResizedCrop, self).__init__(224, scale, ratio)
+ self.interpolation = interpolation
+
+ def forward(self, img: torch.Tensor) -> torch.Tensor:
+ i, j, h, w = self.get_params(img, list(self.scale), list(self.ratio))
+ target_size = RRSController.ACTIVE_SIZE
+ return F.resized_crop(
+ img, i, j, h, w, list(target_size), get_interpolate(self.interpolation)
+ )
+
+ def __repr__(self) -> str:
+ format_string = self.__class__.__name__
+ format_string += f"(\n\tsize={RRSController.get_candidates()},\n"
+ format_string += f"\tscale={tuple(round(s, 4) for s in self.scale)},\n"
+ format_string += f"\tratio={tuple(round(r, 4) for r in self.ratio)},\n"
+ format_string += f"\tinterpolation={self.interpolation})"
+ return format_string
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/setup.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..d304ced6f128897cfd3564fbf12cddd317d1dbd1
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/setup.py
@@ -0,0 +1,141 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import os
+import time
+from copy import deepcopy
+
+import torch.backends.cudnn
+import torch.distributed
+import torch.nn as nn
+
+from efficientvit.apps.data_provider import DataProvider
+from efficientvit.apps.trainer.run_config import RunConfig
+from efficientvit.apps.utils import (dist_init, dump_config,
+ get_dist_local_rank, get_dist_rank,
+ get_dist_size, init_modules, is_master,
+ load_config, partial_update_config,
+ zero_last_gamma)
+from efficientvit.models.utils import (build_kwargs_from_config,
+ load_state_dict_from_file)
+
+__all__ = [
+ "save_exp_config",
+ "setup_dist_env",
+ "setup_seed",
+ "setup_exp_config",
+ "setup_data_provider",
+ "setup_run_config",
+ "init_model",
+]
+
+
+def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None:
+ if not is_master():
+ return
+ dump_config(exp_config, os.path.join(path, name))
+
+
+def setup_dist_env(gpu: str or None = None) -> None:
+ if gpu is not None:
+ os.environ["CUDA_VISIBLE_DEVICES"] = gpu
+ if not torch.distributed.is_initialized():
+ dist_init()
+ torch.backends.cudnn.benchmark = True
+ torch.cuda.set_device(get_dist_local_rank())
+
+
+def setup_seed(manual_seed: int, resume: bool) -> None:
+ if resume:
+ manual_seed = int(time.time())
+ manual_seed = get_dist_rank() + manual_seed
+ torch.manual_seed(manual_seed)
+ torch.cuda.manual_seed_all(manual_seed)
+
+
+def setup_exp_config(
+ config_path: str, recursive=True, opt_args: dict or None = None
+) -> dict:
+ # load config
+ if not os.path.isfile(config_path):
+ raise ValueError(config_path)
+
+ fpaths = [config_path]
+ if recursive:
+ extension = os.path.splitext(config_path)[1]
+ while os.path.dirname(config_path) != config_path:
+ config_path = os.path.dirname(config_path)
+ fpath = os.path.join(config_path, "default" + extension)
+ if os.path.isfile(fpath):
+ fpaths.append(fpath)
+ fpaths = fpaths[::-1]
+
+ default_config = load_config(fpaths[0])
+ exp_config = deepcopy(default_config)
+ for fpath in fpaths[1:]:
+ partial_update_config(exp_config, load_config(fpath))
+ # update config via args
+ if opt_args is not None:
+ partial_update_config(exp_config, opt_args)
+
+ return exp_config
+
+
+def setup_data_provider(
+ exp_config: dict,
+ data_provider_classes: list[type[DataProvider]],
+ is_distributed: bool = True,
+) -> DataProvider:
+ dp_config = exp_config["data_provider"]
+ dp_config["num_replicas"] = get_dist_size() if is_distributed else None
+ dp_config["rank"] = get_dist_rank() if is_distributed else None
+ dp_config["test_batch_size"] = (
+ dp_config.get("test_batch_size", None) or dp_config["base_batch_size"] * 2
+ )
+ dp_config["batch_size"] = dp_config["train_batch_size"] = dp_config[
+ "base_batch_size"
+ ]
+
+ data_provider_lookup = {
+ provider.name: provider for provider in data_provider_classes
+ }
+ data_provider_class = data_provider_lookup[dp_config["dataset"]]
+
+ data_provider_kwargs = build_kwargs_from_config(dp_config, data_provider_class)
+ data_provider = data_provider_class(**data_provider_kwargs)
+ return data_provider
+
+
+def setup_run_config(exp_config: dict, run_config_cls: type[RunConfig]) -> RunConfig:
+ exp_config["run_config"]["init_lr"] = (
+ exp_config["run_config"]["base_lr"] * get_dist_size()
+ )
+
+ run_config = run_config_cls(**exp_config["run_config"])
+
+ return run_config
+
+
+def init_model(
+ network: nn.Module,
+ init_from: str or None = None,
+ backbone_init_from: str or None = None,
+ rand_init="trunc_normal",
+ last_gamma=None,
+) -> None:
+ # initialization
+ init_modules(network, init_type=rand_init)
+ # zero gamma of last bn in each block
+ if last_gamma is not None:
+ zero_last_gamma(network, last_gamma)
+
+ # load weight
+ if init_from is not None and os.path.isfile(init_from):
+ network.load_state_dict(load_state_dict_from_file(init_from))
+ print(f"Loaded init from {init_from}")
+ elif backbone_init_from is not None and os.path.isfile(backbone_init_from):
+ network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from))
+ print(f"Loaded backbone init from {backbone_init_from}")
+ else:
+ print(f"Random init ({rand_init}) with last gamma {last_gamma}")
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__init__.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b9219c0c05c23e46926de0988c658b79b72388b
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__init__.py
@@ -0,0 +1,6 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+from .base import *
+from .run_config import *
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/__init__.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eba79e9c40f140e0983c1680d23f034e8bf0e867
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/__init__.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/base.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/base.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2842492674e076ce5f44479725f407dade85b7e8
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/base.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/run_config.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/run_config.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a1846bbca5a629341de43f10fbf0412802b082a
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/run_config.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/base.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfae1f3887583fa4562c47af4485b723c09f06c5
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/base.py
@@ -0,0 +1,297 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import os
+
+import torch
+import torch.nn as nn
+
+from efficientvit.apps.data_provider import DataProvider, parse_image_size
+from efficientvit.apps.trainer.run_config import RunConfig
+from efficientvit.apps.utils import (EMA, dist_barrier, get_dist_local_rank,
+ is_master)
+from efficientvit.models.nn.norm import reset_bn
+from efficientvit.models.utils import is_parallel, load_state_dict_from_file
+
+__all__ = ["Trainer"]
+
+
+class Trainer:
+ def __init__(self, path: str, model: nn.Module, data_provider: DataProvider):
+ self.path = os.path.realpath(os.path.expanduser(path))
+ self.model = model.cuda()
+ self.data_provider = data_provider
+
+ self.ema = None
+
+ self.checkpoint_path = os.path.join(self.path, "checkpoint")
+ self.logs_path = os.path.join(self.path, "logs")
+ for path in [self.path, self.checkpoint_path, self.logs_path]:
+ os.makedirs(path, exist_ok=True)
+
+ self.best_val = 0.0
+ self.start_epoch = 0
+
+ @property
+ def network(self) -> nn.Module:
+ return self.model.module if is_parallel(self.model) else self.model
+
+ @property
+ def eval_network(self) -> nn.Module:
+ if self.ema is None:
+ model = self.model
+ else:
+ model = self.ema.shadows
+ model = model.module if is_parallel(model) else model
+ return model
+
+ def write_log(self, log_str, prefix="valid", print_log=True, mode="a") -> None:
+ if is_master():
+ fout = open(os.path.join(self.logs_path, f"{prefix}.log"), mode)
+ fout.write(log_str + "\n")
+ fout.flush()
+ fout.close()
+ if print_log:
+ print(log_str)
+
+ def save_model(
+ self,
+ checkpoint=None,
+ only_state_dict=True,
+ epoch=0,
+ model_name=None,
+ ) -> None:
+ if is_master():
+ if checkpoint is None:
+ if only_state_dict:
+ checkpoint = {"state_dict": self.network.state_dict()}
+ else:
+ checkpoint = {
+ "state_dict": self.network.state_dict(),
+ "epoch": epoch,
+ "best_val": self.best_val,
+ "optimizer": self.optimizer.state_dict(),
+ "lr_scheduler": self.lr_scheduler.state_dict(),
+ "ema": self.ema.state_dict() if self.ema is not None else None,
+ "scaler": self.scaler.state_dict() if self.fp16 else None,
+ }
+
+ model_name = model_name or "checkpoint.pt"
+
+ latest_fname = os.path.join(self.checkpoint_path, "latest.txt")
+ model_path = os.path.join(self.checkpoint_path, model_name)
+ with open(latest_fname, "w") as _fout:
+ _fout.write(model_path + "\n")
+ torch.save(checkpoint, model_path)
+
+ def load_model(self, model_fname=None) -> None:
+ latest_fname = os.path.join(self.checkpoint_path, "latest.txt")
+ if model_fname is None and os.path.exists(latest_fname):
+ with open(latest_fname, "r") as fin:
+ model_fname = fin.readline()
+ if len(model_fname) > 0 and model_fname[-1] == "\n":
+ model_fname = model_fname[:-1]
+ try:
+ if model_fname is None:
+ model_fname = f"{self.checkpoint_path}/checkpoint.pt"
+ elif not os.path.exists(model_fname):
+ model_fname = f"{self.checkpoint_path}/{os.path.basename(model_fname)}"
+ if not os.path.exists(model_fname):
+ model_fname = f"{self.checkpoint_path}/checkpoint.pt"
+ print(f"=> loading checkpoint {model_fname}")
+ checkpoint = load_state_dict_from_file(model_fname, False)
+ except Exception:
+ self.write_log(f"fail to load checkpoint from {self.checkpoint_path}")
+ return
+
+ # load checkpoint
+ self.network.load_state_dict(checkpoint["state_dict"], strict=False)
+ log = []
+ if "epoch" in checkpoint:
+ self.start_epoch = checkpoint["epoch"] + 1
+ self.run_config.update_global_step(self.start_epoch)
+ log.append(f"epoch={self.start_epoch - 1}")
+ if "best_val" in checkpoint:
+ self.best_val = checkpoint["best_val"]
+ log.append(f"best_val={self.best_val:.2f}")
+ if "optimizer" in checkpoint:
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
+ log.append("optimizer")
+ if "lr_scheduler" in checkpoint:
+ self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
+ log.append("lr_scheduler")
+ if "ema" in checkpoint and self.ema is not None:
+ self.ema.load_state_dict(checkpoint["ema"])
+ log.append("ema")
+ if "scaler" in checkpoint and self.fp16:
+ self.scaler.load_state_dict(checkpoint["scaler"])
+ log.append("scaler")
+ self.write_log("Loaded: " + ", ".join(log))
+
+ """ validate """
+
+ def reset_bn(
+ self,
+ network: nn.Module or None = None,
+ subset_size: int = 16000,
+ subset_batch_size: int = 100,
+ data_loader=None,
+ progress_bar=False,
+ ) -> None:
+ network = network or self.network
+ if data_loader is None:
+ data_loader = []
+ for data in self.data_provider.build_sub_train_loader(
+ subset_size, subset_batch_size
+ ):
+ if isinstance(data, list):
+ data_loader.append(data[0])
+ elif isinstance(data, dict):
+ data_loader.append(data["data"])
+ elif isinstance(data, torch.Tensor):
+ data_loader.append(data)
+ else:
+ raise NotImplementedError
+
+ network.eval()
+ reset_bn(
+ network,
+ data_loader,
+ sync=True,
+ progress_bar=progress_bar,
+ )
+
+ def _validate(self, model, data_loader, epoch) -> dict[str, any]:
+ raise NotImplementedError
+
+ def validate(
+ self, model=None, data_loader=None, is_test=True, epoch=0
+ ) -> dict[str, any]:
+ model = model or self.eval_network
+ if data_loader is None:
+ if is_test:
+ data_loader = self.data_provider.test
+ else:
+ data_loader = self.data_provider.valid
+
+ model.eval()
+ return self._validate(model, data_loader, epoch)
+
+ def multires_validate(
+ self,
+ model=None,
+ data_loader=None,
+ is_test=True,
+ epoch=0,
+ eval_image_size=None,
+ ) -> dict[str, dict[str, any]]:
+ eval_image_size = eval_image_size or self.run_config.eval_image_size
+ eval_image_size = eval_image_size or self.data_provider.image_size
+ model = model or self.eval_network
+
+ if not isinstance(eval_image_size, list):
+ eval_image_size = [eval_image_size]
+
+ output_dict = {}
+ for r in eval_image_size:
+ self.data_provider.assign_active_image_size(parse_image_size(r))
+ if self.run_config.reset_bn:
+ self.reset_bn(
+ network=model,
+ subset_size=self.run_config.reset_bn_size,
+ subset_batch_size=self.run_config.reset_bn_batch_size,
+ progress_bar=True,
+ )
+ output_dict[f"r{r}"] = self.validate(model, data_loader, is_test, epoch)
+ return output_dict
+
+ """ training """
+
+ def prep_for_training(
+ self, run_config: RunConfig, ema_decay: float or None = None, fp16=False
+ ) -> None:
+ self.run_config = run_config
+ self.model = nn.parallel.DistributedDataParallel(
+ self.model.cuda(),
+ device_ids=[get_dist_local_rank()],
+ static_graph=True,
+ )
+
+ self.run_config.global_step = 0
+ self.run_config.batch_per_epoch = len(self.data_provider.train)
+ assert self.run_config.batch_per_epoch > 0, "Training set is empty"
+
+ # build optimizer
+ self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model)
+
+ if ema_decay is not None:
+ self.ema = EMA(self.network, ema_decay)
+
+ # fp16
+ self.fp16 = fp16
+ self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)
+
+ def sync_model(self):
+ print("Sync model")
+ self.save_model(model_name="sync.pt")
+ dist_barrier()
+ checkpoint = torch.load(
+ os.path.join(self.checkpoint_path, "sync.pt"), map_location="cpu"
+ )
+ dist_barrier()
+ if is_master():
+ os.remove(os.path.join(self.checkpoint_path, "sync.pt"))
+ dist_barrier()
+
+ # load checkpoint
+ self.network.load_state_dict(checkpoint["state_dict"], strict=False)
+ if "optimizer" in checkpoint:
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
+ if "lr_scheduler" in checkpoint:
+ self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
+ if "ema" in checkpoint and self.ema is not None:
+ self.ema.load_state_dict(checkpoint["ema"])
+ if "scaler" in checkpoint and self.fp16:
+ self.scaler.load_state_dict(checkpoint["scaler"])
+
+ def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
+ for key in feed_dict:
+ if isinstance(feed_dict[key], torch.Tensor):
+ feed_dict[key] = feed_dict[key].cuda()
+ return feed_dict
+
+ def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
+ raise NotImplementedError
+
+ def after_step(self) -> None:
+ self.scaler.unscale_(self.optimizer)
+ # gradient clip
+ if self.run_config.grad_clip is not None:
+ torch.nn.utils.clip_grad_value_(
+ self.model.parameters(), self.run_config.grad_clip
+ )
+ # update
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+
+ self.lr_scheduler.step()
+ self.run_config.step()
+ # update ema
+ if self.ema is not None:
+ self.ema.step(self.network, self.run_config.global_step)
+
+ def _train_one_epoch(self, epoch: int) -> dict[str, any]:
+ raise NotImplementedError
+
+ def train_one_epoch(self, epoch: int) -> dict[str, any]:
+ self.model.train()
+
+ self.data_provider.set_epoch(epoch)
+
+ train_info_dict = self._train_one_epoch(epoch)
+
+ return train_info_dict
+
+ def train(self) -> None:
+ raise NotImplementedError
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/run_config.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/run_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..da381d9cc9edd120eaa58680c0abd9aa45de8718
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/run_config.py
@@ -0,0 +1,121 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import json
+
+import numpy as np
+import torch.nn as nn
+
+from efficientvit.apps.utils import CosineLRwithWarmup, build_optimizer
+
+__all__ = ["Scheduler", "RunConfig"]
+
+
+class Scheduler:
+ PROGRESS = 0
+
+
+class RunConfig:
+ n_epochs: int
+ init_lr: float
+ warmup_epochs: int
+ warmup_lr: float
+ lr_schedule_name: str
+ lr_schedule_param: dict
+ optimizer_name: str
+ optimizer_params: dict
+ weight_decay: float
+ no_wd_keys: list
+ grad_clip: float # allow none to turn off grad clipping
+ reset_bn: bool
+ reset_bn_size: int
+ reset_bn_batch_size: int
+ eval_image_size: list # allow none to use image_size in data_provider
+
+ @property
+ def none_allowed(self):
+ return ["grad_clip", "eval_image_size"]
+
+ def __init__(self, **kwargs): # arguments must be passed as kwargs
+ for k, val in kwargs.items():
+ setattr(self, k, val)
+
+ # check that all relevant configs are there
+ annotations = {}
+ for clas in type(self).mro():
+ if hasattr(clas, "__annotations__"):
+ annotations.update(clas.__annotations__)
+ for k, k_type in annotations.items():
+ assert hasattr(
+ self, k
+ ), f"Key {k} with type {k_type} required for initialization."
+ attr = getattr(self, k)
+ if k in self.none_allowed:
+ k_type = (k_type, type(None))
+ assert isinstance(
+ attr, k_type
+ ), f"Key {k} must be type {k_type}, provided={attr}."
+
+ self.global_step = 0
+ self.batch_per_epoch = 1
+
+ def build_optimizer(self, network: nn.Module) -> tuple[any, any]:
+ r"""require setting 'batch_per_epoch' before building optimizer & lr_scheduler"""
+ param_dict = {}
+ for name, param in network.named_parameters():
+ if param.requires_grad:
+ opt_config = [self.weight_decay, self.init_lr]
+ if self.no_wd_keys is not None and len(self.no_wd_keys) > 0:
+ if np.any([key in name for key in self.no_wd_keys]):
+ opt_config[0] = 0
+ opt_key = json.dumps(opt_config)
+ param_dict[opt_key] = param_dict.get(opt_key, []) + [param]
+
+ net_params = []
+ for opt_key, param_list in param_dict.items():
+ wd, lr = json.loads(opt_key)
+ net_params.append({"params": param_list, "weight_decay": wd, "lr": lr})
+
+ optimizer = build_optimizer(
+ net_params, self.optimizer_name, self.optimizer_params, self.init_lr
+ )
+ # build lr scheduler
+ if self.lr_schedule_name == "cosine":
+ decay_steps = []
+ for epoch in self.lr_schedule_param.get("step", []):
+ decay_steps.append(epoch * self.batch_per_epoch)
+ decay_steps.append(self.n_epochs * self.batch_per_epoch)
+ decay_steps.sort()
+ lr_scheduler = CosineLRwithWarmup(
+ optimizer,
+ self.warmup_epochs * self.batch_per_epoch,
+ self.warmup_lr,
+ decay_steps,
+ )
+ else:
+ raise NotImplementedError
+ return optimizer, lr_scheduler
+
+ def update_global_step(self, epoch, batch_id=0) -> None:
+ self.global_step = epoch * self.batch_per_epoch + batch_id
+ Scheduler.PROGRESS = self.progress
+
+ @property
+ def progress(self) -> float:
+ warmup_steps = self.warmup_epochs * self.batch_per_epoch
+ steps = max(0, self.global_step - warmup_steps)
+ return steps / (self.n_epochs * self.batch_per_epoch)
+
+ def step(self) -> None:
+ self.global_step += 1
+ Scheduler.PROGRESS = self.progress
+
+ def get_remaining_epoch(self, epoch, post=True) -> int:
+ return self.n_epochs + self.warmup_epochs - epoch - int(post)
+
+ def epoch_format(self, epoch: int) -> str:
+ epoch_format = f"%.{len(str(self.n_epochs))}d"
+ epoch_format = f"[{epoch_format}/{epoch_format}]"
+ epoch_format = epoch_format % (epoch + 1 - self.warmup_epochs, self.n_epochs)
+ return epoch_format
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__init__.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c826a22544285746c588741f3f20fbe3802ccd50
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__init__.py
@@ -0,0 +1,12 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+from .dist import *
+from .ema import *
+from .export import *
+from .init import *
+from .lr import *
+from .metric import *
+from .misc import *
+from .opt import *
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/__init__.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..34751a0fde007867c4b663071746aa0df1824789
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/__init__.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/dist.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/dist.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d74b80e3cb73eadc3a08d383cc1aa29916f2e5cf
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/dist.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/ema.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/ema.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c17971e7c6639a1427a61d58f4b6d5ff60329366
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/ema.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/export.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/export.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3c7378dc7b7935264ed5c13282eb62ae698e6c7
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/export.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/init.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/init.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b11dd2a00e699f7697f1c814ebe64fc759cf1abd
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/init.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/lr.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/lr.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0371462513b48b5e6d8e7bd4204e0ebb7356c946
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/lr.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/metric.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/metric.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df347956f9e4a4dd0f68b9433c50520f1a0c8903
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/metric.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/misc.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/misc.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd23c50365587432d40a998fdd225381d3e2a2c6
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/misc.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/opt.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/opt.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c0382e8cce3154d30f27a8bf7be7bc1920f0e3d9
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/opt.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/dist.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/dist.py
new file mode 100644
index 0000000000000000000000000000000000000000..796feb88bd8721db90ecbc35f63aeb21991d8df0
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/dist.py
@@ -0,0 +1,73 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import os
+
+import torch
+import torch.distributed
+
+from efficientvit.models.utils.list import list_mean, list_sum
+
+__all__ = [
+ "dist_init",
+ "get_dist_rank",
+ "get_dist_size",
+ "is_master",
+ "dist_barrier",
+ "get_dist_local_rank",
+ "sync_tensor",
+]
+
+
+def dist_init() -> None:
+ try:
+ torch.distributed.init_process_group(backend="nccl")
+ assert torch.distributed.is_initialized()
+ except Exception:
+ # use torchpack
+ from torchpack import distributed as dist
+
+ dist.init()
+ os.environ["RANK"] = f"{dist.rank()}"
+ os.environ["WORLD_SIZE"] = f"{dist.size()}"
+ os.environ["LOCAL_RANK"] = f"{dist.local_rank()}"
+
+
+def get_dist_rank() -> int:
+ return int(os.environ["RANK"])
+
+
+def get_dist_size() -> int:
+ return int(os.environ["WORLD_SIZE"])
+
+
+def is_master() -> bool:
+ return get_dist_rank() == 0
+
+
+def dist_barrier() -> None:
+ torch.distributed.barrier()
+
+
+def get_dist_local_rank() -> int:
+ return int(os.environ["LOCAL_RANK"])
+
+
+def sync_tensor(
+ tensor: torch.Tensor or float, reduce="mean"
+) -> torch.Tensor or list[torch.Tensor]:
+ if not isinstance(tensor, torch.Tensor):
+ tensor = torch.Tensor(1).fill_(tensor).cuda()
+ tensor_list = [torch.empty_like(tensor) for _ in range(get_dist_size())]
+ torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False)
+ if reduce == "mean":
+ return list_mean(tensor_list)
+ elif reduce == "sum":
+ return list_sum(tensor_list)
+ elif reduce == "cat":
+ return torch.cat(tensor_list, dim=0)
+ elif reduce == "root":
+ return tensor_list[0]
+ else:
+ return tensor_list
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/ema.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..95471ed5c27d7e895c880ab1d980685fc0676413
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/ema.py
@@ -0,0 +1,50 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import copy
+import math
+
+import torch
+import torch.nn as nn
+
+from efficientvit.models.utils import is_parallel
+
+__all__ = ["EMA"]
+
+
+def update_ema(
+ ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float
+) -> None:
+ for k, v in ema.state_dict().items():
+ if v.dtype.is_floating_point:
+ v -= (1.0 - decay) * (v - new_state_dict[k].detach())
+
+
+class EMA:
+ def __init__(self, model: nn.Module, decay: float, warmup_steps=2000):
+ self.shadows = copy.deepcopy(
+ model.module if is_parallel(model) else model
+ ).eval()
+ self.decay = decay
+ self.warmup_steps = warmup_steps
+
+ for p in self.shadows.parameters():
+ p.requires_grad = False
+
+ def step(self, model: nn.Module, global_step: int) -> None:
+ with torch.no_grad():
+ msd = (model.module if is_parallel(model) else model).state_dict()
+ update_ema(
+ self.shadows,
+ msd,
+ self.decay * (1 - math.exp(-global_step / self.warmup_steps)),
+ )
+
+ def state_dict(self) -> dict[float, dict[str, torch.Tensor]]:
+ return {self.decay: self.shadows.state_dict()}
+
+ def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None:
+ for decay in state_dict:
+ if decay == self.decay:
+ self.shadows.load_state_dict(state_dict[decay])
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/export.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/export.py
new file mode 100644
index 0000000000000000000000000000000000000000..d611f957a6ff22b98210d611e7344426e091d3df
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/export.py
@@ -0,0 +1,47 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import io
+import os
+
+import onnx
+import torch
+import torch.nn as nn
+from onnxsim import simplify as simplify_func
+
+__all__ = ["export_onnx"]
+
+
+def export_onnx(
+ model: nn.Module, export_path: str, sample_inputs: any, simplify=True, opset=11
+) -> None:
+ """Export a model to a platform-specific onnx format.
+
+ Args:
+ model: a torch.nn.Module object.
+ export_path: export location.
+ sample_inputs: Any.
+ simplify: a flag to turn on onnx-simplifier
+ opset: int
+ """
+ model.eval()
+
+ buffer = io.BytesIO()
+ with torch.no_grad():
+ torch.onnx.export(model, sample_inputs, buffer, opset_version=opset)
+ buffer.seek(0, 0)
+ if simplify:
+ onnx_model = onnx.load_model(buffer)
+ onnx_model, success = simplify_func(onnx_model)
+ assert success
+ new_buffer = io.BytesIO()
+ onnx.save(onnx_model, new_buffer)
+ buffer = new_buffer
+ buffer.seek(0, 0)
+
+ if buffer.getbuffer().nbytes > 0:
+ save_dir = os.path.dirname(export_path)
+ os.makedirs(save_dir, exist_ok=True)
+ with open(export_path, "wb") as f:
+ f.write(buffer.read())
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/init.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/init.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d2ebe26ff45a7ee1de614a39e0db24198097152
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/init.py
@@ -0,0 +1,68 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.batchnorm import _BatchNorm
+
+__all__ = ["init_modules", "zero_last_gamma"]
+
+
+def init_modules(model: nn.Module or list[nn.Module], init_type="trunc_normal") -> None:
+ _DEFAULT_INIT_PARAM = {"trunc_normal": 0.02}
+
+ if isinstance(model, list):
+ for sub_module in model:
+ init_modules(sub_module, init_type)
+ else:
+ init_params = init_type.split("@")
+ init_params = float(init_params[1]) if len(init_params) > 1 else None
+
+ if init_type.startswith("trunc_normal"):
+ init_func = lambda param: nn.init.trunc_normal_(
+ param, std=(init_params or _DEFAULT_INIT_PARAM["trunc_normal"])
+ )
+ else:
+ raise NotImplementedError
+
+ for m in model.modules():
+ if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
+ init_func(m.weight)
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Embedding):
+ init_func(m.weight)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ else:
+ weight = getattr(m, "weight", None)
+ bias = getattr(m, "bias", None)
+ if isinstance(weight, torch.nn.Parameter):
+ init_func(weight)
+ if isinstance(bias, torch.nn.Parameter):
+ bias.data.zero_()
+
+
+def zero_last_gamma(model: nn.Module, init_val=0) -> None:
+ import efficientvit.models.nn.ops as ops
+
+ for m in model.modules():
+ if isinstance(m, ops.ResidualBlock) and isinstance(
+ m.shortcut, ops.IdentityLayer
+ ):
+ if isinstance(m.main, (ops.DSConv, ops.MBConv, ops.FusedMBConv)):
+ parent_module = m.main.point_conv
+ elif isinstance(m.main, ops.ResBlock):
+ parent_module = m.main.conv2
+ elif isinstance(m.main, ops.ConvLayer):
+ parent_module = m.main
+ elif isinstance(m.main, (ops.LiteMLA)):
+ parent_module = m.main.proj
+ else:
+ parent_module = None
+ if parent_module is not None:
+ norm = getattr(parent_module, "norm", None)
+ if norm is not None:
+ nn.init.constant_(norm.weight, init_val)
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/lr.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/lr.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1fc360307b11d8a9cde74ab35672fcf7272dfb1
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/lr.py
@@ -0,0 +1,48 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import math
+
+import torch
+
+from efficientvit.models.utils.list import val2list
+
+__all__ = ["CosineLRwithWarmup"]
+
+
+class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ warmup_steps: int,
+ warmup_lr: float,
+ decay_steps: int or list[int],
+ last_epoch: int = -1,
+ ) -> None:
+ self.warmup_steps = warmup_steps
+ self.warmup_lr = warmup_lr
+ self.decay_steps = val2list(decay_steps)
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self) -> list[float]:
+ if self.last_epoch < self.warmup_steps:
+ return [
+ (base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps
+ + self.warmup_lr
+ for base_lr in self.base_lrs
+ ]
+ else:
+ current_steps = self.last_epoch - self.warmup_steps
+ decay_steps = [0] + self.decay_steps
+ idx = len(decay_steps) - 2
+ for i, decay_step in enumerate(decay_steps[:-1]):
+ if decay_step <= current_steps < decay_steps[i + 1]:
+ idx = i
+ break
+ current_steps -= decay_steps[idx]
+ decay_step = decay_steps[idx + 1] - decay_steps[idx]
+ return [
+ 0.5 * base_lr * (1 + math.cos(math.pi * current_steps / decay_step))
+ for base_lr in self.base_lrs
+ ]
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/metric.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..f20fae35cb8c99fdaf1b1e3bcdda02e9dd9d39e8
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/metric.py
@@ -0,0 +1,37 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import torch
+
+from efficientvit.apps.utils.dist import sync_tensor
+
+__all__ = ["AverageMeter"]
+
+
+class AverageMeter:
+ """Computes and stores the average and current value."""
+
+ def __init__(self, is_distributed=True):
+ self.is_distributed = is_distributed
+ self.sum = 0
+ self.count = 0
+
+ def _sync(self, val: torch.Tensor or int or float) -> torch.Tensor or int or float:
+ return sync_tensor(val, reduce="sum") if self.is_distributed else val
+
+ def update(self, val: torch.Tensor or int or float, delta_n=1):
+ self.count += self._sync(delta_n)
+ self.sum += self._sync(val * delta_n)
+
+ def get_count(self) -> torch.Tensor or int or float:
+ return (
+ self.count.item()
+ if isinstance(self.count, torch.Tensor) and self.count.numel() == 1
+ else self.count
+ )
+
+ @property
+ def avg(self):
+ avg = -1 if self.count == 0 else self.sum / self.count
+ return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/misc.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..111b6618ab20bd02b5b6d8785091122c82fc8a24
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/misc.py
@@ -0,0 +1,111 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import os
+
+import yaml
+
+__all__ = [
+ "parse_with_yaml",
+ "parse_unknown_args",
+ "partial_update_config",
+ "resolve_and_load_config",
+ "load_config",
+ "dump_config",
+]
+
+
+def parse_with_yaml(config_str: str) -> str or dict:
+ try:
+ # add space manually for dict
+ if "{" in config_str and "}" in config_str and ":" in config_str:
+ out_str = config_str.replace(":", ": ")
+ else:
+ out_str = config_str
+ return yaml.safe_load(out_str)
+ except ValueError:
+ # return raw string if parsing fails
+ return config_str
+
+
+def parse_unknown_args(unknown: list) -> dict:
+ """Parse unknown args."""
+ index = 0
+ parsed_dict = {}
+ while index < len(unknown):
+ key, val = unknown[index], unknown[index + 1]
+ index += 2
+ if not key.startswith("--"):
+ continue
+ key = key[2:]
+
+ # try parsing with either dot notation or full yaml notation
+ # Note that the vanilla case "--key value" will be parsed the same
+ if "." in key:
+ # key == a.b.c, val == val --> parsed_dict[a][b][c] = val
+ keys = key.split(".")
+ dict_to_update = parsed_dict
+ for key in keys[:-1]:
+ if not (
+ key in dict_to_update and isinstance(dict_to_update[key], dict)
+ ):
+ dict_to_update[key] = {}
+ dict_to_update = dict_to_update[key]
+ dict_to_update[keys[-1]] = parse_with_yaml(
+ val
+ ) # so we can parse lists, bools, etc...
+ else:
+ parsed_dict[key] = parse_with_yaml(val)
+ return parsed_dict
+
+
+def partial_update_config(config: dict, partial_config: dict) -> dict:
+ for key in partial_config:
+ if (
+ key in config
+ and isinstance(partial_config[key], dict)
+ and isinstance(config[key], dict)
+ ):
+ partial_update_config(config[key], partial_config[key])
+ else:
+ config[key] = partial_config[key]
+ return config
+
+
+def resolve_and_load_config(path: str, config_name="config.yaml") -> dict:
+ path = os.path.realpath(os.path.expanduser(path))
+ if os.path.isdir(path):
+ config_path = os.path.join(path, config_name)
+ else:
+ config_path = path
+ if os.path.isfile(config_path):
+ pass
+ else:
+ raise Exception(f"Cannot find a valid config at {path}")
+ config = load_config(config_path)
+ return config
+
+
+class SafeLoaderWithTuple(yaml.SafeLoader):
+ """A yaml safe loader with python tuple loading capabilities."""
+
+ def construct_python_tuple(self, node):
+ return tuple(self.construct_sequence(node))
+
+
+SafeLoaderWithTuple.add_constructor(
+ "tag:yaml.org,2002:python/tuple", SafeLoaderWithTuple.construct_python_tuple
+)
+
+
+def load_config(filename: str) -> dict:
+ """Load a yaml file."""
+ filename = os.path.realpath(os.path.expanduser(filename))
+ return yaml.load(open(filename), Loader=SafeLoaderWithTuple)
+
+
+def dump_config(config: dict, filename: str) -> None:
+ """Dump a config file"""
+ filename = os.path.realpath(os.path.expanduser(filename))
+ yaml.dump(config, open(filename, "w"), sort_keys=False)
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/opt.py b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/opt.py
new file mode 100644
index 0000000000000000000000000000000000000000..79a03507c8b0aa8ad6e7210657630d5af6555521
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/apps/utils/opt.py
@@ -0,0 +1,31 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import torch
+
+__all__ = ["REGISTERED_OPTIMIZER_DICT", "build_optimizer"]
+
+# register optimizer here
+# name: optimizer, kwargs with default values
+REGISTERED_OPTIMIZER_DICT: dict[str, tuple[type, dict[str, any]]] = {
+ "sgd": (torch.optim.SGD, {"momentum": 0.9, "nesterov": True}),
+ "adam": (torch.optim.Adam, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}),
+ "adamw": (
+ torch.optim.AdamW,
+ {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False},
+ ),
+}
+
+
+def build_optimizer(
+ net_params, optimizer_name: str, optimizer_params: dict or None, init_lr: float
+) -> torch.optim.Optimizer:
+ optimizer_class, default_params = REGISTERED_OPTIMIZER_DICT[optimizer_name]
+ optimizer_params = optimizer_params or {}
+
+ for key in default_params:
+ if key in optimizer_params:
+ default_params[key] = optimizer_params[key]
+ optimizer = optimizer_class(net_params, init_lr, **default_params)
+ return optimizer
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/__init__.py b/yolo-world-with-efficientvit-sam/efficientvit/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/__pycache__/__init__.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/models/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e87a9358b5b57d233e4a025e2269ac90a1bbaff3
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/models/__pycache__/__init__.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/__init__.py b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cea677f24763b605249c05ea37483b579c507cbc
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/__init__.py
@@ -0,0 +1,8 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+from .backbone import *
+from .cls import *
+from .sam import *
+from .seg import *
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/__pycache__/__init__.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..794ceed6efd5979648c852fd05db86db7093f8c2
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/__pycache__/__init__.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/__pycache__/backbone.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/__pycache__/backbone.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea23c550f9e793b3eb1c20683aa1048bc605d3c0
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/__pycache__/backbone.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/__pycache__/cls.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/__pycache__/cls.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d12a7a4009a70b1345a09a1be8200a8bed01a907
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/__pycache__/cls.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/__pycache__/sam.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/__pycache__/sam.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e75061479e34f5b9a957dc66038893b394dd3ef
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/__pycache__/sam.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/__pycache__/seg.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/__pycache__/seg.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a203d18d9644b8f489f193d4db0db94b11f52a9d
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/__pycache__/seg.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/backbone.py b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..40a8052ca53b1a8378c934a5140953e5fe2cf1ca
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/backbone.py
@@ -0,0 +1,372 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import torch
+import torch.nn as nn
+
+from efficientvit.models.nn import (ConvLayer, DSConv, EfficientViTBlock,
+ FusedMBConv, IdentityLayer, MBConv,
+ OpSequential, ResBlock, ResidualBlock)
+from efficientvit.models.utils import build_kwargs_from_config
+
+__all__ = [
+ "EfficientViTBackbone",
+ "efficientvit_backbone_b0",
+ "efficientvit_backbone_b1",
+ "efficientvit_backbone_b2",
+ "efficientvit_backbone_b3",
+ "EfficientViTLargeBackbone",
+ "efficientvit_backbone_l0",
+ "efficientvit_backbone_l1",
+ "efficientvit_backbone_l2",
+ "efficientvit_backbone_l3",
+]
+
+
+class EfficientViTBackbone(nn.Module):
+ def __init__(
+ self,
+ width_list: list[int],
+ depth_list: list[int],
+ in_channels=3,
+ dim=32,
+ expand_ratio=4,
+ norm="bn2d",
+ act_func="hswish",
+ ) -> None:
+ super().__init__()
+
+ self.width_list = []
+ # input stem
+ self.input_stem = [
+ ConvLayer(
+ in_channels=3,
+ out_channels=width_list[0],
+ stride=2,
+ norm=norm,
+ act_func=act_func,
+ )
+ ]
+ for _ in range(depth_list[0]):
+ block = self.build_local_block(
+ in_channels=width_list[0],
+ out_channels=width_list[0],
+ stride=1,
+ expand_ratio=1,
+ norm=norm,
+ act_func=act_func,
+ )
+ self.input_stem.append(ResidualBlock(block, IdentityLayer()))
+ in_channels = width_list[0]
+ self.input_stem = OpSequential(self.input_stem)
+ self.width_list.append(in_channels)
+
+ # stages
+ self.stages = []
+ for w, d in zip(width_list[1:3], depth_list[1:3]):
+ stage = []
+ for i in range(d):
+ stride = 2 if i == 0 else 1
+ block = self.build_local_block(
+ in_channels=in_channels,
+ out_channels=w,
+ stride=stride,
+ expand_ratio=expand_ratio,
+ norm=norm,
+ act_func=act_func,
+ )
+ block = ResidualBlock(block, IdentityLayer() if stride == 1 else None)
+ stage.append(block)
+ in_channels = w
+ self.stages.append(OpSequential(stage))
+ self.width_list.append(in_channels)
+
+ for w, d in zip(width_list[3:], depth_list[3:]):
+ stage = []
+ block = self.build_local_block(
+ in_channels=in_channels,
+ out_channels=w,
+ stride=2,
+ expand_ratio=expand_ratio,
+ norm=norm,
+ act_func=act_func,
+ fewer_norm=True,
+ )
+ stage.append(ResidualBlock(block, None))
+ in_channels = w
+
+ for _ in range(d):
+ stage.append(
+ EfficientViTBlock(
+ in_channels=in_channels,
+ dim=dim,
+ expand_ratio=expand_ratio,
+ norm=norm,
+ act_func=act_func,
+ )
+ )
+ self.stages.append(OpSequential(stage))
+ self.width_list.append(in_channels)
+ self.stages = nn.ModuleList(self.stages)
+
+ @staticmethod
+ def build_local_block(
+ in_channels: int,
+ out_channels: int,
+ stride: int,
+ expand_ratio: float,
+ norm: str,
+ act_func: str,
+ fewer_norm: bool = False,
+ ) -> nn.Module:
+ if expand_ratio == 1:
+ block = DSConv(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=stride,
+ use_bias=(True, False) if fewer_norm else False,
+ norm=(None, norm) if fewer_norm else norm,
+ act_func=(act_func, None),
+ )
+ else:
+ block = MBConv(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=stride,
+ expand_ratio=expand_ratio,
+ use_bias=(True, True, False) if fewer_norm else False,
+ norm=(None, None, norm) if fewer_norm else norm,
+ act_func=(act_func, act_func, None),
+ )
+ return block
+
+ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
+ output_dict = {"input": x}
+ output_dict["stage0"] = x = self.input_stem(x)
+ for stage_id, stage in enumerate(self.stages, 1):
+ output_dict["stage%d" % stage_id] = x = stage(x)
+ output_dict["stage_final"] = x
+ return output_dict
+
+
+def efficientvit_backbone_b0(**kwargs) -> EfficientViTBackbone:
+ backbone = EfficientViTBackbone(
+ width_list=[8, 16, 32, 64, 128],
+ depth_list=[1, 2, 2, 2, 2],
+ dim=16,
+ **build_kwargs_from_config(kwargs, EfficientViTBackbone),
+ )
+ return backbone
+
+
+def efficientvit_backbone_b1(**kwargs) -> EfficientViTBackbone:
+ backbone = EfficientViTBackbone(
+ width_list=[16, 32, 64, 128, 256],
+ depth_list=[1, 2, 3, 3, 4],
+ dim=16,
+ **build_kwargs_from_config(kwargs, EfficientViTBackbone),
+ )
+ return backbone
+
+
+def efficientvit_backbone_b2(**kwargs) -> EfficientViTBackbone:
+ backbone = EfficientViTBackbone(
+ width_list=[24, 48, 96, 192, 384],
+ depth_list=[1, 3, 4, 4, 6],
+ dim=32,
+ **build_kwargs_from_config(kwargs, EfficientViTBackbone),
+ )
+ return backbone
+
+
+def efficientvit_backbone_b3(**kwargs) -> EfficientViTBackbone:
+ backbone = EfficientViTBackbone(
+ width_list=[32, 64, 128, 256, 512],
+ depth_list=[1, 4, 6, 6, 9],
+ dim=32,
+ **build_kwargs_from_config(kwargs, EfficientViTBackbone),
+ )
+ return backbone
+
+
+class EfficientViTLargeBackbone(nn.Module):
+ def __init__(
+ self,
+ width_list: list[int],
+ depth_list: list[int],
+ block_list: list[str] or None = None,
+ expand_list: list[float] or None = None,
+ fewer_norm_list: list[bool] or None = None,
+ in_channels=3,
+ qkv_dim=32,
+ norm="bn2d",
+ act_func="gelu",
+ ) -> None:
+ super().__init__()
+ block_list = block_list or ["res", "fmb", "fmb", "mb", "att"]
+ expand_list = expand_list or [1, 4, 4, 4, 6]
+ fewer_norm_list = fewer_norm_list or [False, False, False, True, True]
+
+ self.width_list = []
+ self.stages = []
+ # stage 0
+ stage0 = [
+ ConvLayer(
+ in_channels=3,
+ out_channels=width_list[0],
+ stride=2,
+ norm=norm,
+ act_func=act_func,
+ )
+ ]
+ for _ in range(depth_list[0]):
+ block = self.build_local_block(
+ block=block_list[0],
+ in_channels=width_list[0],
+ out_channels=width_list[0],
+ stride=1,
+ expand_ratio=expand_list[0],
+ norm=norm,
+ act_func=act_func,
+ fewer_norm=fewer_norm_list[0],
+ )
+ stage0.append(ResidualBlock(block, IdentityLayer()))
+ in_channels = width_list[0]
+ self.stages.append(OpSequential(stage0))
+ self.width_list.append(in_channels)
+
+ for stage_id, (w, d) in enumerate(zip(width_list[1:], depth_list[1:]), start=1):
+ stage = []
+ block = self.build_local_block(
+ block=(
+ "mb"
+ if block_list[stage_id] not in ["mb", "fmb"]
+ else block_list[stage_id]
+ ),
+ in_channels=in_channels,
+ out_channels=w,
+ stride=2,
+ expand_ratio=expand_list[stage_id] * 4,
+ norm=norm,
+ act_func=act_func,
+ fewer_norm=fewer_norm_list[stage_id],
+ )
+ stage.append(ResidualBlock(block, None))
+ in_channels = w
+
+ for _ in range(d):
+ if block_list[stage_id].startswith("att"):
+ stage.append(
+ EfficientViTBlock(
+ in_channels=in_channels,
+ dim=qkv_dim,
+ expand_ratio=expand_list[stage_id],
+ scales=(3,) if block_list[stage_id] == "att@3" else (5,),
+ norm=norm,
+ act_func=act_func,
+ )
+ )
+ else:
+ block = self.build_local_block(
+ block=block_list[stage_id],
+ in_channels=in_channels,
+ out_channels=in_channels,
+ stride=1,
+ expand_ratio=expand_list[stage_id],
+ norm=norm,
+ act_func=act_func,
+ fewer_norm=fewer_norm_list[stage_id],
+ )
+ block = ResidualBlock(block, IdentityLayer())
+ stage.append(block)
+ self.stages.append(OpSequential(stage))
+ self.width_list.append(in_channels)
+ self.stages = nn.ModuleList(self.stages)
+
+ @staticmethod
+ def build_local_block(
+ block: str,
+ in_channels: int,
+ out_channels: int,
+ stride: int,
+ expand_ratio: float,
+ norm: str,
+ act_func: str,
+ fewer_norm: bool = False,
+ ) -> nn.Module:
+ if block == "res":
+ block = ResBlock(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=stride,
+ use_bias=(True, False) if fewer_norm else False,
+ norm=(None, norm) if fewer_norm else norm,
+ act_func=(act_func, None),
+ )
+ elif block == "fmb":
+ block = FusedMBConv(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=stride,
+ expand_ratio=expand_ratio,
+ use_bias=(True, False) if fewer_norm else False,
+ norm=(None, norm) if fewer_norm else norm,
+ act_func=(act_func, None),
+ )
+ elif block == "mb":
+ block = MBConv(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=stride,
+ expand_ratio=expand_ratio,
+ use_bias=(True, True, False) if fewer_norm else False,
+ norm=(None, None, norm) if fewer_norm else norm,
+ act_func=(act_func, act_func, None),
+ )
+ else:
+ raise ValueError(block)
+ return block
+
+ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
+ output_dict = {"input": x}
+ for stage_id, stage in enumerate(self.stages):
+ output_dict["stage%d" % stage_id] = x = stage(x)
+ output_dict["stage_final"] = x
+ return output_dict
+
+
+def efficientvit_backbone_l0(**kwargs) -> EfficientViTLargeBackbone:
+ backbone = EfficientViTLargeBackbone(
+ width_list=[32, 64, 128, 256, 512],
+ depth_list=[1, 1, 1, 4, 4],
+ **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
+ )
+ return backbone
+
+
+def efficientvit_backbone_l1(**kwargs) -> EfficientViTLargeBackbone:
+ backbone = EfficientViTLargeBackbone(
+ width_list=[32, 64, 128, 256, 512],
+ depth_list=[1, 1, 1, 6, 6],
+ **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
+ )
+ return backbone
+
+
+def efficientvit_backbone_l2(**kwargs) -> EfficientViTLargeBackbone:
+ backbone = EfficientViTLargeBackbone(
+ width_list=[32, 64, 128, 256, 512],
+ depth_list=[1, 2, 2, 8, 8],
+ **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
+ )
+ return backbone
+
+
+def efficientvit_backbone_l3(**kwargs) -> EfficientViTLargeBackbone:
+ backbone = EfficientViTLargeBackbone(
+ width_list=[64, 128, 256, 512, 1024],
+ depth_list=[1, 2, 2, 8, 8],
+ **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
+ )
+ return backbone
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/cls.py b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/cls.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d73a10f10ce7d9ed39ce430e64a43dfc975a013
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/cls.py
@@ -0,0 +1,174 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import torch
+import torch.nn as nn
+
+from efficientvit.models.efficientvit.backbone import (
+ EfficientViTBackbone, EfficientViTLargeBackbone)
+from efficientvit.models.nn import ConvLayer, LinearLayer, OpSequential
+from efficientvit.models.utils import build_kwargs_from_config
+
+__all__ = [
+ "EfficientViTCls",
+ ######################
+ "efficientvit_cls_b0",
+ "efficientvit_cls_b1",
+ "efficientvit_cls_b2",
+ "efficientvit_cls_b3",
+ ######################
+ "efficientvit_cls_l1",
+ "efficientvit_cls_l2",
+ "efficientvit_cls_l3",
+]
+
+
+class ClsHead(OpSequential):
+ def __init__(
+ self,
+ in_channels: int,
+ width_list: list[int],
+ n_classes=1000,
+ dropout=0.0,
+ norm="bn2d",
+ act_func="hswish",
+ fid="stage_final",
+ ):
+ ops = [
+ ConvLayer(in_channels, width_list[0], 1, norm=norm, act_func=act_func),
+ nn.AdaptiveAvgPool2d(output_size=1),
+ LinearLayer(
+ width_list[0], width_list[1], False, norm="ln", act_func=act_func
+ ),
+ LinearLayer(width_list[1], n_classes, True, dropout, None, None),
+ ]
+ super().__init__(ops)
+
+ self.fid = fid
+
+ def forward(self, feed_dict: dict[str, torch.Tensor]) -> torch.Tensor:
+ x = feed_dict[self.fid]
+ return OpSequential.forward(self, x)
+
+
+class EfficientViTCls(nn.Module):
+ def __init__(
+ self, backbone: EfficientViTBackbone or EfficientViTLargeBackbone, head: ClsHead
+ ) -> None:
+ super().__init__()
+ self.backbone = backbone
+ self.head = head
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ feed_dict = self.backbone(x)
+ output = self.head(feed_dict)
+ return output
+
+
+def efficientvit_cls_b0(**kwargs) -> EfficientViTCls:
+ from efficientvit.models.efficientvit.backbone import \
+ efficientvit_backbone_b0
+
+ backbone = efficientvit_backbone_b0(**kwargs)
+
+ head = ClsHead(
+ in_channels=128,
+ width_list=[1024, 1280],
+ **build_kwargs_from_config(kwargs, ClsHead),
+ )
+ model = EfficientViTCls(backbone, head)
+ return model
+
+
+def efficientvit_cls_b1(**kwargs) -> EfficientViTCls:
+ from efficientvit.models.efficientvit.backbone import \
+ efficientvit_backbone_b1
+
+ backbone = efficientvit_backbone_b1(**kwargs)
+
+ head = ClsHead(
+ in_channels=256,
+ width_list=[1536, 1600],
+ **build_kwargs_from_config(kwargs, ClsHead),
+ )
+ model = EfficientViTCls(backbone, head)
+ return model
+
+
+def efficientvit_cls_b2(**kwargs) -> EfficientViTCls:
+ from efficientvit.models.efficientvit.backbone import \
+ efficientvit_backbone_b2
+
+ backbone = efficientvit_backbone_b2(**kwargs)
+
+ head = ClsHead(
+ in_channels=384,
+ width_list=[2304, 2560],
+ **build_kwargs_from_config(kwargs, ClsHead),
+ )
+ model = EfficientViTCls(backbone, head)
+ return model
+
+
+def efficientvit_cls_b3(**kwargs) -> EfficientViTCls:
+ from efficientvit.models.efficientvit.backbone import \
+ efficientvit_backbone_b3
+
+ backbone = efficientvit_backbone_b3(**kwargs)
+
+ head = ClsHead(
+ in_channels=512,
+ width_list=[2304, 2560],
+ **build_kwargs_from_config(kwargs, ClsHead),
+ )
+ model = EfficientViTCls(backbone, head)
+ return model
+
+
+def efficientvit_cls_l1(**kwargs) -> EfficientViTCls:
+ from efficientvit.models.efficientvit.backbone import \
+ efficientvit_backbone_l1
+
+ backbone = efficientvit_backbone_l1(**kwargs)
+
+ head = ClsHead(
+ in_channels=512,
+ width_list=[3072, 3200],
+ act_func="gelu",
+ **build_kwargs_from_config(kwargs, ClsHead),
+ )
+ model = EfficientViTCls(backbone, head)
+ return model
+
+
+def efficientvit_cls_l2(**kwargs) -> EfficientViTCls:
+ from efficientvit.models.efficientvit.backbone import \
+ efficientvit_backbone_l2
+
+ backbone = efficientvit_backbone_l2(**kwargs)
+
+ head = ClsHead(
+ in_channels=512,
+ width_list=[3072, 3200],
+ act_func="gelu",
+ **build_kwargs_from_config(kwargs, ClsHead),
+ )
+ model = EfficientViTCls(backbone, head)
+ return model
+
+
+def efficientvit_cls_l3(**kwargs) -> EfficientViTCls:
+ from efficientvit.models.efficientvit.backbone import \
+ efficientvit_backbone_l3
+
+ backbone = efficientvit_backbone_l3(**kwargs)
+
+ head = ClsHead(
+ in_channels=1024,
+ width_list=[6144, 6400],
+ act_func="gelu",
+ **build_kwargs_from_config(kwargs, ClsHead),
+ )
+ model = EfficientViTCls(backbone, head)
+ return model
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/sam.py b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6736c5d7dde284b208c164f0ac1cec926bbb692
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/sam.py
@@ -0,0 +1,653 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import copy
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms as transforms
+from segment_anything import SamAutomaticMaskGenerator
+from segment_anything.modeling import (MaskDecoder, PromptEncoder,
+ TwoWayTransformer)
+from segment_anything.modeling.mask_decoder import MaskDecoder
+from segment_anything.modeling.prompt_encoder import PromptEncoder
+from segment_anything.utils.amg import build_all_layer_point_grids
+from segment_anything.utils.transforms import ResizeLongestSide
+from torchvision.transforms.functional import resize, to_pil_image
+
+from efficientvit.models.efficientvit.backbone import (
+ EfficientViTBackbone, EfficientViTLargeBackbone)
+from efficientvit.models.nn import (ConvLayer, DAGBlock, FusedMBConv,
+ IdentityLayer, MBConv, OpSequential,
+ ResBlock, ResidualBlock, UpSampleLayer,
+ build_norm)
+from efficientvit.models.utils import build_kwargs_from_config, get_device
+
+__all__ = [
+ "SamPad",
+ "SamResize",
+ "SamNeck",
+ "EfficientViTSamImageEncoder",
+ "EfficientViTSam",
+ "EfficientViTSamPredictor",
+ "EfficientViTSamAutomaticMaskGenerator",
+ "efficientvit_sam_l0",
+ "efficientvit_sam_l1",
+ "efficientvit_sam_l2",
+ "efficientvit_sam_xl0",
+ "efficientvit_sam_xl1",
+]
+
+
+class SamPad:
+ def __init__(self, size: int, fill: float = 0, pad_mode="corner") -> None:
+ self.size = size
+ self.fill = fill
+ self.pad_mode = pad_mode
+
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
+ h, w = image.shape[-2:]
+ th, tw = self.size, self.size
+ assert th >= h and tw >= w
+ if self.pad_mode == "corner":
+ image = F.pad(image, (0, tw - w, 0, th - h), value=self.fill)
+ else:
+ raise NotImplementedError
+ return image
+
+ def __repr__(self) -> str:
+ return f"{type(self).__name__}(size={self.size},mode={self.pad_mode},fill={self.fill})"
+
+
+class SamResize:
+ def __init__(self, size: int) -> None:
+ self.size = size
+
+ def __call__(self, image: np.ndarray) -> np.ndarray:
+ h, w, _ = image.shape
+ long_side = max(h, w)
+ if long_side != self.size:
+ return self.apply_image(image)
+ else:
+ return image
+
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
+ """
+ Expects a numpy array with shape HxWxC in uint8 format.
+ """
+ target_size = self.get_preprocess_shape(
+ image.shape[0], image.shape[1], self.size
+ )
+ return np.array(resize(to_pil_image(image), target_size))
+
+ @staticmethod
+ def get_preprocess_shape(
+ oldh: int, oldw: int, long_side_length: int
+ ) -> tuple[int, int]:
+ """
+ Compute the output size given input size and target long side length.
+ """
+ scale = long_side_length * 1.0 / max(oldh, oldw)
+ newh, neww = oldh * scale, oldw * scale
+ neww = int(neww + 0.5)
+ newh = int(newh + 0.5)
+ return (newh, neww)
+
+ def __repr__(self) -> str:
+ return f"{type(self).__name__}(size={self.size})"
+
+
+class SamNeck(DAGBlock):
+ def __init__(
+ self,
+ fid_list: list[str],
+ in_channel_list: list[int],
+ head_width: int,
+ head_depth: int,
+ expand_ratio: float,
+ middle_op: str,
+ out_dim: int = 256,
+ norm="bn2d",
+ act_func="gelu",
+ ):
+ inputs = {}
+ for fid, in_channel in zip(fid_list, in_channel_list):
+ inputs[fid] = OpSequential(
+ [
+ ConvLayer(in_channel, head_width, 1, norm=norm, act_func=None),
+ UpSampleLayer(size=(64, 64)),
+ ]
+ )
+
+ middle = []
+ for _ in range(head_depth):
+ if middle_op == "mb":
+ block = MBConv(
+ head_width,
+ head_width,
+ expand_ratio=expand_ratio,
+ norm=norm,
+ act_func=(act_func, act_func, None),
+ )
+ elif middle_op == "fmb":
+ block = FusedMBConv(
+ head_width,
+ head_width,
+ expand_ratio=expand_ratio,
+ norm=norm,
+ act_func=(act_func, None),
+ )
+ elif middle_op == "res":
+ block = ResBlock(
+ head_width,
+ head_width,
+ expand_ratio=expand_ratio,
+ norm=norm,
+ act_func=(act_func, None),
+ )
+ else:
+ raise NotImplementedError
+ middle.append(ResidualBlock(block, IdentityLayer()))
+ middle = OpSequential(middle)
+
+ outputs = {
+ "sam_encoder": OpSequential(
+ [
+ ConvLayer(
+ head_width,
+ out_dim,
+ 1,
+ use_bias=True,
+ norm=None,
+ act_func=None,
+ ),
+ ]
+ )
+ }
+
+ super(SamNeck, self).__init__(
+ inputs, "add", None, middle=middle, outputs=outputs
+ )
+
+
+class EfficientViTSamImageEncoder(nn.Module):
+ def __init__(
+ self, backbone: EfficientViTBackbone or EfficientViTLargeBackbone, neck: SamNeck
+ ):
+ super().__init__()
+ self.backbone = backbone
+ self.neck = neck
+
+ self.norm = build_norm("ln2d", 256)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ feed_dict = self.backbone(x)
+ feed_dict = self.neck(feed_dict)
+
+ output = feed_dict["sam_encoder"]
+ output = self.norm(output)
+ return output
+
+
+class EfficientViTSam(nn.Module):
+ mask_threshold: float = 0.0
+ image_format: str = "RGB"
+
+ def __init__(
+ self,
+ image_encoder: EfficientViTSamImageEncoder,
+ prompt_encoder: PromptEncoder,
+ mask_decoder: MaskDecoder,
+ image_size: tuple[int, int] = (1024, 512),
+ ) -> None:
+ super().__init__()
+ self.image_encoder = image_encoder
+ self.prompt_encoder = prompt_encoder
+ self.mask_decoder = mask_decoder
+
+ self.image_size = image_size
+
+ self.transform = transforms.Compose(
+ [
+ SamResize(self.image_size[1]),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[123.675 / 255, 116.28 / 255, 103.53 / 255],
+ std=[58.395 / 255, 57.12 / 255, 57.375 / 255],
+ ),
+ SamPad(self.image_size[1]),
+ ]
+ )
+
+ def postprocess_masks(
+ self,
+ masks: torch.Tensor,
+ input_size: tuple[int, ...],
+ original_size: tuple[int, ...],
+ ) -> torch.Tensor:
+ masks = F.interpolate(
+ masks,
+ (self.image_size[0], self.image_size[0]),
+ mode="bilinear",
+ align_corners=False,
+ )
+ masks = masks[..., : input_size[0], : input_size[1]]
+ masks = F.interpolate(
+ masks, original_size, mode="bilinear", align_corners=False
+ )
+ return masks
+
+
+class EfficientViTSamPredictor:
+ def __init__(self, sam_model: EfficientViTSam) -> None:
+ self.model = sam_model
+ self.reset_image()
+
+ @property
+ def transform(self):
+ return self
+
+ @property
+ def device(self):
+ return get_device(self.model)
+
+ def reset_image(self) -> None:
+ self.is_image_set = False
+ self.features = None
+ self.original_size = None
+ self.input_size = None
+
+ def apply_coords(self, coords: np.ndarray, im_size=None) -> np.ndarray:
+ old_h, old_w = self.original_size
+ new_h, new_w = self.input_size
+ coords = copy.deepcopy(coords).astype(float)
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
+ return coords
+
+ def apply_boxes(self, boxes: np.ndarray, im_size=None) -> np.ndarray:
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2))
+ return boxes.reshape(-1, 4)
+
+ @torch.inference_mode()
+ def set_image(self, image: np.ndarray, image_format: str = "RGB") -> None:
+ assert image_format in [
+ "RGB",
+ "BGR",
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
+ if image_format != self.model.image_format:
+ image = image[..., ::-1]
+
+ self.reset_image()
+
+ self.original_size = image.shape[:2]
+ self.input_size = ResizeLongestSide.get_preprocess_shape(
+ *self.original_size, long_side_length=self.model.image_size[0]
+ )
+
+ torch_data = (
+ self.model.transform(image).unsqueeze(dim=0).to(get_device(self.model))
+ )
+ self.features = self.model.image_encoder(torch_data)
+ self.is_image_set = True
+
+ def predict(
+ self,
+ point_coords: np.ndarray or None = None,
+ point_labels: np.ndarray or None = None,
+ box: np.ndarray or None = None,
+ mask_input: np.ndarray or None = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+
+ Arguments:
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (np.ndarray or None): A length N array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ box (np.ndarray or None): A length 4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form 1xHxW, where
+ for SAM, H=W=256.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+
+ Returns:
+ (np.ndarray): The output masks in CxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (np.ndarray): An array of length C containing the model's
+ predictions for the quality of each mask.
+ (np.ndarray): An array of shape CxHxW, where C is the number
+ of masks and H=W=256. These low resolution logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self.is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image(...) before mask prediction."
+ )
+
+ device = get_device(self.model)
+ # Transform input prompts
+ coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
+ if point_coords is not None:
+ assert (
+ point_labels is not None
+ ), "point_labels must be supplied if point_coords is supplied."
+ point_coords = self.apply_coords(point_coords)
+ coords_torch = torch.as_tensor(
+ point_coords, dtype=torch.float, device=device
+ )
+ labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
+ coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
+ if box is not None:
+ box = self.apply_boxes(box)
+ box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
+ box_torch = box_torch[None, :]
+ if mask_input is not None:
+ mask_input_torch = torch.as_tensor(
+ mask_input, dtype=torch.float, device=device
+ )
+ mask_input_torch = mask_input_torch[None, :, :, :]
+
+ masks, iou_predictions, low_res_masks = self.predict_torch(
+ coords_torch,
+ labels_torch,
+ box_torch,
+ mask_input_torch,
+ multimask_output,
+ return_logits=return_logits,
+ )
+
+ masks = masks[0].detach().cpu().numpy()
+ iou_predictions = iou_predictions[0].detach().cpu().numpy()
+ low_res_masks = low_res_masks[0].detach().cpu().numpy()
+ return masks, iou_predictions, low_res_masks
+
+ @torch.inference_mode()
+ def predict_torch(
+ self,
+ point_coords: torch.Tensor or None = None,
+ point_labels: torch.Tensor or None = None,
+ boxes: torch.Tensor or None = None,
+ mask_input: torch.Tensor or None = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+ Input prompts are batched torch tensors and are expected to already be
+ transformed to the input frame using ResizeLongestSide.
+
+ Arguments:
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (torch.Tensor or None): A BxN array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ box (np.ndarray or None): A Bx4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
+ for SAM, H=W=256. Masks returned by a previous iteration of the
+ predict method do not need further transformation.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+
+ Returns:
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (torch.Tensor): An array of shape BxC containing the model's
+ predictions for the quality of each mask.
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
+ of masks and H=W=256. These low res logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self.is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image(...) before mask prediction."
+ )
+
+ if point_coords is not None:
+ points = (point_coords, point_labels)
+ else:
+ points = None
+
+ # Embed prompts
+ sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
+ points=points,
+ boxes=boxes,
+ masks=mask_input,
+ )
+
+ # Predict masks
+ low_res_masks, iou_predictions = self.model.mask_decoder(
+ image_embeddings=self.features,
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ )
+
+ # Upscale the masks to the original image resolution
+ masks = self.model.postprocess_masks(
+ low_res_masks, self.input_size, self.original_size
+ )
+
+ if not return_logits:
+ masks = masks > self.model.mask_threshold
+
+ return masks, iou_predictions, low_res_masks
+
+
+class EfficientViTSamAutomaticMaskGenerator(SamAutomaticMaskGenerator):
+ def __init__(
+ self,
+ model: EfficientViTSam,
+ points_per_side: int or None = 32,
+ points_per_batch: int = 64,
+ pred_iou_thresh: float = 0.88,
+ stability_score_thresh: float = 0.95,
+ stability_score_offset: float = 1.0,
+ box_nms_thresh: float = 0.7,
+ crop_n_layers: int = 0,
+ crop_nms_thresh: float = 0.7,
+ crop_overlap_ratio: float = 512 / 1500,
+ crop_n_points_downscale_factor: int = 1,
+ point_grids: list[np.ndarray] or None = None,
+ min_mask_region_area: int = 0,
+ output_mode: str = "binary_mask",
+ ) -> None:
+ assert (points_per_side is None) != (
+ point_grids is None
+ ), "Exactly one of points_per_side or point_grid must be provided."
+ if points_per_side is not None:
+ self.point_grids = build_all_layer_point_grids(
+ points_per_side,
+ crop_n_layers,
+ crop_n_points_downscale_factor,
+ )
+ elif point_grids is not None:
+ self.point_grids = point_grids
+ else:
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
+
+ assert output_mode in [
+ "binary_mask",
+ "uncompressed_rle",
+ "coco_rle",
+ ], f"Unknown output_mode {output_mode}."
+ if output_mode == "coco_rle":
+ from pycocotools import \
+ mask as mask_utils # type: ignore # noqa: F401
+
+ if min_mask_region_area > 0:
+ import cv2 # type: ignore # noqa: F401
+
+ self.predictor = EfficientViTSamPredictor(model)
+ self.points_per_batch = points_per_batch
+ self.pred_iou_thresh = pred_iou_thresh
+ self.stability_score_thresh = stability_score_thresh
+ self.stability_score_offset = stability_score_offset
+ self.box_nms_thresh = box_nms_thresh
+ self.crop_n_layers = crop_n_layers
+ self.crop_nms_thresh = crop_nms_thresh
+ self.crop_overlap_ratio = crop_overlap_ratio
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
+ self.min_mask_region_area = min_mask_region_area
+ self.output_mode = output_mode
+
+
+def build_efficientvit_sam(
+ image_encoder: EfficientViTSamImageEncoder, image_size: int
+) -> EfficientViTSam:
+ return EfficientViTSam(
+ image_encoder=image_encoder,
+ prompt_encoder=PromptEncoder(
+ embed_dim=256,
+ image_embedding_size=(64, 64),
+ input_image_size=(1024, 1024),
+ mask_in_chans=16,
+ ),
+ mask_decoder=MaskDecoder(
+ num_multimask_outputs=3,
+ transformer=TwoWayTransformer(
+ depth=2,
+ embedding_dim=256,
+ mlp_dim=2048,
+ num_heads=8,
+ ),
+ transformer_dim=256,
+ iou_head_depth=3,
+ iou_head_hidden_dim=256,
+ ),
+ image_size=(1024, image_size),
+ )
+
+
+def efficientvit_sam_l0(image_size: int = 512, **kwargs) -> EfficientViTSam:
+ from efficientvit.models.efficientvit.backbone import \
+ efficientvit_backbone_l0
+
+ backbone = efficientvit_backbone_l0(**kwargs)
+
+ neck = SamNeck(
+ fid_list=["stage4", "stage3", "stage2"],
+ in_channel_list=[512, 256, 128],
+ head_width=256,
+ head_depth=4,
+ expand_ratio=1,
+ middle_op="fmb",
+ )
+
+ image_encoder = EfficientViTSamImageEncoder(backbone, neck)
+ return build_efficientvit_sam(image_encoder, image_size)
+
+
+def efficientvit_sam_l1(image_size: int = 512, **kwargs) -> EfficientViTSam:
+ from efficientvit.models.efficientvit.backbone import \
+ efficientvit_backbone_l1
+
+ backbone = efficientvit_backbone_l1(**kwargs)
+
+ neck = SamNeck(
+ fid_list=["stage4", "stage3", "stage2"],
+ in_channel_list=[512, 256, 128],
+ head_width=256,
+ head_depth=8,
+ expand_ratio=1,
+ middle_op="fmb",
+ )
+
+ image_encoder = EfficientViTSamImageEncoder(backbone, neck)
+ return build_efficientvit_sam(image_encoder, image_size)
+
+
+def efficientvit_sam_l2(image_size: int = 512, **kwargs) -> EfficientViTSam:
+ from efficientvit.models.efficientvit.backbone import \
+ efficientvit_backbone_l2
+
+ backbone = efficientvit_backbone_l2(**kwargs)
+
+ neck = SamNeck(
+ fid_list=["stage4", "stage3", "stage2"],
+ in_channel_list=[512, 256, 128],
+ head_width=256,
+ head_depth=12,
+ expand_ratio=1,
+ middle_op="fmb",
+ )
+
+ image_encoder = EfficientViTSamImageEncoder(backbone, neck)
+ return build_efficientvit_sam(image_encoder, image_size)
+
+
+def efficientvit_sam_xl0(image_size: int = 1024, **kwargs) -> EfficientViTSam:
+ from efficientvit.models.efficientvit.backbone import \
+ EfficientViTLargeBackbone
+
+ backbone = EfficientViTLargeBackbone(
+ width_list=[32, 64, 128, 256, 512, 1024],
+ depth_list=[0, 1, 1, 2, 3, 3],
+ block_list=["res", "fmb", "fmb", "fmb", "att@3", "att@3"],
+ expand_list=[1, 4, 4, 4, 4, 6],
+ fewer_norm_list=[False, False, False, False, True, True],
+ **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
+ )
+
+ neck = SamNeck(
+ fid_list=["stage5", "stage4", "stage3"],
+ in_channel_list=[1024, 512, 256],
+ head_width=256,
+ head_depth=6,
+ expand_ratio=4,
+ middle_op="fmb",
+ )
+
+ image_encoder = EfficientViTSamImageEncoder(backbone, neck)
+ return build_efficientvit_sam(image_encoder, image_size)
+
+
+def efficientvit_sam_xl1(image_size: int = 1024, **kwargs) -> EfficientViTSam:
+ from efficientvit.models.efficientvit.backbone import \
+ EfficientViTLargeBackbone
+
+ backbone = EfficientViTLargeBackbone(
+ width_list=[32, 64, 128, 256, 512, 1024],
+ depth_list=[1, 2, 2, 4, 6, 6],
+ block_list=["res", "fmb", "fmb", "fmb", "att@3", "att@3"],
+ expand_list=[1, 4, 4, 4, 4, 6],
+ fewer_norm_list=[False, False, False, False, True, True],
+ **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
+ )
+
+ neck = SamNeck(
+ fid_list=["stage5", "stage4", "stage3"],
+ in_channel_list=[1024, 512, 256],
+ head_width=256,
+ head_depth=12,
+ expand_ratio=4,
+ middle_op="fmb",
+ )
+
+ image_encoder = EfficientViTSamImageEncoder(backbone, neck)
+ return build_efficientvit_sam(image_encoder, image_size)
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/seg.py b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/seg.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bdb2817b201aae21cb112ae787adddcfd498775
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/models/efficientvit/seg.py
@@ -0,0 +1,355 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import torch
+import torch.nn as nn
+
+from efficientvit.models.efficientvit.backbone import (
+ EfficientViTBackbone, EfficientViTLargeBackbone)
+from efficientvit.models.nn import (ConvLayer, DAGBlock, FusedMBConv,
+ IdentityLayer, MBConv, OpSequential,
+ ResidualBlock, UpSampleLayer)
+from efficientvit.models.utils import build_kwargs_from_config
+
+__all__ = [
+ "EfficientViTSeg",
+ "efficientvit_seg_b0",
+ "efficientvit_seg_b1",
+ "efficientvit_seg_b2",
+ "efficientvit_seg_b3",
+ "efficientvit_seg_l1",
+ "efficientvit_seg_l2",
+]
+
+
+class SegHead(DAGBlock):
+ def __init__(
+ self,
+ fid_list: list[str],
+ in_channel_list: list[int],
+ stride_list: list[int],
+ head_stride: int,
+ head_width: int,
+ head_depth: int,
+ expand_ratio: float,
+ middle_op: str,
+ final_expand: float or None,
+ n_classes: int,
+ dropout=0,
+ norm="bn2d",
+ act_func="hswish",
+ ):
+ inputs = {}
+ for fid, in_channel, stride in zip(fid_list, in_channel_list, stride_list):
+ factor = stride // head_stride
+ if factor == 1:
+ inputs[fid] = ConvLayer(
+ in_channel, head_width, 1, norm=norm, act_func=None
+ )
+ else:
+ inputs[fid] = OpSequential(
+ [
+ ConvLayer(in_channel, head_width, 1, norm=norm, act_func=None),
+ UpSampleLayer(factor=factor),
+ ]
+ )
+
+ middle = []
+ for _ in range(head_depth):
+ if middle_op == "mbconv":
+ block = MBConv(
+ head_width,
+ head_width,
+ expand_ratio=expand_ratio,
+ norm=norm,
+ act_func=(act_func, act_func, None),
+ )
+ elif middle_op == "fmbconv":
+ block = FusedMBConv(
+ head_width,
+ head_width,
+ expand_ratio=expand_ratio,
+ norm=norm,
+ act_func=(act_func, None),
+ )
+ else:
+ raise NotImplementedError
+ middle.append(ResidualBlock(block, IdentityLayer()))
+ middle = OpSequential(middle)
+
+ outputs = {
+ "segout": OpSequential(
+ [
+ (
+ None
+ if final_expand is None
+ else ConvLayer(
+ head_width,
+ head_width * final_expand,
+ 1,
+ norm=norm,
+ act_func=act_func,
+ )
+ ),
+ ConvLayer(
+ head_width * (final_expand or 1),
+ n_classes,
+ 1,
+ use_bias=True,
+ dropout=dropout,
+ norm=None,
+ act_func=None,
+ ),
+ ]
+ )
+ }
+
+ super(SegHead, self).__init__(
+ inputs, "add", None, middle=middle, outputs=outputs
+ )
+
+
+class EfficientViTSeg(nn.Module):
+ def __init__(
+ self, backbone: EfficientViTBackbone or EfficientViTLargeBackbone, head: SegHead
+ ) -> None:
+ super().__init__()
+ self.backbone = backbone
+ self.head = head
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ feed_dict = self.backbone(x)
+ feed_dict = self.head(feed_dict)
+
+ return feed_dict["segout"]
+
+
+def efficientvit_seg_b0(dataset: str, **kwargs) -> EfficientViTSeg:
+ from efficientvit.models.efficientvit.backbone import \
+ efficientvit_backbone_b0
+
+ backbone = efficientvit_backbone_b0(**kwargs)
+
+ if dataset == "cityscapes":
+ head = SegHead(
+ fid_list=["stage4", "stage3", "stage2"],
+ in_channel_list=[128, 64, 32],
+ stride_list=[32, 16, 8],
+ head_stride=8,
+ head_width=32,
+ head_depth=1,
+ expand_ratio=4,
+ middle_op="mbconv",
+ final_expand=4,
+ n_classes=19,
+ **build_kwargs_from_config(kwargs, SegHead),
+ )
+ else:
+ raise NotImplementedError
+ model = EfficientViTSeg(backbone, head)
+ return model
+
+
+def efficientvit_seg_b1(dataset: str, **kwargs) -> EfficientViTSeg:
+ from efficientvit.models.efficientvit.backbone import \
+ efficientvit_backbone_b1
+
+ backbone = efficientvit_backbone_b1(**kwargs)
+
+ if dataset == "cityscapes":
+ head = SegHead(
+ fid_list=["stage4", "stage3", "stage2"],
+ in_channel_list=[256, 128, 64],
+ stride_list=[32, 16, 8],
+ head_stride=8,
+ head_width=64,
+ head_depth=3,
+ expand_ratio=4,
+ middle_op="mbconv",
+ final_expand=4,
+ n_classes=19,
+ **build_kwargs_from_config(kwargs, SegHead),
+ )
+ elif dataset == "ade20k":
+ head = SegHead(
+ fid_list=["stage4", "stage3", "stage2"],
+ in_channel_list=[256, 128, 64],
+ stride_list=[32, 16, 8],
+ head_stride=8,
+ head_width=64,
+ head_depth=3,
+ expand_ratio=4,
+ middle_op="mbconv",
+ final_expand=None,
+ n_classes=150,
+ **build_kwargs_from_config(kwargs, SegHead),
+ )
+ else:
+ raise NotImplementedError
+ model = EfficientViTSeg(backbone, head)
+ return model
+
+
+def efficientvit_seg_b2(dataset: str, **kwargs) -> EfficientViTSeg:
+ from efficientvit.models.efficientvit.backbone import \
+ efficientvit_backbone_b2
+
+ backbone = efficientvit_backbone_b2(**kwargs)
+
+ if dataset == "cityscapes":
+ head = SegHead(
+ fid_list=["stage4", "stage3", "stage2"],
+ in_channel_list=[384, 192, 96],
+ stride_list=[32, 16, 8],
+ head_stride=8,
+ head_width=96,
+ head_depth=3,
+ expand_ratio=4,
+ middle_op="mbconv",
+ final_expand=4,
+ n_classes=19,
+ **build_kwargs_from_config(kwargs, SegHead),
+ )
+ elif dataset == "ade20k":
+ head = SegHead(
+ fid_list=["stage4", "stage3", "stage2"],
+ in_channel_list=[384, 192, 96],
+ stride_list=[32, 16, 8],
+ head_stride=8,
+ head_width=96,
+ head_depth=3,
+ expand_ratio=4,
+ middle_op="mbconv",
+ final_expand=None,
+ n_classes=150,
+ **build_kwargs_from_config(kwargs, SegHead),
+ )
+ else:
+ raise NotImplementedError
+ model = EfficientViTSeg(backbone, head)
+ return model
+
+
+def efficientvit_seg_b3(dataset: str, **kwargs) -> EfficientViTSeg:
+ from efficientvit.models.efficientvit.backbone import \
+ efficientvit_backbone_b3
+
+ backbone = efficientvit_backbone_b3(**kwargs)
+
+ if dataset == "cityscapes":
+ head = SegHead(
+ fid_list=["stage4", "stage3", "stage2"],
+ in_channel_list=[512, 256, 128],
+ stride_list=[32, 16, 8],
+ head_stride=8,
+ head_width=128,
+ head_depth=3,
+ expand_ratio=4,
+ middle_op="mbconv",
+ final_expand=4,
+ n_classes=19,
+ **build_kwargs_from_config(kwargs, SegHead),
+ )
+ elif dataset == "ade20k":
+ head = SegHead(
+ fid_list=["stage4", "stage3", "stage2"],
+ in_channel_list=[512, 256, 128],
+ stride_list=[32, 16, 8],
+ head_stride=8,
+ head_width=128,
+ head_depth=3,
+ expand_ratio=4,
+ middle_op="mbconv",
+ final_expand=None,
+ n_classes=150,
+ **build_kwargs_from_config(kwargs, SegHead),
+ )
+ else:
+ raise NotImplementedError
+ model = EfficientViTSeg(backbone, head)
+ return model
+
+
+def efficientvit_seg_l1(dataset: str, **kwargs) -> EfficientViTSeg:
+ from efficientvit.models.efficientvit.backbone import \
+ efficientvit_backbone_l1
+
+ backbone = efficientvit_backbone_l1(**kwargs)
+
+ if dataset == "cityscapes":
+ head = SegHead(
+ fid_list=["stage4", "stage3", "stage2"],
+ in_channel_list=[512, 256, 128],
+ stride_list=[32, 16, 8],
+ head_stride=8,
+ head_width=256,
+ head_depth=3,
+ expand_ratio=1,
+ middle_op="fmbconv",
+ final_expand=None,
+ n_classes=19,
+ act_func="gelu",
+ **build_kwargs_from_config(kwargs, SegHead),
+ )
+ elif dataset == "ade20k":
+ head = SegHead(
+ fid_list=["stage4", "stage3", "stage2"],
+ in_channel_list=[512, 256, 128],
+ stride_list=[32, 16, 8],
+ head_stride=8,
+ head_width=128,
+ head_depth=3,
+ expand_ratio=4,
+ middle_op="fmbconv",
+ final_expand=8,
+ n_classes=150,
+ act_func="gelu",
+ **build_kwargs_from_config(kwargs, SegHead),
+ )
+ else:
+ raise NotImplementedError
+ model = EfficientViTSeg(backbone, head)
+ return model
+
+
+def efficientvit_seg_l2(dataset: str, **kwargs) -> EfficientViTSeg:
+ from efficientvit.models.efficientvit.backbone import \
+ efficientvit_backbone_l2
+
+ backbone = efficientvit_backbone_l2(**kwargs)
+
+ if dataset == "cityscapes":
+ head = SegHead(
+ fid_list=["stage4", "stage3", "stage2"],
+ in_channel_list=[512, 256, 128],
+ stride_list=[32, 16, 8],
+ head_stride=8,
+ head_width=256,
+ head_depth=5,
+ expand_ratio=1,
+ middle_op="fmbconv",
+ final_expand=None,
+ n_classes=19,
+ act_func="gelu",
+ **build_kwargs_from_config(kwargs, SegHead),
+ )
+ elif dataset == "ade20k":
+ head = SegHead(
+ fid_list=["stage4", "stage3", "stage2"],
+ in_channel_list=[512, 256, 128],
+ stride_list=[32, 16, 8],
+ head_stride=8,
+ head_width=128,
+ head_depth=3,
+ expand_ratio=4,
+ middle_op="fmbconv",
+ final_expand=8,
+ n_classes=150,
+ act_func="gelu",
+ **build_kwargs_from_config(kwargs, SegHead),
+ )
+ else:
+ raise NotImplementedError
+ model = EfficientViTSeg(backbone, head)
+ return model
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/nn/__init__.py b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6152158a1a8a0b4d2fc53622bdf338fbf34809d
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/__init__.py
@@ -0,0 +1,8 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+from .act import *
+from .drop import *
+from .norm import *
+from .ops import *
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/nn/__pycache__/__init__.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a83467683758d227d72604c8d9ab44b12a67b43f
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/__pycache__/__init__.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/nn/__pycache__/act.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/__pycache__/act.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9ebcd0feef6aecaffc0d41ade5dc0c983415087
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/__pycache__/act.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/nn/__pycache__/drop.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/__pycache__/drop.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b4cfffb0daa53a188ae1ca2fe450b3851671501
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/__pycache__/drop.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/nn/__pycache__/norm.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/__pycache__/norm.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83bb8e5ad1dc979c64da074518c4e79687b49ef7
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/__pycache__/norm.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/nn/__pycache__/ops.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/__pycache__/ops.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1187555a8014f327ecbb71093e2094616e42384d
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/__pycache__/ops.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/nn/act.py b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/act.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b9c84c875f3f2483b50a846532480f1916d3afc
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/act.py
@@ -0,0 +1,30 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+from functools import partial
+
+import torch.nn as nn
+
+from efficientvit.models.utils import build_kwargs_from_config
+
+__all__ = ["build_act"]
+
+
+# register activation function here
+REGISTERED_ACT_DICT: dict[str, type] = {
+ "relu": nn.ReLU,
+ "relu6": nn.ReLU6,
+ "hswish": nn.Hardswish,
+ "silu": nn.SiLU,
+ "gelu": partial(nn.GELU, approximate="tanh"),
+}
+
+
+def build_act(name: str, **kwargs) -> nn.Module or None:
+ if name in REGISTERED_ACT_DICT:
+ act_cls = REGISTERED_ACT_DICT[name]
+ args = build_kwargs_from_config(kwargs, act_cls)
+ return act_cls(**args)
+ else:
+ return None
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/nn/drop.py b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/drop.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1e05fe300b52db1e5e654477697d13c54db8827
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/drop.py
@@ -0,0 +1,98 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from efficientvit.apps.trainer.run_config import Scheduler
+from efficientvit.models.nn.ops import IdentityLayer, ResidualBlock
+from efficientvit.models.utils import build_kwargs_from_config
+
+__all__ = ["apply_drop_func"]
+
+
+def apply_drop_func(network: nn.Module, drop_config: dict[str, any] or None) -> None:
+ if drop_config is None:
+ return
+
+ drop_lookup_table = {
+ "droppath": apply_droppath,
+ }
+
+ drop_func = drop_lookup_table[drop_config["name"]]
+ drop_kwargs = build_kwargs_from_config(drop_config, drop_func)
+
+ drop_func(network, **drop_kwargs)
+
+
+def apply_droppath(
+ network: nn.Module,
+ drop_prob: float,
+ linear_decay=True,
+ scheduled=True,
+ skip=0,
+) -> None:
+ all_valid_blocks = []
+ for m in network.modules():
+ for name, sub_module in m.named_children():
+ if isinstance(sub_module, ResidualBlock) and isinstance(
+ sub_module.shortcut, IdentityLayer
+ ):
+ all_valid_blocks.append((m, name, sub_module))
+ all_valid_blocks = all_valid_blocks[skip:]
+ for i, (m, name, sub_module) in enumerate(all_valid_blocks):
+ prob = (
+ drop_prob * (i + 1) / len(all_valid_blocks) if linear_decay else drop_prob
+ )
+ new_module = DropPathResidualBlock(
+ sub_module.main,
+ sub_module.shortcut,
+ sub_module.post_act,
+ sub_module.pre_norm,
+ prob,
+ scheduled,
+ )
+ m._modules[name] = new_module
+
+
+class DropPathResidualBlock(ResidualBlock):
+ def __init__(
+ self,
+ main: nn.Module,
+ shortcut: nn.Module or None,
+ post_act=None,
+ pre_norm: nn.Module or None = None,
+ ######################################
+ drop_prob: float = 0,
+ scheduled=True,
+ ):
+ super().__init__(main, shortcut, post_act, pre_norm)
+
+ self.drop_prob = drop_prob
+ self.scheduled = scheduled
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if (
+ not self.training
+ or self.drop_prob == 0
+ or not isinstance(self.shortcut, IdentityLayer)
+ ):
+ return ResidualBlock.forward(self, x)
+ else:
+ drop_prob = self.drop_prob
+ if self.scheduled:
+ drop_prob *= np.clip(Scheduler.PROGRESS, 0, 1)
+ keep_prob = 1 - drop_prob
+
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
+ random_tensor = keep_prob + torch.rand(
+ shape, dtype=x.dtype, device=x.device
+ )
+ random_tensor.floor_() # binarize
+
+ res = self.forward_main(x) / keep_prob * random_tensor + self.shortcut(x)
+ if self.post_act:
+ res = self.post_act(res)
+ return res
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/nn/norm.py b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9aca166be32bbdd4f3475508a5b15241fe454697
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/norm.py
@@ -0,0 +1,157 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from efficientvit.models.utils import build_kwargs_from_config
+
+__all__ = ["LayerNorm2d", "build_norm", "reset_bn", "set_norm_eps"]
+
+
+class LayerNorm2d(nn.LayerNorm):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ out = x - torch.mean(x, dim=1, keepdim=True)
+ out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps)
+ if self.elementwise_affine:
+ out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
+ return out
+
+
+# register normalization function here
+REGISTERED_NORM_DICT: dict[str, type] = {
+ "bn2d": nn.BatchNorm2d,
+ "ln": nn.LayerNorm,
+ "ln2d": LayerNorm2d,
+}
+
+
+def build_norm(name="bn2d", num_features=None, **kwargs) -> nn.Module or None:
+ if name in ["ln", "ln2d"]:
+ kwargs["normalized_shape"] = num_features
+ else:
+ kwargs["num_features"] = num_features
+ if name in REGISTERED_NORM_DICT:
+ norm_cls = REGISTERED_NORM_DICT[name]
+ args = build_kwargs_from_config(kwargs, norm_cls)
+ return norm_cls(**args)
+ else:
+ return None
+
+
+def reset_bn(
+ model: nn.Module,
+ data_loader: list,
+ sync=True,
+ progress_bar=False,
+) -> None:
+ import copy
+
+ import torch.nn.functional as F
+ from tqdm import tqdm
+
+ from efficientvit.apps.utils import AverageMeter, is_master, sync_tensor
+ from efficientvit.models.utils import get_device, list_join
+
+ bn_mean = {}
+ bn_var = {}
+
+ tmp_model = copy.deepcopy(model)
+ for name, m in tmp_model.named_modules():
+ if isinstance(m, _BatchNorm):
+ bn_mean[name] = AverageMeter(is_distributed=False)
+ bn_var[name] = AverageMeter(is_distributed=False)
+
+ def new_forward(bn, mean_est, var_est):
+ def lambda_forward(x):
+ x = x.contiguous()
+ if sync:
+ batch_mean = (
+ x.mean(0, keepdim=True)
+ .mean(2, keepdim=True)
+ .mean(3, keepdim=True)
+ ) # 1, C, 1, 1
+ batch_mean = sync_tensor(batch_mean, reduce="cat")
+ batch_mean = torch.mean(batch_mean, dim=0, keepdim=True)
+
+ batch_var = (x - batch_mean) * (x - batch_mean)
+ batch_var = (
+ batch_var.mean(0, keepdim=True)
+ .mean(2, keepdim=True)
+ .mean(3, keepdim=True)
+ )
+ batch_var = sync_tensor(batch_var, reduce="cat")
+ batch_var = torch.mean(batch_var, dim=0, keepdim=True)
+ else:
+ batch_mean = (
+ x.mean(0, keepdim=True)
+ .mean(2, keepdim=True)
+ .mean(3, keepdim=True)
+ ) # 1, C, 1, 1
+ batch_var = (x - batch_mean) * (x - batch_mean)
+ batch_var = (
+ batch_var.mean(0, keepdim=True)
+ .mean(2, keepdim=True)
+ .mean(3, keepdim=True)
+ )
+
+ batch_mean = torch.squeeze(batch_mean)
+ batch_var = torch.squeeze(batch_var)
+
+ mean_est.update(batch_mean.data, x.size(0))
+ var_est.update(batch_var.data, x.size(0))
+
+ # bn forward using calculated mean & var
+ _feature_dim = batch_mean.shape[0]
+ return F.batch_norm(
+ x,
+ batch_mean,
+ batch_var,
+ bn.weight[:_feature_dim],
+ bn.bias[:_feature_dim],
+ False,
+ 0.0,
+ bn.eps,
+ )
+
+ return lambda_forward
+
+ m.forward = new_forward(m, bn_mean[name], bn_var[name])
+
+ # skip if there is no batch normalization layers in the network
+ if len(bn_mean) == 0:
+ return
+
+ tmp_model.eval()
+ with torch.no_grad():
+ with tqdm(
+ total=len(data_loader),
+ desc="reset bn",
+ disable=not progress_bar or not is_master(),
+ ) as t:
+ for images in data_loader:
+ images = images.to(get_device(tmp_model))
+ tmp_model(images)
+ t.set_postfix(
+ {
+ "bs": images.size(0),
+ "res": list_join(images.shape[-2:], "x"),
+ }
+ )
+ t.update()
+
+ for name, m in model.named_modules():
+ if name in bn_mean and bn_mean[name].count > 0:
+ feature_dim = bn_mean[name].avg.size(0)
+ assert isinstance(m, _BatchNorm)
+ m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg)
+ m.running_var.data[:feature_dim].copy_(bn_var[name].avg)
+
+
+def set_norm_eps(model: nn.Module, eps: float or None = None) -> None:
+ for m in model.modules():
+ if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)):
+ if eps is not None:
+ m.eps = eps
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/nn/ops.py b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2821758f6c29fb5fe38804dd376d485fdbaa8d7
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/models/nn/ops.py
@@ -0,0 +1,585 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.cuda.amp import autocast
+
+from efficientvit.models.nn.act import build_act
+from efficientvit.models.nn.norm import build_norm
+from efficientvit.models.utils import (get_same_padding, list_sum, resize,
+ val2list, val2tuple)
+
+__all__ = [
+ "ConvLayer",
+ "UpSampleLayer",
+ "LinearLayer",
+ "IdentityLayer",
+ "DSConv",
+ "MBConv",
+ "FusedMBConv",
+ "ResBlock",
+ "LiteMLA",
+ "EfficientViTBlock",
+ "ResidualBlock",
+ "DAGBlock",
+ "OpSequential",
+]
+
+
+#################################################################################
+# Basic Layers #
+#################################################################################
+
+
+class ConvLayer(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size=3,
+ stride=1,
+ dilation=1,
+ groups=1,
+ use_bias=False,
+ dropout=0,
+ norm="bn2d",
+ act_func="relu",
+ ):
+ super(ConvLayer, self).__init__()
+
+ padding = get_same_padding(kernel_size)
+ padding *= dilation
+
+ self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=(kernel_size, kernel_size),
+ stride=(stride, stride),
+ padding=padding,
+ dilation=(dilation, dilation),
+ groups=groups,
+ bias=use_bias,
+ )
+ self.norm = build_norm(norm, num_features=out_channels)
+ self.act = build_act(act_func)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.dropout is not None:
+ x = self.dropout(x)
+ x = self.conv(x)
+ if self.norm:
+ x = self.norm(x)
+ if self.act:
+ x = self.act(x)
+ return x
+
+
+class UpSampleLayer(nn.Module):
+ def __init__(
+ self,
+ mode="bicubic",
+ size: int or tuple[int, int] or list[int] or None = None,
+ factor=2,
+ align_corners=False,
+ ):
+ super(UpSampleLayer, self).__init__()
+ self.mode = mode
+ self.size = val2list(size, 2) if size is not None else None
+ self.factor = None if self.size is not None else factor
+ self.align_corners = align_corners
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if (
+ self.size is not None and tuple(x.shape[-2:]) == self.size
+ ) or self.factor == 1:
+ return x
+ return resize(x, self.size, self.factor, self.mode, self.align_corners)
+
+
+class LinearLayer(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ use_bias=True,
+ dropout=0,
+ norm=None,
+ act_func=None,
+ ):
+ super(LinearLayer, self).__init__()
+
+ self.dropout = nn.Dropout(dropout, inplace=False) if dropout > 0 else None
+ self.linear = nn.Linear(in_features, out_features, use_bias)
+ self.norm = build_norm(norm, num_features=out_features)
+ self.act = build_act(act_func)
+
+ def _try_squeeze(self, x: torch.Tensor) -> torch.Tensor:
+ if x.dim() > 2:
+ x = torch.flatten(x, start_dim=1)
+ return x
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self._try_squeeze(x)
+ if self.dropout:
+ x = self.dropout(x)
+ x = self.linear(x)
+ if self.norm:
+ x = self.norm(x)
+ if self.act:
+ x = self.act(x)
+ return x
+
+
+class IdentityLayer(nn.Module):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x
+
+
+#################################################################################
+# Basic Blocks #
+#################################################################################
+
+
+class DSConv(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size=3,
+ stride=1,
+ use_bias=False,
+ norm=("bn2d", "bn2d"),
+ act_func=("relu6", None),
+ ):
+ super(DSConv, self).__init__()
+
+ use_bias = val2tuple(use_bias, 2)
+ norm = val2tuple(norm, 2)
+ act_func = val2tuple(act_func, 2)
+
+ self.depth_conv = ConvLayer(
+ in_channels,
+ in_channels,
+ kernel_size,
+ stride,
+ groups=in_channels,
+ norm=norm[0],
+ act_func=act_func[0],
+ use_bias=use_bias[0],
+ )
+ self.point_conv = ConvLayer(
+ in_channels,
+ out_channels,
+ 1,
+ norm=norm[1],
+ act_func=act_func[1],
+ use_bias=use_bias[1],
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.depth_conv(x)
+ x = self.point_conv(x)
+ return x
+
+
+class MBConv(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size=3,
+ stride=1,
+ mid_channels=None,
+ expand_ratio=6,
+ use_bias=False,
+ norm=("bn2d", "bn2d", "bn2d"),
+ act_func=("relu6", "relu6", None),
+ ):
+ super(MBConv, self).__init__()
+
+ use_bias = val2tuple(use_bias, 3)
+ norm = val2tuple(norm, 3)
+ act_func = val2tuple(act_func, 3)
+ mid_channels = mid_channels or round(in_channels * expand_ratio)
+
+ self.inverted_conv = ConvLayer(
+ in_channels,
+ mid_channels,
+ 1,
+ stride=1,
+ norm=norm[0],
+ act_func=act_func[0],
+ use_bias=use_bias[0],
+ )
+ self.depth_conv = ConvLayer(
+ mid_channels,
+ mid_channels,
+ kernel_size,
+ stride=stride,
+ groups=mid_channels,
+ norm=norm[1],
+ act_func=act_func[1],
+ use_bias=use_bias[1],
+ )
+ self.point_conv = ConvLayer(
+ mid_channels,
+ out_channels,
+ 1,
+ norm=norm[2],
+ act_func=act_func[2],
+ use_bias=use_bias[2],
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.inverted_conv(x)
+ x = self.depth_conv(x)
+ x = self.point_conv(x)
+ return x
+
+
+class FusedMBConv(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size=3,
+ stride=1,
+ mid_channels=None,
+ expand_ratio=6,
+ groups=1,
+ use_bias=False,
+ norm=("bn2d", "bn2d"),
+ act_func=("relu6", None),
+ ):
+ super().__init__()
+ use_bias = val2tuple(use_bias, 2)
+ norm = val2tuple(norm, 2)
+ act_func = val2tuple(act_func, 2)
+
+ mid_channels = mid_channels or round(in_channels * expand_ratio)
+
+ self.spatial_conv = ConvLayer(
+ in_channels,
+ mid_channels,
+ kernel_size,
+ stride,
+ groups=groups,
+ use_bias=use_bias[0],
+ norm=norm[0],
+ act_func=act_func[0],
+ )
+ self.point_conv = ConvLayer(
+ mid_channels,
+ out_channels,
+ 1,
+ use_bias=use_bias[1],
+ norm=norm[1],
+ act_func=act_func[1],
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.spatial_conv(x)
+ x = self.point_conv(x)
+ return x
+
+
+class ResBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size=3,
+ stride=1,
+ mid_channels=None,
+ expand_ratio=1,
+ use_bias=False,
+ norm=("bn2d", "bn2d"),
+ act_func=("relu6", None),
+ ):
+ super().__init__()
+ use_bias = val2tuple(use_bias, 2)
+ norm = val2tuple(norm, 2)
+ act_func = val2tuple(act_func, 2)
+
+ mid_channels = mid_channels or round(in_channels * expand_ratio)
+
+ self.conv1 = ConvLayer(
+ in_channels,
+ mid_channels,
+ kernel_size,
+ stride,
+ use_bias=use_bias[0],
+ norm=norm[0],
+ act_func=act_func[0],
+ )
+ self.conv2 = ConvLayer(
+ mid_channels,
+ out_channels,
+ kernel_size,
+ 1,
+ use_bias=use_bias[1],
+ norm=norm[1],
+ act_func=act_func[1],
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.conv1(x)
+ x = self.conv2(x)
+ return x
+
+
+class LiteMLA(nn.Module):
+ r"""Lightweight multi-scale linear attention"""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ heads: int or None = None,
+ heads_ratio: float = 1.0,
+ dim=8,
+ use_bias=False,
+ norm=(None, "bn2d"),
+ act_func=(None, None),
+ kernel_func="relu",
+ scales: tuple[int, ...] = (5,),
+ eps=1.0e-15,
+ ):
+ super(LiteMLA, self).__init__()
+ self.eps = eps
+ heads = heads or int(in_channels // dim * heads_ratio)
+
+ total_dim = heads * dim
+
+ use_bias = val2tuple(use_bias, 2)
+ norm = val2tuple(norm, 2)
+ act_func = val2tuple(act_func, 2)
+
+ self.dim = dim
+ self.qkv = ConvLayer(
+ in_channels,
+ 3 * total_dim,
+ 1,
+ use_bias=use_bias[0],
+ norm=norm[0],
+ act_func=act_func[0],
+ )
+ self.aggreg = nn.ModuleList(
+ [
+ nn.Sequential(
+ nn.Conv2d(
+ 3 * total_dim,
+ 3 * total_dim,
+ scale,
+ padding=get_same_padding(scale),
+ groups=3 * total_dim,
+ bias=use_bias[0],
+ ),
+ nn.Conv2d(
+ 3 * total_dim,
+ 3 * total_dim,
+ 1,
+ groups=3 * heads,
+ bias=use_bias[0],
+ ),
+ )
+ for scale in scales
+ ]
+ )
+ self.kernel_func = build_act(kernel_func, inplace=False)
+
+ self.proj = ConvLayer(
+ total_dim * (1 + len(scales)),
+ out_channels,
+ 1,
+ use_bias=use_bias[1],
+ norm=norm[1],
+ act_func=act_func[1],
+ )
+
+ @autocast(enabled=False)
+ def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
+ B, _, H, W = list(qkv.size())
+
+ if qkv.dtype == torch.float16:
+ qkv = qkv.float()
+
+ qkv = torch.reshape(
+ qkv,
+ (
+ B,
+ -1,
+ 3 * self.dim,
+ H * W,
+ ),
+ )
+ qkv = torch.transpose(qkv, -1, -2)
+ q, k, v = (
+ qkv[..., 0 : self.dim],
+ qkv[..., self.dim : 2 * self.dim],
+ qkv[..., 2 * self.dim :],
+ )
+
+ # lightweight linear attention
+ q = self.kernel_func(q)
+ k = self.kernel_func(k)
+
+ # linear matmul
+ trans_k = k.transpose(-1, -2)
+
+ v = F.pad(v, (0, 1), mode="constant", value=1)
+ kv = torch.matmul(trans_k, v)
+ out = torch.matmul(q, kv)
+ out = out[..., :-1] / (out[..., -1:] + self.eps)
+
+ out = torch.transpose(out, -1, -2)
+ out = torch.reshape(out, (B, -1, H, W))
+ return out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # generate multi-scale q, k, v
+ qkv = self.qkv(x)
+ multi_scale_qkv = [qkv]
+ for op in self.aggreg:
+ multi_scale_qkv.append(op(qkv))
+ multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
+
+ out = self.relu_linear_att(multi_scale_qkv)
+ out = self.proj(out)
+
+ return out
+
+
+class EfficientViTBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ heads_ratio: float = 1.0,
+ dim=32,
+ expand_ratio: float = 4,
+ scales=(5,),
+ norm="bn2d",
+ act_func="hswish",
+ ):
+ super(EfficientViTBlock, self).__init__()
+ self.context_module = ResidualBlock(
+ LiteMLA(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ heads_ratio=heads_ratio,
+ dim=dim,
+ norm=(None, norm),
+ scales=scales,
+ ),
+ IdentityLayer(),
+ )
+ local_module = MBConv(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ expand_ratio=expand_ratio,
+ use_bias=(True, True, False),
+ norm=(None, None, norm),
+ act_func=(act_func, act_func, None),
+ )
+ self.local_module = ResidualBlock(local_module, IdentityLayer())
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.context_module(x)
+ x = self.local_module(x)
+ return x
+
+
+#################################################################################
+# Functional Blocks #
+#################################################################################
+
+
+class ResidualBlock(nn.Module):
+ def __init__(
+ self,
+ main: nn.Module or None,
+ shortcut: nn.Module or None,
+ post_act=None,
+ pre_norm: nn.Module or None = None,
+ ):
+ super(ResidualBlock, self).__init__()
+
+ self.pre_norm = pre_norm
+ self.main = main
+ self.shortcut = shortcut
+ self.post_act = build_act(post_act)
+
+ def forward_main(self, x: torch.Tensor) -> torch.Tensor:
+ if self.pre_norm is None:
+ return self.main(x)
+ else:
+ return self.main(self.pre_norm(x))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.main is None:
+ res = x
+ elif self.shortcut is None:
+ res = self.forward_main(x)
+ else:
+ res = self.forward_main(x) + self.shortcut(x)
+ if self.post_act:
+ res = self.post_act(res)
+ return res
+
+
+class DAGBlock(nn.Module):
+ def __init__(
+ self,
+ inputs: dict[str, nn.Module],
+ merge: str,
+ post_input: nn.Module or None,
+ middle: nn.Module,
+ outputs: dict[str, nn.Module],
+ ):
+ super(DAGBlock, self).__init__()
+
+ self.input_keys = list(inputs.keys())
+ self.input_ops = nn.ModuleList(list(inputs.values()))
+ self.merge = merge
+ self.post_input = post_input
+
+ self.middle = middle
+
+ self.output_keys = list(outputs.keys())
+ self.output_ops = nn.ModuleList(list(outputs.values()))
+
+ def forward(self, feature_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ feat = [
+ op(feature_dict[key]) for key, op in zip(self.input_keys, self.input_ops)
+ ]
+ if self.merge == "add":
+ feat = list_sum(feat)
+ elif self.merge == "cat":
+ feat = torch.concat(feat, dim=1)
+ else:
+ raise NotImplementedError
+ if self.post_input is not None:
+ feat = self.post_input(feat)
+ feat = self.middle(feat)
+ for key, op in zip(self.output_keys, self.output_ops):
+ feature_dict[key] = op(feat)
+ return feature_dict
+
+
+class OpSequential(nn.Module):
+ def __init__(self, op_list: list[nn.Module or None]):
+ super(OpSequential, self).__init__()
+ valid_op_list = []
+ for op in op_list:
+ if op is not None:
+ valid_op_list.append(op)
+ self.op_list = nn.ModuleList(valid_op_list)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for op in self.op_list:
+ x = op(x)
+ return x
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/utils/__init__.py b/yolo-world-with-efficientvit-sam/efficientvit/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0aab6b0a576b33e1e72029210f7b4232c9b7b8b6
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/models/utils/__init__.py
@@ -0,0 +1,7 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+from .list import *
+from .network import *
+from .random import *
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/utils/__pycache__/__init__.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/models/utils/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5169ad5749fb0ad7f7b5572f36a595f6ec50a3c1
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/models/utils/__pycache__/__init__.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/utils/__pycache__/list.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/models/utils/__pycache__/list.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b209a8f5139191b65ae30e26ff868694bc8d06cb
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/models/utils/__pycache__/list.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/utils/__pycache__/network.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/models/utils/__pycache__/network.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a90e62de236f44c5671e2ff4f8d8f30d2735d8ac
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/models/utils/__pycache__/network.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/utils/__pycache__/random.cpython-311.pyc b/yolo-world-with-efficientvit-sam/efficientvit/models/utils/__pycache__/random.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..521ad2c8ebd804085322cecf1c5125f851ad81ee
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/efficientvit/models/utils/__pycache__/random.cpython-311.pyc differ
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/utils/list.py b/yolo-world-with-efficientvit-sam/efficientvit/models/utils/list.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a2c3291a88ab1d3cc77f7bc7d5eb475e9670a28
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/models/utils/list.py
@@ -0,0 +1,57 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+__all__ = [
+ "list_sum",
+ "list_mean",
+ "weighted_list_sum",
+ "list_join",
+ "val2list",
+ "val2tuple",
+ "squeeze_list",
+]
+
+
+def list_sum(x: list) -> any:
+ return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
+
+
+def list_mean(x: list) -> any:
+ return list_sum(x) / len(x)
+
+
+def weighted_list_sum(x: list, weights: list) -> any:
+ assert len(x) == len(weights)
+ return (
+ x[0] * weights[0]
+ if len(x) == 1
+ else x[0] * weights[0] + weighted_list_sum(x[1:], weights[1:])
+ )
+
+
+def list_join(x: list, sep="\t", format_str="%s") -> str:
+ return sep.join([format_str % val for val in x])
+
+
+def val2list(x: list or tuple or any, repeat_time=1) -> list:
+ if isinstance(x, (list, tuple)):
+ return list(x)
+ return [x for _ in range(repeat_time)]
+
+
+def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple:
+ x = val2list(x)
+
+ # repeat elements if necessary
+ if len(x) > 0:
+ x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
+
+ return tuple(x)
+
+
+def squeeze_list(x: list or None) -> list or any:
+ if x is not None and len(x) == 1:
+ return x[0]
+ else:
+ return x
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/utils/network.py b/yolo-world-with-efficientvit-sam/efficientvit/models/utils/network.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ba96ec255dc7543be2a7995fed58f7d139d2c75
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/models/utils/network.py
@@ -0,0 +1,77 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import os
+from inspect import signature
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = [
+ "is_parallel",
+ "get_device",
+ "get_same_padding",
+ "resize",
+ "build_kwargs_from_config",
+ "load_state_dict_from_file",
+]
+
+
+def is_parallel(model: nn.Module) -> bool:
+ return isinstance(
+ model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
+ )
+
+
+def get_device(model: nn.Module) -> torch.device:
+ return model.parameters().__next__().device
+
+
+def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]:
+ if isinstance(kernel_size, tuple):
+ return tuple([get_same_padding(ks) for ks in kernel_size])
+ else:
+ assert kernel_size % 2 > 0, "kernel size should be odd number"
+ return kernel_size // 2
+
+
+def resize(
+ x: torch.Tensor,
+ size: any or None = None,
+ scale_factor: list[float] or None = None,
+ mode: str = "bicubic",
+ align_corners: bool or None = False,
+) -> torch.Tensor:
+ if mode in {"bilinear", "bicubic"}:
+ return F.interpolate(
+ x,
+ size=size,
+ scale_factor=scale_factor,
+ mode=mode,
+ align_corners=align_corners,
+ )
+ elif mode in {"nearest", "area"}:
+ return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode)
+ else:
+ raise NotImplementedError(f"resize(mode={mode}) not implemented.")
+
+
+def build_kwargs_from_config(config: dict, target_func: callable) -> dict[str, any]:
+ valid_keys = list(signature(target_func).parameters)
+ kwargs = {}
+ for key in config:
+ if key in valid_keys:
+ kwargs[key] = config[key]
+ return kwargs
+
+
+def load_state_dict_from_file(
+ file: str, only_state_dict=True
+) -> dict[str, torch.Tensor]:
+ file = os.path.realpath(os.path.expanduser(file))
+ checkpoint = torch.load(file, map_location="cpu")
+ if only_state_dict and "state_dict" in checkpoint:
+ checkpoint = checkpoint["state_dict"]
+ return checkpoint
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/models/utils/random.py b/yolo-world-with-efficientvit-sam/efficientvit/models/utils/random.py
new file mode 100644
index 0000000000000000000000000000000000000000..0257f7ab93a3781c159a917823c36d8ada976292
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/models/utils/random.py
@@ -0,0 +1,73 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import numpy as np
+import torch
+
+__all__ = [
+ "torch_randint",
+ "torch_random",
+ "torch_shuffle",
+ "torch_uniform",
+ "torch_random_choices",
+]
+
+
+def torch_randint(
+ low: int, high: int, generator: torch.Generator or None = None
+) -> int:
+ """uniform: [low, high)"""
+ if low == high:
+ return low
+ else:
+ assert low < high
+ return int(torch.randint(low=low, high=high, generator=generator, size=(1,)))
+
+
+def torch_random(generator: torch.Generator or None = None) -> float:
+ """uniform distribution on the interval [0, 1)"""
+ return float(torch.rand(1, generator=generator))
+
+
+def torch_shuffle(
+ src_list: list[any], generator: torch.Generator or None = None
+) -> list[any]:
+ rand_indexes = torch.randperm(len(src_list), generator=generator).tolist()
+ return [src_list[i] for i in rand_indexes]
+
+
+def torch_uniform(
+ low: float, high: float, generator: torch.Generator or None = None
+) -> float:
+ """uniform distribution on the interval [low, high)"""
+ rand_val = torch_random(generator)
+ return (high - low) * rand_val + low
+
+
+def torch_random_choices(
+ src_list: list[any],
+ generator: torch.Generator or None = None,
+ k=1,
+ weight_list: list[float] or None = None,
+) -> any or list:
+ if weight_list is None:
+ rand_idx = torch.randint(
+ low=0, high=len(src_list), generator=generator, size=(k,)
+ )
+ out_list = [src_list[i] for i in rand_idx]
+ else:
+ assert len(weight_list) == len(src_list)
+ accumulate_weight_list = np.cumsum(weight_list)
+
+ out_list = []
+ for _ in range(k):
+ val = torch_uniform(0, accumulate_weight_list[-1], generator)
+ active_id = 0
+ for i, weight_val in enumerate(accumulate_weight_list):
+ active_id = i
+ if weight_val > val:
+ break
+ out_list.append(src_list[active_id])
+
+ return out_list[0] if k == 1 else out_list
diff --git a/yolo-world-with-efficientvit-sam/efficientvit/sam_model_zoo.py b/yolo-world-with-efficientvit-sam/efficientvit/sam_model_zoo.py
new file mode 100644
index 0000000000000000000000000000000000000000..52efa39e23047406ca9fe5e6464140ac2164d8e2
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/efficientvit/sam_model_zoo.py
@@ -0,0 +1,53 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+from efficientvit.models.efficientvit import (EfficientViTSam,
+ efficientvit_sam_l0,
+ efficientvit_sam_l1,
+ efficientvit_sam_l2,
+ efficientvit_sam_xl0,
+ efficientvit_sam_xl1)
+from efficientvit.models.nn.norm import set_norm_eps
+from efficientvit.models.utils import load_state_dict_from_file
+
+__all__ = ["create_sam_model"]
+
+
+REGISTERED_SAM_MODEL: dict[str, str] = {
+ "l0": "assets/checkpoints/sam/l0.pt",
+ "l1": "assets/checkpoints/sam/l1.pt",
+ "l2": "assets/checkpoints/sam/l2.pt",
+ "xl0": "assets/checkpoints/sam/xl0.pt",
+ "xl1": "assets/checkpoints/sam/xl1.pt",
+}
+
+
+def create_sam_model(
+ name: str, pretrained=True, weight_url: str or None = None, **kwargs
+) -> EfficientViTSam:
+ model_dict = {
+ "l0": efficientvit_sam_l0,
+ "l1": efficientvit_sam_l1,
+ "l2": efficientvit_sam_l2,
+ "xl0": efficientvit_sam_xl0,
+ "xl1": efficientvit_sam_xl1,
+ }
+
+ model_id = name.split("-")[0]
+ if model_id not in model_dict:
+ raise ValueError(
+ f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}"
+ )
+ else:
+ model = model_dict[model_id](**kwargs)
+ set_norm_eps(model, 1e-6)
+
+ if pretrained:
+ weight_url = weight_url or REGISTERED_SAM_MODEL.get(name, None)
+ if weight_url is None:
+ raise ValueError(f"Do not find the pretrained weight of {name}.")
+ else:
+ weight = load_state_dict_from_file(weight_url)
+ model.load_state_dict(weight)
+ return model
diff --git a/yolo-world-with-efficientvit-sam/examples/cat_and_dogs.jpg b/yolo-world-with-efficientvit-sam/examples/cat_and_dogs.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1b0c9f953e6425d93489913d69e948fffde15747
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/examples/cat_and_dogs.jpg differ
diff --git a/yolo-world-with-efficientvit-sam/examples/livingroom.jpg b/yolo-world-with-efficientvit-sam/examples/livingroom.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f05ceaa1b35c6b74e196d979efadc5f4b79b6170
Binary files /dev/null and b/yolo-world-with-efficientvit-sam/examples/livingroom.jpg differ
diff --git a/yolo-world-with-efficientvit-sam/requirements.txt b/yolo-world-with-efficientvit-sam/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a1e5a6decf62946f04669440f05c5f83086a340e
--- /dev/null
+++ b/yolo-world-with-efficientvit-sam/requirements.txt
@@ -0,0 +1,7 @@
+inference[yolo-world]==0.9.13
+supervision==0.18.0
+gradio==4.18.0
+timm==0.9.12
+onnx==1.15.0
+onnxsim==0.4.35
+git+https://github.com/facebookresearch/segment-anything.git